pmcore/bestdose/cost.rs
1//! Cost function calculation for BestDose optimization
2//!
3//! Implements the hybrid cost function that balances patient-specific performance
4//! (variance) with population-level robustness (bias). Also enforces dose range
5//! constraints through penalty-based bounds checking.
6//!
7//! # Cost Function
8//!
9//! ```text
10//! Cost = {
11//! (1-λ) × Variance + λ × Bias², if doses within bounds
12//! 1e12 + violation² × 1e6, if any dose violates bounds
13//! }
14//! ```
15//!
16//! ## Variance Term (Patient-Specific)
17//!
18//! Expected squared prediction error using posterior weights:
19//! ```text
20//! Variance = Σᵢ posterior_weight[i] × Σⱼ (target[j] - pred[i,j])²
21//! ```
22//!
23//! - Weighted by patient-specific posterior probabilities
24//! - Minimizes expected error for this specific patient
25//! - Emphasizes parameter values compatible with patient history
26//!
27//! ## Bias Term (Population-Level)
28//!
29//! Squared deviation from population mean prediction using prior weights:
30//! ```text
31//! Bias² = Σⱼ (target[j] - population_mean[j])²
32//! where population_mean[j] = Σᵢ prior_weight[i] × pred[i,j]
33//! ```
34//!
35//! - Weighted by population prior probabilities
36//! - Minimizes deviation from population-typical behavior
37//! - Provides robustness when patient history is limited
38//!
39//! ## Bias Weight Parameter (λ)
40//!
41//! - `λ = 0.0`: Pure personalization (minimize variance only)
42//! - `λ = 0.5`: Balanced hybrid approach
43//! - `λ = 1.0`: Pure population (minimize bias only)
44//!
45//! # Implementation Notes
46//!
47//! The cost function handles both concentration and AUC targets:
48//! - **Concentration**: Simulates model at observation times directly
49//! - **AUC**: Generates dense time grid and calculates cumulative AUC via trapezoidal rule
50//!
51//! See [`calculate_cost`] for the main implementation.
52
53use anyhow::Result;
54
55use crate::bestdose::predictions::{
56 calculate_auc_at_times, calculate_dense_times, calculate_interval_auc_per_observation,
57};
58use crate::bestdose::types::{BestDoseProblem, Target};
59use pharmsol::prelude::*;
60use pharmsol::Equation;
61
62/// Calculate cost function for a candidate dose regimen
63///
64/// This is the core objective function minimized by the Nelder-Mead optimizer during
65/// Stage 2 of the BestDose algorithm.
66///
67/// # Arguments
68///
69/// * `problem` - The [`BestDoseProblem`] containing all necessary data
70/// * `candidate_doses` - Dose amounts to evaluate (only for optimizable doses)
71///
72/// # Returns
73///
74/// The cost value `(1-λ) × Variance + λ × Bias²` for the candidate doses.
75/// Lower cost indicates better match to targets.
76///
77/// # Dose Masking
78///
79/// Only doses with `amount == 0.0` in the target subject are considered optimizable.
80/// Doses with non-zero amounts remain fixed at their specified values.
81///
82/// The `candidate_doses` parameter contains only the optimizable doses, which are
83/// substituted into the target subject before simulation
84///
85/// # Cost Function Details
86///
87/// ## Variance Term
88///
89/// Expected squared prediction error using posterior weights:
90/// ```text
91/// Variance = Σᵢ P(θᵢ|data) × Σⱼ (target[j] - pred[i,j])²
92/// ```
93///
94/// For each support point θᵢ:
95/// 1. Simulate model with candidate doses
96/// 2. Calculate squared error at each observation time j
97/// 3. Weight by posterior probability P(θᵢ|data)
98///
99/// ## Bias Term
100///
101/// Squared deviation from population mean:
102/// ```text
103/// Bias² = Σⱼ (target[j] - E[pred[j]])²
104/// where E[pred[j]] = Σᵢ P(θᵢ) × pred[i,j] (prior weights)
105/// ```
106///
107/// The population mean uses **prior weights**, not posterior weights, to represent
108/// population-typical behavior independent of patient-specific data.
109///
110/// ## Target Types
111///
112/// - **Concentration** ([`Target::Concentration`]):
113/// Predictions are concentrations at observation times
114///
115/// - **AUC** ([`Target::AUC`]):
116/// Predictions are cumulative AUC values calculated via trapezoidal rule
117/// on a dense time grid (controlled by `settings.predictions().idelta`)
118///
119/// # Example
120///
121/// ```rust,ignore
122/// // Internal use by optimizer
123/// let cost = calculate_cost(&problem, &[100.0, 150.0])?;
124/// ```
125///
126/// # Errors
127///
128/// Returns error if:
129/// - Model simulation fails
130/// - Prediction length doesn't match observation count
131/// - AUC calculation fails (for AUC targets)
132pub fn calculate_cost(problem: &BestDoseProblem, candidate_doses: &[f64]) -> Result<f64> {
133 // Validate candidate_doses length matches expected optimizable dose count
134 let expected_optimizable = problem
135 .target
136 .occasions()
137 .iter()
138 .flat_map(|occ| occ.events())
139 .filter(|event| match event {
140 Event::Bolus(b) => b.amount() == 0.0,
141 Event::Infusion(inf) => inf.amount() == 0.0,
142 _ => false,
143 })
144 .count();
145
146 if candidate_doses.len() != expected_optimizable {
147 return Err(anyhow::anyhow!(
148 "Dose count mismatch: received {} candidate doses but expected {} optimizable doses",
149 candidate_doses.len(),
150 expected_optimizable
151 ));
152 }
153
154 // Check bounds and return penalty if violated
155 // This constrains the Nelder-Mead optimizer to search within the specified DoseRange
156 let min_dose = problem.doserange.min;
157 let max_dose = problem.doserange.max;
158
159 for &dose in candidate_doses {
160 if dose < min_dose || dose > max_dose {
161 // Return a large penalty cost to push the optimizer back into bounds
162 // The penalty grows quadratically with distance from the nearest bound
163 let violation = if dose < min_dose {
164 min_dose - dose
165 } else {
166 dose - max_dose
167 };
168 return Ok(1e12 + violation.powi(2) * 1e6);
169 }
170 }
171
172 // Build target subject with candidate doses
173 let mut target_subject = problem.target.clone();
174 let mut optimizable_dose_number = 0; // Index into candidate_doses
175
176 for occasion in target_subject.iter_mut() {
177 for event in occasion.iter_mut() {
178 match event {
179 Event::Bolus(bolus) => {
180 // Only update if this dose is optimizable (amount == 0)
181 if bolus.amount() == 0.0 {
182 bolus.set_amount(candidate_doses[optimizable_dose_number]);
183 optimizable_dose_number += 1;
184 }
185 // If not optimizable (amount > 0), keep original amount
186 }
187 Event::Infusion(infusion) => {
188 // Only update if this dose is optimizable (amount == 0)
189 if infusion.amount() == 0.0 {
190 infusion.set_amount(candidate_doses[optimizable_dose_number]);
191 optimizable_dose_number += 1;
192 }
193 // If not optimizable (amount > 0), keep original amount
194 }
195 Event::Observation(_) => {}
196 }
197 }
198 }
199
200 // Extract target values and observation times
201 let obs_times: Vec<f64> = target_subject
202 .occasions()
203 .iter()
204 .flat_map(|occ| occ.events())
205 .filter_map(|event| match event {
206 Event::Observation(obs) => Some(obs.time()),
207 _ => None,
208 })
209 .collect();
210
211 // Validate that target has observations
212 if obs_times.is_empty() {
213 return Err(anyhow::anyhow!(
214 "Target subject has no observations. At least one observation is required for dose optimization."
215 ));
216 }
217
218 let obs_vec: Vec<f64> = target_subject
219 .occasions()
220 .iter()
221 .flat_map(|occ| occ.events())
222 .filter_map(|event| match event {
223 Event::Observation(obs) => obs.value(),
224 _ => None,
225 })
226 .collect();
227
228 let n_obs = obs_vec.len();
229
230 // Accumulators
231 let mut variance = 0.0_f64; // Expected squared error E[(target - pred)²]
232 let mut y_bar = vec![0.0_f64; n_obs]; // Population mean predictions
233
234 // Calculate variance (using posterior weights) and population mean (using prior weights)
235
236 for ((row, post_prob), prior_prob) in problem
237 .theta
238 .matrix()
239 .row_iter()
240 .zip(problem.posterior.iter()) // Posterior from NPAGFULL11 (patient-specific)
241 .zip(problem.population_weights.iter())
242 // Prior (population)
243 {
244 let spp = row.iter().copied().collect::<Vec<f64>>();
245
246 // Get predictions based on target type
247 let preds_i: Vec<f64> = match problem.target_type {
248 Target::Concentration => {
249 // Simulate at observation times only
250 let pred = problem.eq.simulate_subject(&target_subject, &spp, None)?;
251 pred.0.flat_predictions()
252 }
253 Target::AUCFromZero => {
254 // For AUC: simulate at dense time grid and calculate cumulative AUC
255 let idelta = problem.settings.predictions().idelta;
256 let start_time = 0.0; // Future starts at 0
257 let end_time = obs_times.last().copied().unwrap_or(0.0);
258
259 // Generate dense time grid
260 let dense_times =
261 calculate_dense_times(start_time, end_time, &obs_times, idelta as usize);
262
263 // Create temporary subject with dense time points for simulation
264 let subject_id = target_subject.id().to_string();
265 let mut builder = Subject::builder(&subject_id);
266
267 // Add all doses from original subject
268 for occasion in target_subject.occasions() {
269 for event in occasion.events() {
270 match event {
271 Event::Bolus(bolus) => {
272 builder =
273 builder.bolus(bolus.time(), bolus.amount(), bolus.input());
274 }
275 Event::Infusion(infusion) => {
276 builder = builder.infusion(
277 infusion.time(),
278 infusion.amount(),
279 infusion.input(),
280 infusion.duration(),
281 );
282 }
283 Event::Observation(_) => {} // Skip original observations
284 }
285 }
286 }
287
288 // Collect observations with (time, outeq) pairs to preserve original order
289 let obs_time_outeq: Vec<(f64, usize)> = target_subject
290 .occasions()
291 .iter()
292 .flat_map(|occ| occ.events())
293 .filter_map(|event| match event {
294 Event::Observation(obs) => Some((obs.time(), obs.outeq())),
295 _ => None,
296 })
297 .collect();
298
299 let mut unique_outeqs: Vec<usize> =
300 obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
301 unique_outeqs.sort();
302 unique_outeqs.dedup();
303
304 // Add observations at dense times (with dummy values for timing only)
305 for outeq in unique_outeqs.iter() {
306 for &t in &dense_times {
307 builder = builder.missing_observation(t, *outeq);
308 }
309 }
310
311 let dense_subject = builder.build();
312
313 // Simulate at dense times
314 let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
315 let dense_predictions_with_outeq = pred.0.predictions();
316
317 // Group predictions by outeq using the Prediction struct
318 let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
319 std::collections::HashMap::new();
320
321 for prediction in dense_predictions_with_outeq {
322 outeq_predictions
323 .entry(prediction.outeq())
324 .or_default()
325 .push(prediction.prediction());
326 }
327
328 // Calculate AUC for each outeq separately
329 let mut outeq_aucs: std::collections::HashMap<usize, Vec<f64>> =
330 std::collections::HashMap::new();
331
332 for &outeq in unique_outeqs.iter() {
333 let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
334 anyhow::anyhow!("Missing predictions for outeq {}", outeq)
335 })?;
336
337 // Get observation times for this outeq only
338 let outeq_obs_times: Vec<f64> = obs_time_outeq
339 .iter()
340 .filter(|(_, o)| *o == outeq)
341 .map(|(t, _)| *t)
342 .collect();
343
344 // Calculate AUC at observation times for this outeq
345 let aucs = calculate_auc_at_times(&dense_times, outeq_preds, &outeq_obs_times);
346 outeq_aucs.insert(outeq, aucs);
347 }
348
349 // Build final AUC vector in original observation order
350 let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
351 let mut outeq_counters: std::collections::HashMap<usize, usize> =
352 std::collections::HashMap::new();
353
354 for (_, outeq) in obs_time_outeq.iter() {
355 let aucs = outeq_aucs
356 .get(outeq)
357 .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
358
359 let counter = outeq_counters.entry(*outeq).or_insert(0);
360 if *counter < aucs.len() {
361 result_aucs.push(aucs[*counter]);
362 *counter += 1;
363 } else {
364 return Err(anyhow::anyhow!(
365 "AUC index out of bounds for outeq {}",
366 outeq
367 ));
368 }
369 }
370
371 result_aucs
372 }
373 Target::AUCFromLastDose => {
374 // For interval AUC: simulate at dense time grid and calculate AUC from last dose
375 let idelta = problem.settings.predictions().idelta;
376 let end_time = obs_times.last().copied().unwrap_or(0.0);
377
378 // Generate dense time grid from 0 to end_time (need full grid for intervals)
379 let dense_times = calculate_dense_times(0.0, end_time, &obs_times, idelta as usize);
380
381 // Create temporary subject with dense time points for simulation
382 let subject_id = target_subject.id().to_string();
383 let mut builder = Subject::builder(&subject_id);
384
385 // Add all doses from original subject
386 for occasion in target_subject.occasions() {
387 for event in occasion.events() {
388 match event {
389 Event::Bolus(bolus) => {
390 builder =
391 builder.bolus(bolus.time(), bolus.amount(), bolus.input());
392 }
393 Event::Infusion(infusion) => {
394 builder = builder.infusion(
395 infusion.time(),
396 infusion.amount(),
397 infusion.input(),
398 infusion.duration(),
399 );
400 }
401 Event::Observation(_) => {} // Skip original observations
402 }
403 }
404 }
405
406 // Collect observations with (time, outeq) pairs to preserve original order
407 let obs_time_outeq: Vec<(f64, usize)> = target_subject
408 .occasions()
409 .iter()
410 .flat_map(|occ| occ.events())
411 .filter_map(|event| match event {
412 Event::Observation(obs) => Some((obs.time(), obs.outeq())),
413 _ => None,
414 })
415 .collect();
416
417 let mut unique_outeqs: Vec<usize> =
418 obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
419 unique_outeqs.sort();
420 unique_outeqs.dedup();
421
422 // Add observations at dense times
423 for outeq in unique_outeqs.iter() {
424 for &t in &dense_times {
425 builder = builder.missing_observation(t, *outeq);
426 }
427 }
428
429 let dense_subject = builder.build();
430
431 // Simulate at dense times
432 let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
433 let dense_predictions_with_outeq = pred.0.predictions();
434
435 // Group predictions by outeq
436 let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
437 std::collections::HashMap::new();
438
439 for prediction in dense_predictions_with_outeq {
440 outeq_predictions
441 .entry(prediction.outeq())
442 .or_default()
443 .push(prediction.prediction());
444 }
445
446 // Calculate interval AUC for each outeq separately
447 let mut outeq_aucs: std::collections::HashMap<usize, Vec<f64>> =
448 std::collections::HashMap::new();
449
450 for &outeq in unique_outeqs.iter() {
451 let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
452 anyhow::anyhow!("Missing predictions for outeq {}", outeq)
453 })?;
454
455 // Get observation times for this outeq only
456 let outeq_obs_times: Vec<f64> = obs_time_outeq
457 .iter()
458 .filter(|(_, o)| *o == outeq)
459 .map(|(t, _)| *t)
460 .collect();
461
462 // Calculate interval AUC at observation times for this outeq
463 let aucs = calculate_interval_auc_per_observation(
464 &target_subject,
465 &dense_times,
466 outeq_preds,
467 &outeq_obs_times,
468 );
469 outeq_aucs.insert(outeq, aucs);
470 }
471
472 // Build final AUC vector in original observation order
473 let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
474 let mut outeq_counters: std::collections::HashMap<usize, usize> =
475 std::collections::HashMap::new();
476
477 for (_, outeq) in obs_time_outeq.iter() {
478 let aucs = outeq_aucs
479 .get(outeq)
480 .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
481
482 let counter = outeq_counters.entry(*outeq).or_insert(0);
483 if *counter < aucs.len() {
484 result_aucs.push(aucs[*counter]);
485 *counter += 1;
486 } else {
487 return Err(anyhow::anyhow!(
488 "AUC index out of bounds for outeq {}",
489 outeq
490 ));
491 }
492 }
493
494 result_aucs
495 }
496 };
497
498 if preds_i.len() != n_obs {
499 return Err(anyhow::anyhow!(
500 "prediction length ({}) != observation length ({})",
501 preds_i.len(),
502 n_obs
503 ));
504 }
505
506 // Calculate variance term: weighted by POSTERIOR probability
507 let mut sumsq_i = 0.0_f64;
508 for (j, &obs_val) in obs_vec.iter().enumerate() {
509 let pj = preds_i[j];
510 let se = (obs_val - pj).powi(2);
511 sumsq_i += se;
512 // Calculate population mean using PRIOR probabilities
513 y_bar[j] += prior_prob * pj;
514 }
515
516 variance += post_prob * sumsq_i; // Weighted by posterior
517 }
518
519 // Calculate bias term: squared difference from population mean
520 let mut bias = 0.0_f64;
521 for (j, &obs_val) in obs_vec.iter().enumerate() {
522 bias += (obs_val - y_bar[j]).powi(2);
523 }
524
525 // Final cost: (1-λ)×Variance + λ×Bias²
526 // λ=0: Full personalization (minimize variance)
527 // λ=1: Population-based (minimize bias from population)
528 let cost = (1.0 - problem.bias_weight) * variance + problem.bias_weight * bias;
529
530 Ok(cost)
531}