1use crate::algorithms::StopReason;
2use crate::routines::initialization::sample_space;
3use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
4use crate::structs::weights::Weights;
5use crate::{
6 algorithms::Status,
7 prelude::{
8 algorithms::Algorithms,
9 routines::{
10 estimation::{ipm::burke, qr},
11 settings::Settings,
12 },
13 },
14 structs::{
15 psi::{calculate_psi, Psi},
16 theta::Theta,
17 },
18};
19use pharmsol::SppOptimizer;
20
21use anyhow::bail;
22use anyhow::Result;
23use pharmsol::{prelude::AssayErrorModel, AssayErrorModels};
24use pharmsol::{
25 prelude::{data::Data, simulator::Equation},
26 Subject,
27};
28
29use ndarray::Array1;
30use rayon::prelude::{IntoParallelRefMutIterator, ParallelIterator};
31
32const THETA_F: f64 = 1e-2;
33const THETA_D: f64 = 1e-4;
34
35pub struct NPOD<E: Equation + Send + 'static> {
36 equation: E,
37 psi: Psi,
38 theta: Theta,
39 lambda: Weights,
40 w: Weights,
41 last_objf: f64,
42 objf: f64,
43 cycle: usize,
44 gamma_delta: Vec<f64>,
45 error_models: AssayErrorModels,
46 converged: bool,
47 status: Status,
48 cycle_log: CycleLog,
49 data: Data,
50 settings: Settings,
51}
52
53impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
54 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
55 Ok(Box::new(Self {
56 equation,
57 psi: Psi::new(),
58 theta: Theta::new(),
59 lambda: Weights::default(),
60 w: Weights::default(),
61 last_objf: -1e30,
62 objf: f64::NEG_INFINITY,
63 cycle: 0,
64 gamma_delta: vec![0.1; settings.errormodels().len()],
65 error_models: settings.errormodels().clone(),
66 converged: false,
67 status: Status::Continue,
68 cycle_log: CycleLog::new(),
69 settings,
70 data,
71 }))
72 }
73 fn into_npresult(&self) -> Result<NPResult<E>> {
74 NPResult::new(
75 self.equation.clone(),
76 self.data.clone(),
77 self.theta.clone(),
78 self.psi.clone(),
79 self.w.clone(),
80 -2. * self.objf,
81 self.cycle,
82 self.status.clone(),
83 self.settings.clone(),
84 self.cycle_log.clone(),
85 )
86 }
87
88 fn equation(&self) -> &E {
89 &self.equation
90 }
91
92 fn settings(&self) -> &Settings {
93 &self.settings
94 }
95
96 fn data(&self) -> &Data {
97 &self.data
98 }
99
100 fn get_prior(&self) -> Theta {
101 sample_space(&self.settings).unwrap()
102 }
103
104 fn increment_cycle(&mut self) -> usize {
105 self.cycle += 1;
106 self.cycle
107 }
108
109 fn cycle(&self) -> usize {
110 self.cycle
111 }
112
113 fn set_theta(&mut self, theta: Theta) {
114 self.theta = theta;
115 }
116
117 fn theta(&self) -> &Theta {
118 &self.theta
119 }
120
121 fn psi(&self) -> &Psi {
122 &self.psi
123 }
124
125 fn likelihood(&self) -> f64 {
126 self.objf
127 }
128
129 fn set_status(&mut self, status: Status) {
130 self.status = status;
131 }
132
133 fn status(&self) -> &Status {
134 &self.status
135 }
136
137 fn log_cycle_state(&mut self) {
138 let state = NPCycle::new(
139 self.cycle,
140 -2. * self.objf,
141 self.error_models.clone(),
142 self.theta.clone(),
143 self.theta.nspp(),
144 (self.last_objf - self.objf).abs(),
145 self.status.clone(),
146 );
147 self.cycle_log.push(state);
148 self.last_objf = self.objf;
149 }
150
151 fn evaluation(&mut self) -> Result<Status> {
152 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
153 tracing::debug!("Support points: {}", self.theta.nspp());
154 self.error_models.iter().for_each(|(outeq, em)| {
155 if AssayErrorModel::None == *em {
156 return;
157 }
158 tracing::debug!(
159 "Error model for outeq {}: {:.16}",
160 outeq,
161 em.factor().unwrap_or_default()
162 );
163 });
164 if self.last_objf > self.objf + 1e-4 {
166 tracing::warn!(
167 "Objective function decreased from {:.4} to {:.4} (delta = {})",
168 -2.0 * self.last_objf,
169 -2.0 * self.objf,
170 -2.0 * self.last_objf - -2.0 * self.objf
171 );
172 }
173
174 if (self.last_objf - self.objf).abs() <= THETA_F {
175 tracing::info!("Objective function convergence reached");
176 self.converged = true;
177 self.set_status(Status::Stop(StopReason::Converged));
178 self.log_cycle_state();
179 return Ok(self.status.clone());
180 }
181
182 if self.cycle >= self.settings.config().cycles {
184 tracing::warn!("Maximum number of cycles reached");
185 self.converged = true;
186 self.set_status(Status::Stop(StopReason::MaxCycles));
187 self.log_cycle_state();
188 return Ok(self.status.clone());
189 }
190
191 if std::path::Path::new("stop").exists() {
193 tracing::warn!("Stopfile detected - breaking");
194 self.converged = true;
195 self.set_status(Status::Stop(StopReason::Stopped));
196 self.log_cycle_state();
197 return Ok(self.status.clone());
198 }
199
200 self.status = Status::Continue;
202 self.log_cycle_state();
203 Ok(self.status.clone())
204 }
205
206 fn estimation(&mut self) -> Result<()> {
207 let error_model: AssayErrorModels = self.error_models.clone();
208
209 self.psi = calculate_psi(
210 &self.equation,
211 &self.data,
212 &self.theta,
213 &error_model,
214 self.cycle == 1 && self.settings.config().progress,
215 )?;
216
217 if let Err(err) = self.validate_psi() {
218 bail!(err);
219 }
220
221 (self.lambda, _) = match burke(&self.psi) {
222 Ok((lambda, objf)) => (lambda, objf),
223 Err(err) => {
224 bail!(err);
225 }
226 };
227 Ok(())
228 }
229
230 fn condensation(&mut self) -> Result<()> {
231 let max_lambda = self
232 .lambda
233 .iter()
234 .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
235
236 let mut keep = Vec::<usize>::new();
237 for (index, lam) in self.lambda.iter().enumerate() {
238 if lam > max_lambda / 1000_f64 {
239 keep.push(index);
240 }
241 }
242 if self.psi.matrix().ncols() != keep.len() {
243 tracing::debug!(
244 "Lambda (max/1000) dropped {} support point(s)",
245 self.psi.matrix().ncols() - keep.len(),
246 );
247 }
248
249 self.theta.filter_indices(keep.as_slice());
250 self.psi.filter_column_indices(keep.as_slice());
251
252 let (r, perm) = qr::qrd(&self.psi)?;
254
255 let mut keep = Vec::<usize>::new();
256
257 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
259 for i in 0..keep_n {
260 let test = r.col(i).norm_l2();
261 let r_diag_val = r.get(i, i);
262 let ratio = r_diag_val / test;
263 if ratio.abs() >= 1e-8 {
264 keep.push(*perm.get(i).unwrap());
265 }
266 }
267
268 if self.psi.matrix().ncols() != keep.len() {
270 tracing::debug!(
271 "QR decomposition dropped {} support point(s)",
272 self.psi.matrix().ncols() - keep.len(),
273 );
274 }
275
276 self.theta.filter_indices(keep.as_slice());
277 self.psi.filter_column_indices(keep.as_slice());
278
279 (self.lambda, self.objf) = match burke(&self.psi) {
280 Ok((lambda, objf)) => (lambda, objf),
281 Err(err) => {
282 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
283 }
284 };
285 self.w = self.lambda.clone();
286 Ok(())
287 }
288
289 fn optimizations(&mut self) -> Result<()> {
290 self.error_models
291 .clone()
292 .iter_mut()
293 .filter_map(|(outeq, em)| {
294 if *em == AssayErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
295 None
296 } else {
297 Some((outeq, em))
298 }
299 })
300 .try_for_each(|(outeq, em)| -> Result<()> {
301 let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
304 let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
305
306 let mut error_model_up = self.error_models.clone();
307 error_model_up.set_factor(outeq, gamma_up)?;
308
309 let mut error_model_down = self.error_models.clone();
310 error_model_down.set_factor(outeq, gamma_down)?;
311
312 let psi_up = calculate_psi(
313 &self.equation,
314 &self.data,
315 &self.theta,
316 &error_model_up,
317 false,
318 )?;
319 let psi_down = calculate_psi(
320 &self.equation,
321 &self.data,
322 &self.theta,
323 &error_model_down,
324 false,
325 )?;
326
327 let (lambda_up, objf_up) = match burke(&psi_up) {
328 Ok((lambda, objf)) => (lambda, objf),
329 Err(err) => {
330 bail!("Error in IPM during optim: {:?}", err);
331 }
332 };
333 let (lambda_down, objf_down) = match burke(&psi_down) {
334 Ok((lambda, objf)) => (lambda, objf),
335 Err(err) => {
336 bail!("Error in IPM during optim: {:?}", err);
337 }
338 };
339 if objf_up > self.objf {
340 self.error_models.set_factor(outeq, gamma_up)?;
341 self.objf = objf_up;
342 self.gamma_delta[outeq] *= 4.;
343 self.lambda = lambda_up;
344 self.psi = psi_up;
345 }
346 if objf_down > self.objf {
347 self.error_models.set_factor(outeq, gamma_down)?;
348 self.objf = objf_down;
349 self.gamma_delta[outeq] *= 4.;
350 self.lambda = lambda_down;
351 self.psi = psi_down;
352 }
353 self.gamma_delta[outeq] *= 0.5;
354 if self.gamma_delta[outeq] <= 0.01 {
355 self.gamma_delta[outeq] = 0.1;
356 }
357 Ok(())
358 })?;
359
360 Ok(())
361 }
362
363 fn expansion(&mut self) -> Result<()> {
364 let pyl_col = self.psi().matrix().as_ref() * self.w.weights().as_ref();
366 let pyl: Array1<f64> = pyl_col.iter().copied().collect();
367
368 let error_model: AssayErrorModels = self.error_models.clone();
370
371 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
372 for spp in self.theta.matrix().row_iter() {
373 let candidate: Vec<f64> = spp.iter().cloned().collect();
374 let spp = Array1::from(candidate);
375 candididate_points.push(spp.to_owned());
376 }
377 candididate_points.par_iter_mut().for_each(|spp| {
378 let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
379 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
380 *spp = candidate_point;
381 });
387 for cp in candididate_points {
388 self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
389 }
390 Ok(())
391 }
392}
393
394impl<E: Equation + Send + 'static> NPOD<E> {
395 fn validate_psi(&mut self) -> Result<()> {
396 let mut psi = self.psi().matrix().to_owned();
397 let mut has_bad_values = false;
399 for i in 0..psi.nrows() {
400 for j in 0..psi.ncols() {
401 let val = psi[(i, j)];
402 if val.is_nan() || val.is_infinite() {
403 has_bad_values = true;
404 psi[(i, j)] = 0.0;
405 }
406 }
407 }
408 if has_bad_values {
409 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
410 }
411
412 let nrows = psi.nrows();
414 let ncols = psi.ncols();
415 let indices: Vec<usize> = (0..nrows)
416 .filter(|&i| {
417 let row_sum: f64 = (0..ncols).map(|j| psi[(i, j)]).sum();
418 let w: f64 = 1.0 / row_sum;
419 w.is_nan() || w.is_infinite()
420 })
421 .collect();
422
423 if !indices.is_empty() {
425 let subject: Vec<&Subject> = self.data.subjects();
426 let zero_probability_subjects: Vec<&String> =
427 indices.iter().map(|&i| subject[i].id()).collect();
428
429 return Err(anyhow::anyhow!(
430 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
431 ));
432 }
433
434 Ok(())
435 }
436}