pmcore/algorithms/
npod.rs

1use crate::algorithms::StopReason;
2use crate::routines::initialization::sample_space;
3use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
4use crate::structs::weights::Weights;
5use crate::{
6    algorithms::Status,
7    prelude::{
8        algorithms::Algorithms,
9        routines::{
10            estimation::{ipm::burke, qr},
11            settings::Settings,
12        },
13    },
14    structs::{
15        psi::{calculate_psi, Psi},
16        theta::Theta,
17    },
18};
19use pharmsol::SppOptimizer;
20
21use anyhow::bail;
22use anyhow::Result;
23use pharmsol::{prelude::AssayErrorModel, AssayErrorModels};
24use pharmsol::{
25    prelude::{data::Data, simulator::Equation},
26    Subject,
27};
28
29use ndarray::Array1;
30use rayon::prelude::{IntoParallelRefMutIterator, ParallelIterator};
31
32const THETA_F: f64 = 1e-2;
33const THETA_D: f64 = 1e-4;
34
35pub struct NPOD<E: Equation + Send + 'static> {
36    equation: E,
37    psi: Psi,
38    theta: Theta,
39    lambda: Weights,
40    w: Weights,
41    last_objf: f64,
42    objf: f64,
43    cycle: usize,
44    gamma_delta: Vec<f64>,
45    error_models: AssayErrorModels,
46    converged: bool,
47    status: Status,
48    cycle_log: CycleLog,
49    data: Data,
50    settings: Settings,
51}
52
53impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
54    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
55        Ok(Box::new(Self {
56            equation,
57            psi: Psi::new(),
58            theta: Theta::new(),
59            lambda: Weights::default(),
60            w: Weights::default(),
61            last_objf: -1e30,
62            objf: f64::NEG_INFINITY,
63            cycle: 0,
64            gamma_delta: vec![0.1; settings.errormodels().len()],
65            error_models: settings.errormodels().clone(),
66            converged: false,
67            status: Status::Continue,
68            cycle_log: CycleLog::new(),
69            settings,
70            data,
71        }))
72    }
73    fn into_npresult(&self) -> Result<NPResult<E>> {
74        NPResult::new(
75            self.equation.clone(),
76            self.data.clone(),
77            self.theta.clone(),
78            self.psi.clone(),
79            self.w.clone(),
80            -2. * self.objf,
81            self.cycle,
82            self.status.clone(),
83            self.settings.clone(),
84            self.cycle_log.clone(),
85        )
86    }
87
88    fn equation(&self) -> &E {
89        &self.equation
90    }
91
92    fn settings(&self) -> &Settings {
93        &self.settings
94    }
95
96    fn data(&self) -> &Data {
97        &self.data
98    }
99
100    fn get_prior(&self) -> Theta {
101        sample_space(&self.settings).unwrap()
102    }
103
104    fn increment_cycle(&mut self) -> usize {
105        self.cycle += 1;
106        self.cycle
107    }
108
109    fn cycle(&self) -> usize {
110        self.cycle
111    }
112
113    fn set_theta(&mut self, theta: Theta) {
114        self.theta = theta;
115    }
116
117    fn theta(&self) -> &Theta {
118        &self.theta
119    }
120
121    fn psi(&self) -> &Psi {
122        &self.psi
123    }
124
125    fn likelihood(&self) -> f64 {
126        self.objf
127    }
128
129    fn set_status(&mut self, status: Status) {
130        self.status = status;
131    }
132
133    fn status(&self) -> &Status {
134        &self.status
135    }
136
137    fn log_cycle_state(&mut self) {
138        let state = NPCycle::new(
139            self.cycle,
140            -2. * self.objf,
141            self.error_models.clone(),
142            self.theta.clone(),
143            self.theta.nspp(),
144            (self.last_objf - self.objf).abs(),
145            self.status.clone(),
146        );
147        self.cycle_log.push(state);
148        self.last_objf = self.objf;
149    }
150
151    fn evaluation(&mut self) -> Result<Status> {
152        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
153        tracing::debug!("Support points: {}", self.theta.nspp());
154        self.error_models.iter().for_each(|(outeq, em)| {
155            if AssayErrorModel::None == *em {
156                return;
157            }
158            tracing::debug!(
159                "Error model for outeq {}: {:.16}",
160                outeq,
161                em.factor().unwrap_or_default()
162            );
163        });
164        // Increasing objf signals instability or model misspecification.
165        if self.last_objf > self.objf + 1e-4 {
166            tracing::warn!(
167                "Objective function decreased from {:.4} to {:.4} (delta = {})",
168                -2.0 * self.last_objf,
169                -2.0 * self.objf,
170                -2.0 * self.last_objf - -2.0 * self.objf
171            );
172        }
173
174        if (self.last_objf - self.objf).abs() <= THETA_F {
175            tracing::info!("Objective function convergence reached");
176            self.converged = true;
177            self.set_status(Status::Stop(StopReason::Converged));
178            self.log_cycle_state();
179            return Ok(self.status.clone());
180        }
181
182        // Stop if we have reached maximum number of cycles
183        if self.cycle >= self.settings.config().cycles {
184            tracing::warn!("Maximum number of cycles reached");
185            self.converged = true;
186            self.set_status(Status::Stop(StopReason::MaxCycles));
187            self.log_cycle_state();
188            return Ok(self.status.clone());
189        }
190
191        // Stop if stopfile exists
192        if std::path::Path::new("stop").exists() {
193            tracing::warn!("Stopfile detected - breaking");
194            self.converged = true;
195            self.set_status(Status::Stop(StopReason::Stopped));
196            self.log_cycle_state();
197            return Ok(self.status.clone());
198        }
199
200        // Continue with normal operation
201        self.status = Status::Continue;
202        self.log_cycle_state();
203        Ok(self.status.clone())
204    }
205
206    fn estimation(&mut self) -> Result<()> {
207        let error_model: AssayErrorModels = self.error_models.clone();
208
209        self.psi = calculate_psi(
210            &self.equation,
211            &self.data,
212            &self.theta,
213            &error_model,
214            self.cycle == 1 && self.settings.config().progress,
215        )?;
216
217        if let Err(err) = self.validate_psi() {
218            bail!(err);
219        }
220
221        (self.lambda, _) = match burke(&self.psi) {
222            Ok((lambda, objf)) => (lambda, objf),
223            Err(err) => {
224                bail!(err);
225            }
226        };
227        Ok(())
228    }
229
230    fn condensation(&mut self) -> Result<()> {
231        let max_lambda = self
232            .lambda
233            .iter()
234            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
235
236        let mut keep = Vec::<usize>::new();
237        for (index, lam) in self.lambda.iter().enumerate() {
238            if lam > max_lambda / 1000_f64 {
239                keep.push(index);
240            }
241        }
242        if self.psi.matrix().ncols() != keep.len() {
243            tracing::debug!(
244                "Lambda (max/1000) dropped {} support point(s)",
245                self.psi.matrix().ncols() - keep.len(),
246            );
247        }
248
249        self.theta.filter_indices(keep.as_slice());
250        self.psi.filter_column_indices(keep.as_slice());
251
252        //Rank-Revealing Factorization
253        let (r, perm) = qr::qrd(&self.psi)?;
254
255        let mut keep = Vec::<usize>::new();
256
257        // The minimum between the number of subjects and the actual number of support points
258        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
259        for i in 0..keep_n {
260            let test = r.col(i).norm_l2();
261            let r_diag_val = r.get(i, i);
262            let ratio = r_diag_val / test;
263            if ratio.abs() >= 1e-8 {
264                keep.push(*perm.get(i).unwrap());
265            }
266        }
267
268        // If a support point is dropped, log it as a debug message
269        if self.psi.matrix().ncols() != keep.len() {
270            tracing::debug!(
271                "QR decomposition dropped {} support point(s)",
272                self.psi.matrix().ncols() - keep.len(),
273            );
274        }
275
276        self.theta.filter_indices(keep.as_slice());
277        self.psi.filter_column_indices(keep.as_slice());
278
279        (self.lambda, self.objf) = match burke(&self.psi) {
280            Ok((lambda, objf)) => (lambda, objf),
281            Err(err) => {
282                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
283            }
284        };
285        self.w = self.lambda.clone();
286        Ok(())
287    }
288
289    fn optimizations(&mut self) -> Result<()> {
290        self.error_models
291            .clone()
292            .iter_mut()
293            .filter_map(|(outeq, em)| {
294                if *em == AssayErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
295                    None
296                } else {
297                    Some((outeq, em))
298                }
299            })
300            .try_for_each(|(outeq, em)| -> Result<()> {
301                // OPTIMIZATION
302
303                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
304                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
305
306                let mut error_model_up = self.error_models.clone();
307                error_model_up.set_factor(outeq, gamma_up)?;
308
309                let mut error_model_down = self.error_models.clone();
310                error_model_down.set_factor(outeq, gamma_down)?;
311
312                let psi_up = calculate_psi(
313                    &self.equation,
314                    &self.data,
315                    &self.theta,
316                    &error_model_up,
317                    false,
318                )?;
319                let psi_down = calculate_psi(
320                    &self.equation,
321                    &self.data,
322                    &self.theta,
323                    &error_model_down,
324                    false,
325                )?;
326
327                let (lambda_up, objf_up) = match burke(&psi_up) {
328                    Ok((lambda, objf)) => (lambda, objf),
329                    Err(err) => {
330                        bail!("Error in IPM during optim: {:?}", err);
331                    }
332                };
333                let (lambda_down, objf_down) = match burke(&psi_down) {
334                    Ok((lambda, objf)) => (lambda, objf),
335                    Err(err) => {
336                        bail!("Error in IPM during optim: {:?}", err);
337                    }
338                };
339                if objf_up > self.objf {
340                    self.error_models.set_factor(outeq, gamma_up)?;
341                    self.objf = objf_up;
342                    self.gamma_delta[outeq] *= 4.;
343                    self.lambda = lambda_up;
344                    self.psi = psi_up;
345                }
346                if objf_down > self.objf {
347                    self.error_models.set_factor(outeq, gamma_down)?;
348                    self.objf = objf_down;
349                    self.gamma_delta[outeq] *= 4.;
350                    self.lambda = lambda_down;
351                    self.psi = psi_down;
352                }
353                self.gamma_delta[outeq] *= 0.5;
354                if self.gamma_delta[outeq] <= 0.01 {
355                    self.gamma_delta[outeq] = 0.1;
356                }
357                Ok(())
358            })?;
359
360        Ok(())
361    }
362
363    fn expansion(&mut self) -> Result<()> {
364        // Compute pyl = psi * w using faer native operations
365        let pyl_col = self.psi().matrix().as_ref() * self.w.weights().as_ref();
366        let pyl: Array1<f64> = pyl_col.iter().copied().collect();
367
368        // Add new point to theta based on the optimization of the D function
369        let error_model: AssayErrorModels = self.error_models.clone();
370
371        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
372        for spp in self.theta.matrix().row_iter() {
373            let candidate: Vec<f64> = spp.iter().cloned().collect();
374            let spp = Array1::from(candidate);
375            candididate_points.push(spp.to_owned());
376        }
377        candididate_points.par_iter_mut().for_each(|spp| {
378            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
379            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
380            *spp = candidate_point;
381            // add spp to theta
382            // recalculate psi
383            // re-run ipm to re-calculate w
384            // re-calculate pyl
385            // re-define a new optimization
386        });
387        for cp in candididate_points {
388            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
389        }
390        Ok(())
391    }
392}
393
394impl<E: Equation + Send + 'static> NPOD<E> {
395    fn validate_psi(&mut self) -> Result<()> {
396        let mut psi = self.psi().matrix().to_owned();
397        // First coerce all NaN and infinite in psi to 0.0
398        let mut has_bad_values = false;
399        for i in 0..psi.nrows() {
400            for j in 0..psi.ncols() {
401                let val = psi[(i, j)];
402                if val.is_nan() || val.is_infinite() {
403                    has_bad_values = true;
404                    psi[(i, j)] = 0.0;
405                }
406            }
407        }
408        if has_bad_values {
409            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
410        }
411
412        // Calculate row sums and check for zero-probability subjects
413        let nrows = psi.nrows();
414        let ncols = psi.ncols();
415        let indices: Vec<usize> = (0..nrows)
416            .filter(|&i| {
417                let row_sum: f64 = (0..ncols).map(|j| psi[(i, j)]).sum();
418                let w: f64 = 1.0 / row_sum;
419                w.is_nan() || w.is_infinite()
420            })
421            .collect();
422
423        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
424        if !indices.is_empty() {
425            let subject: Vec<&Subject> = self.data.subjects();
426            let zero_probability_subjects: Vec<&String> =
427                indices.iter().map(|&i| subject[i].id()).collect();
428
429            return Err(anyhow::anyhow!(
430                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
431            ));
432        }
433
434        Ok(())
435    }
436}