pmcore/bestdose/predictions.rs
1//! Stage 3: Prediction calculations
2//!
3//! Handles final prediction calculations with optimal doses, including:
4//! - Dense time grid generation for AUC calculations
5//! - Trapezoidal AUC integration
6//! - Concentration-time predictions
7//!
8//! # AUC Calculation Method
9//!
10//! For [`Target::AUC`](crate::bestdose::Target::AUC) targets:
11//!
12//! 1. **Dense Time Grid**: Generate points at `idelta` intervals plus observation times
13//! 2. **Simulation**: Run model at all dense time points
14//! 3. **Trapezoidal Integration**: Calculate cumulative AUC:
15//! ```text
16//! AUC(t) = Σᵢ₌₁ⁿ (C[i] + C[i-1])/2 × (t[i] - t[i-1])
17//! ```
18//! 4. **Extraction**: Extract AUC values at target observation times
19//!
20//! # Key Functions
21//!
22//! - [`calculate_dense_times`]: Generate time grid for numerical integration
23//! - [`calculate_auc_at_times`]: Trapezoidal AUC calculation
24//! - [`calculate_final_predictions`]: Final predictions with optimal doses
25//!
26//! # See Also
27//!
28//! - Configuration: `settings.predictions().idelta` controls time grid resolution
29
30use anyhow::Result;
31use faer::Mat;
32
33use crate::bestdose::types::{BestDoseProblem, Target};
34use crate::routines::output::posterior::Posterior;
35use crate::routines::output::predictions::NPPredictions;
36use crate::structs::weights::Weights;
37use pharmsol::prelude::*;
38use pharmsol::Equation;
39
40/// Find the time of the last dose (bolus or infusion) before a given observation time
41///
42/// Returns the time of the most recent dose event that occurred before `obs_time`.
43/// If no dose exists before the observation time, returns 0.0.
44///
45/// # Arguments
46/// * `subject` - Subject containing dose events
47/// * `obs_time` - Observation time to find the preceding dose for
48///
49/// # Returns
50/// Time of the last dose before `obs_time`, or 0.0 if none exists
51///
52/// # Example
53/// ```rust,ignore
54/// let subject = Subject::builder("patient")
55/// .bolus(0.0, 100.0, 0)
56/// .bolus(12.0, 50.0, 0)
57/// .observation(18.0, 5.0, 0)
58/// .build();
59///
60/// let last_dose_time = find_last_dose_time_before(&subject, 18.0);
61/// assert_eq!(last_dose_time, 12.0);
62/// ```
63pub fn find_last_dose_time_before(subject: &Subject, obs_time: f64) -> f64 {
64 let mut last_dose_time = 0.0;
65
66 for occasion in subject.occasions() {
67 for event in occasion.events() {
68 let event_time = match event {
69 Event::Bolus(b) => Some(b.time()),
70 Event::Infusion(i) => Some(i.time()),
71 Event::Observation(_) => None,
72 };
73
74 if let Some(t) = event_time {
75 if t < obs_time && t > last_dose_time {
76 last_dose_time = t;
77 }
78 }
79 }
80 }
81
82 last_dose_time
83}
84
85/// Generate dense time grid for AUC calculations
86///
87/// Creates a grid with:
88/// - Observation times from the target
89/// - Intermediate points at `idelta` intervals
90/// - All times sorted and deduplicated
91///
92/// # Arguments
93/// * `start_time` - Start of time range
94/// * `end_time` - End of time range
95/// * `obs_times` - Required observation times (always included)
96/// * `idelta` - Time step for dense grid (minutes)
97///
98/// # Returns
99/// Sorted, unique time vector suitable for AUC calculation
100pub fn calculate_dense_times(
101 start_time: f64,
102 end_time: f64,
103 obs_times: &[f64],
104 idelta: usize,
105) -> Vec<f64> {
106 let idelta_hours = (idelta as f64) / 60.0;
107 let mut times = Vec::new();
108
109 // Add observation times
110 times.extend_from_slice(obs_times);
111
112 // Add regular grid points
113 let mut t = start_time;
114 while t <= end_time {
115 times.push(t);
116 t += idelta_hours;
117 }
118
119 // Ensure end time is included
120 if !times.contains(&end_time) {
121 times.push(end_time);
122 }
123
124 // Sort and deduplicate
125 times.sort_by(|a, b| a.partial_cmp(b).unwrap());
126
127 // Remove duplicates with tolerance
128 let tolerance = 1e-10;
129 let mut unique_times = Vec::new();
130 let mut last_time = f64::NEG_INFINITY;
131
132 for &t in × {
133 if (t - last_time).abs() > tolerance {
134 unique_times.push(t);
135 last_time = t;
136 }
137 }
138
139 unique_times
140}
141
142/// Calculate cumulative AUC at target times using trapezoidal rule
143///
144/// Takes dense concentration predictions and calculates cumulative AUC
145/// from the first time point. AUC values at target observation times
146/// are extracted and returned.
147///
148/// # Arguments
149/// * `dense_times` - Dense time grid (must include all `target_times`)
150/// * `dense_predictions` - Concentration predictions at `dense_times`
151/// * `target_times` - Observation times where AUC should be extracted
152///
153/// # Returns
154/// Vector of AUC values at `target_times`
155pub fn calculate_auc_at_times(
156 dense_times: &[f64],
157 dense_predictions: &[f64],
158 target_times: &[f64],
159) -> Vec<f64> {
160 assert_eq!(dense_times.len(), dense_predictions.len());
161
162 let mut target_aucs = Vec::with_capacity(target_times.len());
163 let mut auc = 0.0;
164 let mut target_idx = 0;
165 let tolerance = 1e-10;
166
167 for i in 1..dense_times.len() {
168 // Update cumulative AUC using trapezoidal rule
169 let dt = dense_times[i] - dense_times[i - 1];
170 let avg_conc = (dense_predictions[i] + dense_predictions[i - 1]) / 2.0;
171 auc += avg_conc * dt;
172
173 // Check if current time matches next target time
174 if target_idx < target_times.len()
175 && (dense_times[i] - target_times[target_idx]).abs() < tolerance
176 {
177 target_aucs.push(auc);
178 target_idx += 1;
179 }
180 }
181
182 target_aucs
183}
184
185/// Calculate interval AUC for each observation independently
186///
187/// For each observation at time t_i, calculates AUC from the last dose before t_i to t_i.
188/// This is useful for calculating dosing interval AUC (AUCτ) in steady-state scenarios.
189///
190/// # Arguments
191/// * `subject` - Subject with doses and observations
192/// * `dense_times` - Complete dense time grid covering all observations
193/// * `dense_predictions` - Concentration predictions at `dense_times`
194/// * `obs_times` - Observation times where interval AUC should be calculated
195///
196/// # Returns
197/// Vector of interval AUC values, one per observation
198///
199/// # Algorithm
200///
201/// For each observation time:
202/// 1. Find the most recent dose (bolus or infusion) before that observation
203/// 2. Locate that dose time in the dense grid
204/// 3. Apply trapezoidal rule from dose time to observation time
205/// 4. Return the interval AUC
206///
207/// # Example
208///
209/// ```rust,ignore
210/// let subject = Subject::builder("patient")
211/// .bolus(0.0, 100.0, 0) // First dose
212/// .bolus(12.0, 100.0, 0) // Second dose
213/// .observation(24.0, 200.0, 0) // Want AUC from t=12 to t=24
214/// .build();
215///
216/// // Dense grid from 0 to 24 hours
217/// let dense_times = vec![0.0, 1.0, 2.0, ..., 24.0];
218/// let dense_predictions = simulate_at_dense_times(...);
219/// let obs_times = vec![24.0];
220///
221/// let interval_aucs = calculate_interval_auc_per_observation(
222/// &subject, &dense_times, &dense_predictions, &obs_times
223/// );
224/// // interval_aucs[0] contains AUC from 12.0 to 24.0
225/// ```
226pub fn calculate_interval_auc_per_observation(
227 subject: &Subject,
228 dense_times: &[f64],
229 dense_predictions: &[f64],
230 obs_times: &[f64],
231) -> Vec<f64> {
232 assert_eq!(dense_times.len(), dense_predictions.len());
233
234 let mut interval_aucs = Vec::with_capacity(obs_times.len());
235 let tolerance = 1e-10;
236
237 for &obs_time in obs_times {
238 // Find the last dose time before this observation
239 let last_dose_time = find_last_dose_time_before(subject, obs_time);
240
241 // Find indices in dense_times that span [last_dose_time, obs_time]
242 let start_idx = dense_times
243 .iter()
244 .position(|&t| (t - last_dose_time).abs() < tolerance || t > last_dose_time)
245 .unwrap_or(0);
246
247 let end_idx = dense_times
248 .iter()
249 .position(|&t| (t - obs_time).abs() < tolerance || t > obs_time)
250 .unwrap_or(dense_times.len() - 1);
251
252 // Calculate AUC for this interval using trapezoidal rule
253 let mut auc = 0.0;
254 for i in (start_idx + 1)..=end_idx.min(dense_times.len() - 1) {
255 let dt = dense_times[i] - dense_times[i - 1];
256 let avg_conc = (dense_predictions[i] + dense_predictions[i - 1]) / 2.0;
257 auc += avg_conc * dt;
258 }
259
260 interval_aucs.push(auc);
261 }
262
263 interval_aucs
264}
265
266/// Calculate predictions for optimal doses
267///
268/// This generates the final NPPredictions structure with the optimal doses
269/// and appropriate weights (posterior or uniform depending on which optimization won).
270pub fn calculate_final_predictions(
271 problem: &BestDoseProblem,
272 optimal_doses: &[f64],
273 weights: &Weights,
274) -> Result<(NPPredictions, Option<Vec<(f64, f64)>>)> {
275 // Validate optimal_doses length matches total dose count (fixed + optimizable)
276 let expected_total_doses = problem
277 .target
278 .occasions()
279 .iter()
280 .flat_map(|occ| occ.events())
281 .filter(|event| matches!(event, Event::Bolus(_) | Event::Infusion(_)))
282 .count();
283
284 if optimal_doses.len() != expected_total_doses {
285 return Err(anyhow::anyhow!(
286 "Dose count mismatch in predictions: received {} optimal doses but expected {} total doses",
287 optimal_doses.len(),
288 expected_total_doses
289 ));
290 }
291
292 // Build subject with optimal doses
293 let mut target_with_optimal = problem.target.clone();
294 let mut dose_number = 0;
295
296 for occasion in target_with_optimal.iter_mut() {
297 for event in occasion.iter_mut() {
298 match event {
299 Event::Bolus(bolus) => {
300 bolus.set_amount(optimal_doses[dose_number]);
301 dose_number += 1;
302 }
303 Event::Infusion(infusion) => {
304 infusion.set_amount(optimal_doses[dose_number]);
305 dose_number += 1;
306 }
307 Event::Observation(_) => {}
308 }
309 }
310 }
311
312 // Create posterior matrix for predictions
313 let posterior_matrix = Mat::from_fn(1, weights.weights().nrows(), |_row, col| {
314 *weights.weights().get(col)
315 });
316 let posterior = Posterior::from(posterior_matrix);
317
318 // Calculate concentration predictions
319 let concentration_preds = NPPredictions::calculate(
320 &problem.eq,
321 &Data::new(vec![target_with_optimal.clone()]),
322 problem.theta.clone(),
323 weights,
324 &posterior,
325 0.0,
326 0.0,
327 )?;
328
329 // Calculate AUC predictions if in AUC mode
330 let auc_predictions = match problem.target_type {
331 Target::Concentration => None,
332 Target::AUCFromZero | Target::AUCFromLastDose => {
333 let obs_times: Vec<f64> = target_with_optimal
334 .occasions()
335 .iter()
336 .flat_map(|occ| occ.events())
337 .filter_map(|event| match event {
338 Event::Observation(obs) => Some(obs.time()),
339 _ => None,
340 })
341 .collect();
342
343 let idelta = problem.settings.predictions().idelta;
344 let start_time = 0.0;
345 let end_time = obs_times.last().copied().unwrap_or(0.0);
346 let dense_times =
347 calculate_dense_times(start_time, end_time, &obs_times, idelta as usize);
348
349 let subject_id = target_with_optimal.id().to_string();
350 let mut builder = Subject::builder(&subject_id);
351
352 // Copy all dose events from target_with_optimal (which already has optimal doses set)
353 for occasion in target_with_optimal.occasions() {
354 for event in occasion.events() {
355 match event {
356 Event::Bolus(bolus) => {
357 builder = builder.bolus(bolus.time(), bolus.amount(), bolus.input());
358 }
359 Event::Infusion(infusion) => {
360 builder = builder.infusion(
361 infusion.time(),
362 infusion.amount(),
363 infusion.input(),
364 infusion.duration(),
365 );
366 }
367 Event::Observation(_) => {}
368 }
369 }
370 }
371
372 // Collect observations with (time, outeq) pairs to preserve original order
373 let obs_time_outeq: Vec<(f64, usize)> = target_with_optimal
374 .occasions()
375 .iter()
376 .flat_map(|occ| occ.events())
377 .filter_map(|event| match event {
378 Event::Observation(obs) => Some((obs.time(), obs.outeq())),
379 _ => None,
380 })
381 .collect();
382
383 let mut unique_outeqs: Vec<usize> =
384 obs_time_outeq.iter().map(|(_, outeq)| *outeq).collect();
385 unique_outeqs.sort_unstable();
386 unique_outeqs.dedup();
387
388 // Add observations at dense times for each outeq
389 for outeq in unique_outeqs.iter() {
390 for &t in &dense_times {
391 builder = builder.missing_observation(t, *outeq);
392 }
393 }
394
395 let dense_subject = builder.build();
396
397 // Initialize AUC storage per outeq
398 let mut outeq_mean_aucs: std::collections::HashMap<usize, Vec<f64>> =
399 std::collections::HashMap::new();
400 for outeq in unique_outeqs.iter() {
401 let outeq_obs_times: Vec<f64> = obs_time_outeq
402 .iter()
403 .filter(|(_, o)| *o == *outeq)
404 .map(|(t, _)| *t)
405 .collect();
406 outeq_mean_aucs.insert(*outeq, vec![0.0; outeq_obs_times.len()]);
407 }
408
409 // Calculate AUC for each support point and accumulate weighted means
410 for (row, weight) in problem.theta.matrix().row_iter().zip(weights.iter()) {
411 let spp = row.iter().copied().collect::<Vec<f64>>();
412 let pred = problem.eq.simulate_subject(&dense_subject, &spp, None)?;
413 let dense_predictions_with_outeq = pred.0.predictions();
414
415 // Group predictions by outeq
416 let mut outeq_predictions: std::collections::HashMap<usize, Vec<f64>> =
417 std::collections::HashMap::new();
418
419 for prediction in dense_predictions_with_outeq {
420 outeq_predictions
421 .entry(prediction.outeq())
422 .or_default()
423 .push(prediction.prediction());
424 }
425
426 // Calculate AUC for each outeq separately based on mode
427 for &outeq in unique_outeqs.iter() {
428 let outeq_preds = outeq_predictions.get(&outeq).ok_or_else(|| {
429 anyhow::anyhow!("Missing predictions for outeq {}", outeq)
430 })?;
431
432 // Get observation times for this outeq only
433 let outeq_obs_times: Vec<f64> = obs_time_outeq
434 .iter()
435 .filter(|(_, o)| *o == outeq)
436 .map(|(t, _)| *t)
437 .collect();
438
439 // Calculate AUC at observation times for this outeq
440 let aucs = match problem.target_type {
441 Target::AUCFromZero => {
442 calculate_auc_at_times(&dense_times, outeq_preds, &outeq_obs_times)
443 }
444 Target::AUCFromLastDose => calculate_interval_auc_per_observation(
445 &target_with_optimal,
446 &dense_times,
447 outeq_preds,
448 &outeq_obs_times,
449 ),
450 Target::Concentration => unreachable!(),
451 };
452
453 // Accumulate weighted AUCs
454 let mean_aucs = outeq_mean_aucs.get_mut(&outeq).unwrap();
455 for (i, &auc) in aucs.iter().enumerate() {
456 mean_aucs[i] += weight * auc;
457 }
458 }
459 }
460
461 // Build final AUC vector in original observation order
462 let mut result_aucs = Vec::with_capacity(obs_time_outeq.len());
463 let mut outeq_counters: std::collections::HashMap<usize, usize> =
464 std::collections::HashMap::new();
465
466 for (_, outeq) in obs_time_outeq.iter() {
467 let aucs = outeq_mean_aucs
468 .get(outeq)
469 .ok_or_else(|| anyhow::anyhow!("Missing AUC for outeq {}", outeq))?;
470
471 let counter = outeq_counters.entry(*outeq).or_insert(0);
472 if *counter < aucs.len() {
473 result_aucs.push(aucs[*counter]);
474 *counter += 1;
475 } else {
476 return Err(anyhow::anyhow!(
477 "AUC index out of bounds for outeq {}",
478 outeq
479 ));
480 }
481 }
482
483 Some(
484 obs_time_outeq
485 .iter()
486 .map(|(t, _)| *t)
487 .zip(result_aucs)
488 .collect(),
489 )
490 }
491 };
492
493 Ok((concentration_preds, auc_predictions))
494}