1use crate::{
2 prelude::{
3 algorithms::Algorithms,
4 routines::{
5 evaluation::{ipm::burke, qr},
6 output::{CycleLog, NPCycle, NPResult},
7 settings::Settings,
8 },
9 },
10 structs::{
11 psi::{calculate_psi, Psi},
12 theta::Theta,
13 },
14};
15use anyhow::bail;
16use anyhow::Result;
17use faer_ext::IntoNdarray;
18use pharmsol::{
19 prelude::{
20 data::{Data, ErrorModel},
21 simulator::Equation,
22 },
23 Subject,
24};
25
26use faer::Col;
27
28use ndarray::{
29 parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
30 Array, Array1, ArrayBase, Dim, OwnedRepr,
31};
32
33use crate::routines::{initialization, optimization::SppOptimizer};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation> {
39 equation: E,
40 ranges: Vec<(f64, f64)>,
41 psi: Psi,
42 theta: Theta,
43 lambda: Col<f64>,
44 w: Col<f64>,
45 last_objf: f64,
46 objf: f64,
47 cycle: usize,
48 gamma_delta: f64,
49 gamma: f64,
50 converged: bool,
51 cycle_log: CycleLog,
52 data: Data,
53 settings: Settings,
54}
55
56impl<E: Equation> Algorithms<E> for NPOD<E> {
57 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
58 Ok(Box::new(Self {
59 equation,
60 ranges: settings.parameters().ranges(),
61 psi: Psi::new(),
62 theta: Theta::new(),
63 lambda: Col::zeros(0),
64 w: Col::zeros(0),
65 last_objf: -1e30,
66 objf: f64::NEG_INFINITY,
67 cycle: 0,
68 gamma_delta: 0.1,
69 gamma: settings.error().value,
70 converged: false,
71 cycle_log: CycleLog::new(),
72 settings,
73 data,
74 }))
75 }
76 fn into_npresult(&self) -> NPResult<E> {
77 NPResult::new(
78 self.equation.clone(),
79 self.data.clone(),
80 self.theta.clone(),
81 self.psi.clone(),
82 self.w.clone(),
83 -2. * self.objf,
84 self.cycle,
85 self.converged,
86 self.settings.clone(),
87 self.cycle_log.clone(),
88 )
89 }
90
91 fn equation(&self) -> &E {
92 &self.equation
93 }
94
95 fn get_settings(&self) -> &Settings {
96 &self.settings
97 }
98
99 fn get_data(&self) -> &Data {
100 &self.data
101 }
102
103 fn get_prior(&self) -> Theta {
104 initialization::sample_space(&self.settings).unwrap()
105 }
106
107 fn inc_cycle(&mut self) -> usize {
108 self.cycle += 1;
109 self.cycle
110 }
111
112 fn get_cycle(&self) -> usize {
113 self.cycle
114 }
115
116 fn set_theta(&mut self, theta: Theta) {
117 self.theta = theta;
118 }
119
120 fn get_theta(&self) -> &Theta {
121 &self.theta
122 }
123
124 fn psi(&self) -> &Psi {
125 &self.psi
126 }
127
128 fn likelihood(&self) -> f64 {
129 self.objf
130 }
131
132 fn convergence_evaluation(&mut self) {
133 if (self.last_objf - self.objf).abs() <= THETA_F {
134 tracing::info!("Objective function convergence reached");
135 self.converged = true;
136 }
137
138 if self.cycle >= self.settings.config().cycles {
140 tracing::warn!("Maximum number of cycles reached");
141 self.converged = true;
142 }
143
144 if std::path::Path::new("stop").exists() {
146 tracing::warn!("Stopfile detected - breaking");
147 self.converged = true;
148 }
149
150 let state = NPCycle {
152 cycle: self.cycle,
153 objf: -2. * self.objf,
154 delta_objf: (self.last_objf - self.objf).abs(),
155 nspp: self.theta.nspp(),
156 theta: self.theta.clone(),
157 gamlam: self.gamma,
158 converged: self.converged,
159 };
160
161 self.cycle_log.push(state);
163 self.last_objf = self.objf;
164 }
165
166 fn converged(&self) -> bool {
167 self.converged
168 }
169
170 fn evaluation(&mut self) -> Result<()> {
171 self.psi = calculate_psi(
172 &self.equation,
173 &self.data,
174 &self.theta,
175 &ErrorModel::new(
176 self.settings.error().poly,
177 self.gamma,
178 &self.settings.error().error_model().into(),
179 ),
180 self.cycle == 1 && self.settings.config().progress,
181 self.cycle != 1,
182 );
183
184 if let Err(err) = self.validate_psi() {
185 bail!(err);
186 }
187
188 (self.lambda, _) = match burke(&self.psi) {
189 Ok((lambda, objf)) => (lambda, objf),
190 Err(err) => {
191 bail!(err);
192 }
193 };
194 Ok(())
195 }
196
197 fn condensation(&mut self) -> Result<()> {
198 let max_lambda = self
199 .lambda
200 .iter()
201 .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
202
203 let mut keep = Vec::<usize>::new();
204 for (index, lam) in self.lambda.iter().enumerate() {
205 if *lam > max_lambda / 1000_f64 {
206 keep.push(index);
207 }
208 }
209 if self.psi.matrix().ncols() != keep.len() {
210 tracing::debug!(
211 "Lambda (max/1000) dropped {} support point(s)",
212 self.psi.matrix().ncols() - keep.len(),
213 );
214 }
215
216 self.theta.filter_indices(keep.as_slice());
217 self.psi.filter_column_indices(keep.as_slice());
218
219 let (r, perm) = qr::qrd(&self.psi)?;
221
222 let mut keep = Vec::<usize>::new();
223
224 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
226 for i in 0..keep_n {
227 let test = r.col(i).norm_l2();
228 let r_diag_val = r.get(i, i);
229 let ratio = r_diag_val / test;
230 if ratio.abs() >= 1e-8 {
231 keep.push(*perm.get(i).unwrap());
232 }
233 }
234
235 if self.psi.matrix().ncols() != keep.len() {
237 tracing::debug!(
238 "QR decomposition dropped {} support point(s)",
239 self.psi.matrix().ncols() - keep.len(),
240 );
241 }
242
243 self.theta.filter_indices(keep.as_slice());
244 self.psi.filter_column_indices(keep.as_slice());
245
246 (self.lambda, self.objf) = match burke(&self.psi) {
247 Ok((lambda, objf)) => (lambda, objf),
248 Err(err) => {
249 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
250 }
251 };
252 self.w = self.lambda.clone();
253 Ok(())
254 }
255
256 fn optimizations(&mut self) -> Result<()> {
257 let gamma_up = self.gamma * (1.0 + self.gamma_delta);
260 let gamma_down = self.gamma / (1.0 + self.gamma_delta);
261
262 let psi_up = calculate_psi(
263 &self.equation,
264 &self.data,
265 &self.theta,
266 &ErrorModel::new(
267 self.settings.error().poly,
268 self.gamma,
269 &self.settings.error().error_model().into(),
270 ),
271 false,
272 true,
273 );
274 let psi_down = calculate_psi(
275 &self.equation,
276 &self.data,
277 &self.theta,
278 &ErrorModel::new(
279 self.settings.error().poly,
280 self.gamma,
281 &self.settings.error().error_model().into(),
282 ),
283 false,
284 true,
285 );
286
287 let (lambda_up, objf_up) = match burke(&psi_up) {
288 Ok((lambda, objf)) => (lambda, objf),
289 Err(err) => {
290 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
291 }
292 };
293 let (lambda_down, objf_down) = match burke(&psi_down) {
294 Ok((lambda, objf)) => (lambda, objf),
295 Err(err) => {
296 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
297 }
298 };
299
300 if objf_up > self.objf {
301 self.gamma = gamma_up;
302 self.objf = objf_up;
303 self.gamma_delta *= 4.;
304 self.lambda = lambda_up;
305 self.psi = psi_up;
306 }
307 if objf_down > self.objf {
308 self.gamma = gamma_down;
309 self.objf = objf_down;
310 self.gamma_delta *= 4.;
311 self.lambda = lambda_down;
312 self.psi = psi_down;
313 }
314 self.gamma_delta *= 0.5;
315 if self.gamma_delta <= 0.01 {
316 self.gamma_delta = 0.1;
317 }
318 Ok(())
319 }
320
321 fn logs(&self) {
322 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
323 tracing::debug!("Support points: {}", self.theta.nspp());
324 tracing::debug!("Gamma = {:.16}", self.gamma);
325 if self.last_objf > self.objf + 1e-4 {
327 tracing::warn!(
328 "Objective function decreased from {:.4} to {:.4} (delta = {})",
329 -2.0 * self.last_objf,
330 -2.0 * self.objf,
331 -2.0 * self.last_objf - -2.0 * self.objf
332 );
333 }
334 }
335
336 fn expansion(&mut self) -> Result<()> {
337 let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
339 let w: Array1<f64> = self.w.clone().iter().cloned().collect();
340 let pyl = psi.dot(&w);
341
342 let error_type = self.settings.error().error_model().into();
344 let sigma = &ErrorModel::new(self.settings.error().poly, self.gamma, &error_type);
345
346 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
347 for spp in self.theta.matrix().row_iter() {
348 let candidate: Vec<f64> = spp.iter().cloned().collect();
349 let spp = Array1::from(candidate);
350 candididate_points.push(spp.to_owned());
351 }
352 candididate_points.par_iter_mut().for_each(|spp| {
353 let optimizer = SppOptimizer::new(&self.equation, &self.data, sigma, &pyl);
354 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
355 *spp = candidate_point;
356 });
362 for cp in candididate_points {
363 self.theta
364 .suggest_point(cp.to_vec().as_slice(), THETA_D, &self.ranges);
365 }
366 Ok(())
367 }
368}
369
370impl<E: Equation> NPOD<E> {
371 fn validate_psi(&mut self) -> Result<()> {
372 let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
373 if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
375 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
376 for i in 0..psi.nrows() {
377 for j in 0..psi.ncols() {
378 let val = psi.get_mut((i, j)).unwrap();
379 if val.is_nan() || val.is_infinite() {
380 *val = 0.0;
381 }
382 }
383 }
384 }
385
386 let (_, col) = psi.dim();
388 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
389 let plam = psi.dot(&ecol);
390 let w = 1. / &plam;
391
392 let indices: Vec<usize> = w
394 .iter()
395 .enumerate()
396 .filter(|(_, x)| x.is_nan() || x.is_infinite())
397 .map(|(i, _)| i)
398 .collect::<Vec<_>>();
399
400 if !indices.is_empty() {
402 let subject: Vec<&Subject> = self.data.get_subjects();
403 let zero_probability_subjects: Vec<&String> =
404 indices.iter().map(|&i| subject[i].id()).collect();
405
406 return Err(anyhow::anyhow!(
407 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
408 ));
409 }
410
411 Ok(())
412 }
413}