pmcore/algorithms/
npag.rs

1use crate::algorithms::{Status, StopReason};
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::estimation::ipm::burke;
5pub use crate::routines::estimation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11use crate::structs::weights::Weights;
12
13use anyhow::bail;
14use anyhow::Result;
15use pharmsol::prelude::{
16    data::{Data, ErrorModels},
17    simulator::Equation,
18};
19
20use pharmsol::prelude::ErrorModel;
21
22use crate::routines::initialization;
23
24use crate::routines::expansion::adaptative_grid::adaptative_grid;
25
26const THETA_E: f64 = 1e-4; // Convergence criteria
27const THETA_G: f64 = 1e-4; // Objective function convergence criteria
28const THETA_F: f64 = 1e-2;
29const THETA_D: f64 = 1e-4;
30
31#[derive(Debug)]
32pub struct NPAG<E: Equation + Send + 'static> {
33    equation: E,
34    ranges: Vec<(f64, f64)>,
35    psi: Psi,
36    theta: Theta,
37    lambda: Weights,
38    w: Weights,
39    eps: f64,
40    last_objf: f64,
41    objf: f64,
42    f0: f64,
43    f1: f64,
44    cycle: usize,
45    gamma_delta: Vec<f64>,
46    error_models: ErrorModels,
47    status: Status,
48    cycle_log: CycleLog,
49    data: Data,
50    settings: Settings,
51}
52
53impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
54    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
55        Ok(Box::new(Self {
56            equation,
57            ranges: settings.parameters().ranges(),
58            psi: Psi::new(),
59            theta: Theta::new(),
60            lambda: Weights::default(),
61            w: Weights::default(),
62            eps: 0.2,
63            last_objf: -1e30,
64            objf: f64::NEG_INFINITY,
65            f0: -1e30,
66            f1: f64::default(),
67            cycle: 0,
68            gamma_delta: vec![0.1; settings.errormodels().len()],
69            error_models: settings.errormodels().clone(),
70            status: Status::Continue,
71            cycle_log: CycleLog::new(),
72            settings,
73            data,
74        }))
75    }
76
77    fn equation(&self) -> &E {
78        &self.equation
79    }
80    fn into_npresult(&self) -> NPResult<E> {
81        NPResult::new(
82            self.equation.clone(),
83            self.data.clone(),
84            self.theta.clone(),
85            self.psi.clone(),
86            self.w.clone(),
87            -2. * self.objf,
88            self.cycle,
89            self.status.clone(),
90            self.settings.clone(),
91            self.cycle_log.clone(),
92        )
93    }
94
95    fn settings(&self) -> &Settings {
96        &self.settings
97    }
98
99    fn data(&self) -> &Data {
100        &self.data
101    }
102
103    fn get_prior(&self) -> Theta {
104        initialization::sample_space(&self.settings).unwrap()
105    }
106
107    fn likelihood(&self) -> f64 {
108        self.objf
109    }
110
111    fn increment_cycle(&mut self) -> usize {
112        self.cycle += 1;
113        self.cycle
114    }
115
116    fn cycle(&self) -> usize {
117        self.cycle
118    }
119
120    fn set_theta(&mut self, theta: Theta) {
121        self.theta = theta;
122    }
123
124    fn theta(&self) -> &Theta {
125        &self.theta
126    }
127
128    fn psi(&self) -> &Psi {
129        &self.psi
130    }
131
132    fn evaluation(&mut self) -> Result<Status> {
133        tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
134        tracing::debug!("Support points: {}", self.theta.nspp());
135
136        self.error_models.iter().for_each(|(outeq, em)| {
137            if ErrorModel::None == *em {
138                return;
139            }
140            tracing::debug!(
141                "Error model for outeq {}: {:.2}",
142                outeq,
143                em.factor().unwrap_or_default()
144            );
145        });
146
147        tracing::debug!("EPS = {:.4}", self.eps);
148        // Increasing objf signals instability or model misspecification.
149        if self.last_objf > self.objf + 1e-4 {
150            tracing::warn!(
151                "Objective function decreased from {:.4} to {:.4} (delta = {})",
152                -2.0 * self.last_objf,
153                -2.0 * self.objf,
154                -2.0 * self.last_objf - -2.0 * self.objf
155            );
156        }
157
158        let psi = self.psi.matrix();
159        let w = &self.w;
160        if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
161            self.eps /= 2.;
162            if self.eps <= THETA_E {
163                let pyl = psi * w.weights();
164                self.f1 = pyl.iter().map(|x| x.ln()).sum();
165                if (self.f1 - self.f0).abs() <= THETA_F {
166                    tracing::info!("The model converged after {} cycles", self.cycle,);
167                    self.set_status(Status::Stop(StopReason::Converged));
168                    self.log_cycle_state();
169                    return Ok(self.status().clone());
170                } else {
171                    self.f0 = self.f1;
172                    self.eps = 0.2;
173                }
174            }
175        }
176
177        // Stop if we have reached maximum number of cycles
178        if self.cycle >= self.settings.config().cycles {
179            tracing::warn!("Maximum number of cycles reached");
180            self.set_status(Status::Stop(StopReason::MaxCycles));
181            self.log_cycle_state();
182            return Ok(self.status().clone());
183        }
184
185        // Stop if stopfile exists
186        if std::path::Path::new("stop").exists() {
187            tracing::warn!("Stopfile detected - breaking");
188            self.set_status(Status::Stop(StopReason::Stopped));
189            self.log_cycle_state();
190            return Ok(self.status().clone());
191        }
192
193        // Continue with normal operation
194        self.set_status(Status::Continue);
195        self.log_cycle_state();
196        Ok(self.status().clone())
197    }
198
199    fn estimation(&mut self) -> Result<()> {
200        self.psi = calculate_psi(
201            &self.equation,
202            &self.data,
203            &self.theta,
204            &self.error_models,
205            self.cycle == 1 && self.settings.config().progress,
206            self.cycle != 1,
207        )?;
208
209        if let Err(err) = self.validate_psi() {
210            bail!(err);
211        }
212
213        (self.lambda, _) = match burke(&self.psi) {
214            Ok((lambda, objf)) => (lambda.into(), objf),
215            Err(err) => {
216                bail!("Error in IPM during estimation: {:?}", err);
217            }
218        };
219        Ok(())
220    }
221
222    fn condensation(&mut self) -> Result<()> {
223        // Filter out the support points with lambda < max(lambda)/1000
224
225        let max_lambda = self
226            .lambda
227            .iter()
228            .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
229
230        let mut keep = Vec::<usize>::new();
231        for (index, lam) in self.lambda.iter().enumerate() {
232            if lam > max_lambda / 1000_f64 {
233                keep.push(index);
234            }
235        }
236        if self.psi.matrix().ncols() != keep.len() {
237            tracing::debug!(
238                "Lambda (max/1000) dropped {} support point(s)",
239                self.psi.matrix().ncols() - keep.len(),
240            );
241        }
242
243        self.theta.filter_indices(keep.as_slice());
244        self.psi.filter_column_indices(keep.as_slice());
245
246        //Rank-Revealing Factorization
247        let (r, perm) = qr::qrd(&self.psi)?;
248
249        let mut keep = Vec::<usize>::new();
250
251        // The minimum between the number of subjects and the actual number of support points
252        let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
253        for i in 0..keep_n {
254            let test = r.col(i).norm_l2();
255            let r_diag_val = r.get(i, i);
256            let ratio = r_diag_val / test;
257            if ratio.abs() >= 1e-8 {
258                keep.push(*perm.get(i).unwrap());
259            }
260        }
261
262        // If a support point is dropped, log it as a debug message
263        if self.psi.matrix().ncols() != keep.len() {
264            tracing::debug!(
265                "QR decomposition dropped {} support point(s)",
266                self.psi.matrix().ncols() - keep.len(),
267            );
268        }
269
270        // Filter to keep only the support points (rows) that are in the `keep` vector
271        self.theta.filter_indices(keep.as_slice());
272        // Filter to keep only the support points (columns) that are in the `keep` vector
273        self.psi.filter_column_indices(keep.as_slice());
274
275        self.validate_psi()?;
276        (self.lambda, self.objf) = match burke(&self.psi) {
277            Ok((lambda, objf)) => (lambda.into(), objf),
278            Err(err) => {
279                return Err(anyhow::anyhow!(
280                    "Error in IPM during condensation: {:?}",
281                    err
282                ));
283            }
284        };
285        self.w = self.lambda.clone().into();
286        Ok(())
287    }
288
289    fn optimizations(&mut self) -> Result<()> {
290        self.error_models
291            .clone()
292            .iter_mut()
293            .filter_map(|(outeq, em)| {
294                if em.optimize() {
295                    Some((outeq, em))
296                } else {
297                    None
298                }
299            })
300            .try_for_each(|(outeq, em)| -> Result<()> {
301                // OPTIMIZATION
302
303                let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
304                let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
305
306                let mut error_model_up = self.error_models.clone();
307                error_model_up.set_factor(outeq, gamma_up)?;
308
309                let mut error_model_down = self.error_models.clone();
310                error_model_down.set_factor(outeq, gamma_down)?;
311
312                let psi_up = calculate_psi(
313                    &self.equation,
314                    &self.data,
315                    &self.theta,
316                    &error_model_up,
317                    false,
318                    true,
319                )?;
320                let psi_down = calculate_psi(
321                    &self.equation,
322                    &self.data,
323                    &self.theta,
324                    &error_model_down,
325                    false,
326                    true,
327                )?;
328
329                let (lambda_up, objf_up) = match burke(&psi_up) {
330                    Ok((lambda, objf)) => (lambda, objf),
331                    Err(err) => {
332                        //todo: write out report
333                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
334                    }
335                };
336                let (lambda_down, objf_down) = match burke(&psi_down) {
337                    Ok((lambda, objf)) => (lambda, objf),
338                    Err(err) => {
339                        //todo: write out report
340                        //panic!("Error in IPM: {:?}", err);
341                        return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
342                        //(Array1::zeros(1), f64::NEG_INFINITY)
343                    }
344                };
345                if objf_up > self.objf {
346                    self.error_models.set_factor(outeq, gamma_up)?;
347                    self.objf = objf_up;
348                    self.gamma_delta[outeq] *= 4.;
349                    self.lambda = lambda_up;
350                    self.psi = psi_up;
351                }
352                if objf_down > self.objf {
353                    self.error_models.set_factor(outeq, gamma_down)?;
354                    self.objf = objf_down;
355                    self.gamma_delta[outeq] *= 4.;
356                    self.lambda = lambda_down;
357                    self.psi = psi_down;
358                }
359                self.gamma_delta[outeq] *= 0.5;
360                if self.gamma_delta[outeq] <= 0.01 {
361                    self.gamma_delta[outeq] = 0.1;
362                }
363                Ok(())
364            })?;
365
366        Ok(())
367    }
368
369    fn expansion(&mut self) -> Result<()> {
370        adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D)?;
371        Ok(())
372    }
373
374    fn set_status(&mut self, status: Status) {
375        self.status = status;
376    }
377
378    fn status(&self) -> &Status {
379        &self.status
380    }
381
382    fn log_cycle_state(&mut self) {
383        let state = NPCycle::new(
384            self.cycle,
385            -2. * self.objf,
386            self.error_models.clone(),
387            self.theta.clone(),
388            self.theta.nspp(),
389            (self.last_objf - self.objf).abs(),
390            self.status.clone(),
391        );
392        self.cycle_log.push(state);
393        self.last_objf = self.objf;
394    }
395}