pmcore/algorithms/
npag.rs

1use crate::algorithms::Status;
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::evaluation::ipm::burke;
5pub use crate::routines::evaluation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11use crate::structs::weights::Weights;
12
13use anyhow::bail;
14use anyhow::Result;
15use pharmsol::prelude::{
16    data::{Data, ErrorModels},
17    simulator::Equation,
18};
19
20use pharmsol::prelude::ErrorModel;
21
22use crate::routines::initialization;
23
24use crate::routines::expansion::adaptative_grid::adaptative_grid;
25
26const THETA_E: f64 = 1e-4; // Convergence criteria
27const THETA_G: f64 = 1e-4; // Objective function convergence criteria
28const THETA_F: f64 = 1e-2;
29const THETA_D: f64 = 1e-4;
30
31#[derive(Debug)]
32pub struct NPAG<E: Equation> {
33    equation: E,
34    ranges: Vec<(f64, f64)>,
35    psi: Psi,
36    theta: Theta,
37    lambda: Weights,
38    w: Weights,
39    eps: f64,
40    last_objf: f64,
41    objf: f64,
42    f0: f64,
43    f1: f64,
44    cycle: usize,
45    gamma_delta: Vec<f64>,
46    error_models: ErrorModels,
47    converged: bool,
48    status: Status,
49    cycle_log: CycleLog,
50    data: Data,
51    settings: Settings,
52}
53
54impl<E: Equation> Algorithms<E> for NPAG<E> {
55    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
56        Ok(Box::new(Self {
57            equation,
58            ranges: settings.parameters().ranges(),
59            psi: Psi::new(),
60            theta: Theta::new(),
61            lambda: Weights::default(),
62            w: Weights::default(),
63            eps: 0.2,
64            last_objf: -1e30,
65            objf: f64::NEG_INFINITY,
66            f0: -1e30,
67            f1: f64::default(),
68            cycle: 0,
69            gamma_delta: vec![0.1; settings.errormodels().len()],
70            error_models: settings.errormodels().clone(),
71            converged: false,
72            status: Status::Starting,
73            cycle_log: CycleLog::new(),
74            settings,
75            data,
76        }))
77    }
78
79    fn equation(&self) -> &E {
80        &self.equation
81    }
82    fn into_npresult(&self) -> NPResult<E> {
83        NPResult::new(
84            self.equation.clone(),
85            self.data.clone(),
86            self.theta.clone(),
87            self.psi.clone(),
88            self.w.clone(),
89            -2. * self.objf,
90            self.cycle,
91            self.status.clone(),
92            self.settings.clone(),
93            self.cycle_log.clone(),
94        )
95    }
96
97    fn get_settings(&self) -> &Settings {
98        &self.settings
99    }
100
101    fn get_data(&self) -> &Data {
102        &self.data
103    }
104
105    fn get_prior(&self) -> Theta {
106        initialization::sample_space(&self.settings).unwrap()
107    }
108
109    fn likelihood(&self) -> f64 {
110        self.objf
111    }
112
113    fn inc_cycle(&mut self) -> usize {
114        self.cycle += 1;
115        self.cycle
116    }
117
118    fn get_cycle(&self) -> usize {
119        self.cycle
120    }
121
122    fn set_theta(&mut self, theta: Theta) {
123        self.theta = theta;
124    }
125
126    fn theta(&self) -> &Theta {
127        &self.theta
128    }
129
130    fn psi(&self) -> &Psi {
131        &self.psi
132    }
133
134    fn convergence_evaluation(&mut self) {
135        let psi = self.psi.matrix();
136        let w = &self.w;
137        if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
138            self.eps /= 2.;
139            if self.eps <= THETA_E {
140                let pyl = psi * w.weights();
141                self.f1 = pyl.iter().map(|x| x.ln()).sum();
142                if (self.f1 - self.f0).abs() <= THETA_F {
143                    tracing::info!("The model converged after {} cycles", self.cycle,);
144                    self.converged = true;
145                    self.status = Status::Converged;
146                } else {
147                    self.f0 = self.f1;
148                    self.eps = 0.2;
149                }
150            }
151        }
152
153        // Stop if we have reached maximum number of cycles
154        if self.cycle >= self.settings.config().cycles {
155            tracing::warn!("Maximum number of cycles reached");
156            self.converged = true;
157            self.status = Status::MaxCycles;
158        }
159
160        // Stop if stopfile exists
161        if std::path::Path::new("stop").exists() {
162            tracing::warn!("Stopfile detected - breaking");
163            self.status = Status::ManualStop;
164        }
165
166        // Create state object
167        let state = NPCycle::new(
168            self.cycle,
169            -2. * self.objf,
170            self.error_models.clone(),
171            self.theta.clone(),
172            self.theta.nspp(),
173            (self.last_objf - self.objf).abs(),
174            self.status.clone(),
175        );
176
177        // Write cycle log
178        self.cycle_log.push(state);
179        self.last_objf = self.objf;
180    }
181
182    fn converged(&self) -> bool {
183        self.converged
184    }
185
186    fn evaluation(&mut self) -> Result<()> {
187        self.psi = calculate_psi(
188            &self.equation,
189            &self.data,
190            &self.theta,
191            &self.error_models,
192            self.cycle == 1 && self.settings.config().progress,
193            self.cycle != 1,
194        )?;
195
196        if let Err(err) = self.validate_psi() {
197            bail!(err);
198        }
199
200        (self.lambda, _) = match burke(&self.psi) {
201            Ok((lambda, objf)) => (lambda.into(), objf),
202            Err(err) => {
203                bail!("Error in IPM during evaluation: {:?}", err);
204            }
205        };
206        Ok(())
207    }
208
209    fn condensation(&mut self) -> Result<()> {
210        // Filter out the support points with lambda < max(lambda)/1000
211
212        let max_lambda = self
213            .lambda
214            .iter()
215            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
216
217        let mut keep = Vec::<usize>::new();
218        for (index, lam) in self.lambda.iter().enumerate() {
219            if lam > max_lambda / 1000_f64 {
220                keep.push(index);
221            }
222        }
223        if self.psi.matrix().ncols() != keep.len() {
224            tracing::debug!(
225                "Lambda (max/1000) dropped {} support point(s)",
226                self.psi.matrix().ncols() - keep.len(),
227            );
228        }
229
230        self.theta.filter_indices(keep.as_slice());
231        self.psi.filter_column_indices(keep.as_slice());
232
233        //Rank-Revealing Factorization
234        let (r, perm) = qr::qrd(&self.psi)?;
235
236        let mut keep = Vec::<usize>::new();
237
238        // The minimum between the number of subjects and the actual number of support points
239        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
240        for i in 0..keep_n {
241            let test = r.col(i).norm_l2();
242            let r_diag_val = r.get(i, i);
243            let ratio = r_diag_val / test;
244            if ratio.abs() >= 1e-8 {
245                keep.push(*perm.get(i).unwrap());
246            }
247        }
248
249        // If a support point is dropped, log it as a debug message
250        if self.psi.matrix().ncols() != keep.len() {
251            tracing::debug!(
252                "QR decomposition dropped {} support point(s)",
253                self.psi.matrix().ncols() - keep.len(),
254            );
255        }
256
257        // Filter to keep only the support points (rows) that are in the `keep` vector
258        self.theta.filter_indices(keep.as_slice());
259        // Filter to keep only the support points (columns) that are in the `keep` vector
260        self.psi.filter_column_indices(keep.as_slice());
261
262        self.validate_psi()?;
263        (self.lambda, self.objf) = match burke(&self.psi) {
264            Ok((lambda, objf)) => (lambda.into(), objf),
265            Err(err) => {
266                return Err(anyhow::anyhow!(
267                    "Error in IPM during condensation: {:?}",
268                    err
269                ));
270            }
271        };
272        self.w = self.lambda.clone().into();
273        Ok(())
274    }
275
276    fn optimizations(&mut self) -> Result<()> {
277        self.error_models
278            .clone()
279            .iter_mut()
280            .filter_map(|(outeq, em)| {
281                if em.optimize() {
282                    Some((outeq, em))
283                } else {
284                    None
285                }
286            })
287            .try_for_each(|(outeq, em)| -> Result<()> {
288                // OPTIMIZATION
289
290                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
291                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
292
293                let mut error_model_up = self.error_models.clone();
294                error_model_up.set_factor(outeq, gamma_up)?;
295
296                let mut error_model_down = self.error_models.clone();
297                error_model_down.set_factor(outeq, gamma_down)?;
298
299                let psi_up = calculate_psi(
300                    &self.equation,
301                    &self.data,
302                    &self.theta,
303                    &error_model_up,
304                    false,
305                    true,
306                )?;
307                let psi_down = calculate_psi(
308                    &self.equation,
309                    &self.data,
310                    &self.theta,
311                    &error_model_down,
312                    false,
313                    true,
314                )?;
315
316                let (lambda_up, objf_up) = match burke(&psi_up) {
317                    Ok((lambda, objf)) => (lambda, objf),
318                    Err(err) => {
319                        //todo: write out report
320                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
321                    }
322                };
323                let (lambda_down, objf_down) = match burke(&psi_down) {
324                    Ok((lambda, objf)) => (lambda, objf),
325                    Err(err) => {
326                        //todo: write out report
327                        //panic!("Error in IPM: {:?}", err);
328                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
329                        //(Array1::zeros(1), f64::NEG_INFINITY)
330                    }
331                };
332                if objf_up > self.objf {
333                    self.error_models.set_factor(outeq, gamma_up)?;
334                    self.objf = objf_up;
335                    self.gamma_delta[outeq] *= 4.;
336                    self.lambda = lambda_up;
337                    self.psi = psi_up;
338                }
339                if objf_down > self.objf {
340                    self.error_models.set_factor(outeq, gamma_down)?;
341                    self.objf = objf_down;
342                    self.gamma_delta[outeq] *= 4.;
343                    self.lambda = lambda_down;
344                    self.psi = psi_down;
345                }
346                self.gamma_delta[outeq] *= 0.5;
347                if self.gamma_delta[outeq] <= 0.01 {
348                    self.gamma_delta[outeq] = 0.1;
349                }
350                Ok(())
351            })?;
352
353        Ok(())
354    }
355
356    fn logs(&self) {
357        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
358        tracing::debug!("Support points: {}", self.theta.nspp());
359
360        self.error_models.iter().for_each(|(outeq, em)| {
361            if ErrorModel::None == *em {
362                return;
363            }
364            tracing::debug!(
365                "Error model for outeq {}: {:.16}",
366                outeq,
367                em.factor().unwrap_or_default()
368            );
369        });
370
371        tracing::debug!("EPS = {:.4}", self.eps);
372        // Increasing objf signals instability or model misspecification.
373        if self.last_objf > self.objf + 1e-4 {
374            tracing::warn!(
375                "Objective function decreased from {:.4} to {:.4} (delta = {})",
376                -2.0 * self.last_objf,
377                -2.0 * self.objf,
378                -2.0 * self.last_objf - -2.0 * self.objf
379            );
380        }
381    }
382
383    fn expansion(&mut self) -> Result<()> {
384        adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D)?;
385        Ok(())
386    }
387
388    fn set_status(&mut self, status: Status) {
389        self.status = status;
390    }
391
392    fn status(&self) -> &Status {
393        &self.status
394    }
395}