pmcore/routines/output/
predictions.rs

1use anyhow::{bail, Result};
2use pharmsol::{prelude::simulator::Prediction, Censor, Data, Predictions as PredTrait};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    routines::output::{posterior::Posterior, weighted_median},
7    structs::{theta::Theta, weights::Weights},
8};
9
10// Structure for the output
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct NPPredictionRow {
13    /// The subject ID
14    id: String,
15    /// The time of the prediction
16    time: f64,
17    /// The output equation number
18    outeq: usize,
19    /// The occasion of the prediction
20    block: usize,
21    /// The observed value, if any
22    obs: Option<f64>,
23    /// Censored observation flag
24    censoring: Censor,
25    /// The population mean prediction
26    pop_mean: f64,
27    /// The population median prediction
28    pop_median: f64,
29    /// The posterior mean prediction
30    post_mean: f64,
31    /// The posterior median prediction
32    post_median: f64,
33}
34
35impl NPPredictionRow {
36    pub fn id(&self) -> &str {
37        &self.id
38    }
39    pub fn time(&self) -> f64 {
40        self.time
41    }
42    pub fn outeq(&self) -> usize {
43        self.outeq
44    }
45    pub fn block(&self) -> usize {
46        self.block
47    }
48    pub fn obs(&self) -> Option<f64> {
49        self.obs
50    }
51    pub fn pop_mean(&self) -> f64 {
52        self.pop_mean
53    }
54    pub fn pop_median(&self) -> f64 {
55        self.pop_median
56    }
57    pub fn post_mean(&self) -> f64 {
58        self.post_mean
59    }
60    pub fn post_median(&self) -> f64 {
61        self.post_median
62    }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct NPPredictions {
67    predictions: Vec<NPPredictionRow>,
68}
69
70impl IntoIterator for NPPredictions {
71    type Item = NPPredictionRow;
72    type IntoIter = std::vec::IntoIter<NPPredictionRow>;
73
74    fn into_iter(self) -> Self::IntoIter {
75        self.predictions.into_iter()
76    }
77}
78
79impl Default for NPPredictions {
80    fn default() -> Self {
81        NPPredictions::new()
82    }
83}
84
85impl NPPredictions {
86    pub fn new() -> Self {
87        NPPredictions {
88            predictions: Vec::new(),
89        }
90    }
91
92    /// Add a [NPPredictionRow] to the predictions
93    pub fn add(&mut self, row: NPPredictionRow) {
94        self.predictions.push(row);
95    }
96
97    /// Get a reference to the predictions
98    pub fn predictions(&self) -> &[NPPredictionRow] {
99        &self.predictions
100    }
101
102    /// Calculate the population and posterior predictions
103    ///
104    /// # Arguments
105    /// * `equation` - The equation to use for simulation
106    /// * `data` - The data to use for simulation
107    /// * `theta` - The theta values for the simulation
108    /// * `w` - The weights for the simulation
109    /// * `posterior` - The posterior probabilities for the simulation
110    /// * `idelta` - The delta for the simulation
111    /// * `tad` - The time after dose for the simulation
112    /// # Returns
113    /// A Result containing the NPPredictions or an error
114    pub fn calculate(
115        equation: &impl pharmsol::prelude::simulator::Equation,
116        data: &Data,
117        theta: Theta,
118        w: &Weights,
119        posterior: &Posterior,
120        idelta: f64,
121        tad: f64,
122    ) -> Result<Self> {
123        // Create a new NPPredictions instance
124        let mut container = NPPredictions::new();
125
126        // Expand data
127        let data = data.clone().expand(idelta, tad);
128        let subjects = data.subjects();
129
130        if subjects.len() != posterior.matrix().nrows() {
131            bail!("Number of subjects and number of posterior means do not match");
132        };
133
134        // Iterate over each subject and then each support point
135        for subject in subjects.iter().enumerate() {
136            let (subject_index, subject) = subject;
137
138            // Container for predictions for this subject
139            // This will hold predictions for each support point
140            // The outer vector is for each support point
141            // The inner vector is for the vector of predictions for that support point
142            let mut predictions: Vec<Vec<Prediction>> = Vec::new();
143
144            // And each support points
145            for spp in theta.matrix().row_iter() {
146                // Simulate the subject with the current support point
147                let spp_values = spp.iter().cloned().collect::<Vec<f64>>();
148                let pred = equation
149                    .simulate_subject(subject, &spp_values, None)?
150                    .0
151                    .get_predictions();
152                predictions.push(pred);
153            }
154
155            if predictions.is_empty() {
156                continue; // Skip this subject if no predictions are available
157            }
158
159            // Calculate population mean using
160            let mut pop_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
161            for outer_pred in predictions.iter().enumerate() {
162                let (i, outer_pred) = outer_pred;
163                for inner_pred in outer_pred.iter().enumerate() {
164                    let (j, pred) = inner_pred;
165                    pop_mean[j] += pred.prediction() * w[i];
166                }
167            }
168
169            // Calculate population median
170            let mut pop_median: Vec<f64> = Vec::new();
171            for j in 0..predictions.first().unwrap().len() {
172                let mut values: Vec<f64> = Vec::new();
173                let mut weights: Vec<f64> = Vec::new();
174
175                for (i, outer_pred) in predictions.iter().enumerate() {
176                    values.push(outer_pred[j].prediction());
177                    weights.push(w[i]);
178                }
179
180                let median_val = weighted_median(&values, &weights);
181                pop_median.push(median_val);
182            }
183
184            // Calculate posterior mean
185            let mut posterior_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
186            for outer_pred in predictions.iter().enumerate() {
187                let (i, outer_pred) = outer_pred;
188                for inner_pred in outer_pred.iter().enumerate() {
189                    let (j, pred) = inner_pred;
190                    posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)];
191                }
192            }
193
194            // Calculate posterior median
195            let mut posterior_median: Vec<f64> = Vec::new();
196            for j in 0..predictions.first().unwrap().len() {
197                let mut values: Vec<f64> = Vec::new();
198                let mut weights: Vec<f64> = Vec::new();
199
200                for (i, outer_pred) in predictions.iter().enumerate() {
201                    values.push(outer_pred[j].prediction());
202                    weights.push(posterior.matrix()[(subject_index, i)]);
203                }
204
205                let median_val = weighted_median(&values, &weights);
206                posterior_median.push(median_val);
207            }
208
209            // Iterate over the aggregated predictions (one row per timepoint per subject)
210            // Use the first support point predictions to get time, outeq, block, and obs info
211            if let Some(first_spp_preds) = predictions.first() {
212                for (j, p) in first_spp_preds.iter().enumerate() {
213                    let row = NPPredictionRow {
214                        id: subject.id().clone(),
215                        time: p.time(),
216                        outeq: p.outeq(),
217                        block: p.occasion(),
218                        obs: p.observation(),
219                        censoring: p.censoring(),
220                        pop_mean: pop_mean[j],
221                        pop_median: pop_median[j],
222                        post_mean: posterior_mean[j],
223                        post_median: posterior_median[j],
224                    };
225                    container.add(row);
226                }
227            }
228        }
229
230        Ok(container)
231    }
232}