pmcore/routines/output/
predictions.rs

1use anyhow::{bail, Result};
2use pharmsol::{prelude::simulator::Prediction, Censor, Data, Predictions as PredTrait};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    routines::output::{posterior::Posterior, weighted_median},
7    structs::{theta::Theta, weights::Weights},
8};
9
10/// Container for the multiple model estimated predictions
11///
12/// Each row contains the predictions for a single time point for a single subject
13/// It includes the population and posterior mean and median predictions
14/// These are defined by the mean and median of the prediction for each model, weighted by the population or posterior weights
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NPPredictionRow {
17    /// The subject ID
18    id: String,
19    /// The time of the prediction
20    time: f64,
21    /// The output equation number
22    outeq: usize,
23    /// The occasion of the prediction
24    block: usize,
25    /// The observed value, if any
26    obs: Option<f64>,
27    /// Censored observation flag
28    cens: Censor,
29    /// The population mean prediction
30    pop_mean: f64,
31    /// The population median prediction
32    pop_median: f64,
33    /// The posterior mean prediction
34    post_mean: f64,
35    /// The posterior median prediction
36    post_median: f64,
37}
38
39impl NPPredictionRow {
40    pub fn id(&self) -> &str {
41        &self.id
42    }
43    pub fn time(&self) -> f64 {
44        self.time
45    }
46    pub fn outeq(&self) -> usize {
47        self.outeq
48    }
49    pub fn block(&self) -> usize {
50        self.block
51    }
52    pub fn obs(&self) -> Option<f64> {
53        self.obs
54    }
55    pub fn pop_mean(&self) -> f64 {
56        self.pop_mean
57    }
58    pub fn pop_median(&self) -> f64 {
59        self.pop_median
60    }
61    pub fn post_mean(&self) -> f64 {
62        self.post_mean
63    }
64    pub fn post_median(&self) -> f64 {
65        self.post_median
66    }
67
68    pub fn censoring(&self) -> Censor {
69        self.cens
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct NPPredictions {
75    predictions: Vec<NPPredictionRow>,
76}
77
78impl IntoIterator for NPPredictions {
79    type Item = NPPredictionRow;
80    type IntoIter = std::vec::IntoIter<NPPredictionRow>;
81
82    fn into_iter(self) -> Self::IntoIter {
83        self.predictions.into_iter()
84    }
85}
86
87impl Default for NPPredictions {
88    fn default() -> Self {
89        NPPredictions::new()
90    }
91}
92
93impl NPPredictions {
94    pub fn new() -> Self {
95        NPPredictions {
96            predictions: Vec::new(),
97        }
98    }
99
100    /// Add a [NPPredictionRow] to the predictions
101    pub fn add(&mut self, row: NPPredictionRow) {
102        self.predictions.push(row);
103    }
104
105    /// Get a reference to the predictions
106    pub fn predictions(&self) -> &[NPPredictionRow] {
107        &self.predictions
108    }
109
110    /// Calculate the population and posterior predictions
111    ///
112    /// # Arguments
113    /// * `equation` - The equation to use for simulation
114    /// * `data` - The data to use for simulation
115    /// * `theta` - The theta values for the simulation
116    /// * `w` - The weights for the simulation
117    /// * `posterior` - The posterior probabilities for the simulation
118    /// * `idelta` - The delta for the simulation
119    /// * `tad` - The time after dose for the simulation
120    /// # Returns
121    /// A Result containing the NPPredictions or an error
122    pub fn calculate(
123        equation: &impl pharmsol::prelude::simulator::Equation,
124        data: &Data,
125        theta: &Theta,
126        w: &Weights,
127        posterior: &Posterior,
128        idelta: f64,
129        tad: f64,
130    ) -> Result<Self> {
131        // Create a new NPPredictions instance
132        let mut container = NPPredictions::new();
133
134        // Expand data
135        let data = data.clone().expand(idelta, tad);
136        let subjects = data.subjects();
137
138        if subjects.len() != posterior.matrix().nrows() {
139            bail!("Number of subjects and number of posterior means do not match");
140        };
141
142        // Iterate over each subject and then each support point
143        for subject in subjects.iter().enumerate() {
144            let (subject_index, subject) = subject;
145
146            // Container for predictions for this subject
147            // This will hold predictions for each support point
148            // The outer vector is for each support point
149            // The inner vector is for the vector of predictions for that support point
150            let mut predictions: Vec<Vec<Prediction>> = Vec::new();
151
152            // And each support points
153            for spp in theta.matrix().row_iter() {
154                // Simulate the subject with the current support point
155                let spp_values = spp.iter().cloned().collect::<Vec<f64>>();
156                let pred = equation
157                    .simulate_subject(subject, &spp_values, None)?
158                    .0
159                    .get_predictions();
160                predictions.push(pred);
161            }
162
163            if predictions.is_empty() {
164                continue; // Skip this subject if no predictions are available
165            }
166
167            // Calculate population mean using
168            let mut pop_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
169            for outer_pred in predictions.iter().enumerate() {
170                let (i, outer_pred) = outer_pred;
171                for inner_pred in outer_pred.iter().enumerate() {
172                    let (j, pred) = inner_pred;
173                    pop_mean[j] += pred.prediction() * w[i];
174                }
175            }
176
177            // Calculate population median
178            let mut pop_median: Vec<f64> = Vec::new();
179            for j in 0..predictions.first().unwrap().len() {
180                let mut values: Vec<f64> = Vec::new();
181                let mut weights: Vec<f64> = Vec::new();
182
183                for (i, outer_pred) in predictions.iter().enumerate() {
184                    values.push(outer_pred[j].prediction());
185                    weights.push(w[i]);
186                }
187
188                let median_val = weighted_median(&values, &weights);
189                pop_median.push(median_val);
190            }
191
192            // Calculate posterior mean
193            let mut posterior_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
194            for outer_pred in predictions.iter().enumerate() {
195                let (i, outer_pred) = outer_pred;
196                for inner_pred in outer_pred.iter().enumerate() {
197                    let (j, pred) = inner_pred;
198                    posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)];
199                }
200            }
201
202            // Calculate posterior median
203            let mut posterior_median: Vec<f64> = Vec::new();
204            for j in 0..predictions.first().unwrap().len() {
205                let mut values: Vec<f64> = Vec::new();
206                let mut weights: Vec<f64> = Vec::new();
207
208                for (i, outer_pred) in predictions.iter().enumerate() {
209                    values.push(outer_pred[j].prediction());
210                    weights.push(posterior.matrix()[(subject_index, i)]);
211                }
212
213                let median_val = weighted_median(&values, &weights);
214                posterior_median.push(median_val);
215            }
216
217            // Iterate over the aggregated predictions (one row per timepoint per subject)
218            // Use the first support point predictions to get time, outeq, block, and obs info
219            if let Some(first_spp_preds) = predictions.first() {
220                for (j, p) in first_spp_preds.iter().enumerate() {
221                    let row = NPPredictionRow {
222                        id: subject.id().clone(),
223                        time: p.time(),
224                        outeq: p.outeq(),
225                        block: p.occasion(),
226                        obs: p.observation(),
227                        cens: p.censoring(),
228                        pop_mean: pop_mean[j],
229                        pop_median: pop_median[j],
230                        post_mean: posterior_mean[j],
231                        post_median: posterior_median[j],
232                    };
233                    container.add(row);
234                }
235            }
236        }
237
238        Ok(container)
239    }
240}