pmcore/algorithms/
mod.rs

1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26    NPAG,
27    NPOD,
28    POSTPROB,
29}
30
31pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
32    fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33    where
34        Self: Sized;
35    fn validate_psi(&mut self) -> Result<()> {
36        // Count problematic values in psi
37        let mut nan_count = 0;
38        let mut inf_count = 0;
39
40        let psi = self.psi().matrix().as_ref().into_ndarray();
41        // First coerce all NaN and infinite in psi to 0.0
42        for i in 0..psi.nrows() {
43            for j in 0..self.psi().matrix().ncols() {
44                let val = psi.get((i, j)).unwrap();
45                if val.is_nan() {
46                    nan_count += 1;
47                    // *val = 0.0;
48                } else if val.is_infinite() {
49                    inf_count += 1;
50                    // *val = 0.0;
51                }
52            }
53        }
54
55        if nan_count + inf_count > 0 {
56            tracing::warn!(
57                "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58                nan_count,
59                inf_count,
60                psi.ncols() * psi.nrows()
61            );
62        }
63
64        let (_, col) = psi.dim();
65        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66        let plam = psi.dot(&ecol);
67        let w = 1. / &plam;
68
69        // Get the index of each element in `w` that is NaN or infinite
70        let indices: Vec<usize> = w
71            .iter()
72            .enumerate()
73            .filter(|(_, x)| x.is_nan() || x.is_infinite())
74            .map(|(i, _)| i)
75            .collect::<Vec<_>>();
76
77        if !indices.is_empty() {
78            let subject: Vec<&Subject> = self.data().subjects();
79            let zero_probability_subjects: Vec<&String> =
80                indices.iter().map(|&i| subject[i].id()).collect();
81
82            tracing::error!(
83                "{}/{} subjects have zero probability given the model",
84                indices.len(),
85                psi.nrows()
86            );
87
88            // For each problematic subject
89            for index in &indices {
90                tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92                let error_model = self.settings().errormodels().clone();
93
94                // Simulate all support points in parallel
95                let spp_results: Vec<_> = self
96                    .theta()
97                    .matrix()
98                    .row_iter()
99                    .enumerate()
100                    .collect::<Vec<_>>()
101                    .into_par_iter()
102                    .map(|(i, spp)| {
103                        let support_point: Vec<f64> = spp.iter().copied().collect();
104                        let (pred, ll) = self
105                            .equation()
106                            .simulate_subject(subject[*index], &support_point, Some(&error_model))
107                            .unwrap(); //TODO: Handle error
108                        (i, support_point, pred.get_predictions(), ll)
109                    })
110                    .collect();
111
112                // Count problematic likelihoods for this subject
113                let mut nan_ll = 0;
114                let mut inf_pos_ll = 0;
115                let mut inf_neg_ll = 0;
116                let mut zero_ll = 0;
117                let mut valid_ll = 0;
118
119                for (_, _, _, ll) in &spp_results {
120                    match ll {
121                        Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123                            inf_pos_ll += 1
124                        }
125                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126                            inf_neg_ll += 1
127                        }
128                        Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129                        Some(_) => valid_ll += 1,
130                        None => nan_ll += 1,
131                    }
132                }
133
134                tracing::debug!(
135                    "\tLikelihood analysis for subject {} ({} support points):",
136                    subject[*index].id(),
137                    spp_results.len()
138                );
139                tracing::debug!(
140                    "\tNaN likelihoods: {} ({:.1}%)",
141                    nan_ll,
142                    100.0 * nan_ll as f64 / spp_results.len() as f64
143                );
144                tracing::debug!(
145                    "\t+Inf likelihoods: {} ({:.1}%)",
146                    inf_pos_ll,
147                    100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148                );
149                tracing::debug!(
150                    "\t-Inf likelihoods: {} ({:.1}%)",
151                    inf_neg_ll,
152                    100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153                );
154                tracing::debug!(
155                    "\tZero likelihoods: {} ({:.1}%)",
156                    zero_ll,
157                    100.0 * zero_ll as f64 / spp_results.len() as f64
158                );
159                tracing::debug!(
160                    "\tValid likelihoods: {} ({:.1}%)",
161                    valid_ll,
162                    100.0 * valid_ll as f64 / spp_results.len() as f64
163                );
164
165                // Sort and show top 10 most likely support points
166                let mut sorted_results = spp_results;
167                sorted_results.sort_by(|a, b| {
168                    b.3.unwrap_or(f64::NEG_INFINITY)
169                        .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170                        .unwrap_or(std::cmp::Ordering::Equal)
171                });
172                let take = 3;
173
174                tracing::debug!("Top {} most likely support points:", take);
175                for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176                    tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177                    tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179                    let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180                    let observations = preds
181                        .iter()
182                        .map(|x| x.observation())
183                        .collect::<Vec<Option<f64>>>();
184                    let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
185                    let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
186                    let states = preds
187                        .iter()
188                        .map(|x| x.state().clone())
189                        .collect::<Vec<Vec<f64>>>();
190
191                    tracing::debug!("\t\tTimes: {:?}", times);
192                    tracing::debug!("\t\tObservations: {:?}", observations);
193                    tracing::debug!("\t\tPredictions: {:?}", predictions);
194                    tracing::debug!("\t\tOuteqs: {:?}", outeqs);
195                    tracing::debug!("\t\tStates: {:?}", states);
196                }
197                tracing::debug!("=====================");
198            }
199
200            return Err(anyhow::anyhow!(
201                    "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
202                    indices.len(),
203                    self.psi().matrix().nrows(),
204                    zero_probability_subjects
205                ));
206        }
207
208        Ok(())
209    }
210
211    fn settings(&self) -> &Settings;
212    /// Get the equation used in the algorithm
213    fn equation(&self) -> &E;
214    /// Get the data used in the algorithm
215    fn data(&self) -> &Data;
216    fn get_prior(&self) -> Theta;
217    /// Increment the cycle counter and return the new value
218    fn increment_cycle(&mut self) -> usize;
219    /// Get the current cycle number
220    fn cycle(&self) -> usize;
221    /// Set the current [Theta]
222    fn set_theta(&mut self, theta: Theta);
223    /// Get the current [Theta]
224    fn theta(&self) -> &Theta;
225    /// Get the current [Psi]
226    fn psi(&self) -> &Psi;
227    /// Get the current likelihood
228    fn likelihood(&self) -> f64;
229    /// Get the current negative two log-likelihood
230    fn n2ll(&self) -> f64 {
231        -2.0 * self.likelihood()
232    }
233    /// Get the current [Status] of the algorithm
234    fn status(&self) -> &Status;
235    /// Set the current [Status] of the algorithm
236    fn set_status(&mut self, status: Status);
237    /// Evaluate convergence criteria and update status
238    fn evaluation(&mut self) -> Result<Status>;
239
240    /// Create and log a cycle state with the current algorithm state
241    fn log_cycle_state(&mut self);
242
243    /// Initialize the algorithm, setting up initial [Theta] and [Status]
244    fn initialize(&mut self) -> Result<()> {
245        // If a stop file exists in the current directory, remove it
246        if Path::new("stop").exists() {
247            tracing::info!("Removing existing stop file prior to run");
248            fs::remove_file("stop").context("Unable to remove previous stop file")?;
249        }
250        self.set_status(Status::Continue);
251        self.set_theta(self.get_prior());
252        Ok(())
253    }
254    fn estimation(&mut self) -> Result<()>;
255    /// Performs condensation of [Theta] and updates [Psi]
256    ///
257    /// This step reduces the number of support points in [Theta] based on the current weights,
258    /// and updates the [Psi] matrix accordingly to reflect the new set of support points.
259    /// It is typically performed after the estimation step in each cycle of the algorithm.
260    fn condensation(&mut self) -> Result<()>;
261
262    /// Performs optimizations on the current [ErrorModels] and updates [Psi] accordingly
263    ///
264    /// This step refines the error model parameters to better fit the data,
265    /// and subsequently updates the [Psi] matrix to reflect these changes.
266    fn optimizations(&mut self) -> Result<()>;
267
268    /// Performs expansion of [Theta]
269    ///
270    /// This step increases the number of support points in [Theta] based on the current distribution,
271    /// allowing for exploration of the parameter space.
272    fn expansion(&mut self) -> Result<()>;
273
274    /// Proceed to the next cycle of the algorithm
275    ///
276    /// This method increments the cycle counter, performs expansion if necessary,
277    /// and then runs the estimation, condensation, optimization, logging, and evaluation steps
278    /// in sequence. It returns the current [Status] of the algorithm after completing these steps.
279    fn next_cycle(&mut self) -> Result<Status> {
280        let cycle = self.increment_cycle();
281
282        if cycle > 1 {
283            self.expansion()?;
284        }
285
286        let span = tracing::info_span!("", "{}", format!("Cycle {}", self.cycle()));
287        let _enter = span.enter();
288        self.estimation()?;
289        self.condensation()?;
290        self.optimizations()?;
291        self.evaluation()
292    }
293
294    /// Fit the model until convergence or stopping criteria are met
295    ///
296    /// This method runs the full fitting process, starting with initialization,
297    /// followed by iterative cycles of estimation, condensation, optimization, and evaluation
298    /// until the algorithm converges or meets a stopping criteria.
299    fn fit(&mut self) -> Result<NPResult<E>> {
300        self.initialize().unwrap();
301        loop {
302            match self.next_cycle()? {
303                Status::Continue => continue,
304                Status::Stop(_) => break,
305            }
306        }
307        Ok(self.into_npresult())
308    }
309
310    #[allow(clippy::wrong_self_convention)]
311    fn into_npresult(&self) -> NPResult<E>;
312}
313
314pub fn dispatch_algorithm<E: Equation + Send + 'static>(
315    settings: Settings,
316    equation: E,
317    data: Data,
318) -> Result<Box<dyn Algorithms<E>>> {
319    match settings.config().algorithm {
320        Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
321        Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
322        Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
323    }
324}
325
326/// Represents the status/result of the algorithm
327#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
328pub enum Status {
329    Continue,
330    Stop(StopReason),
331}
332
333impl std::fmt::Display for Status {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        match self {
336            Status::Continue => write!(f, "Continue"),
337            Status::Stop(s) => write!(f, "Stop: {:?}", s),
338        }
339    }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
343
344pub enum StopReason {
345    Converged,
346    MaxCycles,
347    Stopped,
348    Completed,
349}