pharmsol/simulator/equation/sde/
mod.rs1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, RngExt};
7use rayon::prelude::*;
8
9use crate::{
10 data::{Covariates, Infusion},
11 error_model::AssayErrorModels,
12 prelude::simulator::Prediction,
13 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
14 Subject,
15};
16
17use super::id_hash;
18use super::spphash;
19use crate::simulator::cache::{cache_enabled, sde_cache_lock_read};
20
21use diffsol::VectorCommon;
22
23use crate::PharmsolError;
24
25use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
26
27#[inline(always)]
47pub(crate) fn simulate_sde_event(
48 drift: &Drift,
49 difussion: &Diffusion,
50 x: V,
51 support_point: &[f64],
52 cov: &Covariates,
53 infusions: &[Infusion],
54 ndrugs: usize,
55 ti: f64,
56 tf: f64,
57) -> V {
58 if ti == tf {
59 return x;
60 }
61
62 let mut sde = em::EM::new(
63 *drift,
64 *difussion,
65 DVector::from_column_slice(support_point),
66 x.inner().clone(),
67 cov.clone(),
68 infusions.to_vec(),
69 1e-2,
70 1e-2,
71 ndrugs,
72 );
73 let (_time, solution) = sde.solve(ti, tf);
74 solution.last().unwrap().clone().into()
75}
76
77#[derive(Clone, Debug)]
85pub struct SDE {
86 drift: Drift,
87 diffusion: Diffusion,
88 lag: Lag,
89 fa: Fa,
90 init: Init,
91 out: Out,
92 neqs: Neqs,
93 nparticles: usize,
94}
95
96impl SDE {
97 pub fn new(
107 drift: Drift,
108 diffusion: Diffusion,
109 lag: Lag,
110 fa: Fa,
111 init: Init,
112 out: Out,
113 nparticles: usize,
114 ) -> Self {
115 Self {
116 drift,
117 diffusion,
118 lag,
119 fa,
120 init,
121 out,
122 neqs: Neqs::default(),
123 nparticles,
124 }
125 }
126
127 pub fn with_nstates(mut self, nstates: usize) -> Self {
129 self.neqs.nstates = nstates;
130 self
131 }
132
133 pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
135 self.neqs.ndrugs = ndrugs;
136 self
137 }
138
139 pub fn with_nout(mut self, nout: usize) -> Self {
141 self.neqs.nout = nout;
142 self
143 }
144}
145
146impl State for Vec<DVector<f64>> {
150 fn add_bolus(&mut self, input: usize, amount: f64) {
157 self.par_iter_mut().for_each(|particle| {
158 particle[input] += amount;
159 });
160 }
161}
162
163impl Predictions for Array2<Prediction> {
167 fn new(nparticles: usize) -> Self {
168 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
169 }
170 fn squared_error(&self) -> f64 {
171 unimplemented!();
172 }
173 fn get_predictions(&self) -> Vec<Prediction> {
174 if self.is_empty() || self.ncols() == 0 {
176 return Vec::new();
177 }
178
179 let mut result = Vec::with_capacity(self.ncols());
180
181 for col in 0..self.ncols() {
182 let column = self.column(col);
183
184 let mean_prediction: f64 = column
185 .iter()
186 .map(|pred: &Prediction| pred.prediction())
187 .sum::<f64>()
188 / self.nrows() as f64;
189
190 let mut prediction = column.first().unwrap().clone();
191 prediction.set_prediction(mean_prediction);
192 result.push(prediction);
193 }
194
195 result
196 }
197 fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result<f64, crate::PharmsolError> {
198 let predictions = self.get_predictions();
200 if predictions.is_empty() {
201 return Ok(0.0);
202 }
203
204 let log_liks: Result<Vec<f64>, _> = predictions
205 .iter()
206 .filter(|p| p.observation().is_some())
207 .map(|p| p.log_likelihood(error_models))
208 .collect();
209
210 log_liks.map(|lls| lls.iter().sum())
211 }
212}
213
214impl EquationTypes for SDE {
215 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
218
219impl EquationPriv for SDE {
220 #[inline(always)]
241 fn lag(&self) -> &Lag {
242 &self.lag
243 }
244
245 #[inline(always)]
246 fn fa(&self) -> &Fa {
247 &self.fa
248 }
249
250 #[inline(always)]
251 fn get_nstates(&self) -> usize {
252 self.neqs.nstates
253 }
254
255 #[inline(always)]
256 fn get_ndrugs(&self) -> usize {
257 self.neqs.ndrugs
258 }
259
260 #[inline(always)]
261 fn get_nouteqs(&self) -> usize {
262 self.neqs.nout
263 }
264 #[inline(always)]
265 fn solve(
266 &self,
267 state: &mut Self::S,
268 support_point: &Vec<f64>,
269 covariates: &Covariates,
270 infusions: &Vec<Infusion>,
271 ti: f64,
272 tf: f64,
273 ) -> Result<(), PharmsolError> {
274 let ndrugs = self.get_ndrugs();
275 state.par_iter_mut().for_each(|particle| {
276 *particle = simulate_sde_event(
277 &self.drift,
278 &self.diffusion,
279 particle.clone().into(),
280 support_point,
281 covariates,
282 infusions,
283 ndrugs,
284 ti,
285 tf,
286 )
287 .inner()
288 .clone();
289 });
290 Ok(())
291 }
292 fn nparticles(&self) -> usize {
293 self.nparticles
294 }
295
296 fn is_sde(&self) -> bool {
297 true
298 }
299 #[inline(always)]
300 fn process_observation(
301 &self,
302 support_point: &Vec<f64>,
303 observation: &crate::Observation,
304 error_models: Option<&AssayErrorModels>,
305 _time: f64,
306 covariates: &Covariates,
307 x: &mut Self::S,
308 likelihood: &mut Vec<f64>,
309 output: &mut Self::P,
310 ) -> Result<(), PharmsolError> {
311 let mut pred = vec![Prediction::default(); self.nparticles];
312
313 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
314 let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
315 (self.out)(
316 &x[i].clone().into(),
317 &V::from_vec(support_point.clone(), NalgebraContext),
318 observation.time(),
319 covariates,
320 &mut y,
321 );
322 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
323 });
324 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
325 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
326 if let Some(em) = error_models {
329 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
330
331 pred.iter().for_each(|p| {
332 let lik = p.log_likelihood(em).map(f64::exp);
333 match lik {
334 Ok(l) => q.push(l),
335 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
336 }
337 });
338 let sum_q: f64 = q.iter().sum();
339 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
340 let i = sysresample(&w);
341 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
342 *x = a;
343 likelihood.push(sum_q / self.nparticles as f64);
344 }
347 Ok(())
348 }
349 #[inline(always)]
350 fn initial_state(
351 &self,
352 support_point: &Vec<f64>,
353 covariates: &Covariates,
354 occasion_index: usize,
355 ) -> Self::S {
356 let mut x = Vec::with_capacity(self.nparticles);
357 for _ in 0..self.nparticles {
358 let mut state: V = DVector::zeros(self.get_nstates()).into();
359 if occasion_index == 0 {
360 (self.init)(
361 &V::from_vec(support_point.to_vec(), NalgebraContext),
362 0.0,
363 covariates,
364 &mut state,
365 );
366 }
367 x.push(state.inner().clone());
368 }
369 x
370 }
371}
372
373impl Equation for SDE {
374 fn estimate_likelihood(
386 &self,
387 subject: &Subject,
388 support_point: &Vec<f64>,
389 error_models: &AssayErrorModels,
390 ) -> Result<f64, PharmsolError> {
391 _estimate_likelihood(self, subject, support_point, error_models)
392 }
393
394 fn estimate_log_likelihood(
395 &self,
396 subject: &Subject,
397 support_point: &Vec<f64>,
398 error_models: &AssayErrorModels,
399 ) -> Result<f64, PharmsolError> {
400 let lik = _estimate_likelihood(self, subject, support_point, error_models)?;
403
404 if lik > 0.0 {
405 Ok(lik.ln())
406 } else {
407 Ok(f64::NEG_INFINITY)
408 }
409 }
410
411 fn kind() -> crate::EqnKind {
412 crate::EqnKind::SDE
413 }
414}
415
416#[inline(always)]
417fn _estimate_likelihood(
418 sde: &SDE,
419 subject: &Subject,
420 support_point: &Vec<f64>,
421 error_models: &AssayErrorModels,
422) -> Result<f64, PharmsolError> {
423 if cache_enabled() {
424 let key = (
425 id_hash(subject.id()),
426 spphash(support_point),
427 error_models.hash(),
428 );
429 let cache_guard = sde_cache_lock_read()?;
430 if let Some(cached) = cache_guard.get(&key) {
431 return Ok(cached);
432 }
433 drop(cache_guard);
434
435 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
436 let result = ypred.1.unwrap();
437 let cache_guard = sde_cache_lock_read()?;
438 cache_guard.insert(key, result);
439 Ok(result)
440 } else {
441 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
442 Ok(ypred.1.unwrap())
443 }
444}
445
446fn sysresample(q: &[f64]) -> Vec<usize> {
456 let mut qc = vec![0.0; q.len()];
457 qc[0] = q[0];
458 for i in 1..q.len() {
459 qc[i] = qc[i - 1] + q[i];
460 }
461 let m = q.len();
462 let mut rng = rng();
463 let u: Vec<f64> = (0..m)
464 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
465 .collect();
466 let mut i = vec![0; m];
467 let mut k = 0;
468 for j in 0..m {
469 while qc[k] < u[j] {
470 k += 1;
471 }
472 i[j] = k;
473 }
474 i
475}