pmcore/routines/output/
predictions.rs

1use anyhow::{bail, Result};
2use pharmsol::{prelude::simulator::Prediction, Data, Predictions as PredTrait};
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    routines::output::{posterior::Posterior, weighted_median},
7    structs::{theta::Theta, weights::Weights},
8};
9
10// Structure for the output
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct NPPredictionRow {
13    /// The subject ID
14    id: String,
15    /// The time of the prediction
16    time: f64,
17    /// The output equation number
18    outeq: usize,
19    /// The occasion of the prediction
20    block: usize,
21    /// The observed value, if any
22    obs: Option<f64>,
23    /// The population mean prediction
24    pop_mean: f64,
25    /// The population median prediction
26    pop_median: f64,
27    /// The posterior mean prediction
28    post_mean: f64,
29    /// The posterior median prediction
30    post_median: f64,
31}
32
33impl NPPredictionRow {
34    pub fn id(&self) -> &str {
35        &self.id
36    }
37    pub fn time(&self) -> f64 {
38        self.time
39    }
40    pub fn outeq(&self) -> usize {
41        self.outeq
42    }
43    pub fn block(&self) -> usize {
44        self.block
45    }
46    pub fn obs(&self) -> Option<f64> {
47        self.obs
48    }
49    pub fn pop_mean(&self) -> f64 {
50        self.pop_mean
51    }
52    pub fn pop_median(&self) -> f64 {
53        self.pop_median
54    }
55    pub fn post_mean(&self) -> f64 {
56        self.post_mean
57    }
58    pub fn post_median(&self) -> f64 {
59        self.post_median
60    }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct NPPredictions {
65    predictions: Vec<NPPredictionRow>,
66}
67
68impl IntoIterator for NPPredictions {
69    type Item = NPPredictionRow;
70    type IntoIter = std::vec::IntoIter<NPPredictionRow>;
71
72    fn into_iter(self) -> Self::IntoIter {
73        self.predictions.into_iter()
74    }
75}
76
77impl Default for NPPredictions {
78    fn default() -> Self {
79        NPPredictions::new()
80    }
81}
82
83impl NPPredictions {
84    pub fn new() -> Self {
85        NPPredictions {
86            predictions: Vec::new(),
87        }
88    }
89
90    /// Add a [NPPredictionRow] to the predictions
91    pub fn add(&mut self, row: NPPredictionRow) {
92        self.predictions.push(row);
93    }
94
95    /// Get a reference to the predictions
96    pub fn predictions(&self) -> &[NPPredictionRow] {
97        &self.predictions
98    }
99
100    /// Calculate the population and posterior predictions
101    ///
102    /// # Arguments
103    /// * `equation` - The equation to use for simulation
104    /// * `data` - The data to use for simulation
105    /// * `theta` - The theta values for the simulation
106    /// * `w` - The weights for the simulation
107    /// * `posterior` - The posterior probabilities for the simulation
108    /// * `idelta` - The delta for the simulation
109    /// * `tad` - The time after dose for the simulation
110    /// # Returns
111    /// A Result containing the NPPredictions or an error
112    pub fn calculate(
113        equation: &impl pharmsol::prelude::simulator::Equation,
114        data: &Data,
115        theta: Theta,
116        w: &Weights,
117        posterior: &Posterior,
118        idelta: f64,
119        tad: f64,
120    ) -> Result<Self> {
121        // Create a new NPPredictions instance
122        let mut container = NPPredictions::new();
123
124        // Expand data
125        let data = data.clone().expand(idelta, tad);
126        let subjects = data.subjects();
127
128        if subjects.len() != posterior.matrix().nrows() {
129            bail!("Number of subjects and number of posterior means do not match");
130        };
131
132        // Iterate over each subject and then each support point
133        for subject in subjects.iter().enumerate() {
134            let (subject_index, subject) = subject;
135
136            // Container for predictions for this subject
137            // This will hold predictions for each support point
138            // The outer vector is for each support point
139            // The inner vector is for the vector of predictions for that support point
140            let mut predictions: Vec<Vec<Prediction>> = Vec::new();
141
142            // And each support points
143            for spp in theta.matrix().row_iter() {
144                // Simulate the subject with the current support point
145                let spp_values = spp.iter().cloned().collect::<Vec<f64>>();
146                let pred = equation
147                    .simulate_subject(subject, &spp_values, None)?
148                    .0
149                    .get_predictions();
150                predictions.push(pred);
151            }
152
153            if predictions.is_empty() {
154                continue; // Skip this subject if no predictions are available
155            }
156
157            // Calculate population mean using
158            let mut pop_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
159            for outer_pred in predictions.iter().enumerate() {
160                let (i, outer_pred) = outer_pred;
161                for inner_pred in outer_pred.iter().enumerate() {
162                    let (j, pred) = inner_pred;
163                    pop_mean[j] += pred.prediction() * w[i];
164                }
165            }
166
167            // Calculate population median
168            let mut pop_median: Vec<f64> = Vec::new();
169            for j in 0..predictions.first().unwrap().len() {
170                let mut values: Vec<f64> = Vec::new();
171                let mut weights: Vec<f64> = Vec::new();
172
173                for (i, outer_pred) in predictions.iter().enumerate() {
174                    values.push(outer_pred[j].prediction());
175                    weights.push(w[i]);
176                }
177
178                let median_val = weighted_median(&values, &weights);
179                pop_median.push(median_val);
180            }
181
182            // Calculate posterior mean
183            let mut posterior_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
184            for outer_pred in predictions.iter().enumerate() {
185                let (i, outer_pred) = outer_pred;
186                for inner_pred in outer_pred.iter().enumerate() {
187                    let (j, pred) = inner_pred;
188                    posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)];
189                }
190            }
191
192            // Calculate posterior median
193            let mut posterior_median: Vec<f64> = Vec::new();
194            for j in 0..predictions.first().unwrap().len() {
195                let mut values: Vec<f64> = Vec::new();
196                let mut weights: Vec<f64> = Vec::new();
197
198                for (i, outer_pred) in predictions.iter().enumerate() {
199                    values.push(outer_pred[j].prediction());
200                    weights.push(posterior.matrix()[(subject_index, i)]);
201                }
202
203                let median_val = weighted_median(&values, &weights);
204                posterior_median.push(median_val);
205            }
206
207            // Iterate over the aggregated predictions (one row per timepoint per subject)
208            // Use the first support point predictions to get time, outeq, block, and obs info
209            if let Some(first_spp_preds) = predictions.first() {
210                for (j, p) in first_spp_preds.iter().enumerate() {
211                    let row = NPPredictionRow {
212                        id: subject.id().clone(),
213                        time: p.time(),
214                        outeq: p.outeq(),
215                        block: p.occasion(),
216                        obs: p.observation(),
217                        pop_mean: pop_mean[j],
218                        pop_median: pop_median[j],
219                        post_mean: posterior_mean[j],
220                        post_median: posterior_median[j],
221                    };
222                    container.add(row);
223                }
224            }
225        }
226
227        Ok(container)
228    }
229}