pmcore/algorithms/
npag.rs

1use crate::prelude::algorithms::Algorithms;
2
3pub use crate::routines::evaluation::ipm::burke;
4pub use crate::routines::evaluation::qr;
5use crate::routines::settings::Settings;
6
7use crate::routines::output::{CycleLog, NPCycle, NPResult};
8use crate::structs::psi::{calculate_psi, Psi};
9use crate::structs::theta::Theta;
10
11use anyhow::bail;
12use anyhow::Result;
13use pharmsol::prelude::{
14    data::{Data, ErrorModel, ErrorType},
15    simulator::Equation,
16};
17
18use faer::Col;
19
20use crate::routines::initialization;
21
22use crate::routines::expansion::adaptative_grid::adaptative_grid;
23
24const THETA_E: f64 = 1e-4; // Convergence criteria
25const THETA_G: f64 = 1e-4; // Objective function convergence criteria
26const THETA_F: f64 = 1e-2;
27const THETA_D: f64 = 1e-4;
28
29#[derive(Debug)]
30pub struct NPAG<E: Equation> {
31    equation: E,
32    ranges: Vec<(f64, f64)>,
33    psi: Psi,
34    theta: Theta,
35    lambda: Col<f64>,
36    w: Col<f64>,
37    eps: f64,
38    last_objf: f64,
39    objf: f64,
40    f0: f64,
41    f1: f64,
42    cycle: usize,
43    gamma_delta: f64,
44    gamma: f64,
45    error_type: ErrorType,
46    converged: bool,
47    cycle_log: CycleLog,
48    data: Data,
49    settings: Settings,
50}
51
52impl<E: Equation> Algorithms<E> for NPAG<E> {
53    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
54        Ok(Box::new(Self {
55            equation,
56            ranges: settings.parameters().ranges(),
57            psi: Psi::new(),
58            theta: Theta::new(),
59            lambda: Col::zeros(0),
60            w: Col::zeros(0),
61            eps: 0.2,
62            last_objf: -1e30,
63            objf: f64::NEG_INFINITY,
64            f0: -1e30,
65            f1: f64::default(),
66            cycle: 0,
67            gamma_delta: 0.1,
68            gamma: settings.error().value,
69            error_type: settings.error().error_model().into(),
70            converged: false,
71            cycle_log: CycleLog::new(),
72            settings,
73            data,
74        }))
75    }
76
77    fn equation(&self) -> &E {
78        &self.equation
79    }
80    fn into_npresult(&self) -> NPResult<E> {
81        NPResult::new(
82            self.equation.clone(),
83            self.data.clone(),
84            self.theta.clone(),
85            self.psi.clone(),
86            self.w.clone(),
87            -2. * self.objf,
88            self.cycle,
89            self.converged,
90            self.settings.clone(),
91            self.cycle_log.clone(),
92        )
93    }
94
95    fn get_settings(&self) -> &Settings {
96        &self.settings
97    }
98
99    fn get_data(&self) -> &Data {
100        &self.data
101    }
102
103    fn get_prior(&self) -> Theta {
104        initialization::sample_space(&self.settings).unwrap()
105    }
106
107    fn likelihood(&self) -> f64 {
108        self.objf
109    }
110
111    fn inc_cycle(&mut self) -> usize {
112        self.cycle += 1;
113        self.cycle
114    }
115
116    fn get_cycle(&self) -> usize {
117        self.cycle
118    }
119
120    fn set_theta(&mut self, theta: Theta) {
121        self.theta = theta;
122    }
123
124    fn get_theta(&self) -> &Theta {
125        &self.theta
126    }
127
128    fn psi(&self) -> &Psi {
129        &self.psi
130    }
131
132    fn convergence_evaluation(&mut self) {
133        let psi = self.psi.matrix();
134        let w = &self.w;
135        if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
136            self.eps /= 2.;
137            if self.eps <= THETA_E {
138                let pyl = psi * w;
139                self.f1 = pyl.iter().map(|x| x.ln()).sum();
140                if (self.f1 - self.f0).abs() <= THETA_F {
141                    tracing::info!("The model converged after {} cycles", self.cycle,);
142                    self.converged = true;
143                } else {
144                    self.f0 = self.f1;
145                    self.eps = 0.2;
146                }
147            }
148        }
149
150        // Stop if we have reached maximum number of cycles
151        if self.cycle >= self.settings.config().cycles {
152            tracing::warn!("Maximum number of cycles reached");
153            self.converged = true;
154        }
155
156        // Stop if stopfile exists
157        if std::path::Path::new("stop").exists() {
158            tracing::warn!("Stopfile detected - breaking");
159            self.converged = true;
160        }
161
162        // Create state object
163        let state = NPCycle {
164            cycle: self.cycle,
165            objf: -2. * self.objf,
166            delta_objf: (self.last_objf - self.objf).abs(),
167            nspp: self.theta.nspp(),
168            theta: self.theta.clone(),
169            gamlam: self.gamma,
170            converged: self.converged,
171        };
172
173        // Write cycle log
174        self.cycle_log.push(state);
175        self.last_objf = self.objf;
176    }
177
178    fn converged(&self) -> bool {
179        self.converged
180    }
181
182    fn evaluation(&mut self) -> Result<()> {
183        self.psi = calculate_psi(
184            &self.equation,
185            &self.data,
186            &self.theta,
187            &ErrorModel::new(self.settings.error().poly, self.gamma, &self.error_type),
188            self.cycle == 1 && self.settings.config().progress,
189            self.cycle != 1,
190        );
191
192        if let Err(err) = self.validate_psi() {
193            bail!(err);
194        }
195
196        (self.lambda, _) = match burke(&self.psi) {
197            Ok((lambda, objf)) => (lambda, objf),
198            Err(err) => {
199                bail!("Error in IPM during evaluation: {:?}", err);
200            }
201        };
202        Ok(())
203    }
204
205    fn condensation(&mut self) -> Result<()> {
206        // Filter out the support points with lambda < max(lambda)/1000
207
208        let max_lambda = self
209            .lambda
210            .iter()
211            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
212
213        let mut keep = Vec::<usize>::new();
214        for (index, lam) in self.lambda.iter().enumerate() {
215            if *lam > max_lambda / 1000_f64 {
216                keep.push(index);
217            }
218        }
219        if self.psi.matrix().ncols() != keep.len() {
220            tracing::debug!(
221                "Lambda (max/1000) dropped {} support point(s)",
222                self.psi.matrix().ncols() - keep.len(),
223            );
224        }
225
226        self.theta.filter_indices(keep.as_slice());
227        self.psi.filter_column_indices(keep.as_slice());
228
229        //Rank-Revealing Factorization
230        let (r, perm) = qr::qrd(&self.psi)?;
231
232        let mut keep = Vec::<usize>::new();
233
234        // The minimum between the number of subjects and the actual number of support points
235        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
236        for i in 0..keep_n {
237            let test = r.col(i).norm_l2();
238            let r_diag_val = r.get(i, i);
239            let ratio = r_diag_val / test;
240            if ratio.abs() >= 1e-8 {
241                keep.push(*perm.get(i).unwrap());
242            }
243        }
244
245        // If a support point is dropped, log it as a debug message
246        if self.psi.matrix().ncols() != keep.len() {
247            tracing::debug!(
248                "QR decomposition dropped {} support point(s)",
249                self.psi.matrix().ncols() - keep.len(),
250            );
251        }
252
253        // Filter to keep only the support points (rows) that are in the `keep` vector
254        self.theta.filter_indices(keep.as_slice());
255        // Filter to keep only the support points (columns) that are in the `keep` vector
256        self.psi.filter_column_indices(keep.as_slice());
257
258        self.validate_psi()?;
259        (self.lambda, self.objf) = match burke(&self.psi) {
260            Ok((lambda, objf)) => (lambda, objf),
261            Err(err) => {
262                return Err(anyhow::anyhow!(
263                    "Error in IPM during condensation: {:?}",
264                    err
265                ));
266            }
267        };
268        self.w = self.lambda.clone();
269        Ok(())
270    }
271
272    fn optimizations(&mut self) -> Result<()> {
273        // Gam/Lam optimization
274        // TODO: Move this to e.g. /evaluation/error.rs
275        let gamma_up = self.gamma * (1.0 + self.gamma_delta);
276        let gamma_down = self.gamma / (1.0 + self.gamma_delta);
277
278        let psi_up = calculate_psi(
279            &self.equation,
280            &self.data,
281            &self.theta,
282            &ErrorModel::new(self.settings.error().poly, gamma_up, &self.error_type),
283            false,
284            true,
285        );
286        let psi_down = calculate_psi(
287            &self.equation,
288            &self.data,
289            &self.theta,
290            &ErrorModel::new(self.settings.error().poly, gamma_down, &self.error_type),
291            false,
292            true,
293        );
294
295        let (lambda_up, objf_up) = match burke(&psi_up) {
296            Ok((lambda, objf)) => (lambda, objf),
297            Err(err) => {
298                //todo: write out report
299                return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
300            }
301        };
302        let (lambda_down, objf_down) = match burke(&psi_down) {
303            Ok((lambda, objf)) => (lambda, objf),
304            Err(err) => {
305                //todo: write out report
306                //panic!("Error in IPM: {:?}", err);
307                return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
308                //(Array1::zeros(1), f64::NEG_INFINITY)
309            }
310        };
311        if objf_up > self.objf {
312            self.gamma = gamma_up;
313            self.objf = objf_up;
314            self.gamma_delta *= 4.;
315            self.lambda = lambda_up;
316            self.psi = psi_up;
317        }
318        if objf_down > self.objf {
319            self.gamma = gamma_down;
320            self.objf = objf_down;
321            self.gamma_delta *= 4.;
322            self.lambda = lambda_down;
323            self.psi = psi_down;
324        }
325        self.gamma_delta *= 0.5;
326        if self.gamma_delta <= 0.01 {
327            self.gamma_delta = 0.1;
328        }
329        Ok(())
330    }
331
332    fn logs(&self) {
333        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
334        tracing::debug!("Support points: {}", self.theta.nspp());
335        tracing::debug!("Gamma = {:.16}", self.gamma);
336        tracing::debug!("EPS = {:.4}", self.eps);
337        // Increasing objf signals instability or model misspecification.
338        if self.last_objf > self.objf + 1e-4 {
339            tracing::warn!(
340                "Objective function decreased from {:.4} to {:.4} (delta = {})",
341                -2.0 * self.last_objf,
342                -2.0 * self.objf,
343                -2.0 * self.last_objf - -2.0 * self.objf
344            );
345        }
346    }
347
348    fn expansion(&mut self) -> Result<()> {
349        adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D);
350        Ok(())
351    }
352}