pmcore/algorithms/
npod.rs

1use crate::algorithms::StopReason;
2use crate::routines::initialization::sample_space;
3use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
4use crate::structs::weights::Weights;
5use crate::{
6    algorithms::Status,
7    prelude::{
8        algorithms::Algorithms,
9        routines::{
10            estimation::{ipm::burke, qr},
11            settings::Settings,
12        },
13    },
14    structs::{
15        psi::{calculate_psi, Psi},
16        theta::Theta,
17    },
18};
19use pharmsol::SppOptimizer;
20
21use anyhow::bail;
22use anyhow::Result;
23use faer_ext::IntoNdarray;
24use pharmsol::{prelude::ErrorModel, ErrorModels};
25use pharmsol::{
26    prelude::{data::Data, simulator::Equation},
27    Subject,
28};
29
30use ndarray::{
31    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
32    Array, Array1, ArrayBase, Dim, OwnedRepr,
33};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation + Send + 'static> {
39    equation: E,
40    psi: Psi,
41    theta: Theta,
42    lambda: Weights,
43    w: Weights,
44    last_objf: f64,
45    objf: f64,
46    cycle: usize,
47    gamma_delta: Vec<f64>,
48    error_models: ErrorModels,
49    converged: bool,
50    status: Status,
51    cycle_log: CycleLog,
52    data: Data,
53    settings: Settings,
54}
55
56impl<E: Equation + Send + 'static> Algorithms<E> for NPOD<E> {
57    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
58        Ok(Box::new(Self {
59            equation,
60            psi: Psi::new(),
61            theta: Theta::new(),
62            lambda: Weights::default(),
63            w: Weights::default(),
64            last_objf: -1e30,
65            objf: f64::NEG_INFINITY,
66            cycle: 0,
67            gamma_delta: vec![0.1; settings.errormodels().len()],
68            error_models: settings.errormodels().clone(),
69            converged: false,
70            status: Status::Continue,
71            cycle_log: CycleLog::new(),
72            settings,
73            data,
74        }))
75    }
76    fn into_npresult(&self) -> NPResult<E> {
77        NPResult::new(
78            self.equation.clone(),
79            self.data.clone(),
80            self.theta.clone(),
81            self.psi.clone(),
82            self.w.clone(),
83            -2. * self.objf,
84            self.cycle,
85            self.status.clone(),
86            self.settings.clone(),
87            self.cycle_log.clone(),
88        )
89    }
90
91    fn equation(&self) -> &E {
92        &self.equation
93    }
94
95    fn settings(&self) -> &Settings {
96        &self.settings
97    }
98
99    fn data(&self) -> &Data {
100        &self.data
101    }
102
103    fn get_prior(&self) -> Theta {
104        sample_space(&self.settings).unwrap()
105    }
106
107    fn increment_cycle(&mut self) -> usize {
108        self.cycle += 1;
109        self.cycle
110    }
111
112    fn cycle(&self) -> usize {
113        self.cycle
114    }
115
116    fn set_theta(&mut self, theta: Theta) {
117        self.theta = theta;
118    }
119
120    fn theta(&self) -> &Theta {
121        &self.theta
122    }
123
124    fn psi(&self) -> &Psi {
125        &self.psi
126    }
127
128    fn likelihood(&self) -> f64 {
129        self.objf
130    }
131
132    fn set_status(&mut self, status: Status) {
133        self.status = status;
134    }
135
136    fn status(&self) -> &Status {
137        &self.status
138    }
139
140    fn log_cycle_state(&mut self) {
141        let state = NPCycle::new(
142            self.cycle,
143            -2. * self.objf,
144            self.error_models.clone(),
145            self.theta.clone(),
146            self.theta.nspp(),
147            (self.last_objf - self.objf).abs(),
148            self.status.clone(),
149        );
150        self.cycle_log.push(state);
151        self.last_objf = self.objf;
152    }
153
154    fn evaluation(&mut self) -> Result<Status> {
155        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
156        tracing::debug!("Support points: {}", self.theta.nspp());
157        self.error_models.iter().for_each(|(outeq, em)| {
158            if ErrorModel::None == *em {
159                return;
160            }
161            tracing::debug!(
162                "Error model for outeq {}: {:.16}",
163                outeq,
164                em.factor().unwrap_or_default()
165            );
166        });
167        // Increasing objf signals instability or model misspecification.
168        if self.last_objf > self.objf + 1e-4 {
169            tracing::warn!(
170                "Objective function decreased from {:.4} to {:.4} (delta = {})",
171                -2.0 * self.last_objf,
172                -2.0 * self.objf,
173                -2.0 * self.last_objf - -2.0 * self.objf
174            );
175        }
176
177        if (self.last_objf - self.objf).abs() <= THETA_F {
178            tracing::info!("Objective function convergence reached");
179            self.converged = true;
180            self.set_status(Status::Stop(StopReason::Converged));
181            self.log_cycle_state();
182            return Ok(self.status.clone());
183        }
184
185        // Stop if we have reached maximum number of cycles
186        if self.cycle >= self.settings.config().cycles {
187            tracing::warn!("Maximum number of cycles reached");
188            self.converged = true;
189            self.set_status(Status::Stop(StopReason::MaxCycles));
190            self.log_cycle_state();
191            return Ok(self.status.clone());
192        }
193
194        // Stop if stopfile exists
195        if std::path::Path::new("stop").exists() {
196            tracing::warn!("Stopfile detected - breaking");
197            self.converged = true;
198            self.set_status(Status::Stop(StopReason::Stopped));
199            self.log_cycle_state();
200            return Ok(self.status.clone());
201        }
202
203        // Continue with normal operation
204        self.status = Status::Continue;
205        self.log_cycle_state();
206        Ok(self.status.clone())
207    }
208
209    fn estimation(&mut self) -> Result<()> {
210        let error_model: ErrorModels = self.error_models.clone();
211
212        self.psi = calculate_psi(
213            &self.equation,
214            &self.data,
215            &self.theta,
216            &error_model,
217            self.cycle == 1 && self.settings.config().progress,
218            self.cycle != 1,
219        )?;
220
221        if let Err(err) = self.validate_psi() {
222            bail!(err);
223        }
224
225        (self.lambda, _) = match burke(&self.psi) {
226            Ok((lambda, objf)) => (lambda, objf),
227            Err(err) => {
228                bail!(err);
229            }
230        };
231        Ok(())
232    }
233
234    fn condensation(&mut self) -> Result<()> {
235        let max_lambda = self
236            .lambda
237            .iter()
238            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
239
240        let mut keep = Vec::<usize>::new();
241        for (index, lam) in self.lambda.iter().enumerate() {
242            if lam > max_lambda / 1000_f64 {
243                keep.push(index);
244            }
245        }
246        if self.psi.matrix().ncols() != keep.len() {
247            tracing::debug!(
248                "Lambda (max/1000) dropped {} support point(s)",
249                self.psi.matrix().ncols() - keep.len(),
250            );
251        }
252
253        self.theta.filter_indices(keep.as_slice());
254        self.psi.filter_column_indices(keep.as_slice());
255
256        //Rank-Revealing Factorization
257        let (r, perm) = qr::qrd(&self.psi)?;
258
259        let mut keep = Vec::<usize>::new();
260
261        // The minimum between the number of subjects and the actual number of support points
262        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
263        for i in 0..keep_n {
264            let test = r.col(i).norm_l2();
265            let r_diag_val = r.get(i, i);
266            let ratio = r_diag_val / test;
267            if ratio.abs() >= 1e-8 {
268                keep.push(*perm.get(i).unwrap());
269            }
270        }
271
272        // If a support point is dropped, log it as a debug message
273        if self.psi.matrix().ncols() != keep.len() {
274            tracing::debug!(
275                "QR decomposition dropped {} support point(s)",
276                self.psi.matrix().ncols() - keep.len(),
277            );
278        }
279
280        self.theta.filter_indices(keep.as_slice());
281        self.psi.filter_column_indices(keep.as_slice());
282
283        (self.lambda, self.objf) = match burke(&self.psi) {
284            Ok((lambda, objf)) => (lambda, objf),
285            Err(err) => {
286                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
287            }
288        };
289        self.w = self.lambda.clone();
290        Ok(())
291    }
292
293    fn optimizations(&mut self) -> Result<()> {
294        self.error_models
295            .clone()
296            .iter_mut()
297            .filter_map(|(outeq, em)| {
298                if *em == ErrorModel::None || em.is_factor_fixed().unwrap_or(true) {
299                    None
300                } else {
301                    Some((outeq, em))
302                }
303            })
304            .try_for_each(|(outeq, em)| -> Result<()> {
305                // OPTIMIZATION
306
307                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
308                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
309
310                let mut error_model_up = self.error_models.clone();
311                error_model_up.set_factor(outeq, gamma_up)?;
312
313                let mut error_model_down = self.error_models.clone();
314                error_model_down.set_factor(outeq, gamma_down)?;
315
316                let psi_up = calculate_psi(
317                    &self.equation,
318                    &self.data,
319                    &self.theta,
320                    &error_model_up,
321                    false,
322                    true,
323                )?;
324                let psi_down = calculate_psi(
325                    &self.equation,
326                    &self.data,
327                    &self.theta,
328                    &error_model_down,
329                    false,
330                    true,
331                )?;
332
333                let (lambda_up, objf_up) = match burke(&psi_up) {
334                    Ok((lambda, objf)) => (lambda, objf),
335                    Err(err) => {
336                        //todo: write out report
337                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
338                    }
339                };
340                let (lambda_down, objf_down) = match burke(&psi_down) {
341                    Ok((lambda, objf)) => (lambda, objf),
342                    Err(err) => {
343                        //todo: write out report
344                        //panic!("Error in IPM: {:?}", err);
345                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
346                        //(Array1::zeros(1), f64::NEG_INFINITY)
347                    }
348                };
349                if objf_up > self.objf {
350                    self.error_models.set_factor(outeq, gamma_up)?;
351                    self.objf = objf_up;
352                    self.gamma_delta[outeq] *= 4.;
353                    self.lambda = lambda_up;
354                    self.psi = psi_up;
355                }
356                if objf_down > self.objf {
357                    self.error_models.set_factor(outeq, gamma_down)?;
358                    self.objf = objf_down;
359                    self.gamma_delta[outeq] *= 4.;
360                    self.lambda = lambda_down;
361                    self.psi = psi_down;
362                }
363                self.gamma_delta[outeq] *= 0.5;
364                if self.gamma_delta[outeq] <= 0.01 {
365                    self.gamma_delta[outeq] = 0.1;
366                }
367                Ok(())
368            })?;
369
370        Ok(())
371    }
372
373    fn expansion(&mut self) -> Result<()> {
374        // If no stop signal, add new point to theta based on the optimization of the D function
375        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
376        let w: Array1<f64> = self.w.clone().iter().collect();
377        let pyl = psi.dot(&w);
378
379        // Add new point to theta based on the optimization of the D function
380        let error_model: ErrorModels = self.error_models.clone();
381
382        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
383        for spp in self.theta.matrix().row_iter() {
384            let candidate: Vec<f64> = spp.iter().cloned().collect();
385            let spp = Array1::from(candidate);
386            candididate_points.push(spp.to_owned());
387        }
388        candididate_points.par_iter_mut().for_each(|spp| {
389            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
390            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
391            *spp = candidate_point;
392            // add spp to theta
393            // recalculate psi
394            // re-run ipm to re-calculate w
395            // re-calculate pyl
396            // re-define a new optimization
397        });
398        for cp in candididate_points {
399            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D)?;
400        }
401        Ok(())
402    }
403}
404
405impl<E: Equation + Send + 'static> NPOD<E> {
406    fn validate_psi(&mut self) -> Result<()> {
407        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
408        // First coerce all NaN and infinite in psi to 0.0
409        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
410            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
411            for i in 0..psi.nrows() {
412                for j in 0..psi.ncols() {
413                    let val = psi.get_mut((i, j)).unwrap();
414                    if val.is_nan() || val.is_infinite() {
415                        *val = 0.0;
416                    }
417                }
418            }
419        }
420
421        // Calculate the sum of each column in psi
422        let (_, col) = psi.dim();
423        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
424        let plam = psi.dot(&ecol);
425        let w = 1. / &plam;
426
427        // Get the index of each element in `w` that is NaN or infinite
428        let indices: Vec<usize> = w
429            .iter()
430            .enumerate()
431            .filter(|(_, x)| x.is_nan() || x.is_infinite())
432            .map(|(i, _)| i)
433            .collect::<Vec<_>>();
434
435        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
436        if !indices.is_empty() {
437            let subject: Vec<&Subject> = self.data.subjects();
438            let zero_probability_subjects: Vec<&String> =
439                indices.iter().map(|&i| subject[i].id()).collect();
440
441            return Err(anyhow::anyhow!(
442                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
443            ));
444        }
445
446        Ok(())
447    }
448}