pmcore/algorithms/
npod.rs

1use crate::{
2    algorithms::Status,
3    prelude::{
4        algorithms::Algorithms,
5        routines::{
6            evaluation::{ipm::burke, qr},
7            output::{CycleLog, NPCycle, NPResult},
8            settings::Settings,
9        },
10    },
11    structs::{
12        psi::{calculate_psi, Psi},
13        theta::Theta,
14    },
15};
16use anyhow::bail;
17use anyhow::Result;
18use faer::Col;
19use faer_ext::IntoNdarray;
20use pharmsol::{prelude::ErrorModel, ErrorModels};
21use pharmsol::{
22    prelude::{data::Data, simulator::Equation},
23    Subject,
24};
25
26use ndarray::{
27    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
28    Array, Array1, ArrayBase, Dim, OwnedRepr,
29};
30
31use crate::routines::{initialization, optimization::SppOptimizer};
32
33const THETA_F: f64 = 1e-2;
34const THETA_D: f64 = 1e-4;
35
36pub struct NPOD<E: Equation> {
37    equation: E,
38    psi: Psi,
39    theta: Theta,
40    lambda: Col<f64>,
41    w: Col<f64>,
42    last_objf: f64,
43    objf: f64,
44    cycle: usize,
45    gamma_delta: Vec<f64>,
46    error_models: ErrorModels,
47    converged: bool,
48    status: Status,
49    cycle_log: CycleLog,
50    data: Data,
51    settings: Settings,
52}
53
54impl<E: Equation> Algorithms<E> for NPOD<E> {
55    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
56        Ok(Box::new(Self {
57            equation,
58            psi: Psi::new(),
59            theta: Theta::new(),
60            lambda: Col::zeros(0),
61            w: Col::zeros(0),
62            last_objf: -1e30,
63            objf: f64::NEG_INFINITY,
64            cycle: 0,
65            gamma_delta: vec![0.1; settings.errormodels().len()],
66            error_models: settings.errormodels().clone(),
67            converged: false,
68            status: Status::Starting,
69            cycle_log: CycleLog::new(),
70            settings,
71            data,
72        }))
73    }
74    fn into_npresult(&self) -> NPResult<E> {
75        NPResult::new(
76            self.equation.clone(),
77            self.data.clone(),
78            self.theta.clone(),
79            self.psi.clone(),
80            self.w.clone(),
81            -2. * self.objf,
82            self.cycle,
83            self.status.clone(),
84            self.settings.clone(),
85            self.cycle_log.clone(),
86        )
87    }
88
89    fn equation(&self) -> &E {
90        &self.equation
91    }
92
93    fn get_settings(&self) -> &Settings {
94        &self.settings
95    }
96
97    fn get_data(&self) -> &Data {
98        &self.data
99    }
100
101    fn get_prior(&self) -> Theta {
102        initialization::sample_space(&self.settings).unwrap()
103    }
104
105    fn inc_cycle(&mut self) -> usize {
106        self.cycle += 1;
107        self.cycle
108    }
109
110    fn get_cycle(&self) -> usize {
111        self.cycle
112    }
113
114    fn set_theta(&mut self, theta: Theta) {
115        self.theta = theta;
116    }
117
118    fn theta(&self) -> &Theta {
119        &self.theta
120    }
121
122    fn psi(&self) -> &Psi {
123        &self.psi
124    }
125
126    fn likelihood(&self) -> f64 {
127        self.objf
128    }
129
130    fn set_status(&mut self, status: Status) {
131        self.status = status;
132    }
133
134    fn status(&self) -> &Status {
135        &self.status
136    }
137
138    fn convergence_evaluation(&mut self) {
139        if (self.last_objf - self.objf).abs() <= THETA_F {
140            tracing::info!("Objective function convergence reached");
141            self.converged = true;
142            self.status = Status::Converged;
143        }
144
145        // Stop if we have reached maximum number of cycles
146        if self.cycle >= self.settings.config().cycles {
147            tracing::warn!("Maximum number of cycles reached");
148            self.converged = true;
149            self.status = Status::MaxCycles;
150        }
151
152        // Stop if stopfile exists
153        if std::path::Path::new("stop").exists() {
154            tracing::warn!("Stopfile detected - breaking");
155            self.converged = true;
156            self.status = Status::ManualStop;
157        }
158
159        // Create state object
160        let state = NPCycle {
161            cycle: self.cycle,
162            objf: -2. * self.objf,
163            delta_objf: (self.last_objf - self.objf).abs(),
164            nspp: self.theta.nspp(),
165            theta: self.theta.clone(),
166            error_models: self.error_models.clone(),
167            status: self.status.clone(),
168        };
169
170        // Write cycle log
171        self.cycle_log.push(state);
172        self.last_objf = self.objf;
173    }
174
175    fn converged(&self) -> bool {
176        self.converged
177    }
178
179    fn evaluation(&mut self) -> Result<()> {
180        let error_model: ErrorModels = self.error_models.clone();
181
182        self.psi = calculate_psi(
183            &self.equation,
184            &self.data,
185            &self.theta,
186            &error_model,
187            self.cycle == 1 && self.settings.config().progress,
188            self.cycle != 1,
189        )?;
190
191        if let Err(err) = self.validate_psi() {
192            bail!(err);
193        }
194
195        (self.lambda, _) = match burke(&self.psi) {
196            Ok((lambda, objf)) => (lambda, objf),
197            Err(err) => {
198                bail!(err);
199            }
200        };
201        Ok(())
202    }
203
204    fn condensation(&mut self) -> Result<()> {
205        let max_lambda = self
206            .lambda
207            .iter()
208            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
209
210        let mut keep = Vec::<usize>::new();
211        for (index, lam) in self.lambda.iter().enumerate() {
212            if *lam > max_lambda / 1000_f64 {
213                keep.push(index);
214            }
215        }
216        if self.psi.matrix().ncols() != keep.len() {
217            tracing::debug!(
218                "Lambda (max/1000) dropped {} support point(s)",
219                self.psi.matrix().ncols() - keep.len(),
220            );
221        }
222
223        self.theta.filter_indices(keep.as_slice());
224        self.psi.filter_column_indices(keep.as_slice());
225
226        //Rank-Revealing Factorization
227        let (r, perm) = qr::qrd(&self.psi)?;
228
229        let mut keep = Vec::<usize>::new();
230
231        // The minimum between the number of subjects and the actual number of support points
232        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
233        for i in 0..keep_n {
234            let test = r.col(i).norm_l2();
235            let r_diag_val = r.get(i, i);
236            let ratio = r_diag_val / test;
237            if ratio.abs() >= 1e-8 {
238                keep.push(*perm.get(i).unwrap());
239            }
240        }
241
242        // If a support point is dropped, log it as a debug message
243        if self.psi.matrix().ncols() != keep.len() {
244            tracing::debug!(
245                "QR decomposition dropped {} support point(s)",
246                self.psi.matrix().ncols() - keep.len(),
247            );
248        }
249
250        self.theta.filter_indices(keep.as_slice());
251        self.psi.filter_column_indices(keep.as_slice());
252
253        (self.lambda, self.objf) = match burke(&self.psi) {
254            Ok((lambda, objf)) => (lambda, objf),
255            Err(err) => {
256                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
257            }
258        };
259        self.w = self.lambda.clone();
260        Ok(())
261    }
262
263    fn optimizations(&mut self) -> Result<()> {
264        self.error_models
265            .clone()
266            .iter_mut()
267            .filter_map(|(outeq, em)| {
268                if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
269                    None
270                } else {
271                    Some((outeq, em))
272                }
273            })
274            .try_for_each(|(outeq, em)| -> Result<()> {
275                // OPTIMIZATION
276
277                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
278                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
279
280                let mut error_model_up = self.error_models.clone();
281                error_model_up.set_factor(outeq, gamma_up)?;
282
283                let mut error_model_down = self.error_models.clone();
284                error_model_down.set_factor(outeq, gamma_down)?;
285
286                let psi_up = calculate_psi(
287                    &self.equation,
288                    &self.data,
289                    &self.theta,
290                    &error_model_up,
291                    false,
292                    true,
293                )?;
294                let psi_down = calculate_psi(
295                    &self.equation,
296                    &self.data,
297                    &self.theta,
298                    &error_model_down,
299                    false,
300                    true,
301                )?;
302
303                let (lambda_up, objf_up) = match burke(&psi_up) {
304                    Ok((lambda, objf)) => (lambda, objf),
305                    Err(err) => {
306                        //todo: write out report
307                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
308                    }
309                };
310                let (lambda_down, objf_down) = match burke(&psi_down) {
311                    Ok((lambda, objf)) => (lambda, objf),
312                    Err(err) => {
313                        //todo: write out report
314                        //panic!("Error in IPM: {:?}", err);
315                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
316                        //(Array1::zeros(1), f64::NEG_INFINITY)
317                    }
318                };
319                if objf_up > self.objf {
320                    self.error_models.set_factor(outeq, gamma_up)?;
321                    self.objf = objf_up;
322                    self.gamma_delta[outeq] *= 4.;
323                    self.lambda = lambda_up;
324                    self.psi = psi_up;
325                }
326                if objf_down > self.objf {
327                    self.error_models.set_factor(outeq, gamma_down)?;
328                    self.objf = objf_down;
329                    self.gamma_delta[outeq] *= 4.;
330                    self.lambda = lambda_down;
331                    self.psi = psi_down;
332                }
333                self.gamma_delta[outeq] *= 0.5;
334                if self.gamma_delta[outeq] <= 0.01 {
335                    self.gamma_delta[outeq] = 0.1;
336                }
337                Ok(())
338            })?;
339
340        Ok(())
341    }
342
343    fn logs(&self) {
344        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
345        tracing::debug!("Support points: {}", self.theta.nspp());
346        self.error_models.iter().for_each(|(outeq, em)| {
347            if ErrorModel::None == *em {
348                return;
349            }
350            tracing::debug!(
351                "Error model for outeq {}: {:.16}",
352                outeq,
353                em.factor().unwrap_or_default()
354            );
355        });
356        // Increasing objf signals instability or model misspecification.
357        if self.last_objf > self.objf + 1e-4 {
358            tracing::warn!(
359                "Objective function decreased from {:.4} to {:.4} (delta = {})",
360                -2.0 * self.last_objf,
361                -2.0 * self.objf,
362                -2.0 * self.last_objf - -2.0 * self.objf
363            );
364        }
365    }
366
367    fn expansion(&mut self) -> Result<()> {
368        // If no stop signal, add new point to theta based on the optimization of the D function
369        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
370        let w: Array1<f64> = self.w.clone().iter().cloned().collect();
371        let pyl = psi.dot(&w);
372
373        // Add new point to theta based on the optimization of the D function
374        let error_model: ErrorModels = self.error_models.clone();
375
376        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
377        for spp in self.theta.matrix().row_iter() {
378            let candidate: Vec<f64> = spp.iter().cloned().collect();
379            let spp = Array1::from(candidate);
380            candididate_points.push(spp.to_owned());
381        }
382        candididate_points.par_iter_mut().for_each(|spp| {
383            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
384            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
385            *spp = candidate_point;
386            // add spp to theta
387            // recalculate psi
388            // re-run ipm to re-calculate w
389            // re-calculate pyl
390            // re-define a new optimization
391        });
392        for cp in candididate_points {
393            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D);
394        }
395        Ok(())
396    }
397}
398
399impl<E: Equation> NPOD<E> {
400    fn validate_psi(&mut self) -> Result<()> {
401        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
402        // First coerce all NaN and infinite in psi to 0.0
403        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
404            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
405            for i in 0..psi.nrows() {
406                for j in 0..psi.ncols() {
407                    let val = psi.get_mut((i, j)).unwrap();
408                    if val.is_nan() || val.is_infinite() {
409                        *val = 0.0;
410                    }
411                }
412            }
413        }
414
415        // Calculate the sum of each column in psi
416        let (_, col) = psi.dim();
417        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
418        let plam = psi.dot(&ecol);
419        let w = 1. / &plam;
420
421        // Get the index of each element in `w` that is NaN or infinite
422        let indices: Vec<usize> = w
423            .iter()
424            .enumerate()
425            .filter(|(_, x)| x.is_nan() || x.is_infinite())
426            .map(|(i, _)| i)
427            .collect::<Vec<_>>();
428
429        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
430        if !indices.is_empty() {
431            let subject: Vec<&Subject> = self.data.subjects();
432            let zero_probability_subjects: Vec<&String> =
433                indices.iter().map(|&i| subject[i].id()).collect();
434
435            return Err(anyhow::anyhow!(
436                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
437            ));
438        }
439
440        Ok(())
441    }
442}