1use crate::prelude::algorithms::Algorithms;
2
3pub use crate::routines::evaluation::ipm::burke;
4pub use crate::routines::evaluation::qr;
5use crate::routines::settings::Settings;
6
7use crate::routines::output::{CycleLog, NPCycle, NPResult};
8use crate::structs::psi::{calculate_psi, Psi};
9use crate::structs::theta::Theta;
10
11use anyhow::bail;
12use anyhow::Result;
13use pharmsol::prelude::{
14 data::{Data, ErrorModel, ErrorType},
15 simulator::Equation,
16};
17
18use faer::Col;
19
20use crate::routines::initialization;
21
22use crate::routines::expansion::adaptative_grid::adaptative_grid;
23
24const THETA_E: f64 = 1e-4; const THETA_G: f64 = 1e-4; const THETA_F: f64 = 1e-2;
27const THETA_D: f64 = 1e-4;
28
29#[derive(Debug)]
30pub struct NPAG<E: Equation> {
31 equation: E,
32 ranges: Vec<(f64, f64)>,
33 psi: Psi,
34 theta: Theta,
35 lambda: Col<f64>,
36 w: Col<f64>,
37 eps: f64,
38 last_objf: f64,
39 objf: f64,
40 f0: f64,
41 f1: f64,
42 cycle: usize,
43 gamma_delta: f64,
44 gamma: f64,
45 error_type: ErrorType,
46 converged: bool,
47 cycle_log: CycleLog,
48 data: Data,
49 settings: Settings,
50}
51
52impl<E: Equation> Algorithms<E> for NPAG<E> {
53 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
54 Ok(Box::new(Self {
55 equation,
56 ranges: settings.parameters().ranges(),
57 psi: Psi::new(),
58 theta: Theta::new(),
59 lambda: Col::zeros(0),
60 w: Col::zeros(0),
61 eps: 0.2,
62 last_objf: -1e30,
63 objf: f64::NEG_INFINITY,
64 f0: -1e30,
65 f1: f64::default(),
66 cycle: 0,
67 gamma_delta: 0.1,
68 gamma: settings.error().value,
69 error_type: settings.error().error_model().into(),
70 converged: false,
71 cycle_log: CycleLog::new(),
72 settings,
73 data,
74 }))
75 }
76
77 fn equation(&self) -> &E {
78 &self.equation
79 }
80 fn into_npresult(&self) -> NPResult<E> {
81 NPResult::new(
82 self.equation.clone(),
83 self.data.clone(),
84 self.theta.clone(),
85 self.psi.clone(),
86 self.w.clone(),
87 -2. * self.objf,
88 self.cycle,
89 self.converged,
90 self.settings.clone(),
91 self.cycle_log.clone(),
92 )
93 }
94
95 fn get_settings(&self) -> &Settings {
96 &self.settings
97 }
98
99 fn get_data(&self) -> &Data {
100 &self.data
101 }
102
103 fn get_prior(&self) -> Theta {
104 initialization::sample_space(&self.settings).unwrap()
105 }
106
107 fn likelihood(&self) -> f64 {
108 self.objf
109 }
110
111 fn inc_cycle(&mut self) -> usize {
112 self.cycle += 1;
113 self.cycle
114 }
115
116 fn get_cycle(&self) -> usize {
117 self.cycle
118 }
119
120 fn set_theta(&mut self, theta: Theta) {
121 self.theta = theta;
122 }
123
124 fn get_theta(&self) -> &Theta {
125 &self.theta
126 }
127
128 fn psi(&self) -> &Psi {
129 &self.psi
130 }
131
132 fn convergence_evaluation(&mut self) {
133 let psi = self.psi.matrix();
134 let w = &self.w;
135 if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
136 self.eps /= 2.;
137 if self.eps <= THETA_E {
138 let pyl = psi * w;
139 self.f1 = pyl.iter().map(|x| x.ln()).sum();
140 if (self.f1 - self.f0).abs() <= THETA_F {
141 tracing::info!("The model converged after {} cycles", self.cycle,);
142 self.converged = true;
143 } else {
144 self.f0 = self.f1;
145 self.eps = 0.2;
146 }
147 }
148 }
149
150 if self.cycle >= self.settings.config().cycles {
152 tracing::warn!("Maximum number of cycles reached");
153 self.converged = true;
154 }
155
156 if std::path::Path::new("stop").exists() {
158 tracing::warn!("Stopfile detected - breaking");
159 self.converged = true;
160 }
161
162 let state = NPCycle {
164 cycle: self.cycle,
165 objf: -2. * self.objf,
166 delta_objf: (self.last_objf - self.objf).abs(),
167 nspp: self.theta.nspp(),
168 theta: self.theta.clone(),
169 gamlam: self.gamma,
170 converged: self.converged,
171 };
172
173 self.cycle_log.push(state);
175 self.last_objf = self.objf;
176 }
177
178 fn converged(&self) -> bool {
179 self.converged
180 }
181
182 fn evaluation(&mut self) -> Result<()> {
183 self.psi = calculate_psi(
184 &self.equation,
185 &self.data,
186 &self.theta,
187 &ErrorModel::new(self.settings.error().poly, self.gamma, &self.error_type),
188 self.cycle == 1 && self.settings.config().progress,
189 self.cycle != 1,
190 );
191
192 if let Err(err) = self.validate_psi() {
193 bail!(err);
194 }
195
196 (self.lambda, _) = match burke(&self.psi) {
197 Ok((lambda, objf)) => (lambda, objf),
198 Err(err) => {
199 bail!("Error in IPM during evaluation: {:?}", err);
200 }
201 };
202 Ok(())
203 }
204
205 fn condensation(&mut self) -> Result<()> {
206 let max_lambda = self
209 .lambda
210 .iter()
211 .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
212
213 let mut keep = Vec::<usize>::new();
214 for (index, lam) in self.lambda.iter().enumerate() {
215 if *lam > max_lambda / 1000_f64 {
216 keep.push(index);
217 }
218 }
219 if self.psi.matrix().ncols() != keep.len() {
220 tracing::debug!(
221 "Lambda (max/1000) dropped {} support point(s)",
222 self.psi.matrix().ncols() - keep.len(),
223 );
224 }
225
226 self.theta.filter_indices(keep.as_slice());
227 self.psi.filter_column_indices(keep.as_slice());
228
229 let (r, perm) = qr::qrd(&self.psi)?;
231
232 let mut keep = Vec::<usize>::new();
233
234 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
236 for i in 0..keep_n {
237 let test = r.col(i).norm_l2();
238 let r_diag_val = r.get(i, i);
239 let ratio = r_diag_val / test;
240 if ratio.abs() >= 1e-8 {
241 keep.push(*perm.get(i).unwrap());
242 }
243 }
244
245 if self.psi.matrix().ncols() != keep.len() {
247 tracing::debug!(
248 "QR decomposition dropped {} support point(s)",
249 self.psi.matrix().ncols() - keep.len(),
250 );
251 }
252
253 self.theta.filter_indices(keep.as_slice());
255 self.psi.filter_column_indices(keep.as_slice());
257
258 self.validate_psi()?;
259 (self.lambda, self.objf) = match burke(&self.psi) {
260 Ok((lambda, objf)) => (lambda, objf),
261 Err(err) => {
262 return Err(anyhow::anyhow!(
263 "Error in IPM during condensation: {:?}",
264 err
265 ));
266 }
267 };
268 self.w = self.lambda.clone();
269 Ok(())
270 }
271
272 fn optimizations(&mut self) -> Result<()> {
273 let gamma_up = self.gamma * (1.0 + self.gamma_delta);
276 let gamma_down = self.gamma / (1.0 + self.gamma_delta);
277
278 let psi_up = calculate_psi(
279 &self.equation,
280 &self.data,
281 &self.theta,
282 &ErrorModel::new(self.settings.error().poly, gamma_up, &self.error_type),
283 false,
284 true,
285 );
286 let psi_down = calculate_psi(
287 &self.equation,
288 &self.data,
289 &self.theta,
290 &ErrorModel::new(self.settings.error().poly, gamma_down, &self.error_type),
291 false,
292 true,
293 );
294
295 let (lambda_up, objf_up) = match burke(&psi_up) {
296 Ok((lambda, objf)) => (lambda, objf),
297 Err(err) => {
298 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
300 }
301 };
302 let (lambda_down, objf_down) = match burke(&psi_down) {
303 Ok((lambda, objf)) => (lambda, objf),
304 Err(err) => {
305 return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err));
308 }
310 };
311 if objf_up > self.objf {
312 self.gamma = gamma_up;
313 self.objf = objf_up;
314 self.gamma_delta *= 4.;
315 self.lambda = lambda_up;
316 self.psi = psi_up;
317 }
318 if objf_down > self.objf {
319 self.gamma = gamma_down;
320 self.objf = objf_down;
321 self.gamma_delta *= 4.;
322 self.lambda = lambda_down;
323 self.psi = psi_down;
324 }
325 self.gamma_delta *= 0.5;
326 if self.gamma_delta <= 0.01 {
327 self.gamma_delta = 0.1;
328 }
329 Ok(())
330 }
331
332 fn logs(&self) {
333 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
334 tracing::debug!("Support points: {}", self.theta.nspp());
335 tracing::debug!("Gamma = {:.16}", self.gamma);
336 tracing::debug!("EPS = {:.4}", self.eps);
337 if self.last_objf > self.objf + 1e-4 {
339 tracing::warn!(
340 "Objective function decreased from {:.4} to {:.4} (delta = {})",
341 -2.0 * self.last_objf,
342 -2.0 * self.objf,
343 -2.0 * self.last_objf - -2.0 * self.objf
344 );
345 }
346 }
347
348 fn expansion(&mut self) -> Result<()> {
349 adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D);
350 Ok(())
351 }
352}