pmcore/algorithms/
mod.rs

1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26    NPAG,
27    NPOD,
28    POSTPROB,
29}
30
31pub trait Algorithms<E: Equation>: Sync {
32    fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33    where
34        Self: Sized;
35    fn validate_psi(&mut self) -> Result<()> {
36        // Count problematic values in psi
37        let mut nan_count = 0;
38        let mut inf_count = 0;
39
40        let psi = self.psi().matrix().as_ref().into_ndarray();
41        // First coerce all NaN and infinite in psi to 0.0
42        for i in 0..psi.nrows() {
43            for j in 0..self.psi().matrix().ncols() {
44                let val = psi.get((i, j)).unwrap();
45                if val.is_nan() {
46                    nan_count += 1;
47                    // *val = 0.0;
48                } else if val.is_infinite() {
49                    inf_count += 1;
50                    // *val = 0.0;
51                }
52            }
53        }
54
55        if nan_count + inf_count > 0 {
56            tracing::warn!(
57                "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58                nan_count,
59                inf_count,
60                psi.ncols() * psi.nrows()
61            );
62        }
63
64        let (_, col) = psi.dim();
65        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66        let plam = psi.dot(&ecol);
67        let w = 1. / &plam;
68
69        // Get the index of each element in `w` that is NaN or infinite
70        let indices: Vec<usize> = w
71            .iter()
72            .enumerate()
73            .filter(|(_, x)| x.is_nan() || x.is_infinite())
74            .map(|(i, _)| i)
75            .collect::<Vec<_>>();
76
77        if !indices.is_empty() {
78            let subject: Vec<&Subject> = self.get_data().subjects();
79            let zero_probability_subjects: Vec<&String> =
80                indices.iter().map(|&i| subject[i].id()).collect();
81
82            tracing::error!(
83                "{}/{} subjects have zero probability given the model",
84                indices.len(),
85                psi.nrows()
86            );
87
88            // For each problematic subject
89            for index in &indices {
90                tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92                let error_model = self.get_settings().errormodels().clone();
93
94                // Simulate all support points in parallel
95                let spp_results: Vec<_> = self
96                    .theta()
97                    .matrix()
98                    .row_iter()
99                    .enumerate()
100                    .collect::<Vec<_>>()
101                    .into_par_iter()
102                    .map(|(i, spp)| {
103                        let support_point: Vec<f64> = spp.iter().copied().collect();
104                        let (pred, ll) = self
105                            .equation()
106                            .simulate_subject(subject[*index], &support_point, Some(&error_model))
107                            .unwrap(); //TODO: Handle error
108                        (i, support_point, pred.get_predictions(), ll)
109                    })
110                    .collect();
111
112                // Count problematic likelihoods for this subject
113                let mut nan_ll = 0;
114                let mut inf_pos_ll = 0;
115                let mut inf_neg_ll = 0;
116                let mut zero_ll = 0;
117                let mut valid_ll = 0;
118
119                for (_, _, _, ll) in &spp_results {
120                    match ll {
121                        Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123                            inf_pos_ll += 1
124                        }
125                        Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126                            inf_neg_ll += 1
127                        }
128                        Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129                        Some(_) => valid_ll += 1,
130                        None => nan_ll += 1,
131                    }
132                }
133
134                tracing::debug!(
135                    "\tLikelihood analysis for subject {} ({} support points):",
136                    subject[*index].id(),
137                    spp_results.len()
138                );
139                tracing::debug!(
140                    "\tNaN likelihoods: {} ({:.1}%)",
141                    nan_ll,
142                    100.0 * nan_ll as f64 / spp_results.len() as f64
143                );
144                tracing::debug!(
145                    "\t+Inf likelihoods: {} ({:.1}%)",
146                    inf_pos_ll,
147                    100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148                );
149                tracing::debug!(
150                    "\t-Inf likelihoods: {} ({:.1}%)",
151                    inf_neg_ll,
152                    100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153                );
154                tracing::debug!(
155                    "\tZero likelihoods: {} ({:.1}%)",
156                    zero_ll,
157                    100.0 * zero_ll as f64 / spp_results.len() as f64
158                );
159                tracing::debug!(
160                    "\tValid likelihoods: {} ({:.1}%)",
161                    valid_ll,
162                    100.0 * valid_ll as f64 / spp_results.len() as f64
163                );
164
165                // Sort and show top 10 most likely support points
166                let mut sorted_results = spp_results;
167                sorted_results.sort_by(|a, b| {
168                    b.3.unwrap_or(f64::NEG_INFINITY)
169                        .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170                        .unwrap_or(std::cmp::Ordering::Equal)
171                });
172                let take = 3;
173
174                tracing::debug!("Top {} most likely support points:", take);
175                for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176                    tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177                    tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179                    let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180                    let observations = preds
181                        .iter()
182                        .map(|x| x.observation())
183                        .collect::<Vec<Option<f64>>>();
184                    let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
185                    let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
186                    let states = preds
187                        .iter()
188                        .map(|x| x.state().clone())
189                        .collect::<Vec<Vec<f64>>>();
190
191                    tracing::debug!("\t\tTimes: {:?}", times);
192                    tracing::debug!("\t\tObservations: {:?}", observations);
193                    tracing::debug!("\t\tPredictions: {:?}", predictions);
194                    tracing::debug!("\t\tOuteqs: {:?}", outeqs);
195                    tracing::debug!("\t\tStates: {:?}", states);
196                }
197                tracing::debug!("=====================");
198            }
199
200            return Err(anyhow::anyhow!(
201                    "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
202                    indices.len(),
203                    self.psi().matrix().nrows(),
204                    zero_probability_subjects
205                ));
206        }
207
208        Ok(())
209    }
210    fn get_settings(&self) -> &Settings;
211    fn equation(&self) -> &E;
212    fn get_data(&self) -> &Data;
213    fn get_prior(&self) -> Theta;
214    fn inc_cycle(&mut self) -> usize;
215    fn get_cycle(&self) -> usize;
216    fn set_theta(&mut self, theta: Theta);
217    fn theta(&self) -> &Theta;
218    fn psi(&self) -> &Psi;
219    fn likelihood(&self) -> f64;
220    fn n2ll(&self) -> f64 {
221        -2.0 * self.likelihood()
222    }
223    fn status(&self) -> &Status;
224    fn set_status(&mut self, status: Status);
225    fn convergence_evaluation(&mut self);
226    fn converged(&self) -> bool;
227    fn initialize(&mut self) -> Result<()> {
228        // If a stop file exists in the current directory, remove it
229        if Path::new("stop").exists() {
230            tracing::info!("Removing existing stop file prior to run");
231            fs::remove_file("stop").context("Unable to remove previous stop file")?;
232        }
233        self.set_status(Status::InProgress);
234        self.set_theta(self.get_prior());
235        Ok(())
236    }
237    fn evaluation(&mut self) -> Result<()>;
238    fn condensation(&mut self) -> Result<()>;
239    fn optimizations(&mut self) -> Result<()>;
240    fn logs(&self);
241    fn expansion(&mut self) -> Result<()>;
242    fn next_cycle(&mut self) -> Result<bool> {
243        if self.inc_cycle() > 1 {
244            self.expansion()?;
245        }
246        let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));
247        let _enter = span.enter();
248        self.evaluation()?;
249        self.condensation()?;
250        self.optimizations()?;
251        self.logs();
252        self.convergence_evaluation();
253        Ok(self.converged())
254    }
255    fn fit(&mut self) -> Result<NPResult<E>> {
256        self.initialize().unwrap();
257        while !self.next_cycle()? {}
258        Ok(self.into_npresult())
259    }
260
261    #[allow(clippy::wrong_self_convention)]
262    fn into_npresult(&self) -> NPResult<E>;
263}
264
265pub fn dispatch_algorithm<E: Equation>(
266    settings: Settings,
267    equation: E,
268    data: Data,
269) -> Result<Box<dyn Algorithms<E>>> {
270    match settings.config().algorithm {
271        Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
272        Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
273        Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
274    }
275}
276
277/// Represents the status of the algorithm
278#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279pub enum Status {
280    /// Algorithm is starting up
281    Starting,
282    /// Algorithm has converged to a solution
283    Converged,
284    /// Algorithm stopped due to reaching maximum cycles
285    MaxCycles,
286    /// Algorithm is currently running
287    InProgress,
288    /// Algorithm was manually stopped by user
289    ManualStop,
290    /// Other status with custom message
291    Other(String),
292}
293
294impl std::fmt::Display for Status {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        match self {
297            Status::Starting => write!(f, "Starting"),
298            Status::Converged => write!(f, "Converged"),
299            Status::MaxCycles => write!(f, "Maximum cycles reached"),
300            Status::InProgress => write!(f, "In progress"),
301            Status::ManualStop => write!(f, "Manual stop requested"),
302            Status::Other(msg) => write!(f, "{}", msg),
303        }
304    }
305}