1use crate::routines::initialization::sample_space;
2use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
3use crate::structs::weights::Weights;
4use crate::{
5 algorithms::Status,
6 prelude::{
7 algorithms::Algorithms,
8 routines::{
9 evaluation::{ipm::burke, qr},
10 settings::Settings,
11 },
12 },
13 structs::{
14 psi::{calculate_psi, Psi},
15 theta::Theta,
16 },
17};
18use pharmsol::SppOptimizer;
19
20use anyhow::bail;
21use anyhow::Result;
22use faer_ext::IntoNdarray;
23use pharmsol::{prelude::ErrorModel, ErrorModels};
24use pharmsol::{
25 prelude::{data::Data, simulator::Equation},
26 Subject,
27};
28
29use ndarray::{
30 parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
31 Array, Array1, ArrayBase, Dim, OwnedRepr,
32};
33
34const THETA_F: f64 = 1e-2;
35const THETA_D: f64 = 1e-4;
36
37pub struct NPOD<E: Equation> {
38 equation: E,
39 psi: Psi,
40 theta: Theta,
41 lambda: Weights,
42 w: Weights,
43 last_objf: f64,
44 objf: f64,
45 cycle: usize,
46 gamma_delta: Vec<f64>,
47 error_models: ErrorModels,
48 converged: bool,
49 status: Status,
50 cycle_log: CycleLog,
51 data: Data,
52 settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPOD<E> {
56 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57 Ok(Box::new(Self {
58 equation,
59 psi: Psi::new(),
60 theta: Theta::new(),
61 lambda: Weights::default(),
62 w: Weights::default(),
63 last_objf: -1e30,
64 objf: f64::NEG_INFINITY,
65 cycle: 0,
66 gamma_delta: vec![0.1; settings.errormodels().len()],
67 error_models: settings.errormodels().clone(),
68 converged: false,
69 status: Status::Starting,
70 cycle_log: CycleLog::new(),
71 settings,
72 data,
73 }))
74 }
75 fn into_npresult(&self) -> NPResult<E> {
76 NPResult::new(
77 self.equation.clone(),
78 self.data.clone(),
79 self.theta.clone(),
80 self.psi.clone(),
81 self.w.clone(),
82 -2. * self.objf,
83 self.cycle,
84 self.status.clone(),
85 self.settings.clone(),
86 self.cycle_log.clone(),
87 )
88 }
89
90 fn equation(&self) -> &E {
91 &self.equation
92 }
93
94 fn get_settings(&self) -> &Settings {
95 &self.settings
96 }
97
98 fn get_data(&self) -> &Data {
99 &self.data
100 }
101
102 fn get_prior(&self) -> Theta {
103 sample_space(&self.settings).unwrap()
104 }
105
106 fn inc_cycle(&mut self) -> usize {
107 self.cycle += 1;
108 self.cycle
109 }
110
111 fn get_cycle(&self) -> usize {
112 self.cycle
113 }
114
115 fn set_theta(&mut self, theta: Theta) {
116 self.theta = theta;
117 }
118
119 fn theta(&self) -> &Theta {
120 &self.theta
121 }
122
123 fn psi(&self) -> &Psi {
124 &self.psi
125 }
126
127 fn likelihood(&self) -> f64 {
128 self.objf
129 }
130
131 fn set_status(&mut self, status: Status) {
132 self.status = status;
133 }
134
135 fn status(&self) -> &Status {
136 &self.status
137 }
138
139 fn convergence_evaluation(&mut self) {
140 if (self.last_objf - self.objf).abs() <= THETA_F {
141 tracing::info!("Objective function convergence reached");
142 self.converged = true;
143 self.status = Status::Converged;
144 }
145
146 if self.cycle >= self.settings.config().cycles {
148 tracing::warn!("Maximum number of cycles reached");
149 self.converged = true;
150 self.status = Status::MaxCycles;
151 }
152
153 if std::path::Path::new("stop").exists() {
155 tracing::warn!("Stopfile detected - breaking");
156 self.converged = true;
157 self.status = Status::ManualStop;
158 }
159
160 let state = NPCycle::new(
162 self.cycle,
163 -2. * self.objf,
164 self.error_models.clone(),
165 self.theta.clone(),
166 self.theta.nspp(),
167 (self.last_objf - self.objf).abs(),
168 self.status.clone(),
169 );
170
171 self.cycle_log.push(state);
173 self.last_objf = self.objf;
174 }
175
176 fn converged(&self) -> bool {
177 self.converged
178 }
179
180 fn evaluation(&mut self) -> Result<()> {
181 let error_model: ErrorModels = self.error_models.clone();
182
183 self.psi = calculate_psi(
184 &self.equation,
185 &self.data,
186 &self.theta,
187 &error_model,
188 self.cycle == 1 && self.settings.config().progress,
189 self.cycle != 1,
190 )?;
191
192 if let Err(err) = self.validate_psi() {
193 bail!(err);
194 }
195
196 (self.lambda, _) = match burke(&self.psi) {
197 Ok((lambda, objf)) => (lambda, objf),
198 Err(err) => {
199 bail!(err);
200 }
201 };
202 Ok(())
203 }
204
205 fn condensation(&mut self) -> Result<()> {
206 let max_lambda = self
207 .lambda
208 .iter()
209 .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
210
211 let mut keep = Vec::<usize>::new();
212 for (index, lam) in self.lambda.iter().enumerate() {
213 if lam > max_lambda / 1000_f64 {
214 keep.push(index);
215 }
216 }
217 if self.psi.matrix().ncols() != keep.len() {
218 tracing::debug!(
219 "Lambda (max/1000) dropped {} support point(s)",
220 self.psi.matrix().ncols() - keep.len(),
221 );
222 }
223
224 self.theta.filter_indices(keep.as_slice());
225 self.psi.filter_column_indices(keep.as_slice());
226
227 let (r, perm) = qr::qrd(&self.psi)?;
229
230 let mut keep = Vec::<usize>::new();
231
232 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
234 for i in 0..keep_n {
235 let test = r.col(i).norm_l2();
236 let r_diag_val = r.get(i, i);
237 let ratio = r_diag_val / test;
238 if ratio.abs() >= 1e-8 {
239 keep.push(*perm.get(i).unwrap());
240 }
241 }
242
243 if self.psi.matrix().ncols() != keep.len() {
245 tracing::debug!(
246 "QR decomposition dropped {} support point(s)",
247 self.psi.matrix().ncols() - keep.len(),
248 );
249 }
250
251 self.theta.filter_indices(keep.as_slice());
252 self.psi.filter_column_indices(keep.as_slice());
253
254 (self.lambda, self.objf) = match burke(&self.psi) {
255 Ok((lambda, objf)) => (lambda, objf),
256 Err(err) => {
257 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
258 }
259 };
260 self.w = self.lambda.clone();
261 Ok(())
262 }
263
264 fn optimizations(&mut self) -> Result<()> {
265 self.error_models
266 .clone()
267 .iter_mut()
268 .filter_map(|(outeq, em)| {
269 if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
270 None
271 } else {
272 Some((outeq, em))
273 }
274 })
275 .try_for_each(|(outeq, em)| -> Result<()> {
276 let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
279 let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
280
281 let mut error_model_up = self.error_models.clone();
282 error_model_up.set_factor(outeq, gamma_up)?;
283
284 let mut error_model_down = self.error_models.clone();
285 error_model_down.set_factor(outeq, gamma_down)?;
286
287 let psi_up = calculate_psi(
288 &self.equation,
289 &self.data,
290 &self.theta,
291 &error_model_up,
292 false,
293 true,
294 )?;
295 let psi_down = calculate_psi(
296 &self.equation,
297 &self.data,
298 &self.theta,
299 &error_model_down,
300 false,
301 true,
302 )?;
303
304 let (lambda_up, objf_up) = match burke(&psi_up) {
305 Ok((lambda, objf)) => (lambda, objf),
306 Err(err) => {
307 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
309 }
310 };
311 let (lambda_down, objf_down) = match burke(&psi_down) {
312 Ok((lambda, objf)) => (lambda, objf),
313 Err(err) => {
314 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
317 }
319 };
320 if objf_up > self.objf {
321 self.error_models.set_factor(outeq, gamma_up)?;
322 self.objf = objf_up;
323 self.gamma_delta[outeq] *= 4.;
324 self.lambda = lambda_up;
325 self.psi = psi_up;
326 }
327 if objf_down > self.objf {
328 self.error_models.set_factor(outeq, gamma_down)?;
329 self.objf = objf_down;
330 self.gamma_delta[outeq] *= 4.;
331 self.lambda = lambda_down;
332 self.psi = psi_down;
333 }
334 self.gamma_delta[outeq] *= 0.5;
335 if self.gamma_delta[outeq] <= 0.01 {
336 self.gamma_delta[outeq] = 0.1;
337 }
338 Ok(())
339 })?;
340
341 Ok(())
342 }
343
344 fn logs(&self) {
345 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
346 tracing::debug!("Support points: {}", self.theta.nspp());
347 self.error_models.iter().for_each(|(outeq, em)| {
348 if ErrorModel::None == *em {
349 return;
350 }
351 tracing::debug!(
352 "Error model for outeq {}: {:.16}",
353 outeq,
354 em.factor().unwrap_or_default()
355 );
356 });
357 if self.last_objf > self.objf + 1e-4 {
359 tracing::warn!(
360 "Objective function decreased from {:.4} to {:.4} (delta = {})",
361 -2.0 * self.last_objf,
362 -2.0 * self.objf,
363 -2.0 * self.last_objf - -2.0 * self.objf
364 );
365 }
366 }
367
368 fn expansion(&mut self) -> Result<()> {
369 let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
371 let w: Array1<f64> = self.w.clone().iter().collect();
372 let pyl = psi.dot(&w);
373
374 let error_model: ErrorModels = self.error_models.clone();
376
377 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
378 for spp in self.theta.matrix().row_iter() {
379 let candidate: Vec<f64> = spp.iter().cloned().collect();
380 let spp = Array1::from(candidate);
381 candididate_points.push(spp.to_owned());
382 }
383 candididate_points.par_iter_mut().for_each(|spp| {
384 let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
385 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
386 *spp = candidate_point;
387 });
393 for cp in candididate_points {
394 self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
395 }
396 Ok(())
397 }
398}
399
400impl<E: Equation> NPOD<E> {
401 fn validate_psi(&mut self) -> Result<()> {
402 let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
403 if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
405 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
406 for i in 0..psi.nrows() {
407 for j in 0..psi.ncols() {
408 let val = psi.get_mut((i, j)).unwrap();
409 if val.is_nan() || val.is_infinite() {
410 *val = 0.0;
411 }
412 }
413 }
414 }
415
416 let (_, col) = psi.dim();
418 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
419 let plam = psi.dot(&ecol);
420 let w = 1. / &plam;
421
422 let indices: Vec<usize> = w
424 .iter()
425 .enumerate()
426 .filter(|(_, x)| x.is_nan() || x.is_infinite())
427 .map(|(i, _)| i)
428 .collect::<Vec<_>>();
429
430 if !indices.is_empty() {
432 let subject: Vec<&Subject> = self.data.subjects();
433 let zero_probability_subjects: Vec<&String> =
434 indices.iter().map(|&i| subject[i].id()).collect();
435
436 return Err(anyhow::anyhow!(
437 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
438 ));
439 }
440
441 Ok(())
442 }
443}