pmcore/algorithms/
mod.rs

1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26    NPAG,
27    NPOD,
28    POSTPROB,
29}
30
31pub trait Algorithms<E: Equation>: Sync {
32    fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33    where
34        Self: Sized;
35    fn validate_psi(&mut self) -> Result<()> {
36        // Count problematic values in psi
37        let mut nan_count = 0;
38        let mut inf_count = 0;
39
40        let psi = self.psi().matrix().as_ref().into_ndarray();
41        // First coerce all NaN and infinite in psi to 0.0
42        for i in 0..psi.nrows() {
43            for j in 0..self.psi().matrix().ncols() {
44                let val = psi.get((i, j)).unwrap();
45                if val.is_nan() {
46                    nan_count += 1;
47                    // *val = 0.0;
48                } else if val.is_infinite() {
49                    inf_count += 1;
50                    // *val = 0.0;
51                }
52            }
53        }
54
55        if nan_count + inf_count > 0 {
56            tracing::warn!(
57                "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58                nan_count,
59                inf_count,
60                psi.ncols() * psi.nrows()
61            );
62        }
63
64        let (_, col) = psi.dim();
65        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66        let plam = psi.dot(&ecol);
67        let w = 1. / &plam;
68
69        // Get the index of each element in `w` that is NaN or infinite
70        let indices: Vec<usize> = w
71            .iter()
72            .enumerate()
73            .filter(|(_, x)| x.is_nan() || x.is_infinite())
74            .map(|(i, _)| i)
75            .collect::<Vec<_>>();
76
77        if !indices.is_empty() {
78            let subject: Vec<&Subject> = self.get_data().get_subjects();
79            let zero_probability_subjects: Vec<&String> =
80                indices.iter().map(|&i| subject[i].id()).collect();
81
82            tracing::error!(
83                "{}/{} subjects have zero probability given the model",
84                indices.len(),
85                psi.nrows()
86            );
87
88            // For each problematic subject
89            for index in &indices {
90                tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92                let error_model = self.get_settings().error.clone().into();
93
94                // Simulate all support points in parallel
95                let spp_results: Vec<_> = self
96                    .get_theta()
97                    .matrix()
98                    .row_iter()
99                    .enumerate()
100                    .collect::<Vec<_>>()
101                    .into_par_iter()
102                    .map(|(i, spp)| {
103                        let support_point: Vec<f64> = spp.iter().copied().collect();
104                        let (pred, ll) = self
105                            .equation()
106                            .simulate_subject(subject[*index], &support_point, Some(&error_model))
107                            .unwrap(); //TODO: Handle error
108                        (i, support_point, pred.get_predictions(), ll)
109                    })
110                    .collect();
111
112                // Count problematic likelihoods for this subject
113                let mut nan_ll = 0;
114                let mut inf_pos_ll = 0;
115                let mut inf_neg_ll = 0;
116                let mut zero_ll = 0;
117                let mut valid_ll = 0;
118
119                for (_, _, _, ll) in &spp_results {
120                    match ll {
121                        Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123                            inf_pos_ll += 1
124                        }
125                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126                            inf_neg_ll += 1
127                        }
128                        Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129                        Some(_) => valid_ll += 1,
130                        None => nan_ll += 1,
131                    }
132                }
133
134                tracing::debug!(
135                    "\tLikelihood analysis for subject {} ({} support points):",
136                    subject[*index].id(),
137                    spp_results.len()
138                );
139                tracing::debug!(
140                    "\tNaN likelihoods: {} ({:.1}%)",
141                    nan_ll,
142                    100.0 * nan_ll as f64 / spp_results.len() as f64
143                );
144                tracing::debug!(
145                    "\t+Inf likelihoods: {} ({:.1}%)",
146                    inf_pos_ll,
147                    100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148                );
149                tracing::debug!(
150                    "\t-Inf likelihoods: {} ({:.1}%)",
151                    inf_neg_ll,
152                    100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153                );
154                tracing::debug!(
155                    "\tZero likelihoods: {} ({:.1}%)",
156                    zero_ll,
157                    100.0 * zero_ll as f64 / spp_results.len() as f64
158                );
159                tracing::debug!(
160                    "\tValid likelihoods: {} ({:.1}%)",
161                    valid_ll,
162                    100.0 * valid_ll as f64 / spp_results.len() as f64
163                );
164
165                // Sort and show top 10 most likely support points
166                let mut sorted_results = spp_results;
167                sorted_results.sort_by(|a, b| {
168                    b.3.unwrap_or(f64::NEG_INFINITY)
169                        .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170                        .unwrap_or(std::cmp::Ordering::Equal)
171                });
172                let take = 3;
173
174                tracing::debug!("Top {} most likely support points:", take);
175                for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176                    tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177                    tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179                    let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180                    let observations = preds.iter().map(|x| x.observation()).collect::<Vec<f64>>();
181                    let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
182                    let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
183                    let states = preds
184                        .iter()
185                        .map(|x| x.state().clone())
186                        .collect::<Vec<Vec<f64>>>();
187
188                    tracing::debug!("\t\tTimes: {:?}", times);
189                    tracing::debug!("\t\tObservations: {:?}", observations);
190                    tracing::debug!("\t\tPredictions: {:?}", predictions);
191                    tracing::debug!("\t\tOuteqs: {:?}", outeqs);
192                    tracing::debug!("\t\tStates: {:?}", states);
193                }
194                tracing::debug!("=====================");
195            }
196
197            return Err(anyhow::anyhow!(
198                    "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
199                    indices.len(),
200                    self.psi().matrix().nrows(),
201                    zero_probability_subjects
202                ));
203        }
204
205        Ok(())
206    }
207    fn get_settings(&self) -> &Settings;
208    fn equation(&self) -> &E;
209    fn get_data(&self) -> &Data;
210    fn get_prior(&self) -> Theta;
211    fn inc_cycle(&mut self) -> usize;
212    fn get_cycle(&self) -> usize;
213    fn set_theta(&mut self, theta: Theta);
214    fn get_theta(&self) -> &Theta;
215    fn psi(&self) -> &Psi;
216    fn likelihood(&self) -> f64;
217    fn n2ll(&self) -> f64 {
218        -2.0 * self.likelihood()
219    }
220    fn convergence_evaluation(&mut self);
221    fn converged(&self) -> bool;
222    fn initialize(&mut self) -> Result<()> {
223        // If a stop file exists in the current directory, remove it
224        if Path::new("stop").exists() {
225            tracing::info!("Removing existing stop file prior to run");
226            fs::remove_file("stop").context("Unable to remove previous stop file")?;
227        }
228        self.set_theta(self.get_prior());
229        Ok(())
230    }
231    fn evaluation(&mut self) -> Result<()>;
232    fn condensation(&mut self) -> Result<()>;
233    fn optimizations(&mut self) -> Result<()>;
234    fn logs(&self);
235    fn expansion(&mut self) -> Result<()>;
236    fn next_cycle(&mut self) -> Result<bool> {
237        if self.inc_cycle() > 1 {
238            self.expansion()?;
239        }
240        let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));
241        let _enter = span.enter();
242        self.evaluation()?;
243        self.condensation()?;
244        self.optimizations()?;
245        self.logs();
246        self.convergence_evaluation();
247        Ok(self.converged())
248    }
249    fn fit(&mut self) -> Result<NPResult<E>> {
250        self.initialize().unwrap();
251        while !self.next_cycle()? {}
252        Ok(self.into_npresult())
253    }
254
255    #[allow(clippy::wrong_self_convention)]
256    fn into_npresult(&self) -> NPResult<E>;
257}
258
259pub fn dispatch_algorithm<E: Equation>(
260    settings: Settings,
261    equation: E,
262    data: Data,
263) -> Result<Box<dyn Algorithms<E>>> {
264    match settings.config().algorithm {
265        Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
266        Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
267        Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
268    }
269}