pmcore/bestdose/
predictions.rs

1//! Stage 3: Prediction calculations
2//!
3//! Handles final prediction calculations with optimal doses, including:
4//! - Dense time grid generation for AUC calculations
5//! - Trapezoidal AUC integration
6//! - Concentration-time predictions
7//!
8//! # AUC Calculation Method
9//!
10//! For [`Target::AUC`](crate::bestdose::Target::AUC) targets:
11//!
12//! 1. **Dense Time Grid**: Generate points at `idelta` intervals plus observation times
13//! 2. **Simulation**: Run model at all dense time points
14//! 3. **Trapezoidal Integration**: Calculate cumulative AUC:
15//!    ```text
16//!    AUC(t) = Σᵢ₌₁ⁿ (C[i] + C[i-1])/2 × (t[i] - t[i-1])
17//!    ```
18//! 4. **Extraction**: Extract AUC values at target observation times
19//!
20//! # Key Functions
21//!
22//! - [`calculate_dense_times`]: Generate time grid for numerical integration
23//! - [`calculate_auc_at_times`]: Trapezoidal AUC calculation
24//! - [`calculate_final_predictions`]: Final predictions with optimal doses
25//!
26//! # See Also
27//!
28//! - Configuration: `settings.predictions().idelta` controls time grid resolution
29
30use anyhow::Result;
31use faer::Mat;
32
33use crate::bestdose::types::{BestDoseProblem, Target};
34use crate::routines::output::posterior::Posterior;
35use crate::routines::output::predictions::NPPredictions;
36use crate::structs::weights::Weights;
37use pharmsol::prelude::*;
38use pharmsol::Equation;
39
40/// Find the time of the last dose (bolus or infusion) before a given observation time
41///
42/// Returns the time of the most recent dose event that occurred before `obs_time`.
43/// If no dose exists before the observation time, returns 0.0.
44///
45/// # Arguments
46/// * `subject` - Subject containing dose events
47/// * `obs_time` - Observation time to find the preceding dose for
48///
49/// # Returns
50/// Time of the last dose before `obs_time`, or 0.0 if none exists
51///
52/// # Example
53/// ```rust,ignore
54/// let subject = Subject::builder("patient")
55///     .bolus(0.0, 100.0, 0)
56///     .bolus(12.0, 50.0, 0)
57///     .observation(18.0, 5.0, 0)
58///     .build();
59///
60/// let last_dose_time = find_last_dose_time_before(&subject, 18.0);
61/// assert_eq!(last_dose_time, 12.0);
62/// ```
63pub fn find_last_dose_time_before(subject: &Subject, obs_time: f64) -> f64 {
64    let mut last_dose_time = 0.0;
65
66    for occasion in subject.occasions() {
67        for event in occasion.events() {
68            let event_time = match event {
69                Event::Bolus(b) => Some(b.time()),
70                Event::Infusion(i) => Some(i.time()),
71                Event::Observation(_) => None,
72            };
73
74            if let Some(t) = event_time {
75                if t < obs_time && t > last_dose_time {
76                    last_dose_time = t;
77                }
78            }
79        }
80    }
81
82    last_dose_time
83}
84
85/// Generate dense time grid for AUC calculations
86///
87/// Creates a grid with:
88/// - Observation times from the target
89/// - Intermediate points at `idelta` intervals
90/// - All times sorted and deduplicated
91///
92/// # Arguments
93/// * `start_time` - Start of time range
94/// * `end_time` - End of time range
95/// * `obs_times` - Required observation times (always included)
96/// * `idelta` - Time step for dense grid (minutes)
97///
98/// # Returns
99/// Sorted, unique time vector suitable for AUC calculation
100pub fn calculate_dense_times(
101    start_time: f64,
102    end_time: f64,
103    obs_times: &[f64],
104    idelta: usize,
105) -> Vec<f64> {
106    let idelta_hours = (idelta as f64) / 60.0;
107    let mut times = Vec::new();
108
109    // Add observation times
110    times.extend_from_slice(obs_times);
111
112    // Add regular grid points
113    let mut t = start_time;
114    while t <= end_time {
115        times.push(t);
116        t += idelta_hours;
117    }
118
119    // Ensure end time is included
120    if !times.contains(&end_time) {
121        times.push(end_time);
122    }
123
124    // Sort and deduplicate
125    times.sort_by(|a, b| a.partial_cmp(b).unwrap());
126
127    // Remove duplicates with tolerance
128    let tolerance = 1e-10;
129    let mut unique_times = Vec::new();
130    let mut last_time = f64::NEG_INFINITY;
131
132    for &t in &times {
133        if (t - last_time).abs() > tolerance {
134            unique_times.push(t);
135            last_time = t;
136        }
137    }
138
139    unique_times
140}
141
142/// Calculate cumulative AUC at target times using trapezoidal rule
143///
144/// Takes dense concentration predictions and calculates cumulative AUC
145/// from the first time point. AUC values at target observation times
146/// are extracted and returned.
147///
148/// # Arguments
149/// * `dense_times` - Dense time grid (must include all `target_times`)
150/// * `dense_predictions` - Concentration predictions at `dense_times`
151/// * `target_times` - Observation times where AUC should be extracted
152///
153/// # Returns
154/// Vector of AUC values at `target_times`
155pub fn calculate_auc_at_times(
156    dense_times: &[f64],
157    dense_predictions: &[f64],
158    target_times: &[f64],
159) -> Vec<f64> {
160    assert_eq!(dense_times.len(), dense_predictions.len());
161
162    let mut target_aucs = Vec::with_capacity(target_times.len());
163    let mut auc = 0.0;
164    let mut target_idx = 0;
165    let tolerance = 1e-10;
166
167    for i in 1..dense_times.len() {
168        // Update cumulative AUC using trapezoidal rule
169        let dt = dense_times[i] - dense_times[i - 1];
170        let avg_conc = (dense_predictions[i] + dense_predictions[i - 1]) / 2.0;
171        auc += avg_conc * dt;
172
173        // Check if current time matches next target time
174        if target_idx < target_times.len()
175            && (dense_times[i] - target_times[target_idx]).abs() < tolerance
176        {
177            target_aucs.push(auc);
178            target_idx += 1;
179        }
180    }
181
182    target_aucs
183}
184
185/// Calculate interval AUC for each observation independently
186///
187/// For each observation at time t_i, calculates AUC from the last dose before t_i to t_i.
188/// This is useful for calculating dosing interval AUC (AUCτ) in steady-state scenarios.
189///
190/// # Arguments
191/// * `subject` - Subject with doses and observations
192/// * `dense_times` - Complete dense time grid covering all observations
193/// * `dense_predictions` - Concentration predictions at `dense_times`
194/// * `obs_times` - Observation times where interval AUC should be calculated
195///
196/// # Returns
197/// Vector of interval AUC values, one per observation
198///
199/// # Algorithm
200///
201/// For each observation time:
202/// 1. Find the most recent dose (bolus or infusion) before that observation
203/// 2. Locate that dose time in the dense grid
204/// 3. Apply trapezoidal rule from dose time to observation time
205/// 4. Return the interval AUC
206///
207/// # Example
208///
209/// ```rust,ignore
210/// let subject = Subject::builder("patient")
211///     .bolus(0.0, 100.0, 0)      // First dose
212///     .bolus(12.0, 100.0, 0)     // Second dose
213///     .observation(24.0, 200.0, 0)  // Want AUC from t=12 to t=24
214///     .build();
215///
216/// // Dense grid from 0 to 24 hours
217/// let dense_times = vec![0.0, 1.0, 2.0, ..., 24.0];
218/// let dense_predictions = simulate_at_dense_times(...);
219/// let obs_times = vec![24.0];
220///
221/// let interval_aucs = calculate_interval_auc_per_observation(
222///     &subject, &dense_times, &dense_predictions, &obs_times
223/// );
224/// // interval_aucs[0] contains AUC from 12.0 to 24.0
225/// ```
226pub fn calculate_interval_auc_per_observation(
227    subject: &Subject,
228    dense_times: &[f64],
229    dense_predictions: &[f64],
230    obs_times: &[f64],
231) -> Vec<f64> {
232    assert_eq!(dense_times.len(), dense_predictions.len());
233
234    let mut interval_aucs = Vec::with_capacity(obs_times.len());
235    let tolerance = 1e-10;
236
237    for &obs_time in obs_times {
238        // Find the last dose time before this observation
239        let last_dose_time = find_last_dose_time_before(subject, obs_time);
240
241        // Find indices in dense_times that span [last_dose_time, obs_time]
242        let start_idx = dense_times
243            .iter()
244            .position(|&t| (t - last_dose_time).abs() < tolerance || t > last_dose_time)
245            .unwrap_or(0);
246
247        let end_idx = dense_times
248            .iter()
249            .position(|&t| (t - obs_time).abs() < tolerance || t > obs_time)
250            .unwrap_or(dense_times.len() - 1);
251
252        // Calculate AUC for this interval using trapezoidal rule
253        let mut auc = 0.0;
254        for i in (start_idx + 1)..=end_idx.min(dense_times.len() - 1) {
255            let dt = dense_times[i] - dense_times[i - 1];
256            let avg_conc = (dense_predictions[i] + dense_predictions[i - 1]) / 2.0;
257            auc += avg_conc * dt;
258        }
259
260        interval_aucs.push(auc);
261    }
262
263    interval_aucs
264}
265
266/// Calculate predictions for optimal doses
267///
268/// This generates the final NPPredictions structure with the optimal doses
269/// and appropriate weights (posterior or uniform depending on which optimization won).
270pub fn calculate_final_predictions(
271    problem: &BestDoseProblem,
272    optimal_doses: &[f64],
273    weights: &Weights,
274) -> Result<(NPPredictions, Option<Vec<(f64, f64)>>)> {
275    // Validate optimal_doses length matches total dose count (fixed + optimizable)
276    let expected_total_doses = problem
277        .target
278        .occasions()
279        .iter()
280        .flat_map(|occ| occ.events())
281        .filter(|event| matches!(event, Event::Bolus(_) | Event::Infusion(_)))
282        .count();
283
284    if optimal_doses.len() != expected_total_doses {
285        return Err(anyhow::anyhow!(
286            "Dose count mismatch in predictions: received {} optimal doses but expected {} total doses",
287            optimal_doses.len(),
288            expected_total_doses
289        ));
290    }
291
292    // Build subject with optimal doses
293    let mut target_with_optimal = problem.target.clone();
294    let mut dose_number = 0;
295
296    for occasion in target_with_optimal.iter_mut() {
297        for event in occasion.iter_mut() {
298            match event {
299                Event::Bolus(bolus) => {
300                    bolus.set_amount(optimal_doses[dose_number]);
301                    dose_number += 1;
302                }
303                Event::Infusion(infusion) => {
304                    infusion.set_amount(optimal_doses[dose_number]);
305                    dose_number += 1;
306                }
307                Event::Observation(_) => {}
308            }
309        }
310    }
311
312    // Create posterior matrix for predictions
313    let posterior_matrix = Mat::from_fn(1, weights.weights().nrows(), |_row, col| {
314        *weights.weights().get(col)
315    });
316    let posterior = Posterior::from(posterior_matrix);
317
318    // Calculate concentration predictions
319    let concentration_preds = NPPredictions::calculate(
320        &problem.eq,
321        &Data::new(vec![target_with_optimal.clone()]),
322        problem.theta.clone(),
323        weights,
324        &posterior,
325        0.0,
326        0.0,
327    )?;
328
329    // Calculate AUC predictions if in AUC mode
330    let auc_predictions = match problem.target_type {
331        Target::Concentration => None,
332        Target::AUCFromZero | Target::AUCFromLastDose => {
333            let obs_times: Vec<f64> = target_with_optimal
334                .occasions()
335                .iter()
336                .flat_map(|occ| occ.events())
337                .filter_map(|event| match event {
338                    Event::Observation(obs) => Some(obs.time()),
339                    _ => None,
340                })
341                .collect();
342
343            let idelta = problem.settings.predictions().idelta;
344            let start_time = 0.0;
345            let end_time = obs_times.last().copied().unwrap_or(0.0);
346            let dense_times =
347                calculate_dense_times(start_time, end_time, &obs_times, idelta as usize);
348
349            let subject_id = target_with_optimal.id().to_string();
350            let mut builder = Subject::builder(&subject_id);
351
352            // Copy all dose events from target_with_optimal (which already has optimal doses set)
353            for occasion in target_with_optimal.occasions() {
354                for event in occasion.events() {
355                    match event {
356                        Event::Bolus(bolus) => {
357                            builder = builder.bolus(bolus.time(), bolus.amount(), bolus.input());
358                        }
359                        Event::Infusion(infusion) => {
360                            builder = builder.infusion(
361                                infusion.time(),
362                                infusion.amount(),
363                                infusion.input(),
364                                infusion.duration(),
365                            );
366                        }
367                        Event::Observation(_) => {}
368                    }
369                }
370            }
371
372            // Collect observations with (time, outeq) pairs to preserve original order
373            let obs_time_outeq: Vec<(f64, usize)> = target_with_optimal
374                .occasions()
375                .iter()
376                .flat_map(|occ| occ.events())
377                .filter_map(|event| match event {
378                    Event::Observation(obs) => Some((obs.time(), obs.outeq())),
379                    _ => None,
380                })
381                .collect();
382
383            let mut unique_outeqs: Vec<usize> =
384                obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
385            unique_outeqs.sort_unstable();
386            unique_outeqs.dedup();
387
388            // Add observations at dense times for each outeq
389            for outeq in unique_outeqs.iter() {
390                for &t in &dense_times {
391                    builder = builder.missing_observation(t, *outeq);
392                }
393            }
394
395            let dense_subject = builder.build();
396
397            // Initialize AUC storage per outeq
398            let mut outeq_mean_aucs: std::collections::HashMap<usize, Vec<f64>> =
399                std::collections::HashMap::new();
400            for outeq in unique_outeqs.iter() {
401                let outeq_obs_times: Vec<f64> = obs_time_outeq
402                    .iter()
403                    .filter(|(_, o)| *o == *outeq)
404                    .map(|(t, _)| *t)
405                    .collect();
406                outeq_mean_aucs.insert(*outeq, vec![0.0; outeq_obs_times.len()]);
407            }
408
409            // Calculate AUC for each support point and accumulate weighted means
410            for (row, weight) in problem.theta.matrix().row_iter().zip(weights.iter()) {
411                let spp = row.iter().copied().collect::<Vec<f64>>();
412                let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
413                let dense_predictions_with_outeq = pred.0.predictions();
414
415                // Group predictions by outeq
416                let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
417                    std::collections::HashMap::new();
418
419                for prediction in dense_predictions_with_outeq {
420                    outeq_predictions
421                        .entry(prediction.outeq())
422                        .or_default()
423                        .push(prediction.prediction());
424                }
425
426                // Calculate AUC for each outeq separately based on mode
427                for &outeq in unique_outeqs.iter() {
428                    let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
429                        anyhow::anyhow!("Missing predictions for outeq {}", outeq)
430                    })?;
431
432                    // Get observation times for this outeq only
433                    let outeq_obs_times: Vec<f64> = obs_time_outeq
434                        .iter()
435                        .filter(|(_, o)| *o == outeq)
436                        .map(|(t, _)| *t)
437                        .collect();
438
439                    // Calculate AUC at observation times for this outeq
440                    let aucs = match problem.target_type {
441                        Target::AUCFromZero => {
442                            calculate_auc_at_times(&dense_times, outeq_preds, &outeq_obs_times)
443                        }
444                        Target::AUCFromLastDose => calculate_interval_auc_per_observation(
445                            &target_with_optimal,
446                            &dense_times,
447                            outeq_preds,
448                            &outeq_obs_times,
449                        ),
450                        Target::Concentration => unreachable!(),
451                    };
452
453                    // Accumulate weighted AUCs
454                    let mean_aucs = outeq_mean_aucs.get_mut(&outeq).unwrap();
455                    for (i, &auc) in aucs.iter().enumerate() {
456                        mean_aucs[i] += weight * auc;
457                    }
458                }
459            }
460
461            // Build final AUC vector in original observation order
462            let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
463            let mut outeq_counters: std::collections::HashMap<usize, usize> =
464                std::collections::HashMap::new();
465
466            for (_, outeq) in obs_time_outeq.iter() {
467                let aucs = outeq_mean_aucs
468                    .get(outeq)
469                    .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
470
471                let counter = outeq_counters.entry(*outeq).or_insert(0);
472                if *counter < aucs.len() {
473                    result_aucs.push(aucs[*counter]);
474                    *counter += 1;
475                } else {
476                    return Err(anyhow::anyhow!(
477                        "AUC index out of bounds for outeq {}",
478                        outeq
479                    ));
480                }
481            }
482
483            Some(
484                obs_time_outeq
485                    .iter()
486                    .map(|(t, _)| *t)
487                    .zip(result_aucs)
488                    .collect(),
489            )
490        }
491    };
492
493    Ok((concentration_preds, auc_predictions))
494}