pmcore/algorithms/
npag.rs

1use crate::prelude::algorithms::Algorithms;
2
3pub use crate::routines::evaluation::ipm::burke;
4pub use crate::routines::evaluation::qr;
5use crate::routines::settings::Settings;
6
7use crate::routines::output::{CycleLog, NPCycle, NPResult};
8use crate::structs::psi::{calculate_psi, Psi};
9use crate::structs::theta::Theta;
10
11use anyhow::bail;
12use anyhow::Result;
13use pharmsol::prelude::{
14    data::{Data, ErrorModel},
15    simulator::Equation,
16};
17
18use faer::Col;
19
20use crate::routines::initialization;
21
22use crate::routines::expansion::adaptative_grid::adaptative_grid;
23
24const THETA_E: f64 = 1e-4; // Convergence criteria
25const THETA_G: f64 = 1e-4; // Objective function convergence criteria
26const THETA_F: f64 = 1e-2;
27const THETA_D: f64 = 1e-4;
28
29#[derive(Debug)]
30pub struct NPAG<E: Equation> {
31    equation: E,
32    ranges: Vec<(f64, f64)>,
33    psi: Psi,
34    theta: Theta,
35    lambda: Col<f64>,
36    w: Col<f64>,
37    eps: f64,
38    last_objf: f64,
39    objf: f64,
40    f0: f64,
41    f1: f64,
42    cycle: usize,
43    gamma_delta: f64,
44    error_model: ErrorModel,
45    converged: bool,
46    cycle_log: CycleLog,
47    data: Data,
48    settings: Settings,
49}
50
51impl<E: Equation> Algorithms<E> for NPAG<E> {
52    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
53        Ok(Box::new(Self {
54            equation,
55            ranges: settings.parameters().ranges(),
56            psi: Psi::new(),
57            theta: Theta::new(),
58            lambda: Col::zeros(0),
59            w: Col::zeros(0),
60            eps: 0.2,
61            last_objf: -1e30,
62            objf: f64::NEG_INFINITY,
63            f0: -1e30,
64            f1: f64::default(),
65            cycle: 0,
66            gamma_delta: 0.1,
67            error_model: settings.error().clone().into(),
68            converged: false,
69            cycle_log: CycleLog::new(),
70            settings,
71            data,
72        }))
73    }
74
75    fn equation(&self) -> &E {
76        &self.equation
77    }
78    fn into_npresult(&self) -> NPResult<E> {
79        NPResult::new(
80            self.equation.clone(),
81            self.data.clone(),
82            self.theta.clone(),
83            self.psi.clone(),
84            self.w.clone(),
85            -2. * self.objf,
86            self.cycle,
87            self.converged,
88            self.settings.clone(),
89            self.cycle_log.clone(),
90        )
91    }
92
93    fn get_settings(&self) -> &Settings {
94        &self.settings
95    }
96
97    fn get_data(&self) -> &Data {
98        &self.data
99    }
100
101    fn get_prior(&self) -> Theta {
102        initialization::sample_space(&self.settings).unwrap()
103    }
104
105    fn likelihood(&self) -> f64 {
106        self.objf
107    }
108
109    fn inc_cycle(&mut self) -> usize {
110        self.cycle += 1;
111        self.cycle
112    }
113
114    fn get_cycle(&self) -> usize {
115        self.cycle
116    }
117
118    fn set_theta(&mut self, theta: Theta) {
119        self.theta = theta;
120    }
121
122    fn get_theta(&self) -> &Theta {
123        &self.theta
124    }
125
126    fn psi(&self) -> &Psi {
127        &self.psi
128    }
129
130    fn convergence_evaluation(&mut self) {
131        let psi = self.psi.matrix();
132        let w = &self.w;
133        if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
134            self.eps /= 2.;
135            if self.eps <= THETA_E {
136                let pyl = psi * w;
137                self.f1 = pyl.iter().map(|x| x.ln()).sum();
138                if (self.f1 - self.f0).abs() <= THETA_F {
139                    tracing::info!("The model converged after {} cycles", self.cycle,);
140                    self.converged = true;
141                } else {
142                    self.f0 = self.f1;
143                    self.eps = 0.2;
144                }
145            }
146        }
147
148        // Stop if we have reached maximum number of cycles
149        if self.cycle >= self.settings.config().cycles {
150            tracing::warn!("Maximum number of cycles reached");
151            self.converged = true;
152        }
153
154        // Stop if stopfile exists
155        if std::path::Path::new("stop").exists() {
156            tracing::warn!("Stopfile detected - breaking");
157            self.converged = true;
158        }
159
160        // Create state object
161        let state = NPCycle {
162            cycle: self.cycle,
163            objf: -2. * self.objf,
164            delta_objf: (self.last_objf - self.objf).abs(),
165            nspp: self.theta.nspp(),
166            theta: self.theta.clone(),
167            gamlam: self.error_model.scalar(),
168            converged: self.converged,
169        };
170
171        // Write cycle log
172        self.cycle_log.push(state);
173        self.last_objf = self.objf;
174    }
175
176    fn converged(&self) -> bool {
177        self.converged
178    }
179
180    fn evaluation(&mut self) -> Result<()> {
181        self.psi = calculate_psi(
182            &self.equation,
183            &self.data,
184            &self.theta,
185            &self.error_model,
186            self.cycle == 1 && self.settings.config().progress,
187            self.cycle != 1,
188        )?;
189
190        if let Err(err) = self.validate_psi() {
191            bail!(err);
192        }
193
194        (self.lambda, _) = match burke(&self.psi) {
195            Ok((lambda, objf)) => (lambda, objf),
196            Err(err) => {
197                bail!("Error in IPM during evaluation: {:?}", err);
198            }
199        };
200        Ok(())
201    }
202
203    fn condensation(&mut self) -> Result<()> {
204        // Filter out the support points with lambda < max(lambda)/1000
205
206        let max_lambda = self
207            .lambda
208            .iter()
209            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
210
211        let mut keep = Vec::<usize>::new();
212        for (index, lam) in self.lambda.iter().enumerate() {
213            if *lam > max_lambda / 1000_f64 {
214                keep.push(index);
215            }
216        }
217        if self.psi.matrix().ncols() != keep.len() {
218            tracing::debug!(
219                "Lambda (max/1000) dropped {} support point(s)",
220                self.psi.matrix().ncols() - keep.len(),
221            );
222        }
223
224        self.theta.filter_indices(keep.as_slice());
225        self.psi.filter_column_indices(keep.as_slice());
226
227        //Rank-Revealing Factorization
228        let (r, perm) = qr::qrd(&self.psi)?;
229
230        let mut keep = Vec::<usize>::new();
231
232        // The minimum between the number of subjects and the actual number of support points
233        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
234        for i in 0..keep_n {
235            let test = r.col(i).norm_l2();
236            let r_diag_val = r.get(i, i);
237            let ratio = r_diag_val / test;
238            if ratio.abs() >= 1e-8 {
239                keep.push(*perm.get(i).unwrap());
240            }
241        }
242
243        // If a support point is dropped, log it as a debug message
244        if self.psi.matrix().ncols() != keep.len() {
245            tracing::debug!(
246                "QR decomposition dropped {} support point(s)",
247                self.psi.matrix().ncols() - keep.len(),
248            );
249        }
250
251        // Filter to keep only the support points (rows) that are in the `keep` vector
252        self.theta.filter_indices(keep.as_slice());
253        // Filter to keep only the support points (columns) that are in the `keep` vector
254        self.psi.filter_column_indices(keep.as_slice());
255
256        self.validate_psi()?;
257        (self.lambda, self.objf) = match burke(&self.psi) {
258            Ok((lambda, objf)) => (lambda, objf),
259            Err(err) => {
260                return Err(anyhow::anyhow!(
261                    "Error in IPM during condensation: {:?}",
262                    err
263                ));
264            }
265        };
266        self.w = self.lambda.clone();
267        Ok(())
268    }
269
270    fn optimizations(&mut self) -> Result<()> {
271        // Gam/Lam optimization
272        // TODO: Move this to e.g. /evaluation/error.rs
273        let gamma_up = self.error_model.scalar() * (1.0 + self.gamma_delta);
274        let gamma_down = self.error_model.scalar() / (1.0 + self.gamma_delta);
275
276        let mut error_model_up = self.error_model.clone();
277        error_model_up.set_scalar(gamma_up);
278
279        let mut error_model_down = self.error_model.clone();
280        error_model_down.set_scalar(gamma_down);
281
282        let psi_up = calculate_psi(
283            &self.equation,
284            &self.data,
285            &self.theta,
286            &error_model_up,
287            false,
288            true,
289        )?;
290        let psi_down = calculate_psi(
291            &self.equation,
292            &self.data,
293            &self.theta,
294            &error_model_down,
295            false,
296            true,
297        )?;
298
299        let (lambda_up, objf_up) = match burke(&psi_up) {
300            Ok((lambda, objf)) => (lambda, objf),
301            Err(err) => {
302                //todo: write out report
303                return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
304            }
305        };
306        let (lambda_down, objf_down) = match burke(&psi_down) {
307            Ok((lambda, objf)) => (lambda, objf),
308            Err(err) => {
309                //todo: write out report
310                //panic!("Error in IPM: {:?}", err);
311                return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
312                //(Array1::zeros(1), f64::NEG_INFINITY)
313            }
314        };
315        if objf_up > self.objf {
316            self.error_model.set_scalar(gamma_up);
317            self.objf = objf_up;
318            self.gamma_delta *= 4.;
319            self.lambda = lambda_up;
320            self.psi = psi_up;
321        }
322        if objf_down > self.objf {
323            self.error_model.set_scalar(gamma_down);
324            self.objf = objf_down;
325            self.gamma_delta *= 4.;
326            self.lambda = lambda_down;
327            self.psi = psi_down;
328        }
329        self.gamma_delta *= 0.5;
330        if self.gamma_delta <= 0.01 {
331            self.gamma_delta = 0.1;
332        }
333        Ok(())
334    }
335
336    fn logs(&self) {
337        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
338        tracing::debug!("Support points: {}", self.theta.nspp());
339        tracing::debug!("Gamma = {:.16}", self.error_model.scalar());
340        tracing::debug!("EPS = {:.4}", self.eps);
341        // Increasing objf signals instability or model misspecification.
342        if self.last_objf > self.objf + 1e-4 {
343            tracing::warn!(
344                "Objective function decreased from {:.4} to {:.4} (delta = {})",
345                -2.0 * self.last_objf,
346                -2.0 * self.objf,
347                -2.0 * self.last_objf - -2.0 * self.objf
348            );
349        }
350    }
351
352    fn expansion(&mut self) -> Result<()> {
353        adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D);
354        Ok(())
355    }
356}