1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26 NPAG,
27 NPOD,
28 POSTPROB,
29}
30
31pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
32 fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33 where
34 Self: Sized;
35 fn validate_psi(&mut self) -> Result<()> {
36 let mut nan_count = 0;
38 let mut inf_count = 0;
39
40 let psi = self.psi().matrix().as_ref().into_ndarray();
41 for i in 0..psi.nrows() {
43 for j in 0..self.psi().matrix().ncols() {
44 let val = psi.get((i, j)).unwrap();
45 if val.is_nan() {
46 nan_count += 1;
47 } else if val.is_infinite() {
49 inf_count += 1;
50 }
52 }
53 }
54
55 if nan_count + inf_count > 0 {
56 tracing::warn!(
57 "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58 nan_count,
59 inf_count,
60 psi.ncols() * psi.nrows()
61 );
62 }
63
64 let (_, col) = psi.dim();
65 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66 let plam = psi.dot(&ecol);
67 let w = 1. / &plam;
68
69 let indices: Vec<usize> = w
71 .iter()
72 .enumerate()
73 .filter(|(_, x)| x.is_nan() || x.is_infinite())
74 .map(|(i, _)| i)
75 .collect::<Vec<_>>();
76
77 if !indices.is_empty() {
78 let subject: Vec<&Subject> = self.data().subjects();
79 let zero_probability_subjects: Vec<&String> =
80 indices.iter().map(|&i| subject[i].id()).collect();
81
82 tracing::error!(
83 "{}/{} subjects have zero probability given the model",
84 indices.len(),
85 psi.nrows()
86 );
87
88 for index in &indices {
90 tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92 let error_model = self.settings().errormodels().clone();
93
94 let spp_results: Vec<_> = self
96 .theta()
97 .matrix()
98 .row_iter()
99 .enumerate()
100 .collect::<Vec<_>>()
101 .into_par_iter()
102 .map(|(i, spp)| {
103 let support_point: Vec<f64> = spp.iter().copied().collect();
104 let (pred, ll) = self
105 .equation()
106 .simulate_subject(subject[*index], &support_point, Some(&error_model))
107 .unwrap(); (i, support_point, pred.get_predictions(), ll)
109 })
110 .collect();
111
112 let mut nan_ll = 0;
114 let mut inf_pos_ll = 0;
115 let mut inf_neg_ll = 0;
116 let mut zero_ll = 0;
117 let mut valid_ll = 0;
118
119 for (_, _, _, ll) in &spp_results {
120 match ll {
121 Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123 inf_pos_ll += 1
124 }
125 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126 inf_neg_ll += 1
127 }
128 Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129 Some(_) => valid_ll += 1,
130 None => nan_ll += 1,
131 }
132 }
133
134 tracing::debug!(
135 "\tLikelihood analysis for subject {} ({} support points):",
136 subject[*index].id(),
137 spp_results.len()
138 );
139 tracing::debug!(
140 "\tNaN likelihoods: {} ({:.1}%)",
141 nan_ll,
142 100.0 * nan_ll as f64 / spp_results.len() as f64
143 );
144 tracing::debug!(
145 "\t+Inf likelihoods: {} ({:.1}%)",
146 inf_pos_ll,
147 100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148 );
149 tracing::debug!(
150 "\t-Inf likelihoods: {} ({:.1}%)",
151 inf_neg_ll,
152 100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153 );
154 tracing::debug!(
155 "\tZero likelihoods: {} ({:.1}%)",
156 zero_ll,
157 100.0 * zero_ll as f64 / spp_results.len() as f64
158 );
159 tracing::debug!(
160 "\tValid likelihoods: {} ({:.1}%)",
161 valid_ll,
162 100.0 * valid_ll as f64 / spp_results.len() as f64
163 );
164
165 let mut sorted_results = spp_results;
167 sorted_results.sort_by(|a, b| {
168 b.3.unwrap_or(f64::NEG_INFINITY)
169 .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172 let take = 3;
173
174 tracing::debug!("Top {} most likely support points:", take);
175 for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176 tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177 tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179 let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180 let observations = preds
181 .iter()
182 .map(|x| x.observation())
183 .collect::<Vec<Option<f64>>>();
184 let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
185 let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
186 let states = preds
187 .iter()
188 .map(|x| x.state().clone())
189 .collect::<Vec<Vec<f64>>>();
190
191 tracing::debug!("\t\tTimes: {:?}", times);
192 tracing::debug!("\t\tObservations: {:?}", observations);
193 tracing::debug!("\t\tPredictions: {:?}", predictions);
194 tracing::debug!("\t\tOuteqs: {:?}", outeqs);
195 tracing::debug!("\t\tStates: {:?}", states);
196 }
197 tracing::debug!("=====================");
198 }
199
200 return Err(anyhow::anyhow!(
201 "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
202 indices.len(),
203 self.psi().matrix().nrows(),
204 zero_probability_subjects
205 ));
206 }
207
208 Ok(())
209 }
210
211 fn settings(&self) -> &Settings;
212 fn equation(&self) -> &E;
214 fn data(&self) -> &Data;
216 fn get_prior(&self) -> Theta;
217 fn increment_cycle(&mut self) -> usize;
219 fn cycle(&self) -> usize;
221 fn set_theta(&mut self, theta: Theta);
223 fn theta(&self) -> Θ
225 fn psi(&self) -> Ψ
227 fn likelihood(&self) -> f64;
229 fn n2ll(&self) -> f64 {
231 -2.0 * self.likelihood()
232 }
233 fn status(&self) -> &Status;
235 fn set_status(&mut self, status: Status);
237 fn evaluation(&mut self) -> Result<Status>;
239
240 fn log_cycle_state(&mut self);
242
243 fn initialize(&mut self) -> Result<()> {
245 if Path::new("stop").exists() {
247 tracing::info!("Removing existing stop file prior to run");
248 fs::remove_file("stop").context("Unable to remove previous stop file")?;
249 }
250 self.set_status(Status::Continue);
251 self.set_theta(self.get_prior());
252 Ok(())
253 }
254 fn estimation(&mut self) -> Result<()>;
255 fn condensation(&mut self) -> Result<()>;
261
262 fn optimizations(&mut self) -> Result<()>;
267
268 fn expansion(&mut self) -> Result<()>;
273
274 fn next_cycle(&mut self) -> Result<Status> {
280 let cycle = self.increment_cycle();
281
282 if cycle > 1 {
283 self.expansion()?;
284 }
285
286 let span = tracing::info_span!("", "{}", format!("Cycle {}", self.cycle()));
287 let _enter = span.enter();
288 self.estimation()?;
289 self.condensation()?;
290 self.optimizations()?;
291 self.evaluation()
292 }
293
294 fn fit(&mut self) -> Result<NPResult<E>> {
300 self.initialize().unwrap();
301 loop {
302 match self.next_cycle()? {
303 Status::Continue => continue,
304 Status::Stop(_) => break,
305 }
306 }
307 Ok(self.into_npresult())
308 }
309
310 #[allow(clippy::wrong_self_convention)]
311 fn into_npresult(&self) -> NPResult<E>;
312}
313
314pub fn dispatch_algorithm<E: Equation + Send + 'static>(
315 settings: Settings,
316 equation: E,
317 data: Data,
318) -> Result<Box<dyn Algorithms<E>>> {
319 match settings.config().algorithm {
320 Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
321 Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
322 Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
323 }
324}
325
326#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
328pub enum Status {
329 Continue,
330 Stop(StopReason),
331}
332
333impl std::fmt::Display for Status {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 match self {
336 Status::Continue => write!(f, "Continue"),
337 Status::Stop(s) => write!(f, "Stop: {:?}", s),
338 }
339 }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
343
344pub enum StopReason {
345 Converged,
346 MaxCycles,
347 Stopped,
348 Completed,
349}