1use crate::algorithms::Status;
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::evaluation::ipm::burke;
5pub use crate::routines::evaluation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11use crate::structs::weights::Weights;
12
13use anyhow::bail;
14use anyhow::Result;
15use pharmsol::prelude::{
16 data::{Data, ErrorModels},
17 simulator::Equation,
18};
19
20use pharmsol::prelude::ErrorModel;
21
22use crate::routines::initialization;
23
24use crate::routines::expansion::adaptative_grid::adaptative_grid;
25
26const THETA_E: f64 = 1e-4; const THETA_G: f64 = 1e-4; const THETA_F: f64 = 1e-2;
29const THETA_D: f64 = 1e-4;
30
31#[derive(Debug)]
32pub struct NPAG<E: Equation> {
33 equation: E,
34 ranges: Vec<(f64, f64)>,
35 psi: Psi,
36 theta: Theta,
37 lambda: Weights,
38 w: Weights,
39 eps: f64,
40 last_objf: f64,
41 objf: f64,
42 f0: f64,
43 f1: f64,
44 cycle: usize,
45 gamma_delta: Vec<f64>,
46 error_models: ErrorModels,
47 converged: bool,
48 status: Status,
49 cycle_log: CycleLog,
50 data: Data,
51 settings: Settings,
52}
53
54impl<E: Equation> Algorithms<E> for NPAG<E> {
55 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
56 Ok(Box::new(Self {
57 equation,
58 ranges: settings.parameters().ranges(),
59 psi: Psi::new(),
60 theta: Theta::new(),
61 lambda: Weights::default(),
62 w: Weights::default(),
63 eps: 0.2,
64 last_objf: -1e30,
65 objf: f64::NEG_INFINITY,
66 f0: -1e30,
67 f1: f64::default(),
68 cycle: 0,
69 gamma_delta: vec![0.1; settings.errormodels().len()],
70 error_models: settings.errormodels().clone(),
71 converged: false,
72 status: Status::Starting,
73 cycle_log: CycleLog::new(),
74 settings,
75 data,
76 }))
77 }
78
79 fn equation(&self) -> &E {
80 &self.equation
81 }
82 fn into_npresult(&self) -> NPResult<E> {
83 NPResult::new(
84 self.equation.clone(),
85 self.data.clone(),
86 self.theta.clone(),
87 self.psi.clone(),
88 self.w.clone(),
89 -2. * self.objf,
90 self.cycle,
91 self.status.clone(),
92 self.settings.clone(),
93 self.cycle_log.clone(),
94 )
95 }
96
97 fn get_settings(&self) -> &Settings {
98 &self.settings
99 }
100
101 fn get_data(&self) -> &Data {
102 &self.data
103 }
104
105 fn get_prior(&self) -> Theta {
106 initialization::sample_space(&self.settings).unwrap()
107 }
108
109 fn likelihood(&self) -> f64 {
110 self.objf
111 }
112
113 fn inc_cycle(&mut self) -> usize {
114 self.cycle += 1;
115 self.cycle
116 }
117
118 fn get_cycle(&self) -> usize {
119 self.cycle
120 }
121
122 fn set_theta(&mut self, theta: Theta) {
123 self.theta = theta;
124 }
125
126 fn theta(&self) -> &Theta {
127 &self.theta
128 }
129
130 fn psi(&self) -> &Psi {
131 &self.psi
132 }
133
134 fn convergence_evaluation(&mut self) {
135 let psi = self.psi.matrix();
136 let w = &self.w;
137 if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
138 self.eps /= 2.;
139 if self.eps <= THETA_E {
140 let pyl = psi * w.weights();
141 self.f1 = pyl.iter().map(|x| x.ln()).sum();
142 if (self.f1 - self.f0).abs() <= THETA_F {
143 tracing::info!("The model converged after {} cycles", self.cycle,);
144 self.converged = true;
145 self.status = Status::Converged;
146 } else {
147 self.f0 = self.f1;
148 self.eps = 0.2;
149 }
150 }
151 }
152
153 if self.cycle >= self.settings.config().cycles {
155 tracing::warn!("Maximum number of cycles reached");
156 self.converged = true;
157 self.status = Status::MaxCycles;
158 }
159
160 if std::path::Path::new("stop").exists() {
162 tracing::warn!("Stopfile detected - breaking");
163 self.status = Status::ManualStop;
164 }
165
166 let state = NPCycle::new(
168 self.cycle,
169 -2. * self.objf,
170 self.error_models.clone(),
171 self.theta.clone(),
172 self.theta.nspp(),
173 (self.last_objf - self.objf).abs(),
174 self.status.clone(),
175 );
176
177 self.cycle_log.push(state);
179 self.last_objf = self.objf;
180 }
181
182 fn converged(&self) -> bool {
183 self.converged
184 }
185
186 fn evaluation(&mut self) -> Result<()> {
187 self.psi = calculate_psi(
188 &self.equation,
189 &self.data,
190 &self.theta,
191 &self.error_models,
192 self.cycle == 1 && self.settings.config().progress,
193 self.cycle != 1,
194 )?;
195
196 if let Err(err) = self.validate_psi() {
197 bail!(err);
198 }
199
200 (self.lambda, _) = match burke(&self.psi) {
201 Ok((lambda, objf)) => (lambda.into(), objf),
202 Err(err) => {
203 bail!("Error in IPM during evaluation: {:?}", err);
204 }
205 };
206 Ok(())
207 }
208
209 fn condensation(&mut self) -> Result<()> {
210 let max_lambda = self
213 .lambda
214 .iter()
215 .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
216
217 let mut keep = Vec::<usize>::new();
218 for (index, lam) in self.lambda.iter().enumerate() {
219 if lam > max_lambda / 1000_f64 {
220 keep.push(index);
221 }
222 }
223 if self.psi.matrix().ncols() != keep.len() {
224 tracing::debug!(
225 "Lambda (max/1000) dropped {} support point(s)",
226 self.psi.matrix().ncols() - keep.len(),
227 );
228 }
229
230 self.theta.filter_indices(keep.as_slice());
231 self.psi.filter_column_indices(keep.as_slice());
232
233 let (r, perm) = qr::qrd(&self.psi)?;
235
236 let mut keep = Vec::<usize>::new();
237
238 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
240 for i in 0..keep_n {
241 let test = r.col(i).norm_l2();
242 let r_diag_val = r.get(i, i);
243 let ratio = r_diag_val / test;
244 if ratio.abs() >= 1e-8 {
245 keep.push(*perm.get(i).unwrap());
246 }
247 }
248
249 if self.psi.matrix().ncols() != keep.len() {
251 tracing::debug!(
252 "QR decomposition dropped {} support point(s)",
253 self.psi.matrix().ncols() - keep.len(),
254 );
255 }
256
257 self.theta.filter_indices(keep.as_slice());
259 self.psi.filter_column_indices(keep.as_slice());
261
262 self.validate_psi()?;
263 (self.lambda, self.objf) = match burke(&self.psi) {
264 Ok((lambda, objf)) => (lambda.into(), objf),
265 Err(err) => {
266 return Err(anyhow::anyhow!(
267 "Error in IPM during condensation: {:?}",
268 err
269 ));
270 }
271 };
272 self.w = self.lambda.clone().into();
273 Ok(())
274 }
275
276 fn optimizations(&mut self) -> Result<()> {
277 self.error_models
278 .clone()
279 .iter_mut()
280 .filter_map(|(outeq, em)| {
281 if em.optimize() {
282 Some((outeq, em))
283 } else {
284 None
285 }
286 })
287 .try_for_each(|(outeq, em)| -> Result<()> {
288 let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
291 let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
292
293 let mut error_model_up = self.error_models.clone();
294 error_model_up.set_factor(outeq, gamma_up)?;
295
296 let mut error_model_down = self.error_models.clone();
297 error_model_down.set_factor(outeq, gamma_down)?;
298
299 let psi_up = calculate_psi(
300 &self.equation,
301 &self.data,
302 &self.theta,
303 &error_model_up,
304 false,
305 true,
306 )?;
307 let psi_down = calculate_psi(
308 &self.equation,
309 &self.data,
310 &self.theta,
311 &error_model_down,
312 false,
313 true,
314 )?;
315
316 let (lambda_up, objf_up) = match burke(&psi_up) {
317 Ok((lambda, objf)) => (lambda, objf),
318 Err(err) => {
319 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
321 }
322 };
323 let (lambda_down, objf_down) = match burke(&psi_down) {
324 Ok((lambda, objf)) => (lambda, objf),
325 Err(err) => {
326 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
329 }
331 };
332 if objf_up > self.objf {
333 self.error_models.set_factor(outeq, gamma_up)?;
334 self.objf = objf_up;
335 self.gamma_delta[outeq] *= 4.;
336 self.lambda = lambda_up;
337 self.psi = psi_up;
338 }
339 if objf_down > self.objf {
340 self.error_models.set_factor(outeq, gamma_down)?;
341 self.objf = objf_down;
342 self.gamma_delta[outeq] *= 4.;
343 self.lambda = lambda_down;
344 self.psi = psi_down;
345 }
346 self.gamma_delta[outeq] *= 0.5;
347 if self.gamma_delta[outeq] <= 0.01 {
348 self.gamma_delta[outeq] = 0.1;
349 }
350 Ok(())
351 })?;
352
353 Ok(())
354 }
355
356 fn logs(&self) {
357 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
358 tracing::debug!("Support points: {}", self.theta.nspp());
359
360 self.error_models.iter().for_each(|(outeq, em)| {
361 if ErrorModel::None == *em {
362 return;
363 }
364 tracing::debug!(
365 "Error model for outeq {}: {:.16}",
366 outeq,
367 em.factor().unwrap_or_default()
368 );
369 });
370
371 tracing::debug!("EPS = {:.4}", self.eps);
372 if self.last_objf > self.objf + 1e-4 {
374 tracing::warn!(
375 "Objective function decreased from {:.4} to {:.4} (delta = {})",
376 -2.0 * self.last_objf,
377 -2.0 * self.objf,
378 -2.0 * self.last_objf - -2.0 * self.objf
379 );
380 }
381 }
382
383 fn expansion(&mut self) -> Result<()> {
384 adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D)?;
385 Ok(())
386 }
387
388 fn set_status(&mut self, status: Status) {
389 self.status = status;
390 }
391
392 fn status(&self) -> &Status {
393 &self.status
394 }
395}