pmcore/algorithms/
npod.rs

1use crate::algorithms::StopReason;
2use crate::routines::initialization::sample_space;
3use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
4use crate::structs::weights::Weights;
5use crate::{
6    algorithms::Status,
7    prelude::{
8        algorithms::Algorithms,
9        routines::{
10            estimation::{ipm::burke, qr},
11            settings::Settings,
12        },
13    },
14    structs::{
15        psi::{calculate_psi, Psi},
16        theta::Theta,
17    },
18};
19use pharmsol::SppOptimizer;
20
21use anyhow::bail;
22use anyhow::Result;
23use faer_ext::IntoNdarray;
24use pharmsol::{prelude::ErrorModel, ErrorModels};
25use pharmsol::{
26    prelude::{data::Data, simulator::Equation},
27    Subject,
28};
29
30use ndarray::{
31    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
32    Array, Array1, ArrayBase, Dim, OwnedRepr,
33};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation + Send + 'static> {
39    equation: E,
40    psi: Psi,
41    theta: Theta,
42    lambda: Weights,
43    w: Weights,
44    last_objf: f64,
45    objf: f64,
46    cycle: usize,
47    gamma_delta: Vec<f64>,
48    error_models: ErrorModels,
49    converged: bool,
50    status: Status,
51    cycle_log: CycleLog,
52    data: Data,
53    settings: Settings,
54}
55
56impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
57    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
58        Ok(Box::new(Self {
59            equation,
60            psi: Psi::new(),
61            theta: Theta::new(),
62            lambda: Weights::default(),
63            w: Weights::default(),
64            last_objf: -1e30,
65            objf: f64::NEG_INFINITY,
66            cycle: 0,
67            gamma_delta: vec![0.1; settings.errormodels().len()],
68            error_models: settings.errormodels().clone(),
69            converged: false,
70            status: Status::Continue,
71            cycle_log: CycleLog::new(),
72            settings,
73            data,
74        }))
75    }
76    fn into_npresult(&self) -> Result<NPResult<E>> {
77        NPResult::new(
78            self.equation.clone(),
79            self.data.clone(),
80            self.theta.clone(),
81            self.psi.clone(),
82            self.w.clone(),
83            -2. * self.objf,
84            self.cycle,
85            self.status.clone(),
86            self.settings.clone(),
87            self.cycle_log.clone(),
88        )
89    }
90
91    fn equation(&self) -> &E {
92        &self.equation
93    }
94
95    fn settings(&self) -> &Settings {
96        &self.settings
97    }
98
99    fn data(&self) -> &Data {
100        &self.data
101    }
102
103    fn get_prior(&self) -> Theta {
104        sample_space(&self.settings).unwrap()
105    }
106
107    fn increment_cycle(&mut self) -> usize {
108        self.cycle += 1;
109        self.cycle
110    }
111
112    fn cycle(&self) -> usize {
113        self.cycle
114    }
115
116    fn set_theta(&mut self, theta: Theta) {
117        self.theta = theta;
118    }
119
120    fn theta(&self) -> &Theta {
121        &self.theta
122    }
123
124    fn psi(&self) -> &Psi {
125        &self.psi
126    }
127
128    fn likelihood(&self) -> f64 {
129        self.objf
130    }
131
132    fn set_status(&mut self, status: Status) {
133        self.status = status;
134    }
135
136    fn status(&self) -> &Status {
137        &self.status
138    }
139
140    fn log_cycle_state(&mut self) {
141        let state = NPCycle::new(
142            self.cycle,
143            -2. * self.objf,
144            self.error_models.clone(),
145            self.theta.clone(),
146            self.theta.nspp(),
147            (self.last_objf - self.objf).abs(),
148            self.status.clone(),
149        );
150        self.cycle_log.push(state);
151        self.last_objf = self.objf;
152    }
153
154    fn evaluation(&mut self) -> Result<Status> {
155        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
156        tracing::debug!("Support points: {}", self.theta.nspp());
157        self.error_models.iter().for_each(|(outeq, em)| {
158            if ErrorModel::None == *em {
159                return;
160            }
161            tracing::debug!(
162                "Error model for outeq {}: {:.16}",
163                outeq,
164                em.factor().unwrap_or_default()
165            );
166        });
167        // Increasing objf signals instability or model misspecification.
168        if self.last_objf > self.objf + 1e-4 {
169            tracing::warn!(
170                "Objective function decreased from {:.4} to {:.4} (delta = {})",
171                -2.0 * self.last_objf,
172                -2.0 * self.objf,
173                -2.0 * self.last_objf - -2.0 * self.objf
174            );
175        }
176
177        if (self.last_objf - self.objf).abs() <= THETA_F {
178            tracing::info!("Objective function convergence reached");
179            self.converged = true;
180            self.set_status(Status::Stop(StopReason::Converged));
181            self.log_cycle_state();
182            return Ok(self.status.clone());
183        }
184
185        // Stop if we have reached maximum number of cycles
186        if self.cycle >= self.settings.config().cycles {
187            tracing::warn!("Maximum number of cycles reached");
188            self.converged = true;
189            self.set_status(Status::Stop(StopReason::MaxCycles));
190            self.log_cycle_state();
191            return Ok(self.status.clone());
192        }
193
194        // Stop if stopfile exists
195        if std::path::Path::new("stop").exists() {
196            tracing::warn!("Stopfile detected - breaking");
197            self.converged = true;
198            self.set_status(Status::Stop(StopReason::Stopped));
199            self.log_cycle_state();
200            return Ok(self.status.clone());
201        }
202
203        // Continue with normal operation
204        self.status = Status::Continue;
205        self.log_cycle_state();
206        Ok(self.status.clone())
207    }
208
209    fn estimation(&mut self) -> Result<()> {
210        let error_model: ErrorModels = self.error_models.clone();
211
212        self.psi = calculate_psi(
213            &self.equation,
214            &self.data,
215            &self.theta,
216            &error_model,
217            self.cycle == 1 && self.settings.config().progress,
218            self.cycle != 1,
219        )?;
220
221        if let Err(err) = self.validate_psi() {
222            bail!(err);
223        }
224
225        (self.lambda, _) = match burke(&self.psi) {
226            Ok((lambda, objf)) => (lambda, objf),
227            Err(err) => {
228                bail!(err);
229            }
230        };
231        Ok(())
232    }
233
234    fn condensation(&mut self) -> Result<()> {
235        let max_lambda = self
236            .lambda
237            .iter()
238            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
239
240        let mut keep = Vec::<usize>::new();
241        for (index, lam) in self.lambda.iter().enumerate() {
242            if lam > max_lambda / 1000_f64 {
243                keep.push(index);
244            }
245        }
246        if self.psi.matrix().ncols() != keep.len() {
247            tracing::debug!(
248                "Lambda (max/1000) dropped {} support point(s)",
249                self.psi.matrix().ncols() - keep.len(),
250            );
251        }
252
253        self.theta.filter_indices(keep.as_slice());
254        self.psi.filter_column_indices(keep.as_slice());
255
256        //Rank-Revealing Factorization
257        let (r, perm) = qr::qrd(&self.psi)?;
258
259        let mut keep = Vec::<usize>::new();
260
261        // The minimum between the number of subjects and the actual number of support points
262        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
263        for i in 0..keep_n {
264            let test = r.col(i).norm_l2();
265            let r_diag_val = r.get(i, i);
266            let ratio = r_diag_val / test;
267            if ratio.abs() >= 1e-8 {
268                keep.push(*perm.get(i).unwrap());
269            }
270        }
271
272        // If a support point is dropped, log it as a debug message
273        if self.psi.matrix().ncols() != keep.len() {
274            tracing::debug!(
275                "QR decomposition dropped {} support point(s)",
276                self.psi.matrix().ncols() - keep.len(),
277            );
278        }
279
280        self.theta.filter_indices(keep.as_slice());
281        self.psi.filter_column_indices(keep.as_slice());
282
283        (self.lambda, self.objf) = match burke(&self.psi) {
284            Ok((lambda, objf)) => (lambda, objf),
285            Err(err) => {
286                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
287            }
288        };
289        self.w = self.lambda.clone();
290        Ok(())
291    }
292
293    fn optimizations(&mut self) -> Result<()> {
294        self.error_models
295            .clone()
296            .iter_mut()
297            .filter_map(|(outeq, em)| {
298                if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
299                    None
300                } else {
301                    Some((outeq, em))
302                }
303            })
304            .try_for_each(|(outeq, em)| -> Result<()> {
305                // OPTIMIZATION
306
307                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
308                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
309
310                let mut error_model_up = self.error_models.clone();
311                error_model_up.set_factor(outeq, gamma_up)?;
312
313                let mut error_model_down = self.error_models.clone();
314                error_model_down.set_factor(outeq, gamma_down)?;
315
316                let psi_up = calculate_psi(
317                    &self.equation,
318                    &self.data,
319                    &self.theta,
320                    &error_model_up,
321                    false,
322                    true,
323                )?;
324                let psi_down = calculate_psi(
325                    &self.equation,
326                    &self.data,
327                    &self.theta,
328                    &error_model_down,
329                    false,
330                    true,
331                )?;
332
333                let (lambda_up, objf_up) = match burke(&psi_up) {
334                    Ok((lambda, objf)) => (lambda, objf),
335                    Err(err) => {
336                        bail!("Error in IPM during optim: {:?}", err);
337                    }
338                };
339                let (lambda_down, objf_down) = match burke(&psi_down) {
340                    Ok((lambda, objf)) => (lambda, objf),
341                    Err(err) => {
342                        bail!("Error in IPM during optim: {:?}", err);
343                    }
344                };
345                if objf_up > self.objf {
346                    self.error_models.set_factor(outeq, gamma_up)?;
347                    self.objf = objf_up;
348                    self.gamma_delta[outeq] *= 4.;
349                    self.lambda = lambda_up;
350                    self.psi = psi_up;
351                }
352                if objf_down > self.objf {
353                    self.error_models.set_factor(outeq, gamma_down)?;
354                    self.objf = objf_down;
355                    self.gamma_delta[outeq] *= 4.;
356                    self.lambda = lambda_down;
357                    self.psi = psi_down;
358                }
359                self.gamma_delta[outeq] *= 0.5;
360                if self.gamma_delta[outeq] <= 0.01 {
361                    self.gamma_delta[outeq] = 0.1;
362                }
363                Ok(())
364            })?;
365
366        Ok(())
367    }
368
369    fn expansion(&mut self) -> Result<()> {
370        // If no stop signal, add new point to theta based on the optimization of the D function
371        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
372        let w: Array1<f64> = self.w.clone().iter().collect();
373        let pyl = psi.dot(&w);
374
375        // Add new point to theta based on the optimization of the D function
376        let error_model: ErrorModels = self.error_models.clone();
377
378        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
379        for spp in self.theta.matrix().row_iter() {
380            let candidate: Vec<f64> = spp.iter().cloned().collect();
381            let spp = Array1::from(candidate);
382            candididate_points.push(spp.to_owned());
383        }
384        candididate_points.par_iter_mut().for_each(|spp| {
385            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
386            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
387            *spp = candidate_point;
388            // add spp to theta
389            // recalculate psi
390            // re-run ipm to re-calculate w
391            // re-calculate pyl
392            // re-define a new optimization
393        });
394        for cp in candididate_points {
395            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
396        }
397        Ok(())
398    }
399}
400
401impl<E: Equation + Send + 'static> NPOD<E> {
402    fn validate_psi(&mut self) -> Result<()> {
403        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
404        // First coerce all NaN and infinite in psi to 0.0
405        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
406            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
407            for i in 0..psi.nrows() {
408                for j in 0..psi.ncols() {
409                    let val = psi.get_mut((i, j)).unwrap();
410                    if val.is_nan() || val.is_infinite() {
411                        *val = 0.0;
412                    }
413                }
414            }
415        }
416
417        // Calculate the sum of each column in psi
418        let (_, col) = psi.dim();
419        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
420        let plam = psi.dot(&ecol);
421        let w = 1. / &plam;
422
423        // Get the index of each element in `w` that is NaN or infinite
424        let indices: Vec<usize> = w
425            .iter()
426            .enumerate()
427            .filter(|(_, x)| x.is_nan() || x.is_infinite())
428            .map(|(i, _)| i)
429            .collect::<Vec<_>>();
430
431        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
432        if !indices.is_empty() {
433            let subject: Vec<&Subject> = self.data.subjects();
434            let zero_probability_subjects: Vec<&String> =
435                indices.iter().map(|&i| subject[i].id()).collect();
436
437            return Err(anyhow::anyhow!(
438                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
439            ));
440        }
441
442        Ok(())
443    }
444}