pmcore/algorithms/
npod.rs

1use crate::{
2    prelude::{
3        algorithms::Algorithms,
4        routines::{
5            evaluation::{ipm::burke, qr},
6            output::{CycleLog, NPCycle, NPResult},
7            settings::Settings,
8        },
9    },
10    structs::{
11        psi::{calculate_psi, Psi},
12        theta::Theta,
13    },
14};
15use anyhow::bail;
16use anyhow::Result;
17use faer_ext::IntoNdarray;
18use pharmsol::{
19    prelude::{
20        data::{Data, ErrorModel},
21        simulator::Equation,
22    },
23    Subject,
24};
25
26use faer::Col;
27
28use ndarray::{
29    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
30    Array, Array1, ArrayBase, Dim, OwnedRepr,
31};
32
33use crate::routines::{initialization, optimization::SppOptimizer};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation> {
39    equation: E,
40    psi: Psi,
41    theta: Theta,
42    lambda: Col<f64>,
43    w: Col<f64>,
44    last_objf: f64,
45    objf: f64,
46    cycle: usize,
47    gamma_delta: f64,
48    error_model: ErrorModel,
49    converged: bool,
50    cycle_log: CycleLog,
51    data: Data,
52    settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPOD<E> {
56    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57        Ok(Box::new(Self {
58            equation,
59            psi: Psi::new(),
60            theta: Theta::new(),
61            lambda: Col::zeros(0),
62            w: Col::zeros(0),
63            last_objf: -1e30,
64            objf: f64::NEG_INFINITY,
65            cycle: 0,
66            gamma_delta: 0.1,
67            error_model: settings.error().clone().into(),
68            converged: false,
69            cycle_log: CycleLog::new(),
70            settings,
71            data,
72        }))
73    }
74    fn into_npresult(&self) -> NPResult<E> {
75        NPResult::new(
76            self.equation.clone(),
77            self.data.clone(),
78            self.theta.clone(),
79            self.psi.clone(),
80            self.w.clone(),
81            -2. * self.objf,
82            self.cycle,
83            self.converged,
84            self.settings.clone(),
85            self.cycle_log.clone(),
86        )
87    }
88
89    fn equation(&self) -> &E {
90        &self.equation
91    }
92
93    fn get_settings(&self) -> &Settings {
94        &self.settings
95    }
96
97    fn get_data(&self) -> &Data {
98        &self.data
99    }
100
101    fn get_prior(&self) -> Theta {
102        initialization::sample_space(&self.settings).unwrap()
103    }
104
105    fn inc_cycle(&mut self) -> usize {
106        self.cycle += 1;
107        self.cycle
108    }
109
110    fn get_cycle(&self) -> usize {
111        self.cycle
112    }
113
114    fn set_theta(&mut self, theta: Theta) {
115        self.theta = theta;
116    }
117
118    fn get_theta(&self) -> &Theta {
119        &self.theta
120    }
121
122    fn psi(&self) -> &Psi {
123        &self.psi
124    }
125
126    fn likelihood(&self) -> f64 {
127        self.objf
128    }
129
130    fn convergence_evaluation(&mut self) {
131        if (self.last_objf - self.objf).abs() <= THETA_F {
132            tracing::info!("Objective function convergence reached");
133            self.converged = true;
134        }
135
136        // Stop if we have reached maximum number of cycles
137        if self.cycle >= self.settings.config().cycles {
138            tracing::warn!("Maximum number of cycles reached");
139            self.converged = true;
140        }
141
142        // Stop if stopfile exists
143        if std::path::Path::new("stop").exists() {
144            tracing::warn!("Stopfile detected - breaking");
145            self.converged = true;
146        }
147
148        // Create state object
149        let state = NPCycle {
150            cycle: self.cycle,
151            objf: -2. * self.objf,
152            delta_objf: (self.last_objf - self.objf).abs(),
153            nspp: self.theta.nspp(),
154            theta: self.theta.clone(),
155            gamlam: self.error_model.scalar(),
156            converged: self.converged,
157        };
158
159        // Write cycle log
160        self.cycle_log.push(state);
161        self.last_objf = self.objf;
162    }
163
164    fn converged(&self) -> bool {
165        self.converged
166    }
167
168    fn evaluation(&mut self) -> Result<()> {
169        let error_model: ErrorModel = self.error_model.clone();
170
171        self.psi = calculate_psi(
172            &self.equation,
173            &self.data,
174            &self.theta,
175            &error_model,
176            self.cycle == 1 && self.settings.config().progress,
177            self.cycle != 1,
178        )?;
179
180        if let Err(err) = self.validate_psi() {
181            bail!(err);
182        }
183
184        (self.lambda, _) = match burke(&self.psi) {
185            Ok((lambda, objf)) => (lambda, objf),
186            Err(err) => {
187                bail!(err);
188            }
189        };
190        Ok(())
191    }
192
193    fn condensation(&mut self) -> Result<()> {
194        let max_lambda = self
195            .lambda
196            .iter()
197            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
198
199        let mut keep = Vec::<usize>::new();
200        for (index, lam) in self.lambda.iter().enumerate() {
201            if *lam > max_lambda / 1000_f64 {
202                keep.push(index);
203            }
204        }
205        if self.psi.matrix().ncols() != keep.len() {
206            tracing::debug!(
207                "Lambda (max/1000) dropped {} support point(s)",
208                self.psi.matrix().ncols() - keep.len(),
209            );
210        }
211
212        self.theta.filter_indices(keep.as_slice());
213        self.psi.filter_column_indices(keep.as_slice());
214
215        //Rank-Revealing Factorization
216        let (r, perm) = qr::qrd(&self.psi)?;
217
218        let mut keep = Vec::<usize>::new();
219
220        // The minimum between the number of subjects and the actual number of support points
221        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
222        for i in 0..keep_n {
223            let test = r.col(i).norm_l2();
224            let r_diag_val = r.get(i, i);
225            let ratio = r_diag_val / test;
226            if ratio.abs() >= 1e-8 {
227                keep.push(*perm.get(i).unwrap());
228            }
229        }
230
231        // If a support point is dropped, log it as a debug message
232        if self.psi.matrix().ncols() != keep.len() {
233            tracing::debug!(
234                "QR decomposition dropped {} support point(s)",
235                self.psi.matrix().ncols() - keep.len(),
236            );
237        }
238
239        self.theta.filter_indices(keep.as_slice());
240        self.psi.filter_column_indices(keep.as_slice());
241
242        (self.lambda, self.objf) = match burke(&self.psi) {
243            Ok((lambda, objf)) => (lambda, objf),
244            Err(err) => {
245                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
246            }
247        };
248        self.w = self.lambda.clone();
249        Ok(())
250    }
251
252    fn optimizations(&mut self) -> Result<()> {
253        // Gam/Lam optimization
254        // TODO: Move this to e.g. /evaluation/error.rs
255        let gamma_up = self.error_model.scalar() * (1.0 + self.gamma_delta);
256        let gamma_down = self.error_model.scalar() / (1.0 + self.gamma_delta);
257
258        let mut error_model_up: ErrorModel = self.error_model.clone();
259        error_model_up.set_scalar(gamma_up);
260
261        let mut error_model_down: ErrorModel = self.error_model.clone();
262        error_model_down.set_scalar(gamma_down);
263
264        let psi_up = calculate_psi(
265            &self.equation,
266            &self.data,
267            &self.theta,
268            &error_model_up,
269            false,
270            true,
271        )?;
272        let psi_down = calculate_psi(
273            &self.equation,
274            &self.data,
275            &self.theta,
276            &error_model_down,
277            false,
278            true,
279        )?;
280
281        let (lambda_up, objf_up) = match burke(&psi_up) {
282            Ok((lambda, objf)) => (lambda, objf),
283            Err(err) => {
284                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
285            }
286        };
287        let (lambda_down, objf_down) = match burke(&psi_down) {
288            Ok((lambda, objf)) => (lambda, objf),
289            Err(err) => {
290                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
291            }
292        };
293
294        if objf_up > self.objf {
295            self.error_model.set_scalar(gamma_up);
296            self.objf = objf_up;
297            self.gamma_delta *= 4.;
298            self.lambda = lambda_up;
299            self.psi = psi_up;
300        }
301        if objf_down > self.objf {
302            self.error_model.set_scalar(gamma_down);
303            self.objf = objf_down;
304            self.gamma_delta *= 4.;
305            self.lambda = lambda_down;
306            self.psi = psi_down;
307        }
308        self.gamma_delta *= 0.5;
309        if self.gamma_delta <= 0.01 {
310            self.gamma_delta = 0.1;
311        }
312        Ok(())
313    }
314
315    fn logs(&self) {
316        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
317        tracing::debug!("Support points: {}", self.theta.nspp());
318        tracing::debug!("Gamma = {:.16}", self.error_model.scalar());
319        // Increasing objf signals instability or model misspecification.
320        if self.last_objf > self.objf + 1e-4 {
321            tracing::warn!(
322                "Objective function decreased from {:.4} to {:.4} (delta = {})",
323                -2.0 * self.last_objf,
324                -2.0 * self.objf,
325                -2.0 * self.last_objf - -2.0 * self.objf
326            );
327        }
328    }
329
330    fn expansion(&mut self) -> Result<()> {
331        // If no stop signal, add new point to theta based on the optimization of the D function
332        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
333        let w: Array1<f64> = self.w.clone().iter().cloned().collect();
334        let pyl = psi.dot(&w);
335
336        // Add new point to theta based on the optimization of the D function
337        let error_model: ErrorModel = self.error_model.clone();
338
339        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
340        for spp in self.theta.matrix().row_iter() {
341            let candidate: Vec<f64> = spp.iter().cloned().collect();
342            let spp = Array1::from(candidate);
343            candididate_points.push(spp.to_owned());
344        }
345        candididate_points.par_iter_mut().for_each(|spp| {
346            let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
347            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
348            *spp = candidate_point;
349            // add spp to theta
350            // recalculate psi
351            // re-run ipm to re-calculate w
352            // re-calculate pyl
353            // re-define a new optimization
354        });
355        for cp in candididate_points {
356            self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D);
357        }
358        Ok(())
359    }
360}
361
362impl<E: Equation> NPOD<E> {
363    fn validate_psi(&mut self) -> Result<()> {
364        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
365        // First coerce all NaN and infinite in psi to 0.0
366        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
367            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
368            for i in 0..psi.nrows() {
369                for j in 0..psi.ncols() {
370                    let val = psi.get_mut((i, j)).unwrap();
371                    if val.is_nan() || val.is_infinite() {
372                        *val = 0.0;
373                    }
374                }
375            }
376        }
377
378        // Calculate the sum of each column in psi
379        let (_, col) = psi.dim();
380        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
381        let plam = psi.dot(&ecol);
382        let w = 1. / &plam;
383
384        // Get the index of each element in `w` that is NaN or infinite
385        let indices: Vec<usize> = w
386            .iter()
387            .enumerate()
388            .filter(|(_, x)| x.is_nan() || x.is_infinite())
389            .map(|(i, _)| i)
390            .collect::<Vec<_>>();
391
392        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
393        if !indices.is_empty() {
394            let subject: Vec<&Subject> = self.data.get_subjects();
395            let zero_probability_subjects: Vec<&String> =
396                indices.iter().map(|&i| subject[i].id()).collect();
397
398            return Err(anyhow::anyhow!(
399                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
400            ));
401        }
402
403        Ok(())
404    }
405}