1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26 NPAG,
27 NPOD,
28 POSTPROB,
29}
30
31pub trait Algorithms<E: Equation>: Sync {
32 fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33 where
34 Self: Sized;
35 fn validate_psi(&mut self) -> Result<()> {
36 let mut nan_count = 0;
38 let mut inf_count = 0;
39
40 let psi = self.psi().matrix().as_ref().into_ndarray();
41 for i in 0..psi.nrows() {
43 for j in 0..self.psi().matrix().ncols() {
44 let val = psi.get((i, j)).unwrap();
45 if val.is_nan() {
46 nan_count += 1;
47 } else if val.is_infinite() {
49 inf_count += 1;
50 }
52 }
53 }
54
55 if nan_count + inf_count > 0 {
56 tracing::warn!(
57 "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58 nan_count,
59 inf_count,
60 psi.ncols() * psi.nrows()
61 );
62 }
63
64 let (_, col) = psi.dim();
65 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66 let plam = psi.dot(&ecol);
67 let w = 1. / &plam;
68
69 let indices: Vec<usize> = w
71 .iter()
72 .enumerate()
73 .filter(|(_, x)| x.is_nan() || x.is_infinite())
74 .map(|(i, _)| i)
75 .collect::<Vec<_>>();
76
77 if !indices.is_empty() {
78 let subject: Vec<&Subject> = self.get_data().subjects();
79 let zero_probability_subjects: Vec<&String> =
80 indices.iter().map(|&i| subject[i].id()).collect();
81
82 tracing::error!(
83 "{}/{} subjects have zero probability given the model",
84 indices.len(),
85 psi.nrows()
86 );
87
88 for index in &indices {
90 tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92 let error_model = self.get_settings().errormodels().clone();
93
94 let spp_results: Vec<_> = self
96 .theta()
97 .matrix()
98 .row_iter()
99 .enumerate()
100 .collect::<Vec<_>>()
101 .into_par_iter()
102 .map(|(i, spp)| {
103 let support_point: Vec<f64> = spp.iter().copied().collect();
104 let (pred, ll) = self
105 .equation()
106 .simulate_subject(subject[*index], &support_point, Some(&error_model))
107 .unwrap(); (i, support_point, pred.get_predictions(), ll)
109 })
110 .collect();
111
112 let mut nan_ll = 0;
114 let mut inf_pos_ll = 0;
115 let mut inf_neg_ll = 0;
116 let mut zero_ll = 0;
117 let mut valid_ll = 0;
118
119 for (_, _, _, ll) in &spp_results {
120 match ll {
121 Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123 inf_pos_ll += 1
124 }
125 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126 inf_neg_ll += 1
127 }
128 Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129 Some(_) => valid_ll += 1,
130 None => nan_ll += 1,
131 }
132 }
133
134 tracing::debug!(
135 "\tLikelihood analysis for subject {} ({} support points):",
136 subject[*index].id(),
137 spp_results.len()
138 );
139 tracing::debug!(
140 "\tNaN likelihoods: {} ({:.1}%)",
141 nan_ll,
142 100.0 * nan_ll as f64 / spp_results.len() as f64
143 );
144 tracing::debug!(
145 "\t+Inf likelihoods: {} ({:.1}%)",
146 inf_pos_ll,
147 100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148 );
149 tracing::debug!(
150 "\t-Inf likelihoods: {} ({:.1}%)",
151 inf_neg_ll,
152 100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153 );
154 tracing::debug!(
155 "\tZero likelihoods: {} ({:.1}%)",
156 zero_ll,
157 100.0 * zero_ll as f64 / spp_results.len() as f64
158 );
159 tracing::debug!(
160 "\tValid likelihoods: {} ({:.1}%)",
161 valid_ll,
162 100.0 * valid_ll as f64 / spp_results.len() as f64
163 );
164
165 let mut sorted_results = spp_results;
167 sorted_results.sort_by(|a, b| {
168 b.3.unwrap_or(f64::NEG_INFINITY)
169 .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172 let take = 3;
173
174 tracing::debug!("Top {} most likely support points:", take);
175 for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176 tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177 tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179 let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180 let observations = preds
181 .iter()
182 .map(|x| x.observation())
183 .collect::<Vec<Option<f64>>>();
184 let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
185 let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
186 let states = preds
187 .iter()
188 .map(|x| x.state().clone())
189 .collect::<Vec<Vec<f64>>>();
190
191 tracing::debug!("\t\tTimes: {:?}", times);
192 tracing::debug!("\t\tObservations: {:?}", observations);
193 tracing::debug!("\t\tPredictions: {:?}", predictions);
194 tracing::debug!("\t\tOuteqs: {:?}", outeqs);
195 tracing::debug!("\t\tStates: {:?}", states);
196 }
197 tracing::debug!("=====================");
198 }
199
200 return Err(anyhow::anyhow!(
201 "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
202 indices.len(),
203 self.psi().matrix().nrows(),
204 zero_probability_subjects
205 ));
206 }
207
208 Ok(())
209 }
210 fn get_settings(&self) -> &Settings;
211 fn equation(&self) -> &E;
212 fn get_data(&self) -> &Data;
213 fn get_prior(&self) -> Theta;
214 fn inc_cycle(&mut self) -> usize;
215 fn get_cycle(&self) -> usize;
216 fn set_theta(&mut self, theta: Theta);
217 fn theta(&self) -> Θ
218 fn psi(&self) -> Ψ
219 fn likelihood(&self) -> f64;
220 fn n2ll(&self) -> f64 {
221 -2.0 * self.likelihood()
222 }
223 fn status(&self) -> &Status;
224 fn set_status(&mut self, status: Status);
225 fn convergence_evaluation(&mut self);
226 fn converged(&self) -> bool;
227 fn initialize(&mut self) -> Result<()> {
228 if Path::new("stop").exists() {
230 tracing::info!("Removing existing stop file prior to run");
231 fs::remove_file("stop").context("Unable to remove previous stop file")?;
232 }
233 self.set_status(Status::InProgress);
234 self.set_theta(self.get_prior());
235 Ok(())
236 }
237 fn evaluation(&mut self) -> Result<()>;
238 fn condensation(&mut self) -> Result<()>;
239 fn optimizations(&mut self) -> Result<()>;
240 fn logs(&self);
241 fn expansion(&mut self) -> Result<()>;
242 fn next_cycle(&mut self) -> Result<bool> {
243 if self.inc_cycle() > 1 {
244 self.expansion()?;
245 }
246 let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));
247 let _enter = span.enter();
248 self.evaluation()?;
249 self.condensation()?;
250 self.optimizations()?;
251 self.logs();
252 self.convergence_evaluation();
253 Ok(self.converged())
254 }
255 fn fit(&mut self) -> Result<NPResult<E>> {
256 self.initialize().unwrap();
257 while !self.next_cycle()? {}
258 Ok(self.into_npresult())
259 }
260
261 #[allow(clippy::wrong_self_convention)]
262 fn into_npresult(&self) -> NPResult<E>;
263}
264
265pub fn dispatch_algorithm<E: Equation>(
266 settings: Settings,
267 equation: E,
268 data: Data,
269) -> Result<Box<dyn Algorithms<E>>> {
270 match settings.config().algorithm {
271 Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
272 Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
273 Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279pub enum Status {
280 Starting,
282 Converged,
284 MaxCycles,
286 InProgress,
288 ManualStop,
290 Other(String),
292}
293
294impl std::fmt::Display for Status {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 match self {
297 Status::Starting => write!(f, "Starting"),
298 Status::Converged => write!(f, "Converged"),
299 Status::MaxCycles => write!(f, "Maximum cycles reached"),
300 Status::InProgress => write!(f, "In progress"),
301 Status::ManualStop => write!(f, "Manual stop requested"),
302 Status::Other(msg) => write!(f, "{}", msg),
303 }
304 }
305}