1use crate::algorithms::{Status, StopReason};
2use crate::prelude::algorithms::Algorithms;
3
4pub use crate::routines::estimation::ipm::burke;
5pub use crate::routines::estimation::qr;
6use crate::routines::settings::Settings;
7
8use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult};
9use crate::structs::psi::{calculate_psi, Psi};
10use crate::structs::theta::Theta;
11use crate::structs::weights::Weights;
12
13use anyhow::bail;
14use anyhow::Result;
15use pharmsol::prelude::{
16 data::{Data, ErrorModels},
17 simulator::Equation,
18};
19
20use pharmsol::prelude::ErrorModel;
21
22use crate::routines::initialization;
23
24use crate::routines::expansion::adaptative_grid::adaptative_grid;
25
26const THETA_E: f64 = 1e-4; const THETA_G: f64 = 1e-4; const THETA_F: f64 = 1e-2;
29const THETA_D: f64 = 1e-4;
30
31#[derive(Debug)]
32pub struct NPAG<E: Equation + Send + 'static> {
33 equation: E,
34 ranges: Vec<(f64, f64)>,
35 psi: Psi,
36 theta: Theta,
37 lambda: Weights,
38 w: Weights,
39 eps: f64,
40 last_objf: f64,
41 objf: f64,
42 f0: f64,
43 f1: f64,
44 cycle: usize,
45 gamma_delta: Vec<f64>,
46 error_models: ErrorModels,
47 status: Status,
48 cycle_log: CycleLog,
49 data: Data,
50 settings: Settings,
51}
52
53impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
54 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
55 Ok(Box::new(Self {
56 equation,
57 ranges: settings.parameters().ranges(),
58 psi: Psi::new(),
59 theta: Theta::new(),
60 lambda: Weights::default(),
61 w: Weights::default(),
62 eps: 0.2,
63 last_objf: -1e30,
64 objf: f64::NEG_INFINITY,
65 f0: -1e30,
66 f1: f64::default(),
67 cycle: 0,
68 gamma_delta: vec![0.1; settings.errormodels().len()],
69 error_models: settings.errormodels().clone(),
70 status: Status::Continue,
71 cycle_log: CycleLog::new(),
72 settings,
73 data,
74 }))
75 }
76
77 fn equation(&self) -> &E {
78 &self.equation
79 }
80 fn into_npresult(&self) -> NPResult<E> {
81 NPResult::new(
82 self.equation.clone(),
83 self.data.clone(),
84 self.theta.clone(),
85 self.psi.clone(),
86 self.w.clone(),
87 -2. * self.objf,
88 self.cycle,
89 self.status.clone(),
90 self.settings.clone(),
91 self.cycle_log.clone(),
92 )
93 }
94
95 fn settings(&self) -> &Settings {
96 &self.settings
97 }
98
99 fn data(&self) -> &Data {
100 &self.data
101 }
102
103 fn get_prior(&self) -> Theta {
104 initialization::sample_space(&self.settings).unwrap()
105 }
106
107 fn likelihood(&self) -> f64 {
108 self.objf
109 }
110
111 fn increment_cycle(&mut self) -> usize {
112 self.cycle += 1;
113 self.cycle
114 }
115
116 fn cycle(&self) -> usize {
117 self.cycle
118 }
119
120 fn set_theta(&mut self, theta: Theta) {
121 self.theta = theta;
122 }
123
124 fn theta(&self) -> &Theta {
125 &self.theta
126 }
127
128 fn psi(&self) -> &Psi {
129 &self.psi
130 }
131
132 fn evaluation(&mut self) -> Result<Status> {
133 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
134 tracing::debug!("Support points: {}", self.theta.nspp());
135
136 self.error_models.iter().for_each(|(outeq, em)| {
137 if ErrorModel::None == *em {
138 return;
139 }
140 tracing::debug!(
141 "Error model for outeq {}: {:.2}",
142 outeq,
143 em.factor().unwrap_or_default()
144 );
145 });
146
147 tracing::debug!("EPS = {:.4}", self.eps);
148 if self.last_objf > self.objf + 1e-4 {
150 tracing::warn!(
151 "Objective function decreased from {:.4} to {:.4} (delta = {})",
152 -2.0 * self.last_objf,
153 -2.0 * self.objf,
154 -2.0 * self.last_objf - -2.0 * self.objf
155 );
156 }
157
158 let psi = self.psi.matrix();
159 let w = &self.w;
160 if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
161 self.eps /= 2.;
162 if self.eps <= THETA_E {
163 let pyl = psi * w.weights();
164 self.f1 = pyl.iter().map(|x| x.ln()).sum();
165 if (self.f1 - self.f0).abs() <= THETA_F {
166 tracing::info!("The model converged after {} cycles", self.cycle,);
167 self.set_status(Status::Stop(StopReason::Converged));
168 self.log_cycle_state();
169 return Ok(self.status().clone());
170 } else {
171 self.f0 = self.f1;
172 self.eps = 0.2;
173 }
174 }
175 }
176
177 if self.cycle >= self.settings.config().cycles {
179 tracing::warn!("Maximum number of cycles reached");
180 self.set_status(Status::Stop(StopReason::MaxCycles));
181 self.log_cycle_state();
182 return Ok(self.status().clone());
183 }
184
185 if std::path::Path::new("stop").exists() {
187 tracing::warn!("Stopfile detected - breaking");
188 self.set_status(Status::Stop(StopReason::Stopped));
189 self.log_cycle_state();
190 return Ok(self.status().clone());
191 }
192
193 self.set_status(Status::Continue);
195 self.log_cycle_state();
196 Ok(self.status().clone())
197 }
198
199 fn estimation(&mut self) -> Result<()> {
200 self.psi = calculate_psi(
201 &self.equation,
202 &self.data,
203 &self.theta,
204 &self.error_models,
205 self.cycle == 1 && self.settings.config().progress,
206 self.cycle != 1,
207 )?;
208
209 if let Err(err) = self.validate_psi() {
210 bail!(err);
211 }
212
213 (self.lambda, _) = match burke(&self.psi) {
214 Ok((lambda, objf)) => (lambda.into(), objf),
215 Err(err) => {
216 bail!("Error in IPM during estimation: {:?}", err);
217 }
218 };
219 Ok(())
220 }
221
222 fn condensation(&mut self) -> Result<()> {
223 let max_lambda = self
226 .lambda
227 .iter()
228 .fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
229
230 let mut keep = Vec::<usize>::new();
231 for (index, lam) in self.lambda.iter().enumerate() {
232 if lam > max_lambda / 1000_f64 {
233 keep.push(index);
234 }
235 }
236 if self.psi.matrix().ncols() != keep.len() {
237 tracing::debug!(
238 "Lambda (max/1000) dropped {} support point(s)",
239 self.psi.matrix().ncols() - keep.len(),
240 );
241 }
242
243 self.theta.filter_indices(keep.as_slice());
244 self.psi.filter_column_indices(keep.as_slice());
245
246 let (r, perm) = qr::qrd(&self.psi)?;
248
249 let mut keep = Vec::<usize>::new();
250
251 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
253 for i in 0..keep_n {
254 let test = r.col(i).norm_l2();
255 let r_diag_val = r.get(i, i);
256 let ratio = r_diag_val / test;
257 if ratio.abs() >= 1e-8 {
258 keep.push(*perm.get(i).unwrap());
259 }
260 }
261
262 if self.psi.matrix().ncols() != keep.len() {
264 tracing::debug!(
265 "QR decomposition dropped {} support point(s)",
266 self.psi.matrix().ncols() - keep.len(),
267 );
268 }
269
270 self.theta.filter_indices(keep.as_slice());
272 self.psi.filter_column_indices(keep.as_slice());
274
275 self.validate_psi()?;
276 (self.lambda, self.objf) = match burke(&self.psi) {
277 Ok((lambda, objf)) => (lambda.into(), objf),
278 Err(err) => {
279 return Err(anyhow::anyhow!(
280 "Error in IPM during condensation: {:?}",
281 err
282 ));
283 }
284 };
285 self.w = self.lambda.clone().into();
286 Ok(())
287 }
288
289 fn optimizations(&mut self) -> Result<()> {
290 self.error_models
291 .clone()
292 .iter_mut()
293 .filter_map(|(outeq, em)| {
294 if em.optimize() {
295 Some((outeq, em))
296 } else {
297 None
298 }
299 })
300 .try_for_each(|(outeq, em)| -> Result<()> {
301 let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]);
304 let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]);
305
306 let mut error_model_up = self.error_models.clone();
307 error_model_up.set_factor(outeq, gamma_up)?;
308
309 let mut error_model_down = self.error_models.clone();
310 error_model_down.set_factor(outeq, gamma_down)?;
311
312 let psi_up = calculate_psi(
313 &self.equation,
314 &self.data,
315 &self.theta,
316 &error_model_up,
317 false,
318 true,
319 )?;
320 let psi_down = calculate_psi(
321 &self.equation,
322 &self.data,
323 &self.theta,
324 &error_model_down,
325 false,
326 true,
327 )?;
328
329 let (lambda_up, objf_up) = match burke(&psi_up) {
330 Ok((lambda, objf)) => (lambda, objf),
331 Err(err) => {
332 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
334 }
335 };
336 let (lambda_down, objf_down) = match burke(&psi_down) {
337 Ok((lambda, objf)) => (lambda, objf),
338 Err(err) => {
339 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
342 }
344 };
345 if objf_up > self.objf {
346 self.error_models.set_factor(outeq, gamma_up)?;
347 self.objf = objf_up;
348 self.gamma_delta[outeq] *= 4.;
349 self.lambda = lambda_up;
350 self.psi = psi_up;
351 }
352 if objf_down > self.objf {
353 self.error_models.set_factor(outeq, gamma_down)?;
354 self.objf = objf_down;
355 self.gamma_delta[outeq] *= 4.;
356 self.lambda = lambda_down;
357 self.psi = psi_down;
358 }
359 self.gamma_delta[outeq] *= 0.5;
360 if self.gamma_delta[outeq] <= 0.01 {
361 self.gamma_delta[outeq] = 0.1;
362 }
363 Ok(())
364 })?;
365
366 Ok(())
367 }
368
369 fn expansion(&mut self) -> Result<()> {
370 adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D)?;
371 Ok(())
372 }
373
374 fn set_status(&mut self, status: Status) {
375 self.status = status;
376 }
377
378 fn status(&self) -> &Status {
379 &self.status
380 }
381
382 fn log_cycle_state(&mut self) {
383 let state = NPCycle::new(
384 self.cycle,
385 -2. * self.objf,
386 self.error_models.clone(),
387 self.theta.clone(),
388 self.theta.nspp(),
389 (self.last_objf - self.objf).abs(),
390 self.status.clone(),
391 );
392 self.cycle_log.push(state);
393 self.last_objf = self.objf;
394 }
395}