1use crate::{
2 prelude::{
3 algorithms::Algorithms,
4 routines::{
5 evaluation::{ipm::burke, qr},
6 output::{CycleLog, NPCycle, NPResult},
7 settings::Settings,
8 },
9 },
10 structs::{
11 psi::{calculate_psi, Psi},
12 theta::Theta,
13 },
14};
15use anyhow::bail;
16use anyhow::Result;
17use faer_ext::IntoNdarray;
18use pharmsol::{
19 prelude::{
20 data::{Data, ErrorModel},
21 simulator::Equation,
22 },
23 Subject,
24};
25
26use faer::Col;
27
28use ndarray::{
29 parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator},
30 Array, Array1, ArrayBase, Dim, OwnedRepr,
31};
32
33use crate::routines::{initialization, optimization::SppOptimizer};
34
35const THETA_F: f64 = 1e-2;
36const THETA_D: f64 = 1e-4;
37
38pub struct NPOD<E: Equation> {
39 equation: E,
40 psi: Psi,
41 theta: Theta,
42 lambda: Col<f64>,
43 w: Col<f64>,
44 last_objf: f64,
45 objf: f64,
46 cycle: usize,
47 gamma_delta: f64,
48 error_model: ErrorModel,
49 converged: bool,
50 cycle_log: CycleLog,
51 data: Data,
52 settings: Settings,
53}
54
55impl<E: Equation> Algorithms<E> for NPOD<E> {
56 fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
57 Ok(Box::new(Self {
58 equation,
59 psi: Psi::new(),
60 theta: Theta::new(),
61 lambda: Col::zeros(0),
62 w: Col::zeros(0),
63 last_objf: -1e30,
64 objf: f64::NEG_INFINITY,
65 cycle: 0,
66 gamma_delta: 0.1,
67 error_model: settings.error().clone().into(),
68 converged: false,
69 cycle_log: CycleLog::new(),
70 settings,
71 data,
72 }))
73 }
74 fn into_npresult(&self) -> NPResult<E> {
75 NPResult::new(
76 self.equation.clone(),
77 self.data.clone(),
78 self.theta.clone(),
79 self.psi.clone(),
80 self.w.clone(),
81 -2. * self.objf,
82 self.cycle,
83 self.converged,
84 self.settings.clone(),
85 self.cycle_log.clone(),
86 )
87 }
88
89 fn equation(&self) -> &E {
90 &self.equation
91 }
92
93 fn get_settings(&self) -> &Settings {
94 &self.settings
95 }
96
97 fn get_data(&self) -> &Data {
98 &self.data
99 }
100
101 fn get_prior(&self) -> Theta {
102 initialization::sample_space(&self.settings).unwrap()
103 }
104
105 fn inc_cycle(&mut self) -> usize {
106 self.cycle += 1;
107 self.cycle
108 }
109
110 fn get_cycle(&self) -> usize {
111 self.cycle
112 }
113
114 fn set_theta(&mut self, theta: Theta) {
115 self.theta = theta;
116 }
117
118 fn get_theta(&self) -> &Theta {
119 &self.theta
120 }
121
122 fn psi(&self) -> &Psi {
123 &self.psi
124 }
125
126 fn likelihood(&self) -> f64 {
127 self.objf
128 }
129
130 fn convergence_evaluation(&mut self) {
131 if (self.last_objf - self.objf).abs() <= THETA_F {
132 tracing::info!("Objective function convergence reached");
133 self.converged = true;
134 }
135
136 if self.cycle >= self.settings.config().cycles {
138 tracing::warn!("Maximum number of cycles reached");
139 self.converged = true;
140 }
141
142 if std::path::Path::new("stop").exists() {
144 tracing::warn!("Stopfile detected - breaking");
145 self.converged = true;
146 }
147
148 let state = NPCycle {
150 cycle: self.cycle,
151 objf: -2. * self.objf,
152 delta_objf: (self.last_objf - self.objf).abs(),
153 nspp: self.theta.nspp(),
154 theta: self.theta.clone(),
155 gamlam: self.error_model.scalar(),
156 converged: self.converged,
157 };
158
159 self.cycle_log.push(state);
161 self.last_objf = self.objf;
162 }
163
164 fn converged(&self) -> bool {
165 self.converged
166 }
167
168 fn evaluation(&mut self) -> Result<()> {
169 let error_model: ErrorModel = self.error_model.clone();
170
171 self.psi = calculate_psi(
172 &self.equation,
173 &self.data,
174 &self.theta,
175 &error_model,
176 self.cycle == 1 && self.settings.config().progress,
177 self.cycle != 1,
178 )?;
179
180 if let Err(err) = self.validate_psi() {
181 bail!(err);
182 }
183
184 (self.lambda, _) = match burke(&self.psi) {
185 Ok((lambda, objf)) => (lambda, objf),
186 Err(err) => {
187 bail!(err);
188 }
189 };
190 Ok(())
191 }
192
193 fn condensation(&mut self) -> Result<()> {
194 let max_lambda = self
195 .lambda
196 .iter()
197 .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
198
199 let mut keep = Vec::<usize>::new();
200 for (index, lam) in self.lambda.iter().enumerate() {
201 if *lam > max_lambda / 1000_f64 {
202 keep.push(index);
203 }
204 }
205 if self.psi.matrix().ncols() != keep.len() {
206 tracing::debug!(
207 "Lambda (max/1000) dropped {} support point(s)",
208 self.psi.matrix().ncols() - keep.len(),
209 );
210 }
211
212 self.theta.filter_indices(keep.as_slice());
213 self.psi.filter_column_indices(keep.as_slice());
214
215 let (r, perm) = qr::qrd(&self.psi)?;
217
218 let mut keep = Vec::<usize>::new();
219
220 let keep_n = self.psi.matrix().ncols().min(self.psi.matrix().nrows());
222 for i in 0..keep_n {
223 let test = r.col(i).norm_l2();
224 let r_diag_val = r.get(i, i);
225 let ratio = r_diag_val / test;
226 if ratio.abs() >= 1e-8 {
227 keep.push(*perm.get(i).unwrap());
228 }
229 }
230
231 if self.psi.matrix().ncols() != keep.len() {
233 tracing::debug!(
234 "QR decomposition dropped {} support point(s)",
235 self.psi.matrix().ncols() - keep.len(),
236 );
237 }
238
239 self.theta.filter_indices(keep.as_slice());
240 self.psi.filter_column_indices(keep.as_slice());
241
242 (self.lambda, self.objf) = match burke(&self.psi) {
243 Ok((lambda, objf)) => (lambda, objf),
244 Err(err) => {
245 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
246 }
247 };
248 self.w = self.lambda.clone();
249 Ok(())
250 }
251
252 fn optimizations(&mut self) -> Result<()> {
253 let gamma_up = self.error_model.scalar() * (1.0 + self.gamma_delta);
256 let gamma_down = self.error_model.scalar() / (1.0 + self.gamma_delta);
257
258 let mut error_model_up: ErrorModel = self.error_model.clone();
259 error_model_up.set_scalar(gamma_up);
260
261 let mut error_model_down: ErrorModel = self.error_model.clone();
262 error_model_down.set_scalar(gamma_down);
263
264 let psi_up = calculate_psi(
265 &self.equation,
266 &self.data,
267 &self.theta,
268 &error_model_up,
269 false,
270 true,
271 )?;
272 let psi_down = calculate_psi(
273 &self.equation,
274 &self.data,
275 &self.theta,
276 &error_model_down,
277 false,
278 true,
279 )?;
280
281 let (lambda_up, objf_up) = match burke(&psi_up) {
282 Ok((lambda, objf)) => (lambda, objf),
283 Err(err) => {
284 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
285 }
286 };
287 let (lambda_down, objf_down) = match burke(&psi_down) {
288 Ok((lambda, objf)) => (lambda, objf),
289 Err(err) => {
290 return Err(anyhow::anyhow!("Error in IPM: {:?}", err));
291 }
292 };
293
294 if objf_up > self.objf {
295 self.error_model.set_scalar(gamma_up);
296 self.objf = objf_up;
297 self.gamma_delta *= 4.;
298 self.lambda = lambda_up;
299 self.psi = psi_up;
300 }
301 if objf_down > self.objf {
302 self.error_model.set_scalar(gamma_down);
303 self.objf = objf_down;
304 self.gamma_delta *= 4.;
305 self.lambda = lambda_down;
306 self.psi = psi_down;
307 }
308 self.gamma_delta *= 0.5;
309 if self.gamma_delta <= 0.01 {
310 self.gamma_delta = 0.1;
311 }
312 Ok(())
313 }
314
315 fn logs(&self) {
316 tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
317 tracing::debug!("Support points: {}", self.theta.nspp());
318 tracing::debug!("Gamma = {:.16}", self.error_model.scalar());
319 if self.last_objf > self.objf + 1e-4 {
321 tracing::warn!(
322 "Objective function decreased from {:.4} to {:.4} (delta = {})",
323 -2.0 * self.last_objf,
324 -2.0 * self.objf,
325 -2.0 * self.last_objf - -2.0 * self.objf
326 );
327 }
328 }
329
330 fn expansion(&mut self) -> Result<()> {
331 let psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
333 let w: Array1<f64> = self.w.clone().iter().cloned().collect();
334 let pyl = psi.dot(&w);
335
336 let error_model: ErrorModel = self.error_model.clone();
338
339 let mut candididate_points: Vec<Array1<f64>> = Vec::default();
340 for spp in self.theta.matrix().row_iter() {
341 let candidate: Vec<f64> = spp.iter().cloned().collect();
342 let spp = Array1::from(candidate);
343 candididate_points.push(spp.to_owned());
344 }
345 candididate_points.par_iter_mut().for_each(|spp| {
346 let optimizer = SppOptimizer::new(&self.equation, &self.data, &error_model, &pyl);
347 let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
348 *spp = candidate_point;
349 });
355 for cp in candididate_points {
356 self.theta.suggest_point(cp.to_vec().as_slice(), THETA_D);
357 }
358 Ok(())
359 }
360}
361
362impl<E: Equation> NPOD<E> {
363 fn validate_psi(&mut self) -> Result<()> {
364 let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned();
365 if psi.iter().any(|x| x.is_nan() || x.is_infinite()) {
367 tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0");
368 for i in 0..psi.nrows() {
369 for j in 0..psi.ncols() {
370 let val = psi.get_mut((i, j)).unwrap();
371 if val.is_nan() || val.is_infinite() {
372 *val = 0.0;
373 }
374 }
375 }
376 }
377
378 let (_, col) = psi.dim();
380 let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
381 let plam = psi.dot(&ecol);
382 let w = 1. / &plam;
383
384 let indices: Vec<usize> = w
386 .iter()
387 .enumerate()
388 .filter(|(_, x)| x.is_nan() || x.is_infinite())
389 .map(|(i, _)| i)
390 .collect::<Vec<_>>();
391
392 if !indices.is_empty() {
394 let subject: Vec<&Subject> = self.data.get_subjects();
395 let zero_probability_subjects: Vec<&String> =
396 indices.iter().map(|&i| subject[i].id()).collect();
397
398 return Err(anyhow::anyhow!(
399 "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects
400 ));
401 }
402
403 Ok(())
404 }
405}