1use crate::algorithms::StopReason;
2use crate::routines::initialization::sample_space;
3use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
4use crate::structs::weights::Weights;
5use crate::{
6 algorithms::Status,
7 prelude::{
8 algorithms::Algorithms,
9 routines::{
10 estimation::{ipm::burke, qr},
11 settings::Settings,
12 },
13 },
14 structs::{
15 psi::{calculate_psi, Psi},
16 theta::Theta,
17 },
18};
19use pharmsol::SppOptimizer;
20
21use anyhow::bail;
22use anyhow::Result;
23use faer_ext::IntoNdarray;
24use pharmsol::{prelude::ErrorModel, ErrorModels};
25use pharmsol::{
26 prelude::{data::Data, simulator::Equation},
27 Subject,
28};
29
30use ndarray::{
31 parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
32 Array, Array1, ArrayBase, Dim, OwnedRepr,
33};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation + Send + 'static> {
39 equation: E,
40 psi: Psi,
41 theta: Theta,
42 lambda: Weights,
43 w: Weights,
44 last_objf: f64,
45 objf: f64,
46 cycle: usize,
47 gamma_delta: Vec<f64>,
48 error_models: ErrorModels,
49 converged: bool,
50 status: Status,
51 cycle_log: CycleLog,
52 data: Data,
53 settings: Settings,
54}
55
56impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
57 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
58 Ok(Box::new(Self {
59 equation,
60 psi: Psi::new(),
61 theta: Theta::new(),
62 lambda: Weights::default(),
63 w: Weights::default(),
64 last_objf: -1e30,
65 objf: f64::NEG_INFINITY,
66 cycle: 0,
67 gamma_delta: vec![0.1; settings.errormodels().len()],
68 error_models: settings.errormodels().clone(),
69 converged: false,
70 status: Status::Continue,
71 cycle_log: CycleLog::new(),
72 settings,
73 data,
74 }))
75 }
76 fn into_npresult(&self) -> NPResult<E> {
77 NPResult::new(
78 self.equation.clone(),
79 self.data.clone(),
80 self.theta.clone(),
81 self.psi.clone(),
82 self.w.clone(),
83 -2. * self.objf,
84 self.cycle,
85 self.status.clone(),
86 self.settings.clone(),
87 self.cycle_log.clone(),
88 )
89 }
90
91 fn equation(&self) -> &E {
92 &self.equation
93 }
94
95 fn settings(&self) -> &Settings {
96 &self.settings
97 }
98
99 fn data(&self) -> &Data {
100 &self.data
101 }
102
103 fn get_prior(&self) -> Theta {
104 sample_space(&self.settings).unwrap()
105 }
106
107 fn increment_cycle(&mut self) -> usize {
108 self.cycle += 1;
109 self.cycle
110 }
111
112 fn cycle(&self) -> usize {
113 self.cycle
114 }
115
116 fn set_theta(&mut self, theta: Theta) {
117 self.theta = theta;
118 }
119
120 fn theta(&self) -> &Theta {
121 &self.theta
122 }
123
124 fn psi(&self) -> &Psi {
125 &self.psi
126 }
127
128 fn likelihood(&self) -> f64 {
129 self.objf
130 }
131
132 fn set_status(&mut self, status: Status) {
133 self.status = status;
134 }
135
136 fn status(&self) -> &Status {
137 &self.status
138 }
139
140 fn log_cycle_state(&mut self) {
141 let state = NPCycle::new(
142 self.cycle,
143 -2. * self.objf,
144 self.error_models.clone(),
145 self.theta.clone(),
146 self.theta.nspp(),
147 (self.last_objf - self.objf).abs(),
148 self.status.clone(),
149 );
150 self.cycle_log.push(state);
151 self.last_objf = self.objf;
152 }
153
154 fn evaluation(&mut self) -> Result<Status> {
155 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
156 tracing::debug!("Support points: {}", self.theta.nspp());
157 self.error_models.iter().for_each(|(outeq, em)| {
158 if ErrorModel::None == *em {
159 return;
160 }
161 tracing::debug!(
162 "Error model for outeq {}: {:.16}",
163 outeq,
164 em.factor().unwrap_or_default()
165 );
166 });
167 if self.last_objf > self.objf + 1e-4 {
169 tracing::warn!(
170 "Objective function decreased from {:.4} to {:.4} (delta = {})",
171 -2.0 * self.last_objf,
172 -2.0 * self.objf,
173 -2.0 * self.last_objf - -2.0 * self.objf
174 );
175 }
176
177 if (self.last_objf - self.objf).abs() <= THETA_F {
178 tracing::info!("Objective function convergence reached");
179 self.converged = true;
180 self.set_status(Status::Stop(StopReason::Converged));
181 self.log_cycle_state();
182 return Ok(self.status.clone());
183 }
184
185 if self.cycle >= self.settings.config().cycles {
187 tracing::warn!("Maximum number of cycles reached");
188 self.converged = true;
189 self.set_status(Status::Stop(StopReason::MaxCycles));
190 self.log_cycle_state();
191 return Ok(self.status.clone());
192 }
193
194 if std::path::Path::new("stop").exists() {
196 tracing::warn!("Stopfile detected - breaking");
197 self.converged = true;
198 self.set_status(Status::Stop(StopReason::Stopped));
199 self.log_cycle_state();
200 return Ok(self.status.clone());
201 }
202
203 self.status = Status::Continue;
205 self.log_cycle_state();
206 Ok(self.status.clone())
207 }
208
209 fn estimation(&mut self) -> Result<()> {
210 let error_model: ErrorModels = self.error_models.clone();
211
212 self.psi = calculate_psi(
213 &self.equation,
214 &self.data,
215 &self.theta,
216 &error_model,
217 self.cycle == 1 && self.settings.config().progress,
218 self.cycle != 1,
219 )?;
220
221 if let Err(err) = self.validate_psi() {
222 bail!(err);
223 }
224
225 (self.lambda, _) = match burke(&self.psi) {
226 Ok((lambda, objf)) => (lambda, objf),
227 Err(err) => {
228 bail!(err);
229 }
230 };
231 Ok(())
232 }
233
234 fn condensation(&mut self) -> Result<()> {
235 let max_lambda = self
236 .lambda
237 .iter()
238 .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
239
240 let mut keep = Vec::<usize>::new();
241 for (index, lam) in self.lambda.iter().enumerate() {
242 if lam > max_lambda / 1000_f64 {
243 keep.push(index);
244 }
245 }
246 if self.psi.matrix().ncols() != keep.len() {
247 tracing::debug!(
248 "Lambda (max/1000) dropped {} support point(s)",
249 self.psi.matrix().ncols() - keep.len(),
250 );
251 }
252
253 self.theta.filter_indices(keep.as_slice());
254 self.psi.filter_column_indices(keep.as_slice());
255
256 let (r, perm) = qr::qrd(&self.psi)?;
258
259 let mut keep = Vec::<usize>::new();
260
261 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
263 for i in 0..keep_n {
264 let test = r.col(i).norm_l2();
265 let r_diag_val = r.get(i, i);
266 let ratio = r_diag_val / test;
267 if ratio.abs() >= 1e-8 {
268 keep.push(*perm.get(i).unwrap());
269 }
270 }
271
272 if self.psi.matrix().ncols() != keep.len() {
274 tracing::debug!(
275 "QR decomposition dropped {} support point(s)",
276 self.psi.matrix().ncols() - keep.len(),
277 );
278 }
279
280 self.theta.filter_indices(keep.as_slice());
281 self.psi.filter_column_indices(keep.as_slice());
282
283 (self.lambda, self.objf) = match burke(&self.psi) {
284 Ok((lambda, objf)) => (lambda, objf),
285 Err(err) => {
286 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
287 }
288 };
289 self.w = self.lambda.clone();
290 Ok(())
291 }
292
293 fn optimizations(&mut self) -> Result<()> {
294 self.error_models
295 .clone()
296 .iter_mut()
297 .filter_map(|(outeq, em)| {
298 if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
299 None
300 } else {
301 Some((outeq, em))
302 }
303 })
304 .try_for_each(|(outeq, em)| -> Result<()> {
305 let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
308 let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
309
310 let mut error_model_up = self.error_models.clone();
311 error_model_up.set_factor(outeq, gamma_up)?;
312
313 let mut error_model_down = self.error_models.clone();
314 error_model_down.set_factor(outeq, gamma_down)?;
315
316 let psi_up = calculate_psi(
317 &self.equation,
318 &self.data,
319 &self.theta,
320 &error_model_up,
321 false,
322 true,
323 )?;
324 let psi_down = calculate_psi(
325 &self.equation,
326 &self.data,
327 &self.theta,
328 &error_model_down,
329 false,
330 true,
331 )?;
332
333 let (lambda_up, objf_up) = match burke(&psi_up) {
334 Ok((lambda, objf)) => (lambda, objf),
335 Err(err) => {
336 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
338 }
339 };
340 let (lambda_down, objf_down) = match burke(&psi_down) {
341 Ok((lambda, objf)) => (lambda, objf),
342 Err(err) => {
343 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
346 }
348 };
349 if objf_up > self.objf {
350 self.error_models.set_factor(outeq, gamma_up)?;
351 self.objf = objf_up;
352 self.gamma_delta[outeq] *= 4.;
353 self.lambda = lambda_up;
354 self.psi = psi_up;
355 }
356 if objf_down > self.objf {
357 self.error_models.set_factor(outeq, gamma_down)?;
358 self.objf = objf_down;
359 self.gamma_delta[outeq] *= 4.;
360 self.lambda = lambda_down;
361 self.psi = psi_down;
362 }
363 self.gamma_delta[outeq] *= 0.5;
364 if self.gamma_delta[outeq] <= 0.01 {
365 self.gamma_delta[outeq] = 0.1;
366 }
367 Ok(())
368 })?;
369
370 Ok(())
371 }
372
373 fn expansion(&mut self) -> Result<()> {
374 let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
376 let w: Array1<f64> = self.w.clone().iter().collect();
377 let pyl = psi.dot(&w);
378
379 let error_model: ErrorModels = self.error_models.clone();
381
382 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
383 for spp in self.theta.matrix().row_iter() {
384 let candidate: Vec<f64> = spp.iter().cloned().collect();
385 let spp = Array1::from(candidate);
386 candididate_points.push(spp.to_owned());
387 }
388 candididate_points.par_iter_mut().for_each(|spp| {
389 let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
390 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
391 *spp = candidate_point;
392 });
398 for cp in candididate_points {
399 self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
400 }
401 Ok(())
402 }
403}
404
405impl<E: Equation + Send + 'static> NPOD<E> {
406 fn validate_psi(&mut self) -> Result<()> {
407 let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
408 if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
410 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
411 for i in 0..psi.nrows() {
412 for j in 0..psi.ncols() {
413 let val = psi.get_mut((i, j)).unwrap();
414 if val.is_nan() || val.is_infinite() {
415 *val = 0.0;
416 }
417 }
418 }
419 }
420
421 let (_, col) = psi.dim();
423 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
424 let plam = psi.dot(&ecol);
425 let w = 1. / &plam;
426
427 let indices: Vec<usize> = w
429 .iter()
430 .enumerate()
431 .filter(|(_, x)| x.is_nan() || x.is_infinite())
432 .map(|(i, _)| i)
433 .collect::<Vec<_>>();
434
435 if !indices.is_empty() {
437 let subject: Vec<&Subject> = self.data.subjects();
438 let zero_probability_subjects: Vec<&String> =
439 indices.iter().map(|&i| subject[i].id()).collect();
440
441 return Err(anyhow::anyhow!(
442 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
443 ));
444 }
445
446 Ok(())
447 }
448}