pmcore/bestdose/
cost.rs

1//! Cost function calculation for BestDose optimization
2//!
3//! Implements the hybrid cost function that balances patient-specific performance
4//! (variance) with population-level robustness (bias). Also enforces dose range
5//! constraints through penalty-based bounds checking.
6//!
7//! # Cost Function
8//!
9//! ```text
10//! Cost = {
11//!   (1-λ) × Variance + λ × Bias²,  if doses within bounds
12//!   1e12 + violation² × 1e6,        if any dose violates bounds
13//! }
14//! ```
15//!
16//! ## Variance Term (Patient-Specific)
17//!
18//! Expected squared prediction error using posterior weights:
19//! ```text
20//! Variance = Σᵢ posterior_weight[i] × Σⱼ (target[j] - pred[i,j])²
21//! ```
22//!
23//! - Weighted by patient-specific posterior probabilities
24//! - Minimizes expected error for this specific patient
25//! - Emphasizes parameter values compatible with patient history
26//!
27//! ## Bias Term (Population-Level)
28//!
29//! Squared deviation from population mean prediction using prior weights:
30//! ```text
31//! Bias² = Σⱼ (target[j] - population_mean[j])²
32//! where population_mean[j] = Σᵢ prior_weight[i] × pred[i,j]
33//! ```
34//!
35//! - Weighted by population prior probabilities
36//! - Minimizes deviation from population-typical behavior
37//! - Provides robustness when patient history is limited
38//!
39//! ## Bias Weight Parameter (λ)
40//!
41//! - `λ = 0.0`: Pure personalization (minimize variance only)
42//! - `λ = 0.5`: Balanced hybrid approach
43//! - `λ = 1.0`: Pure population (minimize bias only)
44//!
45//! # Implementation Notes
46//!
47//! The cost function handles both concentration and AUC targets:
48//! - **Concentration**: Simulates model at observation times directly
49//! - **AUC**: Generates dense time grid and calculates cumulative AUC via trapezoidal rule
50//!
51//! See [`calculate_cost`] for the main implementation.
52
53use anyhow::Result;
54
55use crate::bestdose::predictions::{
56    calculate_auc_at_times, calculate_dense_times, calculate_interval_auc_per_observation,
57};
58use crate::bestdose::types::{BestDoseProblem, Target};
59use pharmsol::prelude::*;
60use pharmsol::Equation;
61
62/// Calculate cost function for a candidate dose regimen
63///
64/// This is the core objective function minimized by the Nelder-Mead optimizer during
65/// Stage 2 of the BestDose algorithm.
66///
67/// # Arguments
68///
69/// * `problem` - The [`BestDoseProblem`] containing all necessary data
70/// * `candidate_doses` - Dose amounts to evaluate (only for optimizable doses)
71///
72/// # Returns
73///
74/// The cost value `(1-λ) × Variance + λ × Bias²` for the candidate doses.
75/// Lower cost indicates better match to targets.
76///
77/// # Dose Masking
78///
79/// Only doses with `amount == 0.0` in the target subject are considered optimizable.
80/// Doses with non-zero amounts remain fixed at their specified values.
81///
82/// The `candidate_doses` parameter contains only the optimizable doses, which are
83/// substituted into the target subject before simulation
84///
85/// # Cost Function Details
86///
87/// ## Variance Term
88///
89/// Expected squared prediction error using posterior weights:
90/// ```text
91/// Variance = Σᵢ P(θᵢ|data) × Σⱼ (target[j] - pred[i,j])²
92/// ```
93///
94/// For each support point θᵢ:
95/// 1. Simulate model with candidate doses
96/// 2. Calculate squared error at each observation time j
97/// 3. Weight by posterior probability P(θᵢ|data)
98///
99/// ## Bias Term
100///
101/// Squared deviation from population mean:
102/// ```text
103/// Bias² = Σⱼ (target[j] - E[pred[j]])²
104/// where E[pred[j]] = Σᵢ P(θᵢ) × pred[i,j]  (prior weights)
105/// ```
106///
107/// The population mean uses **prior weights**, not posterior weights, to represent
108/// population-typical behavior independent of patient-specific data.
109///
110/// ## Target Types
111///
112/// - **Concentration** ([`Target::Concentration`]):
113///   Predictions are concentrations at observation times
114///
115/// - **AUC** ([`Target::AUC`]):
116///   Predictions are cumulative AUC values calculated via trapezoidal rule
117///   on a dense time grid (controlled by `settings.predictions().idelta`)
118///
119/// # Example
120///
121/// ```rust,ignore
122/// // Internal use by optimizer
123/// let cost = calculate_cost(&problem, &[100.0, 150.0])?;
124/// ```
125///
126/// # Errors
127///
128/// Returns error if:
129/// - Model simulation fails
130/// - Prediction length doesn't match observation count
131/// - AUC calculation fails (for AUC targets)
132pub fn calculate_cost(problem: &BestDoseProblem, candidate_doses: &[f64]) -> Result<f64> {
133    // Validate candidate_doses length matches expected optimizable dose count
134    let expected_optimizable = problem
135        .target
136        .occasions()
137        .iter()
138        .flat_map(|occ| occ.events())
139        .filter(|event| match event {
140            Event::Bolus(b) => b.amount() == 0.0,
141            Event::Infusion(inf) => inf.amount() == 0.0,
142            _ => false,
143        })
144        .count();
145
146    if candidate_doses.len() != expected_optimizable {
147        return Err(anyhow::anyhow!(
148            "Dose count mismatch: received {} candidate doses but expected {} optimizable doses",
149            candidate_doses.len(),
150            expected_optimizable
151        ));
152    }
153
154    // Check bounds and return penalty if violated
155    // This constrains the Nelder-Mead optimizer to search within the specified DoseRange
156    let min_dose = problem.doserange.min;
157    let max_dose = problem.doserange.max;
158
159    for &dose in candidate_doses {
160        if dose < min_dose || dose > max_dose {
161            // Return a large penalty cost to push the optimizer back into bounds
162            // The penalty grows quadratically with distance from the nearest bound
163            let violation = if dose < min_dose {
164                min_dose - dose
165            } else {
166                dose - max_dose
167            };
168            return Ok(1e12 + violation.powi(2) * 1e6);
169        }
170    }
171
172    // Build target subject with candidate doses
173    let mut target_subject = problem.target.clone();
174    let mut optimizable_dose_number = 0; // Index into candidate_doses
175
176    for occasion in target_subject.iter_mut() {
177        for event in occasion.iter_mut() {
178            match event {
179                Event::Bolus(bolus) => {
180                    // Only update if this dose is optimizable (amount == 0)
181                    if bolus.amount() == 0.0 {
182                        bolus.set_amount(candidate_doses[optimizable_dose_number]);
183                        optimizable_dose_number += 1;
184                    }
185                    // If not optimizable (amount > 0), keep original amount
186                }
187                Event::Infusion(infusion) => {
188                    // Only update if this dose is optimizable (amount == 0)
189                    if infusion.amount() == 0.0 {
190                        infusion.set_amount(candidate_doses[optimizable_dose_number]);
191                        optimizable_dose_number += 1;
192                    }
193                    // If not optimizable (amount > 0), keep original amount
194                }
195                Event::Observation(_) => {}
196            }
197        }
198    }
199
200    // Extract target values and observation times
201    let obs_times: Vec<f64> = target_subject
202        .occasions()
203        .iter()
204        .flat_map(|occ| occ.events())
205        .filter_map(|event| match event {
206            Event::Observation(obs) => Some(obs.time()),
207            _ => None,
208        })
209        .collect();
210
211    // Validate that target has observations
212    if obs_times.is_empty() {
213        return Err(anyhow::anyhow!(
214            "Target subject has no observations. At least one observation is required for dose optimization."
215        ));
216    }
217
218    let obs_vec: Vec<f64> = target_subject
219        .occasions()
220        .iter()
221        .flat_map(|occ| occ.events())
222        .filter_map(|event| match event {
223            Event::Observation(obs) => obs.value(),
224            _ => None,
225        })
226        .collect();
227
228    let n_obs = obs_vec.len();
229
230    // Accumulators
231    let mut variance = 0.0_f64; // Expected squared error E[(target - pred)²]
232    let mut y_bar = vec![0.0_f64; n_obs]; // Population mean predictions
233
234    // Calculate variance (using posterior weights) and population mean (using prior weights)
235
236    for ((row, post_prob), prior_prob) in problem
237        .theta
238        .matrix()
239        .row_iter()
240        .zip(problem.posterior.iter()) // Posterior from NPAGFULL11 (patient-specific)
241        .zip(problem.population_weights.iter())
242    // Prior (population)
243    {
244        let spp = row.iter().copied().collect::<Vec<f64>>();
245
246        // Get predictions based on target type
247        let preds_i: Vec<f64> = match problem.target_type {
248            Target::Concentration => {
249                // Simulate at observation times only
250                let pred = problem.eq.simulate_subject(&target_subject, &spp, None)?;
251                pred.0.flat_predictions()
252            }
253            Target::AUCFromZero => {
254                // For AUC: simulate at dense time grid and calculate cumulative AUC
255                let idelta = problem.settings.predictions().idelta;
256                let start_time = 0.0; // Future starts at 0
257                let end_time = obs_times.last().copied().unwrap_or(0.0);
258
259                // Generate dense time grid
260                let dense_times =
261                    calculate_dense_times(start_time, end_time, &obs_times, idelta as usize);
262
263                // Create temporary subject with dense time points for simulation
264                let subject_id = target_subject.id().to_string();
265                let mut builder = Subject::builder(&subject_id);
266
267                // Add all doses from original subject
268                for occasion in target_subject.occasions() {
269                    for event in occasion.events() {
270                        match event {
271                            Event::Bolus(bolus) => {
272                                builder =
273                                    builder.bolus(bolus.time(), bolus.amount(), bolus.input());
274                            }
275                            Event::Infusion(infusion) => {
276                                builder = builder.infusion(
277                                    infusion.time(),
278                                    infusion.amount(),
279                                    infusion.input(),
280                                    infusion.duration(),
281                                );
282                            }
283                            Event::Observation(_) => {} // Skip original observations
284                        }
285                    }
286                }
287
288                // Collect observations with (time, outeq) pairs to preserve original order
289                let obs_time_outeq: Vec<(f64, usize)> = target_subject
290                    .occasions()
291                    .iter()
292                    .flat_map(|occ| occ.events())
293                    .filter_map(|event| match event {
294                        Event::Observation(obs) => Some((obs.time(), obs.outeq())),
295                        _ => None,
296                    })
297                    .collect();
298
299                let mut unique_outeqs: Vec<usize> =
300                    obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
301                unique_outeqs.sort();
302                unique_outeqs.dedup();
303
304                // Add observations at dense times (with dummy values for timing only)
305                for outeq in unique_outeqs.iter() {
306                    for &t in &dense_times {
307                        builder = builder.missing_observation(t, *outeq);
308                    }
309                }
310
311                let dense_subject = builder.build();
312
313                // Simulate at dense times
314                let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
315                let dense_predictions_with_outeq = pred.0.predictions();
316
317                // Group predictions by outeq using the Prediction struct
318                let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
319                    std::collections::HashMap::new();
320
321                for prediction in dense_predictions_with_outeq {
322                    outeq_predictions
323                        .entry(prediction.outeq())
324                        .or_default()
325                        .push(prediction.prediction());
326                }
327
328                // Calculate AUC for each outeq separately
329                let mut outeq_aucs: std::collections::HashMap<usize, Vec<f64>> =
330                    std::collections::HashMap::new();
331
332                for &outeq in unique_outeqs.iter() {
333                    let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
334                        anyhow::anyhow!("Missing predictions for outeq {}", outeq)
335                    })?;
336
337                    // Get observation times for this outeq only
338                    let outeq_obs_times: Vec<f64> = obs_time_outeq
339                        .iter()
340                        .filter(|(_, o)| *o == outeq)
341                        .map(|(t, _)| *t)
342                        .collect();
343
344                    // Calculate AUC at observation times for this outeq
345                    let aucs = calculate_auc_at_times(&dense_times, outeq_preds, &outeq_obs_times);
346                    outeq_aucs.insert(outeq, aucs);
347                }
348
349                // Build final AUC vector in original observation order
350                let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
351                let mut outeq_counters: std::collections::HashMap<usize, usize> =
352                    std::collections::HashMap::new();
353
354                for (_, outeq) in obs_time_outeq.iter() {
355                    let aucs = outeq_aucs
356                        .get(outeq)
357                        .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
358
359                    let counter = outeq_counters.entry(*outeq).or_insert(0);
360                    if *counter < aucs.len() {
361                        result_aucs.push(aucs[*counter]);
362                        *counter += 1;
363                    } else {
364                        return Err(anyhow::anyhow!(
365                            "AUC index out of bounds for outeq {}",
366                            outeq
367                        ));
368                    }
369                }
370
371                result_aucs
372            }
373            Target::AUCFromLastDose => {
374                // For interval AUC: simulate at dense time grid and calculate AUC from last dose
375                let idelta = problem.settings.predictions().idelta;
376                let end_time = obs_times.last().copied().unwrap_or(0.0);
377
378                // Generate dense time grid from 0 to end_time (need full grid for intervals)
379                let dense_times = calculate_dense_times(0.0, end_time, &obs_times, idelta as usize);
380
381                // Create temporary subject with dense time points for simulation
382                let subject_id = target_subject.id().to_string();
383                let mut builder = Subject::builder(&subject_id);
384
385                // Add all doses from original subject
386                for occasion in target_subject.occasions() {
387                    for event in occasion.events() {
388                        match event {
389                            Event::Bolus(bolus) => {
390                                builder =
391                                    builder.bolus(bolus.time(), bolus.amount(), bolus.input());
392                            }
393                            Event::Infusion(infusion) => {
394                                builder = builder.infusion(
395                                    infusion.time(),
396                                    infusion.amount(),
397                                    infusion.input(),
398                                    infusion.duration(),
399                                );
400                            }
401                            Event::Observation(_) => {} // Skip original observations
402                        }
403                    }
404                }
405
406                // Collect observations with (time, outeq) pairs to preserve original order
407                let obs_time_outeq: Vec<(f64, usize)> = target_subject
408                    .occasions()
409                    .iter()
410                    .flat_map(|occ| occ.events())
411                    .filter_map(|event| match event {
412                        Event::Observation(obs) => Some((obs.time(), obs.outeq())),
413                        _ => None,
414                    })
415                    .collect();
416
417                let mut unique_outeqs: Vec<usize> =
418                    obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
419                unique_outeqs.sort();
420                unique_outeqs.dedup();
421
422                // Add observations at dense times
423                for outeq in unique_outeqs.iter() {
424                    for &t in &dense_times {
425                        builder = builder.missing_observation(t, *outeq);
426                    }
427                }
428
429                let dense_subject = builder.build();
430
431                // Simulate at dense times
432                let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
433                let dense_predictions_with_outeq = pred.0.predictions();
434
435                // Group predictions by outeq
436                let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
437                    std::collections::HashMap::new();
438
439                for prediction in dense_predictions_with_outeq {
440                    outeq_predictions
441                        .entry(prediction.outeq())
442                        .or_default()
443                        .push(prediction.prediction());
444                }
445
446                // Calculate interval AUC for each outeq separately
447                let mut outeq_aucs: std::collections::HashMap<usize, Vec<f64>> =
448                    std::collections::HashMap::new();
449
450                for &outeq in unique_outeqs.iter() {
451                    let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
452                        anyhow::anyhow!("Missing predictions for outeq {}", outeq)
453                    })?;
454
455                    // Get observation times for this outeq only
456                    let outeq_obs_times: Vec<f64> = obs_time_outeq
457                        .iter()
458                        .filter(|(_, o)| *o == outeq)
459                        .map(|(t, _)| *t)
460                        .collect();
461
462                    // Calculate interval AUC at observation times for this outeq
463                    let aucs = calculate_interval_auc_per_observation(
464                        &target_subject,
465                        &dense_times,
466                        outeq_preds,
467                        &outeq_obs_times,
468                    );
469                    outeq_aucs.insert(outeq, aucs);
470                }
471
472                // Build final AUC vector in original observation order
473                let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
474                let mut outeq_counters: std::collections::HashMap<usize, usize> =
475                    std::collections::HashMap::new();
476
477                for (_, outeq) in obs_time_outeq.iter() {
478                    let aucs = outeq_aucs
479                        .get(outeq)
480                        .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
481
482                    let counter = outeq_counters.entry(*outeq).or_insert(0);
483                    if *counter < aucs.len() {
484                        result_aucs.push(aucs[*counter]);
485                        *counter += 1;
486                    } else {
487                        return Err(anyhow::anyhow!(
488                            "AUC index out of bounds for outeq {}",
489                            outeq
490                        ));
491                    }
492                }
493
494                result_aucs
495            }
496        };
497
498        if preds_i.len() != n_obs {
499            return Err(anyhow::anyhow!(
500                "prediction length ({}) != observation length ({})",
501                preds_i.len(),
502                n_obs
503            ));
504        }
505
506        // Calculate variance term: weighted by POSTERIOR probability
507        let mut sumsq_i = 0.0_f64;
508        for (j, &obs_val) in obs_vec.iter().enumerate() {
509            let pj = preds_i[j];
510            let se = (obs_val - pj).powi(2);
511            sumsq_i += se;
512            // Calculate population mean using PRIOR probabilities
513            y_bar[j] += prior_prob * pj;
514        }
515
516        variance += post_prob * sumsq_i; // Weighted by posterior
517    }
518
519    // Calculate bias term: squared difference from population mean
520    let mut bias = 0.0_f64;
521    for (j, &obs_val) in obs_vec.iter().enumerate() {
522        bias += (obs_val - y_bar[j]).powi(2);
523    }
524
525    // Final cost: (1-λ)×Variance + λ×Bias²
526    // λ=0: Full personalization (minimize variance)
527    // λ=1: Population-based (minimize bias from population)
528    let cost = (1.0 - problem.bias_weight) * variance + problem.bias_weight * bias;
529
530    Ok(cost)
531}