pmcore/algorithms/
npag.rs

1use crate::algorithms::Status;
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::evaluation::ipm::burke;
5pub use crate::routines::evaluation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{CycleLog, NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11
12use anyhow::bail;
13use anyhow::Result;
14use pharmsol::prelude::{
15    data::{Data, ErrorModels},
16    simulator::Equation,
17};
18
19use pharmsol::prelude::ErrorModel;
20
21use faer::Col;
22
23use crate::routines::initialization;
24
25use crate::routines::expansion::adaptative_grid::adaptative_grid;
26
27const THETA_E: f64 = 1e-4; // Convergence criteria
28const THETA_G: f64 = 1e-4; // Objective function convergence criteria
29const THETA_F: f64 = 1e-2;
30const THETA_D: f64 = 1e-4;
31
32#[derive(Debug)]
33pub struct NPAG<E: Equation> {
34    equation: E,
35    ranges: Vec<(f64, f64)>,
36    psi: Psi,
37    theta: Theta,
38    lambda: Col<f64>,
39    w: Col<f64>,
40    eps: f64,
41    last_objf: f64,
42    objf: f64,
43    f0: f64,
44    f1: f64,
45    cycle: usize,
46    gamma_delta: Vec<f64>,
47    error_models: ErrorModels,
48    converged: bool,
49    status: Status,
50    cycle_log: CycleLog,
51    data: Data,
52    settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPAG<E> {
56    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57        Ok(Box::new(Self {
58            equation,
59            ranges: settings.parameters().ranges(),
60            psi: Psi::new(),
61            theta: Theta::new(),
62            lambda: Col::zeros(0),
63            w: Col::zeros(0),
64            eps: 0.2,
65            last_objf: -1e30,
66            objf: f64::NEG_INFINITY,
67            f0: -1e30,
68            f1: f64::default(),
69            cycle: 0,
70            gamma_delta: vec![0.1; settings.errormodels().len()],
71            error_models: settings.errormodels().clone().into(),
72            converged: false,
73            status: Status::Starting,
74            cycle_log: CycleLog::new(),
75            settings,
76            data,
77        }))
78    }
79
80    fn equation(&self) -> &E {
81        &self.equation
82    }
83    fn into_npresult(&self) -> NPResult<E> {
84        NPResult::new(
85            self.equation.clone(),
86            self.data.clone(),
87            self.theta.clone(),
88            self.psi.clone(),
89            self.w.clone(),
90            -2. * self.objf,
91            self.cycle,
92            self.status.clone(),
93            self.settings.clone(),
94            self.cycle_log.clone(),
95        )
96    }
97
98    fn get_settings(&self) -> &Settings {
99        &self.settings
100    }
101
102    fn get_data(&self) -> &Data {
103        &self.data
104    }
105
106    fn get_prior(&self) -> Theta {
107        initialization::sample_space(&self.settings).unwrap()
108    }
109
110    fn likelihood(&self) -> f64 {
111        self.objf
112    }
113
114    fn inc_cycle(&mut self) -> usize {
115        self.cycle += 1;
116        self.cycle
117    }
118
119    fn get_cycle(&self) -> usize {
120        self.cycle
121    }
122
123    fn set_theta(&mut self, theta: Theta) {
124        self.theta = theta;
125    }
126
127    fn theta(&self) -> &Theta {
128        &self.theta
129    }
130
131    fn psi(&self) -> &Psi {
132        &self.psi
133    }
134
135    fn convergence_evaluation(&mut self) {
136        let psi = self.psi.matrix();
137        let w = &self.w;
138        if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
139            self.eps /= 2.;
140            if self.eps <= THETA_E {
141                let pyl = psi * w;
142                self.f1 = pyl.iter().map(|x| x.ln()).sum();
143                if (self.f1 - self.f0).abs() <= THETA_F {
144                    tracing::info!("The model converged after {} cycles", self.cycle,);
145                    self.converged = true;
146                    self.status = Status::Converged;
147                } else {
148                    self.f0 = self.f1;
149                    self.eps = 0.2;
150                }
151            }
152        }
153
154        // Stop if we have reached maximum number of cycles
155        if self.cycle >= self.settings.config().cycles {
156            tracing::warn!("Maximum number of cycles reached");
157            self.converged = true;
158            self.status = Status::MaxCycles;
159        }
160
161        // Stop if stopfile exists
162        if std::path::Path::new("stop").exists() {
163            tracing::warn!("Stopfile detected - breaking");
164            self.status = Status::ManualStop;
165        }
166
167        // Create state object
168        let state = NPCycle {
169            cycle: self.cycle,
170            objf: -2. * self.objf,
171            delta_objf: (self.last_objf - self.objf).abs(),
172            nspp: self.theta.nspp(),
173            theta: self.theta.clone(),
174            error_models: self.error_models.clone(),
175            status: self.status.clone(),
176        };
177
178        // Write cycle log
179        self.cycle_log.push(state);
180        self.last_objf = self.objf;
181    }
182
183    fn converged(&self) -> bool {
184        self.converged
185    }
186
187    fn evaluation(&mut self) -> Result<()> {
188        self.psi = calculate_psi(
189            &self.equation,
190            &self.data,
191            &self.theta,
192            &self.error_models,
193            self.cycle == 1 && self.settings.config().progress,
194            self.cycle != 1,
195        )?;
196
197        if let Err(err) = self.validate_psi() {
198            bail!(err);
199        }
200
201        (self.lambda, _) = match burke(&self.psi) {
202            Ok((lambda, objf)) => (lambda, objf),
203            Err(err) => {
204                bail!("Error in IPM during evaluation: {:?}", err);
205            }
206        };
207        Ok(())
208    }
209
210    fn condensation(&mut self) -> Result<()> {
211        // Filter out the support points with lambda < max(lambda)/1000
212
213        let max_lambda = self
214            .lambda
215            .iter()
216            .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
217
218        let mut keep = Vec::<usize>::new();
219        for (index, lam) in self.lambda.iter().enumerate() {
220            if *lam > max_lambda / 1000_f64 {
221                keep.push(index);
222            }
223        }
224        if self.psi.matrix().ncols() != keep.len() {
225            tracing::debug!(
226                "Lambda (max/1000) dropped {} support point(s)",
227                self.psi.matrix().ncols() - keep.len(),
228            );
229        }
230
231        self.theta.filter_indices(keep.as_slice());
232        self.psi.filter_column_indices(keep.as_slice());
233
234        //Rank-Revealing Factorization
235        let (r, perm) = qr::qrd(&self.psi)?;
236
237        let mut keep = Vec::<usize>::new();
238
239        // The minimum between the number of subjects and the actual number of support points
240        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
241        for i in 0..keep_n {
242            let test = r.col(i).norm_l2();
243            let r_diag_val = r.get(i, i);
244            let ratio = r_diag_val / test;
245            if ratio.abs() >= 1e-8 {
246                keep.push(*perm.get(i).unwrap());
247            }
248        }
249
250        // If a support point is dropped, log it as a debug message
251        if self.psi.matrix().ncols() != keep.len() {
252            tracing::debug!(
253                "QR decomposition dropped {} support point(s)",
254                self.psi.matrix().ncols() - keep.len(),
255            );
256        }
257
258        // Filter to keep only the support points (rows) that are in the `keep` vector
259        self.theta.filter_indices(keep.as_slice());
260        // Filter to keep only the support points (columns) that are in the `keep` vector
261        self.psi.filter_column_indices(keep.as_slice());
262
263        self.validate_psi()?;
264        (self.lambda, self.objf) = match burke(&self.psi) {
265            Ok((lambda, objf)) => (lambda, objf),
266            Err(err) => {
267                return Err(anyhow::anyhow!(
268                    "Error in IPM during condensation: {:?}",
269                    err
270                ));
271            }
272        };
273        self.w = self.lambda.clone();
274        Ok(())
275    }
276
277    fn optimizations(&mut self) -> Result<()> {
278        self.error_models
279            .clone()
280            .iter_mut()
281            .filter_map(|(outeq, em)| match em {
282                ErrorModel::None => None,
283                _ => Some((outeq, em)),
284            })
285            .try_for_each(|(outeq, em)| -> Result<()> {
286                // OPTIMIZATION
287
288                let gamma_up = em.scalar()? * (1.0 + self.gamma_delta[outeq]);
289                let gamma_down = em.scalar()? / (1.0 + self.gamma_delta[outeq]);
290
291                let mut error_model_up = self.error_models.clone();
292                error_model_up.set_scalar(outeq, gamma_up)?;
293
294                let mut error_model_down = self.error_models.clone();
295                error_model_down.set_scalar(outeq, gamma_down)?;
296
297                let psi_up = calculate_psi(
298                    &self.equation,
299                    &self.data,
300                    &self.theta,
301                    &error_model_up,
302                    false,
303                    true,
304                )?;
305                let psi_down = calculate_psi(
306                    &self.equation,
307                    &self.data,
308                    &self.theta,
309                    &error_model_down,
310                    false,
311                    true,
312                )?;
313
314                let (lambda_up, objf_up) = match burke(&psi_up) {
315                    Ok((lambda, objf)) => (lambda, objf),
316                    Err(err) => {
317                        //todo: write out report
318                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
319                    }
320                };
321                let (lambda_down, objf_down) = match burke(&psi_down) {
322                    Ok((lambda, objf)) => (lambda, objf),
323                    Err(err) => {
324                        //todo: write out report
325                        //panic!("Error in IPM: {:?}", err);
326                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
327                        //(Array1::zeros(1), f64::NEG_INFINITY)
328                    }
329                };
330                if objf_up > self.objf {
331                    self.error_models.set_scalar(outeq, gamma_up)?;
332                    self.objf = objf_up;
333                    self.gamma_delta[outeq] *= 4.;
334                    self.lambda = lambda_up;
335                    self.psi = psi_up;
336                }
337                if objf_down > self.objf {
338                    self.error_models.set_scalar(outeq, gamma_down)?;
339                    self.objf = objf_down;
340                    self.gamma_delta[outeq] *= 4.;
341                    self.lambda = lambda_down;
342                    self.psi = psi_down;
343                }
344                self.gamma_delta[outeq] *= 0.5;
345                if self.gamma_delta[outeq] <= 0.01 {
346                    self.gamma_delta[outeq] = 0.1;
347                }
348                Ok(())
349            })?;
350
351        Ok(())
352    }
353
354    fn logs(&self) {
355        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
356        tracing::debug!("Support points: {}", self.theta.nspp());
357
358        self.error_models.iter().for_each(|(outeq, em)| {
359            if ErrorModel::None == *em {
360                return;
361            }
362            tracing::debug!(
363                "Error model for outeq {}: {:.16}",
364                outeq,
365                em.scalar().unwrap_or_default()
366            );
367        });
368
369        tracing::debug!("EPS = {:.4}", self.eps);
370        // Increasing objf signals instability or model misspecification.
371        if self.last_objf > self.objf + 1e-4 {
372            tracing::warn!(
373                "Objective function decreased from {:.4} to {:.4} (delta = {})",
374                -2.0 * self.last_objf,
375                -2.0 * self.objf,
376                -2.0 * self.last_objf - -2.0 * self.objf
377            );
378        }
379    }
380
381    fn expansion(&mut self) -> Result<()> {
382        adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D);
383        Ok(())
384    }
385
386    fn set_status(&mut self, status: Status) {
387        self.status = status;
388    }
389
390    fn status(&self) -> &Status {
391        &self.status
392    }
393}