pmcore/algorithms/
npod.rs

1use crate::routines::initialization::sample_space;
2use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
3use crate::structs::weights::Weights;
4use crate::{
5    algorithms::Status,
6    prelude::{
7        algorithms::Algorithms,
8        routines::{
9            evaluation::{ipm::burke, qr},
10            settings::Settings,
11        },
12    },
13    structs::{
14        psi::{calculate_psi, Psi},
15        theta::Theta,
16    },
17};
18use pharmsol::SppOptimizer;
19
20use anyhow::bail;
21use anyhow::Result;
22use faer_ext::IntoNdarray;
23use pharmsol::{prelude::ErrorModel, ErrorModels};
24use pharmsol::{
25    prelude::{data::Data, simulator::Equation},
26    Subject,
27};
28
29use ndarray::{
30    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
31    Array, Array1, ArrayBase, Dim, OwnedRepr,
32};
33
34const THETA_F: f64 = 1e-2;
35const THETA_D: f64 = 1e-4;
36
37pub struct NPOD<E: Equation> {
38    equation: E,
39    psi: Psi,
40    theta: Theta,
41    lambda: Weights,
42    w: Weights,
43    last_objf: f64,
44    objf: f64,
45    cycle: usize,
46    gamma_delta: Vec<f64>,
47    error_models: ErrorModels,
48    converged: bool,
49    status: Status,
50    cycle_log: CycleLog,
51    data: Data,
52    settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPOD<E> {
56    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57        Ok(Box::new(Self {
58            equation,
59            psi: Psi::new(),
60            theta: Theta::new(),
61            lambda: Weights::default(),
62            w: Weights::default(),
63            last_objf: -1e30,
64            objf: f64::NEG_INFINITY,
65            cycle: 0,
66            gamma_delta: vec![0.1; settings.errormodels().len()],
67            error_models: settings.errormodels().clone(),
68            converged: false,
69            status: Status::Starting,
70            cycle_log: CycleLog::new(),
71            settings,
72            data,
73        }))
74    }
75    fn into_npresult(&self) -> NPResult<E> {
76        NPResult::new(
77            self.equation.clone(),
78            self.data.clone(),
79            self.theta.clone(),
80            self.psi.clone(),
81            self.w.clone(),
82            -2. * self.objf,
83            self.cycle,
84            self.status.clone(),
85            self.settings.clone(),
86            self.cycle_log.clone(),
87        )
88    }
89
90    fn equation(&self) -> &E {
91        &self.equation
92    }
93
94    fn get_settings(&self) -> &Settings {
95        &self.settings
96    }
97
98    fn get_data(&self) -> &Data {
99        &self.data
100    }
101
102    fn get_prior(&self) -> Theta {
103        sample_space(&self.settings).unwrap()
104    }
105
106    fn inc_cycle(&mut self) -> usize {
107        self.cycle += 1;
108        self.cycle
109    }
110
111    fn get_cycle(&self) -> usize {
112        self.cycle
113    }
114
115    fn set_theta(&mut self, theta: Theta) {
116        self.theta = theta;
117    }
118
119    fn theta(&self) -> &Theta {
120        &self.theta
121    }
122
123    fn psi(&self) -> &Psi {
124        &self.psi
125    }
126
127    fn likelihood(&self) -> f64 {
128        self.objf
129    }
130
131    fn set_status(&mut self, status: Status) {
132        self.status = status;
133    }
134
135    fn status(&self) -> &Status {
136        &self.status
137    }
138
139    fn convergence_evaluation(&mut self) {
140        if (self.last_objf - self.objf).abs() <= THETA_F {
141            tracing::info!("Objective function convergence reached");
142            self.converged = true;
143            self.status = Status::Converged;
144        }
145
146        // Stop if we have reached maximum number of cycles
147        if self.cycle >= self.settings.config().cycles {
148            tracing::warn!("Maximum number of cycles reached");
149            self.converged = true;
150            self.status = Status::MaxCycles;
151        }
152
153        // Stop if stopfile exists
154        if std::path::Path::new("stop").exists() {
155            tracing::warn!("Stopfile detected - breaking");
156            self.converged = true;
157            self.status = Status::ManualStop;
158        }
159
160        // Create state object
161        let state = NPCycle::new(
162            self.cycle,
163            -2. * self.objf,
164            self.error_models.clone(),
165            self.theta.clone(),
166            self.theta.nspp(),
167            (self.last_objf - self.objf).abs(),
168            self.status.clone(),
169        );
170
171        // Write cycle log
172        self.cycle_log.push(state);
173        self.last_objf = self.objf;
174    }
175
176    fn converged(&self) -> bool {
177        self.converged
178    }
179
180    fn evaluation(&mut self) -> Result<()> {
181        let error_model: ErrorModels = self.error_models.clone();
182
183        self.psi = calculate_psi(
184            &self.equation,
185            &self.data,
186            &self.theta,
187            &error_model,
188            self.cycle == 1 && self.settings.config().progress,
189            self.cycle != 1,
190        )?;
191
192        if let Err(err) = self.validate_psi() {
193            bail!(err);
194        }
195
196        (self.lambda, _) = match burke(&self.psi) {
197            Ok((lambda, objf)) => (lambda, objf),
198            Err(err) => {
199                bail!(err);
200            }
201        };
202        Ok(())
203    }
204
205    fn condensation(&mut self) -> Result<()> {
206        let max_lambda = self
207            .lambda
208            .iter()
209            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
210
211        let mut keep = Vec::<usize>::new();
212        for (index, lam) in self.lambda.iter().enumerate() {
213            if lam > max_lambda / 1000_f64 {
214                keep.push(index);
215            }
216        }
217        if self.psi.matrix().ncols() != keep.len() {
218            tracing::debug!(
219                "Lambda (max/1000) dropped {} support point(s)",
220                self.psi.matrix().ncols() - keep.len(),
221            );
222        }
223
224        self.theta.filter_indices(keep.as_slice());
225        self.psi.filter_column_indices(keep.as_slice());
226
227        //Rank-Revealing Factorization
228        let (r, perm) = qr::qrd(&self.psi)?;
229
230        let mut keep = Vec::<usize>::new();
231
232        // The minimum between the number of subjects and the actual number of support points
233        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
234        for i in 0..keep_n {
235            let test = r.col(i).norm_l2();
236            let r_diag_val = r.get(i, i);
237            let ratio = r_diag_val / test;
238            if ratio.abs() >= 1e-8 {
239                keep.push(*perm.get(i).unwrap());
240            }
241        }
242
243        // If a support point is dropped, log it as a debug message
244        if self.psi.matrix().ncols() != keep.len() {
245            tracing::debug!(
246                "QR decomposition dropped {} support point(s)",
247                self.psi.matrix().ncols() - keep.len(),
248            );
249        }
250
251        self.theta.filter_indices(keep.as_slice());
252        self.psi.filter_column_indices(keep.as_slice());
253
254        (self.lambda, self.objf) = match burke(&self.psi) {
255            Ok((lambda, objf)) => (lambda, objf),
256            Err(err) => {
257                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
258            }
259        };
260        self.w = self.lambda.clone();
261        Ok(())
262    }
263
264    fn optimizations(&mut self) -> Result<()> {
265        self.error_models
266            .clone()
267            .iter_mut()
268            .filter_map(|(outeq, em)| {
269                if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
270                    None
271                } else {
272                    Some((outeq, em))
273                }
274            })
275            .try_for_each(|(outeq, em)| -> Result<()> {
276                // OPTIMIZATION
277
278                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
279                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
280
281                let mut error_model_up = self.error_models.clone();
282                error_model_up.set_factor(outeq, gamma_up)?;
283
284                let mut error_model_down = self.error_models.clone();
285                error_model_down.set_factor(outeq, gamma_down)?;
286
287                let psi_up = calculate_psi(
288                    &self.equation,
289                    &self.data,
290                    &self.theta,
291                    &error_model_up,
292                    false,
293                    true,
294                )?;
295                let psi_down = calculate_psi(
296                    &self.equation,
297                    &self.data,
298                    &self.theta,
299                    &error_model_down,
300                    false,
301                    true,
302                )?;
303
304                let (lambda_up, objf_up) = match burke(&psi_up) {
305                    Ok((lambda, objf)) => (lambda, objf),
306                    Err(err) => {
307                        //todo: write out report
308                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
309                    }
310                };
311                let (lambda_down, objf_down) = match burke(&psi_down) {
312                    Ok((lambda, objf)) => (lambda, objf),
313                    Err(err) => {
314                        //todo: write out report
315                        //panic!("Error in IPM: {:?}", err);
316                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
317                        //(Array1::zeros(1), f64::NEG_INFINITY)
318                    }
319                };
320                if objf_up > self.objf {
321                    self.error_models.set_factor(outeq, gamma_up)?;
322                    self.objf = objf_up;
323                    self.gamma_delta[outeq] *= 4.;
324                    self.lambda = lambda_up;
325                    self.psi = psi_up;
326                }
327                if objf_down > self.objf {
328                    self.error_models.set_factor(outeq, gamma_down)?;
329                    self.objf = objf_down;
330                    self.gamma_delta[outeq] *= 4.;
331                    self.lambda = lambda_down;
332                    self.psi = psi_down;
333                }
334                self.gamma_delta[outeq] *= 0.5;
335                if self.gamma_delta[outeq] <= 0.01 {
336                    self.gamma_delta[outeq] = 0.1;
337                }
338                Ok(())
339            })?;
340
341        Ok(())
342    }
343
344    fn logs(&self) {
345        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
346        tracing::debug!("Support points: {}", self.theta.nspp());
347        self.error_models.iter().for_each(|(outeq, em)| {
348            if ErrorModel::None == *em {
349                return;
350            }
351            tracing::debug!(
352                "Error model for outeq {}: {:.16}",
353                outeq,
354                em.factor().unwrap_or_default()
355            );
356        });
357        // Increasing objf signals instability or model misspecification.
358        if self.last_objf > self.objf + 1e-4 {
359            tracing::warn!(
360                "Objective function decreased from {:.4} to {:.4} (delta = {})",
361                -2.0 * self.last_objf,
362                -2.0 * self.objf,
363                -2.0 * self.last_objf - -2.0 * self.objf
364            );
365        }
366    }
367
368    fn expansion(&mut self) -> Result<()> {
369        // If no stop signal, add new point to theta based on the optimization of the D function
370        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
371        let w: Array1<f64> = self.w.clone().iter().collect();
372        let pyl = psi.dot(&w);
373
374        // Add new point to theta based on the optimization of the D function
375        let error_model: ErrorModels = self.error_models.clone();
376
377        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
378        for spp in self.theta.matrix().row_iter() {
379            let candidate: Vec<f64> = spp.iter().cloned().collect();
380            let spp = Array1::from(candidate);
381            candididate_points.push(spp.to_owned());
382        }
383        candididate_points.par_iter_mut().for_each(|spp| {
384            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
385            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
386            *spp = candidate_point;
387            // add spp to theta
388            // recalculate psi
389            // re-run ipm to re-calculate w
390            // re-calculate pyl
391            // re-define a new optimization
392        });
393        for cp in candididate_points {
394            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
395        }
396        Ok(())
397    }
398}
399
400impl<E: Equation> NPOD<E> {
401    fn validate_psi(&mut self) -> Result<()> {
402        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
403        // First coerce all NaN and infinite in psi to 0.0
404        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
405            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
406            for i in 0..psi.nrows() {
407                for j in 0..psi.ncols() {
408                    let val = psi.get_mut((i, j)).unwrap();
409                    if val.is_nan() || val.is_infinite() {
410                        *val = 0.0;
411                    }
412                }
413            }
414        }
415
416        // Calculate the sum of each column in psi
417        let (_, col) = psi.dim();
418        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
419        let plam = psi.dot(&ecol);
420        let w = 1. / &plam;
421
422        // Get the index of each element in `w` that is NaN or infinite
423        let indices: Vec<usize> = w
424            .iter()
425            .enumerate()
426            .filter(|(_, x)| x.is_nan() || x.is_infinite())
427            .map(|(i, _)| i)
428            .collect::<Vec<_>>();
429
430        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
431        if !indices.is_empty() {
432            let subject: Vec<&Subject> = self.data.subjects();
433            let zero_probability_subjects: Vec<&String> =
434                indices.iter().map(|&i| subject[i].id()).collect();
435
436            return Err(anyhow::anyhow!(
437                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
438            ));
439        }
440
441        Ok(())
442    }
443}