1use crate::{
2 algorithms::Status,
3 prelude::{
4 algorithms::Algorithms,
5 routines::{
6 evaluation::{ipm::burke, qr},
7 output::{CycleLog, NPCycle, NPResult},
8 settings::Settings,
9 },
10 },
11 structs::{
12 psi::{calculate_psi, Psi},
13 theta::Theta,
14 },
15};
16use anyhow::bail;
17use anyhow::Result;
18use faer::Col;
19use faer_ext::IntoNdarray;
20use pharmsol::{prelude::ErrorModel, ErrorModels};
21use pharmsol::{
22 prelude::{data::Data, simulator::Equation},
23 Subject,
24};
25
26use ndarray::{
27 parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
28 Array, Array1, ArrayBase, Dim, OwnedRepr,
29};
30
31use crate::routines::{initialization, optimization::SppOptimizer};
32
33const THETA_F: f64 = 1e-2;
34const THETA_D: f64 = 1e-4;
35
36pub struct NPOD<E: Equation> {
37 equation: E,
38 psi: Psi,
39 theta: Theta,
40 lambda: Col<f64>,
41 w: Col<f64>,
42 last_objf: f64,
43 objf: f64,
44 cycle: usize,
45 gamma_delta: Vec<f64>,
46 error_models: ErrorModels,
47 converged: bool,
48 status: Status,
49 cycle_log: CycleLog,
50 data: Data,
51 settings: Settings,
52}
53
54impl<E: Equation> Algorithms<E> for NPOD<E> {
55 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
56 Ok(Box::new(Self {
57 equation,
58 psi: Psi::new(),
59 theta: Theta::new(),
60 lambda: Col::zeros(0),
61 w: Col::zeros(0),
62 last_objf: -1e30,
63 objf: f64::NEG_INFINITY,
64 cycle: 0,
65 gamma_delta: vec![0.1; settings.errormodels().len()],
66 error_models: settings.errormodels().clone().into(),
67 converged: false,
68 status: Status::Starting,
69 cycle_log: CycleLog::new(),
70 settings,
71 data,
72 }))
73 }
74 fn into_npresult(&self) -> NPResult<E> {
75 NPResult::new(
76 self.equation.clone(),
77 self.data.clone(),
78 self.theta.clone(),
79 self.psi.clone(),
80 self.w.clone(),
81 -2. * self.objf,
82 self.cycle,
83 self.status.clone(),
84 self.settings.clone(),
85 self.cycle_log.clone(),
86 )
87 }
88
89 fn equation(&self) -> &E {
90 &self.equation
91 }
92
93 fn get_settings(&self) -> &Settings {
94 &self.settings
95 }
96
97 fn get_data(&self) -> &Data {
98 &self.data
99 }
100
101 fn get_prior(&self) -> Theta {
102 initialization::sample_space(&self.settings).unwrap()
103 }
104
105 fn inc_cycle(&mut self) -> usize {
106 self.cycle += 1;
107 self.cycle
108 }
109
110 fn get_cycle(&self) -> usize {
111 self.cycle
112 }
113
114 fn set_theta(&mut self, theta: Theta) {
115 self.theta = theta;
116 }
117
118 fn theta(&self) -> &Theta {
119 &self.theta
120 }
121
122 fn psi(&self) -> &Psi {
123 &self.psi
124 }
125
126 fn likelihood(&self) -> f64 {
127 self.objf
128 }
129
130 fn set_status(&mut self, status: Status) {
131 self.status = status;
132 }
133
134 fn status(&self) -> &Status {
135 &self.status
136 }
137
138 fn convergence_evaluation(&mut self) {
139 if (self.last_objf - self.objf).abs() <= THETA_F {
140 tracing::info!("Objective function convergence reached");
141 self.converged = true;
142 self.status = Status::Converged;
143 }
144
145 if self.cycle >= self.settings.config().cycles {
147 tracing::warn!("Maximum number of cycles reached");
148 self.converged = true;
149 self.status = Status::MaxCycles;
150 }
151
152 if std::path::Path::new("stop").exists() {
154 tracing::warn!("Stopfile detected - breaking");
155 self.converged = true;
156 self.status = Status::ManualStop;
157 }
158
159 let state = NPCycle {
161 cycle: self.cycle,
162 objf: -2. * self.objf,
163 delta_objf: (self.last_objf - self.objf).abs(),
164 nspp: self.theta.nspp(),
165 theta: self.theta.clone(),
166 error_models: self.error_models.clone(),
167 status: self.status.clone(),
168 };
169
170 self.cycle_log.push(state);
172 self.last_objf = self.objf;
173 }
174
175 fn converged(&self) -> bool {
176 self.converged
177 }
178
179 fn evaluation(&mut self) -> Result<()> {
180 let error_model: ErrorModels = self.error_models.clone();
181
182 self.psi = calculate_psi(
183 &self.equation,
184 &self.data,
185 &self.theta,
186 &error_model,
187 self.cycle == 1 && self.settings.config().progress,
188 self.cycle != 1,
189 )?;
190
191 if let Err(err) = self.validate_psi() {
192 bail!(err);
193 }
194
195 (self.lambda, _) = match burke(&self.psi) {
196 Ok((lambda, objf)) => (lambda, objf),
197 Err(err) => {
198 bail!(err);
199 }
200 };
201 Ok(())
202 }
203
204 fn condensation(&mut self) -> Result<()> {
205 let max_lambda = self
206 .lambda
207 .iter()
208 .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
209
210 let mut keep = Vec::<usize>::new();
211 for (index, lam) in self.lambda.iter().enumerate() {
212 if *lam > max_lambda / 1000_f64 {
213 keep.push(index);
214 }
215 }
216 if self.psi.matrix().ncols() != keep.len() {
217 tracing::debug!(
218 "Lambda (max/1000) dropped {} support point(s)",
219 self.psi.matrix().ncols() - keep.len(),
220 );
221 }
222
223 self.theta.filter_indices(keep.as_slice());
224 self.psi.filter_column_indices(keep.as_slice());
225
226 let (r, perm) = qr::qrd(&self.psi)?;
228
229 let mut keep = Vec::<usize>::new();
230
231 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
233 for i in 0..keep_n {
234 let test = r.col(i).norm_l2();
235 let r_diag_val = r.get(i, i);
236 let ratio = r_diag_val / test;
237 if ratio.abs() >= 1e-8 {
238 keep.push(*perm.get(i).unwrap());
239 }
240 }
241
242 if self.psi.matrix().ncols() != keep.len() {
244 tracing::debug!(
245 "QR decomposition dropped {} support point(s)",
246 self.psi.matrix().ncols() - keep.len(),
247 );
248 }
249
250 self.theta.filter_indices(keep.as_slice());
251 self.psi.filter_column_indices(keep.as_slice());
252
253 (self.lambda, self.objf) = match burke(&self.psi) {
254 Ok((lambda, objf)) => (lambda, objf),
255 Err(err) => {
256 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
257 }
258 };
259 self.w = self.lambda.clone();
260 Ok(())
261 }
262
263 fn optimizations(&mut self) -> Result<()> {
264 self.error_models
265 .clone()
266 .iter_mut()
267 .filter_map(|(outeq, em)| match em {
268 ErrorModel::None => None,
269 _ => Some((outeq, em)),
270 })
271 .try_for_each(|(outeq, em)| -> Result<()> {
272 let gamma_up = em.scalar()? * (1.0 + self.gamma_delta[outeq]);
275 let gamma_down = em.scalar()? / (1.0 + self.gamma_delta[outeq]);
276
277 let mut error_model_up = self.error_models.clone();
278 error_model_up.set_scalar(outeq, gamma_up)?;
279
280 let mut error_model_down = self.error_models.clone();
281 error_model_down.set_scalar(outeq, gamma_down)?;
282
283 let psi_up = calculate_psi(
284 &self.equation,
285 &self.data,
286 &self.theta,
287 &error_model_up,
288 false,
289 true,
290 )?;
291 let psi_down = calculate_psi(
292 &self.equation,
293 &self.data,
294 &self.theta,
295 &error_model_down,
296 false,
297 true,
298 )?;
299
300 let (lambda_up, objf_up) = match burke(&psi_up) {
301 Ok((lambda, objf)) => (lambda, objf),
302 Err(err) => {
303 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
305 }
306 };
307 let (lambda_down, objf_down) = match burke(&psi_down) {
308 Ok((lambda, objf)) => (lambda, objf),
309 Err(err) => {
310 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
313 }
315 };
316 if objf_up > self.objf {
317 self.error_models.set_scalar(outeq, gamma_up)?;
318 self.objf = objf_up;
319 self.gamma_delta[outeq] *= 4.;
320 self.lambda = lambda_up;
321 self.psi = psi_up;
322 }
323 if objf_down > self.objf {
324 self.error_models.set_scalar(outeq, gamma_down)?;
325 self.objf = objf_down;
326 self.gamma_delta[outeq] *= 4.;
327 self.lambda = lambda_down;
328 self.psi = psi_down;
329 }
330 self.gamma_delta[outeq] *= 0.5;
331 if self.gamma_delta[outeq] <= 0.01 {
332 self.gamma_delta[outeq] = 0.1;
333 }
334 Ok(())
335 })?;
336
337 Ok(())
338 }
339
340 fn logs(&self) {
341 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
342 tracing::debug!("Support points: {}", self.theta.nspp());
343 self.error_models.iter().for_each(|(outeq, em)| {
344 if ErrorModel::None == *em {
345 return;
346 }
347 tracing::debug!(
348 "Error model for outeq {}: {:.16}",
349 outeq,
350 em.scalar().unwrap_or_default()
351 );
352 });
353 if self.last_objf > self.objf + 1e-4 {
355 tracing::warn!(
356 "Objective function decreased from {:.4} to {:.4} (delta = {})",
357 -2.0 * self.last_objf,
358 -2.0 * self.objf,
359 -2.0 * self.last_objf - -2.0 * self.objf
360 );
361 }
362 }
363
364 fn expansion(&mut self) -> Result<()> {
365 let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
367 let w: Array1<f64> = self.w.clone().iter().cloned().collect();
368 let pyl = psi.dot(&w);
369
370 let error_model: ErrorModels = self.error_models.clone();
372
373 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
374 for spp in self.theta.matrix().row_iter() {
375 let candidate: Vec<f64> = spp.iter().cloned().collect();
376 let spp = Array1::from(candidate);
377 candididate_points.push(spp.to_owned());
378 }
379 candididate_points.par_iter_mut().for_each(|spp| {
380 let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
381 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
382 *spp = candidate_point;
383 });
389 for cp in candididate_points {
390 self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D);
391 }
392 Ok(())
393 }
394}
395
396impl<E: Equation> NPOD<E> {
397 fn validate_psi(&mut self) -> Result<()> {
398 let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
399 if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
401 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
402 for i in 0..psi.nrows() {
403 for j in 0..psi.ncols() {
404 let val = psi.get_mut((i, j)).unwrap();
405 if val.is_nan() || val.is_infinite() {
406 *val = 0.0;
407 }
408 }
409 }
410 }
411
412 let (_, col) = psi.dim();
414 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
415 let plam = psi.dot(&ecol);
416 let w = 1. / &plam;
417
418 let indices: Vec<usize> = w
420 .iter()
421 .enumerate()
422 .filter(|(_, x)| x.is_nan() || x.is_infinite())
423 .map(|(i, _)| i)
424 .collect::<Vec<_>>();
425
426 if !indices.is_empty() {
428 let subject: Vec<&Subject> = self.data.get_subjects();
429 let zero_probability_subjects: Vec<&String> =
430 indices.iter().map(|&i| subject[i].id()).collect();
431
432 return Err(anyhow::anyhow!(
433 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
434 ));
435 }
436
437 Ok(())
438 }
439}