1use crate::algorithms::Status;
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::evaluation::ipm::burke;
5pub use crate::routines::evaluation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{CycleLog, NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11
12use anyhow::bail;
13use anyhow::Result;
14use pharmsol::prelude::{
15 data::{Data, ErrorModels},
16 simulator::Equation,
17};
18
19use pharmsol::prelude::ErrorModel;
20
21use faer::Col;
22
23use crate::routines::initialization;
24
25use crate::routines::expansion::adaptative_grid::adaptative_grid;
26
27const THETA_E: f64 = 1e-4; const THETA_G: f64 = 1e-4; const THETA_F: f64 = 1e-2;
30const THETA_D: f64 = 1e-4;
31
32#[derive(Debug)]
33pub struct NPAG<E: Equation> {
34 equation: E,
35 ranges: Vec<(f64, f64)>,
36 psi: Psi,
37 theta: Theta,
38 lambda: Col<f64>,
39 w: Col<f64>,
40 eps: f64,
41 last_objf: f64,
42 objf: f64,
43 f0: f64,
44 f1: f64,
45 cycle: usize,
46 gamma_delta: Vec<f64>,
47 error_models: ErrorModels,
48 converged: bool,
49 status: Status,
50 cycle_log: CycleLog,
51 data: Data,
52 settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPAG<E> {
56 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57 Ok(Box::new(Self {
58 equation,
59 ranges: settings.parameters().ranges(),
60 psi: Psi::new(),
61 theta: Theta::new(),
62 lambda: Col::zeros(0),
63 w: Col::zeros(0),
64 eps: 0.2,
65 last_objf: -1e30,
66 objf: f64::NEG_INFINITY,
67 f0: -1e30,
68 f1: f64::default(),
69 cycle: 0,
70 gamma_delta: vec![0.1; settings.errormodels().len()],
71 error_models: settings.errormodels().clone().into(),
72 converged: false,
73 status: Status::Starting,
74 cycle_log: CycleLog::new(),
75 settings,
76 data,
77 }))
78 }
79
80 fn equation(&self) -> &E {
81 &self.equation
82 }
83 fn into_npresult(&self) -> NPResult<E> {
84 NPResult::new(
85 self.equation.clone(),
86 self.data.clone(),
87 self.theta.clone(),
88 self.psi.clone(),
89 self.w.clone(),
90 -2. * self.objf,
91 self.cycle,
92 self.status.clone(),
93 self.settings.clone(),
94 self.cycle_log.clone(),
95 )
96 }
97
98 fn get_settings(&self) -> &Settings {
99 &self.settings
100 }
101
102 fn get_data(&self) -> &Data {
103 &self.data
104 }
105
106 fn get_prior(&self) -> Theta {
107 initialization::sample_space(&self.settings).unwrap()
108 }
109
110 fn likelihood(&self) -> f64 {
111 self.objf
112 }
113
114 fn inc_cycle(&mut self) -> usize {
115 self.cycle += 1;
116 self.cycle
117 }
118
119 fn get_cycle(&self) -> usize {
120 self.cycle
121 }
122
123 fn set_theta(&mut self, theta: Theta) {
124 self.theta = theta;
125 }
126
127 fn theta(&self) -> &Theta {
128 &self.theta
129 }
130
131 fn psi(&self) -> &Psi {
132 &self.psi
133 }
134
135 fn convergence_evaluation(&mut self) {
136 let psi = self.psi.matrix();
137 let w = &self.w;
138 if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
139 self.eps /= 2.;
140 if self.eps <= THETA_E {
141 let pyl = psi * w;
142 self.f1 = pyl.iter().map(|x| x.ln()).sum();
143 if (self.f1 - self.f0).abs() <= THETA_F {
144 tracing::info!("The model converged after {} cycles", self.cycle,);
145 self.converged = true;
146 self.status = Status::Converged;
147 } else {
148 self.f0 = self.f1;
149 self.eps = 0.2;
150 }
151 }
152 }
153
154 if self.cycle >= self.settings.config().cycles {
156 tracing::warn!("Maximum number of cycles reached");
157 self.converged = true;
158 self.status = Status::MaxCycles;
159 }
160
161 if std::path::Path::new("stop").exists() {
163 tracing::warn!("Stopfile detected - breaking");
164 self.status = Status::ManualStop;
165 }
166
167 let state = NPCycle {
169 cycle: self.cycle,
170 objf: -2. * self.objf,
171 delta_objf: (self.last_objf - self.objf).abs(),
172 nspp: self.theta.nspp(),
173 theta: self.theta.clone(),
174 error_models: self.error_models.clone(),
175 status: self.status.clone(),
176 };
177
178 self.cycle_log.push(state);
180 self.last_objf = self.objf;
181 }
182
183 fn converged(&self) -> bool {
184 self.converged
185 }
186
187 fn evaluation(&mut self) -> Result<()> {
188 self.psi = calculate_psi(
189 &self.equation,
190 &self.data,
191 &self.theta,
192 &self.error_models,
193 self.cycle == 1 && self.settings.config().progress,
194 self.cycle != 1,
195 )?;
196
197 if let Err(err) = self.validate_psi() {
198 bail!(err);
199 }
200
201 (self.lambda, _) = match burke(&self.psi) {
202 Ok((lambda, objf)) => (lambda, objf),
203 Err(err) => {
204 bail!("Error in IPM during evaluation: {:?}", err);
205 }
206 };
207 Ok(())
208 }
209
210 fn condensation(&mut self) -> Result<()> {
211 let max_lambda = self
214 .lambda
215 .iter()
216 .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
217
218 let mut keep = Vec::<usize>::new();
219 for (index, lam) in self.lambda.iter().enumerate() {
220 if *lam > max_lambda / 1000_f64 {
221 keep.push(index);
222 }
223 }
224 if self.psi.matrix().ncols() != keep.len() {
225 tracing::debug!(
226 "Lambda (max/1000) dropped {} support point(s)",
227 self.psi.matrix().ncols() - keep.len(),
228 );
229 }
230
231 self.theta.filter_indices(keep.as_slice());
232 self.psi.filter_column_indices(keep.as_slice());
233
234 let (r, perm) = qr::qrd(&self.psi)?;
236
237 let mut keep = Vec::<usize>::new();
238
239 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
241 for i in 0..keep_n {
242 let test = r.col(i).norm_l2();
243 let r_diag_val = r.get(i, i);
244 let ratio = r_diag_val / test;
245 if ratio.abs() >= 1e-8 {
246 keep.push(*perm.get(i).unwrap());
247 }
248 }
249
250 if self.psi.matrix().ncols() != keep.len() {
252 tracing::debug!(
253 "QR decomposition dropped {} support point(s)",
254 self.psi.matrix().ncols() - keep.len(),
255 );
256 }
257
258 self.theta.filter_indices(keep.as_slice());
260 self.psi.filter_column_indices(keep.as_slice());
262
263 self.validate_psi()?;
264 (self.lambda, self.objf) = match burke(&self.psi) {
265 Ok((lambda, objf)) => (lambda, objf),
266 Err(err) => {
267 return Err(anyhow::anyhow!(
268 "Error in IPM during condensation: {:?}",
269 err
270 ));
271 }
272 };
273 self.w = self.lambda.clone();
274 Ok(())
275 }
276
277 fn optimizations(&mut self) -> Result<()> {
278 self.error_models
279 .clone()
280 .iter_mut()
281 .filter_map(|(outeq, em)| match em {
282 ErrorModel::None => None,
283 _ => Some((outeq, em)),
284 })
285 .try_for_each(|(outeq, em)| -> Result<()> {
286 let gamma_up = em.scalar()? * (1.0 + self.gamma_delta[outeq]);
289 let gamma_down = em.scalar()? / (1.0 + self.gamma_delta[outeq]);
290
291 let mut error_model_up = self.error_models.clone();
292 error_model_up.set_scalar(outeq, gamma_up)?;
293
294 let mut error_model_down = self.error_models.clone();
295 error_model_down.set_scalar(outeq, gamma_down)?;
296
297 let psi_up = calculate_psi(
298 &self.equation,
299 &self.data,
300 &self.theta,
301 &error_model_up,
302 false,
303 true,
304 )?;
305 let psi_down = calculate_psi(
306 &self.equation,
307 &self.data,
308 &self.theta,
309 &error_model_down,
310 false,
311 true,
312 )?;
313
314 let (lambda_up, objf_up) = match burke(&psi_up) {
315 Ok((lambda, objf)) => (lambda, objf),
316 Err(err) => {
317 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
319 }
320 };
321 let (lambda_down, objf_down) = match burke(&psi_down) {
322 Ok((lambda, objf)) => (lambda, objf),
323 Err(err) => {
324 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
327 }
329 };
330 if objf_up > self.objf {
331 self.error_models.set_scalar(outeq, gamma_up)?;
332 self.objf = objf_up;
333 self.gamma_delta[outeq] *= 4.;
334 self.lambda = lambda_up;
335 self.psi = psi_up;
336 }
337 if objf_down > self.objf {
338 self.error_models.set_scalar(outeq, gamma_down)?;
339 self.objf = objf_down;
340 self.gamma_delta[outeq] *= 4.;
341 self.lambda = lambda_down;
342 self.psi = psi_down;
343 }
344 self.gamma_delta[outeq] *= 0.5;
345 if self.gamma_delta[outeq] <= 0.01 {
346 self.gamma_delta[outeq] = 0.1;
347 }
348 Ok(())
349 })?;
350
351 Ok(())
352 }
353
354 fn logs(&self) {
355 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
356 tracing::debug!("Support points: {}", self.theta.nspp());
357
358 self.error_models.iter().for_each(|(outeq, em)| {
359 if ErrorModel::None == *em {
360 return;
361 }
362 tracing::debug!(
363 "Error model for outeq {}: {:.16}",
364 outeq,
365 em.scalar().unwrap_or_default()
366 );
367 });
368
369 tracing::debug!("EPS = {:.4}", self.eps);
370 if self.last_objf > self.objf + 1e-4 {
372 tracing::warn!(
373 "Objective function decreased from {:.4} to {:.4} (delta = {})",
374 -2.0 * self.last_objf,
375 -2.0 * self.objf,
376 -2.0 * self.last_objf - -2.0 * self.objf
377 );
378 }
379 }
380
381 fn expansion(&mut self) -> Result<()> {
382 adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D);
383 Ok(())
384 }
385
386 fn set_status(&mut self, status: Status) {
387 self.status = status;
388 }
389
390 fn status(&self) -> &Status {
391 &self.status
392 }
393}