pmcore/algorithms/
npod.rs

1use crate::{
2    prelude::{
3        algorithms::Algorithms,
4        routines::{
5            evaluation::{ipm::burke, qr},
6            output::{CycleLog, NPCycle, NPResult},
7            settings::Settings,
8        },
9    },
10    structs::{
11        psi::{calculate_psi, Psi},
12        theta::Theta,
13    },
14};
15use anyhow::bail;
16use anyhow::Result;
17use faer_ext::IntoNdarray;
18use pharmsol::{
19    prelude::{
20        data::{Data, ErrorModel},
21        simulator::Equation,
22    },
23    Subject,
24};
25
26use faer::Col;
27
28use ndarray::{
29    parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
30    Array, Array1, ArrayBase, Dim, OwnedRepr,
31};
32
33use crate::routines::{initialization, optimization::SppOptimizer};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation> {
39    equation: E,
40    ranges: Vec<(f64, f64)>,
41    psi: Psi,
42    theta: Theta,
43    lambda: Col<f64>,
44    w: Col<f64>,
45    last_objf: f64,
46    objf: f64,
47    cycle: usize,
48    gamma_delta: f64,
49    gamma: f64,
50    converged: bool,
51    cycle_log: CycleLog,
52    data: Data,
53    settings: Settings,
54}
55
56impl<E: Equation> Algorithms<E> for NPOD<E> {
57    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
58        Ok(Box::new(Self {
59            equation,
60            ranges: settings.parameters().ranges(),
61            psi: Psi::new(),
62            theta: Theta::new(),
63            lambda: Col::zeros(0),
64            w: Col::zeros(0),
65            last_objf: -1e30,
66            objf: f64::NEG_INFINITY,
67            cycle: 0,
68            gamma_delta: 0.1,
69            gamma: settings.error().value,
70            converged: false,
71            cycle_log: CycleLog::new(),
72            settings,
73            data,
74        }))
75    }
76    fn into_npresult(&self) -> NPResult<E> {
77        NPResult::new(
78            self.equation.clone(),
79            self.data.clone(),
80            self.theta.clone(),
81            self.psi.clone(),
82            self.w.clone(),
83            -2. * self.objf,
84            self.cycle,
85            self.converged,
86            self.settings.clone(),
87            self.cycle_log.clone(),
88        )
89    }
90
91    fn equation(&self) -> &E {
92        &self.equation
93    }
94
95    fn get_settings(&self) -> &Settings {
96        &self.settings
97    }
98
99    fn get_data(&self) -> &Data {
100        &self.data
101    }
102
103    fn get_prior(&self) -> Theta {
104        initialization::sample_space(&self.settings).unwrap()
105    }
106
107    fn inc_cycle(&mut self) -> usize {
108        self.cycle += 1;
109        self.cycle
110    }
111
112    fn get_cycle(&self) -> usize {
113        self.cycle
114    }
115
116    fn set_theta(&mut self, theta: Theta) {
117        self.theta = theta;
118    }
119
120    fn get_theta(&self) -> &Theta {
121        &self.theta
122    }
123
124    fn psi(&self) -> &Psi {
125        &self.psi
126    }
127
128    fn likelihood(&self) -> f64 {
129        self.objf
130    }
131
132    fn convergence_evaluation(&mut self) {
133        if (self.last_objf - self.objf).abs() <= THETA_F {
134            tracing::info!("Objective function convergence reached");
135            self.converged = true;
136        }
137
138        // Stop if we have reached maximum number of cycles
139        if self.cycle >= self.settings.config().cycles {
140            tracing::warn!("Maximum number of cycles reached");
141            self.converged = true;
142        }
143
144        // Stop if stopfile exists
145        if std::path::Path::new("stop").exists() {
146            tracing::warn!("Stopfile detected - breaking");
147            self.converged = true;
148        }
149
150        // Create state object
151        let state = NPCycle {
152            cycle: self.cycle,
153            objf: -2. * self.objf,
154            delta_objf: (self.last_objf - self.objf).abs(),
155            nspp: self.theta.nspp(),
156            theta: self.theta.clone(),
157            gamlam: self.gamma,
158            converged: self.converged,
159        };
160
161        // Write cycle log
162        self.cycle_log.push(state);
163        self.last_objf = self.objf;
164    }
165
166    fn converged(&self) -> bool {
167        self.converged
168    }
169
170    fn evaluation(&mut self) -> Result<()> {
171        self.psi = calculate_psi(
172            &self.equation,
173            &self.data,
174            &self.theta,
175            &ErrorModel::new(
176                self.settings.error().poly,
177                self.gamma,
178                &self.settings.error().error_model().into(),
179            ),
180            self.cycle == 1 && self.settings.config().progress,
181            self.cycle != 1,
182        );
183
184        if let Err(err) = self.validate_psi() {
185            bail!(err);
186        }
187
188        (self.lambda, _) = match burke(&self.psi) {
189            Ok((lambda, objf)) => (lambda, objf),
190            Err(err) => {
191                bail!(err);
192            }
193        };
194        Ok(())
195    }
196
197    fn condensation(&mut self) -> Result<()> {
198        let max_lambda = self
199            .lambda
200            .iter()
201            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
202
203        let mut keep = Vec::<usize>::new();
204        for (index, lam) in self.lambda.iter().enumerate() {
205            if *lam > max_lambda / 1000_f64 {
206                keep.push(index);
207            }
208        }
209        if self.psi.matrix().ncols() != keep.len() {
210            tracing::debug!(
211                "Lambda (max/1000) dropped {} support point(s)",
212                self.psi.matrix().ncols() - keep.len(),
213            );
214        }
215
216        self.theta.filter_indices(keep.as_slice());
217        self.psi.filter_column_indices(keep.as_slice());
218
219        //Rank-Revealing Factorization
220        let (r, perm) = qr::qrd(&self.psi)?;
221
222        let mut keep = Vec::<usize>::new();
223
224        // The minimum between the number of subjects and the actual number of support points
225        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
226        for i in 0..keep_n {
227            let test = r.col(i).norm_l2();
228            let r_diag_val = r.get(i, i);
229            let ratio = r_diag_val / test;
230            if ratio.abs() >= 1e-8 {
231                keep.push(*perm.get(i).unwrap());
232            }
233        }
234
235        // If a support point is dropped, log it as a debug message
236        if self.psi.matrix().ncols() != keep.len() {
237            tracing::debug!(
238                "QR decomposition dropped {} support point(s)",
239                self.psi.matrix().ncols() - keep.len(),
240            );
241        }
242
243        self.theta.filter_indices(keep.as_slice());
244        self.psi.filter_column_indices(keep.as_slice());
245
246        (self.lambda, self.objf) = match burke(&self.psi) {
247            Ok((lambda, objf)) => (lambda, objf),
248            Err(err) => {
249                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
250            }
251        };
252        self.w = self.lambda.clone();
253        Ok(())
254    }
255
256    fn optimizations(&mut self) -> Result<()> {
257        // Gam/Lam optimization
258        // TODO: Move this to e.g. /evaluation/error.rs
259        let gamma_up = self.gamma * (1.0 + self.gamma_delta);
260        let gamma_down = self.gamma / (1.0 + self.gamma_delta);
261
262        let psi_up = calculate_psi(
263            &self.equation,
264            &self.data,
265            &self.theta,
266            &ErrorModel::new(
267                self.settings.error().poly,
268                self.gamma,
269                &self.settings.error().error_model().into(),
270            ),
271            false,
272            true,
273        );
274        let psi_down = calculate_psi(
275            &self.equation,
276            &self.data,
277            &self.theta,
278            &ErrorModel::new(
279                self.settings.error().poly,
280                self.gamma,
281                &self.settings.error().error_model().into(),
282            ),
283            false,
284            true,
285        );
286
287        let (lambda_up, objf_up) = match burke(&psi_up) {
288            Ok((lambda, objf)) => (lambda, objf),
289            Err(err) => {
290                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
291            }
292        };
293        let (lambda_down, objf_down) = match burke(&psi_down) {
294            Ok((lambda, objf)) => (lambda, objf),
295            Err(err) => {
296                return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
297            }
298        };
299
300        if objf_up > self.objf {
301            self.gamma = gamma_up;
302            self.objf = objf_up;
303            self.gamma_delta *= 4.;
304            self.lambda = lambda_up;
305            self.psi = psi_up;
306        }
307        if objf_down > self.objf {
308            self.gamma = gamma_down;
309            self.objf = objf_down;
310            self.gamma_delta *= 4.;
311            self.lambda = lambda_down;
312            self.psi = psi_down;
313        }
314        self.gamma_delta *= 0.5;
315        if self.gamma_delta <= 0.01 {
316            self.gamma_delta = 0.1;
317        }
318        Ok(())
319    }
320
321    fn logs(&self) {
322        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
323        tracing::debug!("Support points: {}", self.theta.nspp());
324        tracing::debug!("Gamma = {:.16}", self.gamma);
325        // Increasing objf signals instability or model misspecification.
326        if self.last_objf > self.objf + 1e-4 {
327            tracing::warn!(
328                "Objective function decreased from {:.4} to {:.4} (delta = {})",
329                -2.0 * self.last_objf,
330                -2.0 * self.objf,
331                -2.0 * self.last_objf - -2.0 * self.objf
332            );
333        }
334    }
335
336    fn expansion(&mut self) -> Result<()> {
337        // If no stop signal, add new point to theta based on the optimization of the D function
338        let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
339        let w: Array1<f64> = self.w.clone().iter().cloned().collect();
340        let pyl = psi.dot(&w);
341
342        // Add new point to theta based on the optimization of the D function
343        let error_type = self.settings.error().error_model().into();
344        let sigma = &ErrorModel::new(self.settings.error().poly, self.gamma, &error_type);
345
346        let mut candididate_points: Vec<Array1<f64>> = Vec::default();
347        for spp in self.theta.matrix().row_iter() {
348            let candidate: Vec<f64> = spp.iter().cloned().collect();
349            let spp = Array1::from(candidate);
350            candididate_points.push(spp.to_owned());
351        }
352        candididate_points.par_iter_mut().for_each(|spp| {
353            let optimizer = SppOptimizer::new(&self.equation, &self.data, sigma, &pyl);
354            let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
355            *spp = candidate_point;
356            // add spp to theta
357            // recalculate psi
358            // re-run ipm to re-calculate w
359            // re-calculate pyl
360            // re-define a new optimization
361        });
362        for cp in candididate_points {
363            self.theta
364                .suggest_point(cp.to_vec().as_slice(), THETA_D, &self.ranges);
365        }
366        Ok(())
367    }
368}
369
370impl<E: Equation> NPOD<E> {
371    fn validate_psi(&mut self) -> Result<()> {
372        let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
373        // First coerce all NaN and infinite in psi to 0.0
374        if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
375            tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
376            for i in 0..psi.nrows() {
377                for j in 0..psi.ncols() {
378                    let val = psi.get_mut((i, j)).unwrap();
379                    if val.is_nan() || val.is_infinite() {
380                        *val = 0.0;
381                    }
382                }
383            }
384        }
385
386        // Calculate the sum of each column in psi
387        let (_, col) = psi.dim();
388        let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
389        let plam = psi.dot(&ecol);
390        let w = 1. / &plam;
391
392        // Get the index of each element in `w` that is NaN or infinite
393        let indices: Vec<usize> = w
394            .iter()
395            .enumerate()
396            .filter(|(_, x)| x.is_nan() || x.is_infinite())
397            .map(|(i, _)| i)
398            .collect::<Vec<_>>();
399
400        // If any elements in `w` are NaN or infinite, return the subject IDs for each index
401        if !indices.is_empty() {
402            let subject: Vec<&Subject> = self.data.get_subjects();
403            let zero_probability_subjects: Vec<&String> =
404                indices.iter().map(|&i| subject[i].id()).collect();
405
406            return Err(anyhow::anyhow!(
407                "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
408            ));
409        }
410
411        Ok(())
412    }
413}