1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, RngExt};
7use rayon::prelude::*;
8
9use crate::{
10 data::{Covariates, Infusion},
11 error_model::AssayErrorModels,
12 prelude::simulator::Prediction,
13 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
14 Subject,
15};
16
17use super::spphash;
18use crate::simulator::cache::{SdeLikelihoodCache, DEFAULT_CACHE_SIZE};
19
20use diffsol::VectorCommon;
21
22use crate::PharmsolError;
23
24use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
25
26#[inline(always)]
46#[allow(clippy::too_many_arguments)]
47pub(crate) fn simulate_sde_event(
48 drift: &Drift,
49 difussion: &Diffusion,
50 x: V,
51 support_point: &[f64],
52 cov: &Covariates,
53 infusions: &[Infusion],
54 ndrugs: usize,
55 ti: f64,
56 tf: f64,
57) -> V {
58 if ti == tf {
59 return x;
60 }
61
62 let params_v = V::from_vec(support_point.to_vec(), NalgebraContext);
63 let covariates = cov.clone();
64 let infusion_events = infusions.to_vec();
65 let drift_fn = *drift;
66 let diffusion_fn = *difussion;
67
68 let params_for_drift = params_v.clone();
69 let drift_closure = move |time: f64, state: &DVector<f64>, out: &mut DVector<f64>| {
70 let mut rateiv = V::zeros(ndrugs, NalgebraContext);
71 for infusion in &infusion_events {
72 if time >= infusion.time() && time <= infusion.duration() + infusion.time() {
73 rateiv[infusion.input()] += infusion.amount() / infusion.duration();
74 }
75 }
76
77 let state_v: V = state.clone().into();
78 let mut out_v = V::zeros(state.len(), NalgebraContext);
79 drift_fn(
80 &state_v,
81 ¶ms_for_drift,
82 time,
83 &mut out_v,
84 &rateiv,
85 &covariates,
86 );
87 out.copy_from(out_v.inner());
88 };
89
90 let diffusion_closure = move |_time: f64, _state: &DVector<f64>, out: &mut DVector<f64>| {
91 let mut out_v = V::zeros(out.len(), NalgebraContext);
92 diffusion_fn(¶ms_v, &mut out_v);
93 out.copy_from(out_v.inner());
94 };
95
96 simulate_sde_event_with(drift_closure, diffusion_closure, x.inner().clone(), ti, tf).into()
97}
98
99pub(crate) fn simulate_sde_event_with<D, G>(
100 drift: D,
101 diffusion: G,
102 initial_state: DVector<f64>,
103 ti: f64,
104 tf: f64,
105) -> DVector<f64>
106where
107 D: Fn(f64, &DVector<f64>, &mut DVector<f64>),
108 G: Fn(f64, &DVector<f64>, &mut DVector<f64>),
109{
110 if ti == tf {
111 return initial_state;
112 }
113
114 let mut sde = em::EM::new(drift, diffusion, initial_state, 1e-2, 1e-2);
115 let (_time, solution) = sde.solve(ti, tf);
116 solution.last().unwrap().clone()
117}
118
119#[derive(Clone, Debug)]
127pub struct SDE {
128 drift: Drift,
129 diffusion: Diffusion,
130 lag: Lag,
131 fa: Fa,
132 init: Init,
133 out: Out,
134 neqs: Neqs,
135 nparticles: usize,
136 cache: Option<SdeLikelihoodCache>,
137}
138
139impl SDE {
140 pub fn new(
150 drift: Drift,
151 diffusion: Diffusion,
152 lag: Lag,
153 fa: Fa,
154 init: Init,
155 out: Out,
156 nparticles: usize,
157 ) -> Self {
158 Self {
159 drift,
160 diffusion,
161 lag,
162 fa,
163 init,
164 out,
165 neqs: Neqs::default(),
166 nparticles,
167 cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)),
168 }
169 }
170
171 pub fn with_nstates(mut self, nstates: usize) -> Self {
173 self.neqs.nstates = nstates;
174 self
175 }
176
177 pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
179 self.neqs.ndrugs = ndrugs;
180 self
181 }
182
183 pub fn with_nout(mut self, nout: usize) -> Self {
185 self.neqs.nout = nout;
186 self
187 }
188}
189
190impl super::Cache for SDE {
191 fn with_cache_capacity(mut self, size: u64) -> Self {
192 self.cache = Some(SdeLikelihoodCache::new(size));
193 self
194 }
195
196 fn enable_cache(mut self) -> Self {
197 self.cache = Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE));
198 self
199 }
200
201 fn clear_cache(&self) {
202 if let Some(cache) = &self.cache {
203 cache.invalidate_all();
204 }
205 }
206
207 fn disable_cache(mut self) -> Self {
208 self.cache = None;
209 self
210 }
211}
212
213impl State for Vec<DVector<f64>> {
217 fn add_bolus(&mut self, input: usize, amount: f64) {
224 self.par_iter_mut().for_each(|particle| {
225 particle[input] += amount;
226 });
227 }
228}
229
230impl Predictions for Array2<Prediction> {
234 fn new(nparticles: usize) -> Self {
235 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
236 }
237 fn squared_error(&self) -> f64 {
238 unimplemented!();
239 }
240 fn get_predictions(&self) -> Vec<Prediction> {
241 if self.is_empty() || self.ncols() == 0 {
243 return Vec::new();
244 }
245
246 let mut result = Vec::with_capacity(self.ncols());
247
248 for col in 0..self.ncols() {
249 let column = self.column(col);
250
251 let mean_prediction: f64 = column
252 .iter()
253 .map(|pred: &Prediction| pred.prediction())
254 .sum::<f64>()
255 / self.nrows() as f64;
256
257 let mut prediction = column.first().unwrap().clone();
258 prediction.set_prediction(mean_prediction);
259 result.push(prediction);
260 }
261
262 result
263 }
264 fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result<f64, crate::PharmsolError> {
265 let predictions = self.get_predictions();
267 if predictions.is_empty() {
268 return Ok(0.0);
269 }
270
271 let log_liks: Result<Vec<f64>, _> = predictions
272 .iter()
273 .filter(|p| p.observation().is_some())
274 .map(|p| p.log_likelihood(error_models))
275 .collect();
276
277 log_liks.map(|lls| lls.iter().sum())
278 }
279}
280
281impl EquationTypes for SDE {
282 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
285
286impl EquationPriv for SDE {
287 #[inline(always)]
308 fn lag(&self) -> &Lag {
309 &self.lag
310 }
311
312 #[inline(always)]
313 fn fa(&self) -> &Fa {
314 &self.fa
315 }
316
317 #[inline(always)]
318 fn get_nstates(&self) -> usize {
319 self.neqs.nstates
320 }
321
322 #[inline(always)]
323 fn get_ndrugs(&self) -> usize {
324 self.neqs.ndrugs
325 }
326
327 #[inline(always)]
328 fn get_nouteqs(&self) -> usize {
329 self.neqs.nout
330 }
331 #[inline(always)]
332 fn solve(
333 &self,
334 state: &mut Self::S,
335 support_point: &[f64],
336 covariates: &Covariates,
337 infusions: &[Infusion],
338 ti: f64,
339 tf: f64,
340 ) -> Result<(), PharmsolError> {
341 let ndrugs = self.get_ndrugs();
342 state.par_iter_mut().for_each(|particle| {
343 *particle = simulate_sde_event(
344 &self.drift,
345 &self.diffusion,
346 particle.clone().into(),
347 support_point,
348 covariates,
349 infusions,
350 ndrugs,
351 ti,
352 tf,
353 )
354 .inner()
355 .clone();
356 });
357 Ok(())
358 }
359 fn nparticles(&self) -> usize {
360 self.nparticles
361 }
362
363 fn is_sde(&self) -> bool {
364 true
365 }
366 #[inline(always)]
367 fn process_observation(
368 &self,
369 support_point: &[f64],
370 observation: &crate::Observation,
371 error_models: Option<&AssayErrorModels>,
372 _time: f64,
373 covariates: &Covariates,
374 x: &mut Self::S,
375 likelihood: &mut Vec<f64>,
376 output: &mut Self::P,
377 ) -> Result<(), PharmsolError> {
378 let mut pred = vec![Prediction::default(); self.nparticles];
379
380 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
381 let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
382 (self.out)(
383 &x[i].clone().into(),
384 &V::from_vec(support_point.to_vec(), NalgebraContext),
385 observation.time(),
386 covariates,
387 &mut y,
388 );
389 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
390 });
391 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
392 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
393 if let Some(em) = error_models {
396 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
397
398 pred.iter().for_each(|p| {
399 let lik = p.log_likelihood(em).map(f64::exp);
400 match lik {
401 Ok(l) => q.push(l),
402 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
403 }
404 });
405 let sum_q: f64 = q.iter().sum();
406 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
407 let i = sysresample(&w);
408 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
409 *x = a;
410 likelihood.push(sum_q / self.nparticles as f64);
411 }
414 Ok(())
415 }
416 #[inline(always)]
417 fn initial_state(
418 &self,
419 support_point: &[f64],
420 covariates: &Covariates,
421 occasion_index: usize,
422 ) -> Self::S {
423 let mut x = Vec::with_capacity(self.nparticles);
424 for _ in 0..self.nparticles {
425 let mut state: V = DVector::zeros(self.get_nstates()).into();
426 if occasion_index == 0 {
427 (self.init)(
428 &V::from_vec(support_point.to_vec(), NalgebraContext),
429 0.0,
430 covariates,
431 &mut state,
432 );
433 }
434 x.push(state.inner().clone());
435 }
436 x
437 }
438}
439
440impl Equation for SDE {
441 fn estimate_likelihood(
453 &self,
454 subject: &Subject,
455 support_point: &[f64],
456 error_models: &AssayErrorModels,
457 ) -> Result<f64, PharmsolError> {
458 _estimate_likelihood(self, subject, support_point, error_models)
459 }
460
461 fn estimate_log_likelihood(
462 &self,
463 subject: &Subject,
464 support_point: &[f64],
465 error_models: &AssayErrorModels,
466 ) -> Result<f64, PharmsolError> {
467 let lik = _estimate_likelihood(self, subject, support_point, error_models)?;
470
471 if lik > 0.0 {
472 Ok(lik.ln())
473 } else {
474 Ok(f64::NEG_INFINITY)
475 }
476 }
477
478 fn kind() -> crate::EqnKind {
479 crate::EqnKind::SDE
480 }
481}
482
483#[inline(always)]
484fn _estimate_likelihood(
485 sde: &SDE,
486 subject: &Subject,
487 support_point: &[f64],
488 error_models: &AssayErrorModels,
489) -> Result<f64, PharmsolError> {
490 if let Some(cache) = &sde.cache {
491 let key = (subject.hash(), spphash(support_point), error_models.hash());
492 if let Some(cached) = cache.get(&key) {
493 return Ok(cached);
494 }
495
496 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
497 let result = ypred.1.unwrap();
498 cache.insert(key, result);
499 Ok(result)
500 } else {
501 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
502 Ok(ypred.1.unwrap())
503 }
504}
505
506fn sysresample(q: &[f64]) -> Vec<usize> {
516 let mut qc = vec![0.0; q.len()];
517 qc[0] = q[0];
518 for i in 1..q.len() {
519 qc[i] = qc[i - 1] + q[i];
520 }
521 let m = q.len();
522 let mut rng = rng();
523 let u: Vec<f64> = (0..m)
524 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
525 .collect();
526 let mut i = vec![0; m];
527 let mut k = 0;
528 for j in 0..m {
529 while qc[k] < u[j] {
530 k += 1;
531 }
532 i[j] = k;
533 }
534 i
535}