1use std::fs;
2use std::path::Path;
3
4use crate::routines::output::NPResult;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use anyhow::Context;
9use anyhow::Result;
10use faer_ext::IntoNdarray;
11use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator};
12use ndarray::{Array, ArrayBase, Dim, OwnedRepr};
13use npag::*;
14use npod::NPOD;
15use pharmsol::prelude::{data::Data, simulator::Equation};
16use pharmsol::{Predictions, Subject};
17use postprob::POSTPROB;
18use serde::{Deserialize, Serialize};
19
20pub mod npag;
21pub mod npod;
22pub mod postprob;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum Algorithm {
26 NPAG,
27 NPOD,
28 POSTPROB,
29}
30
31pub trait Algorithms<E: Equation>: Sync {
32 fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>>
33 where
34 Self: Sized;
35 fn validate_psi(&mut self) -> Result<()> {
36 let mut nan_count = 0;
38 let mut inf_count = 0;
39
40 let psi = self.psi().matrix().as_ref().into_ndarray();
41 for i in 0..psi.nrows() {
43 for j in 0..self.psi().matrix().ncols() {
44 let val = psi.get((i, j)).unwrap();
45 if val.is_nan() {
46 nan_count += 1;
47 } else if val.is_infinite() {
49 inf_count += 1;
50 }
52 }
53 }
54
55 if nan_count + inf_count > 0 {
56 tracing::warn!(
57 "Psi matrix contains {} NaN, {} Infinite values of {} total values",
58 nan_count,
59 inf_count,
60 psi.ncols() * psi.nrows()
61 );
62 }
63
64 let (_, col) = psi.dim();
65 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
66 let plam = psi.dot(&ecol);
67 let w = 1. / &plam;
68
69 let indices: Vec<usize> = w
71 .iter()
72 .enumerate()
73 .filter(|(_, x)| x.is_nan() || x.is_infinite())
74 .map(|(i, _)| i)
75 .collect::<Vec<_>>();
76
77 if !indices.is_empty() {
78 let subject: Vec<&Subject> = self.get_data().get_subjects();
79 let zero_probability_subjects: Vec<&String> =
80 indices.iter().map(|&i| subject[i].id()).collect();
81
82 tracing::error!(
83 "{}/{} subjects have zero probability given the model",
84 indices.len(),
85 psi.nrows()
86 );
87
88 for index in &indices {
90 tracing::debug!("Subject with zero probability: {}", subject[*index].id());
91
92 let error_model = self.get_settings().error.clone().into();
93
94 let spp_results: Vec<_> = self
96 .get_theta()
97 .matrix()
98 .row_iter()
99 .enumerate()
100 .collect::<Vec<_>>()
101 .into_par_iter()
102 .map(|(i, spp)| {
103 let support_point: Vec<f64> = spp.iter().copied().collect();
104 let (pred, ll) = self
105 .equation()
106 .simulate_subject(subject[*index], &support_point, Some(&error_model))
107 .unwrap(); (i, support_point, pred.get_predictions(), ll)
109 })
110 .collect();
111
112 let mut nan_ll = 0;
114 let mut inf_pos_ll = 0;
115 let mut inf_neg_ll = 0;
116 let mut zero_ll = 0;
117 let mut valid_ll = 0;
118
119 for (_, _, _, ll) in &spp_results {
120 match ll {
121 Some(ll_val) if ll_val.is_nan() => nan_ll += 1,
122 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => {
123 inf_pos_ll += 1
124 }
125 Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => {
126 inf_neg_ll += 1
127 }
128 Some(ll_val) if *ll_val == 0.0 => zero_ll += 1,
129 Some(_) => valid_ll += 1,
130 None => nan_ll += 1,
131 }
132 }
133
134 tracing::debug!(
135 "\tLikelihood analysis for subject {} ({} support points):",
136 subject[*index].id(),
137 spp_results.len()
138 );
139 tracing::debug!(
140 "\tNaN likelihoods: {} ({:.1}%)",
141 nan_ll,
142 100.0 * nan_ll as f64 / spp_results.len() as f64
143 );
144 tracing::debug!(
145 "\t+Inf likelihoods: {} ({:.1}%)",
146 inf_pos_ll,
147 100.0 * inf_pos_ll as f64 / spp_results.len() as f64
148 );
149 tracing::debug!(
150 "\t-Inf likelihoods: {} ({:.1}%)",
151 inf_neg_ll,
152 100.0 * inf_neg_ll as f64 / spp_results.len() as f64
153 );
154 tracing::debug!(
155 "\tZero likelihoods: {} ({:.1}%)",
156 zero_ll,
157 100.0 * zero_ll as f64 / spp_results.len() as f64
158 );
159 tracing::debug!(
160 "\tValid likelihoods: {} ({:.1}%)",
161 valid_ll,
162 100.0 * valid_ll as f64 / spp_results.len() as f64
163 );
164
165 let mut sorted_results = spp_results;
167 sorted_results.sort_by(|a, b| {
168 b.3.unwrap_or(f64::NEG_INFINITY)
169 .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY))
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172 let take = 3;
173
174 tracing::debug!("Top {} most likely support points:", take);
175 for (i, support_point, preds, ll) in sorted_results.iter().take(take) {
176 tracing::debug!("\tSupport point #{}: {:?}", i, support_point);
177 tracing::debug!("\t\tLog-likelihood: {:?}", ll);
178
179 let times = preds.iter().map(|x| x.time()).collect::<Vec<f64>>();
180 let observations = preds.iter().map(|x| x.observation()).collect::<Vec<f64>>();
181 let predictions = preds.iter().map(|x| x.prediction()).collect::<Vec<f64>>();
182 let outeqs = preds.iter().map(|x| x.outeq()).collect::<Vec<usize>>();
183 let states = preds
184 .iter()
185 .map(|x| x.state().clone())
186 .collect::<Vec<Vec<f64>>>();
187
188 tracing::debug!("\t\tTimes: {:?}", times);
189 tracing::debug!("\t\tObservations: {:?}", observations);
190 tracing::debug!("\t\tPredictions: {:?}", predictions);
191 tracing::debug!("\t\tOuteqs: {:?}", outeqs);
192 tracing::debug!("\t\tStates: {:?}", states);
193 }
194 tracing::debug!("=====================");
195 }
196
197 return Err(anyhow::anyhow!(
198 "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}",
199 indices.len(),
200 self.psi().matrix().nrows(),
201 zero_probability_subjects
202 ));
203 }
204
205 Ok(())
206 }
207 fn get_settings(&self) -> &Settings;
208 fn equation(&self) -> &E;
209 fn get_data(&self) -> &Data;
210 fn get_prior(&self) -> Theta;
211 fn inc_cycle(&mut self) -> usize;
212 fn get_cycle(&self) -> usize;
213 fn set_theta(&mut self, theta: Theta);
214 fn get_theta(&self) -> Θ
215 fn psi(&self) -> Ψ
216 fn likelihood(&self) -> f64;
217 fn n2ll(&self) -> f64 {
218 -2.0 * self.likelihood()
219 }
220 fn convergence_evaluation(&mut self);
221 fn converged(&self) -> bool;
222 fn initialize(&mut self) -> Result<()> {
223 if Path::new("stop").exists() {
225 tracing::info!("Removing existing stop file prior to run");
226 fs::remove_file("stop").context("Unable to remove previous stop file")?;
227 }
228 self.set_theta(self.get_prior());
229 Ok(())
230 }
231 fn evaluation(&mut self) -> Result<()>;
232 fn condensation(&mut self) -> Result<()>;
233 fn optimizations(&mut self) -> Result<()>;
234 fn logs(&self);
235 fn expansion(&mut self) -> Result<()>;
236 fn next_cycle(&mut self) -> Result<bool> {
237 if self.inc_cycle() > 1 {
238 self.expansion()?;
239 }
240 let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));
241 let _enter = span.enter();
242 self.evaluation()?;
243 self.condensation()?;
244 self.optimizations()?;
245 self.logs();
246 self.convergence_evaluation();
247 Ok(self.converged())
248 }
249 fn fit(&mut self) -> Result<NPResult<E>> {
250 self.initialize().unwrap();
251 while !self.next_cycle()? {}
252 Ok(self.into_npresult())
253 }
254
255 #[allow(clippy::wrong_self_convention)]
256 fn into_npresult(&self) -> NPResult<E>;
257}
258
259pub fn dispatch_algorithm<E: Equation>(
260 settings: Settings,
261 equation: E,
262 data: Data,
263) -> Result<Box<dyn Algorithms<E>>> {
264 match settings.config().algorithm {
265 Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
266 Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
267 Algorithm::POSTPROB => Ok(POSTPROB::new(settings, equation, data)?),
268 }
269}