pharmsol/simulator/equation/sde/
mod.rs1mod em;
2
3use nalgebra::DVector;
4use ndarray::{concatenate, Array2, Axis};
5use rand::{rng, Rng};
6use rayon::prelude::*;
7
8use cached::proc_macro::cached;
9use cached::UnboundCache;
10
11use crate::{
12 data::{Covariates, Infusion},
13 error_model::ErrorModels,
14 prelude::simulator::Prediction,
15 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
16 Subject,
17};
18
19use crate::PharmsolError;
20
21use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
22
23#[inline(always)]
43pub(crate) fn simulate_sde_event(
44 drift: &Drift,
45 difussion: &Diffusion,
46 x: V,
47 support_point: &[f64],
48 cov: &Covariates,
49 infusions: &[Infusion],
50 ti: f64,
51 tf: f64,
52) -> V {
53 if ti == tf {
54 return x;
55 }
56
57 let mut sde = em::EM::new(
58 *drift,
59 *difussion,
60 DVector::from_column_slice(support_point),
61 x,
62 cov.clone(),
63 infusions.to_vec(),
64 1e-2,
65 1e-2,
66 );
67 let (_time, solution) = sde.solve(ti, tf);
68 solution.last().unwrap().clone()
69}
70
71#[derive(Clone, Debug)]
79pub struct SDE {
80 drift: Drift,
81 diffusion: Diffusion,
82 lag: Lag,
83 fa: Fa,
84 init: Init,
85 out: Out,
86 neqs: Neqs,
87 nparticles: usize,
88}
89
90impl SDE {
91 #[allow(clippy::too_many_arguments)]
108 pub fn new(
109 drift: Drift,
110 diffusion: Diffusion,
111 lag: Lag,
112 fa: Fa,
113 init: Init,
114 out: Out,
115 neqs: Neqs,
116 nparticles: usize,
117 ) -> Self {
118 Self {
119 drift,
120 diffusion,
121 lag,
122 fa,
123 init,
124 out,
125 neqs,
126 nparticles,
127 }
128 }
129}
130
131impl State for Vec<DVector<f64>> {
135 fn add_bolus(&mut self, input: usize, amount: f64) {
142 self.par_iter_mut().for_each(|particle| {
143 particle[input] += amount;
144 });
145 }
146}
147
148impl Predictions for Array2<Prediction> {
152 fn new(nparticles: usize) -> Self {
153 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
154 }
155 fn squared_error(&self) -> f64 {
156 unimplemented!();
157 }
158 fn get_predictions(&self) -> Vec<Prediction> {
159 let row = self.row(0).to_vec();
162 row
163 }
164}
165
166impl EquationTypes for SDE {
167 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
170
171impl EquationPriv for SDE {
172 #[inline(always)]
193 fn lag(&self) -> &Lag {
194 &self.lag
195 }
196
197 #[inline(always)]
198 fn fa(&self) -> &Fa {
199 &self.fa
200 }
201
202 #[inline(always)]
203 fn get_nstates(&self) -> usize {
204 self.neqs.0
205 }
206
207 #[inline(always)]
208 fn get_nouteqs(&self) -> usize {
209 self.neqs.1
210 }
211 #[inline(always)]
212 fn solve(
213 &self,
214 state: &mut Self::S,
215 support_point: &Vec<f64>,
216 covariates: &Covariates,
217 infusions: &Vec<Infusion>,
218 ti: f64,
219 tf: f64,
220 ) -> Result<(), PharmsolError> {
221 state.par_iter_mut().for_each(|particle| {
222 *particle = simulate_sde_event(
223 &self.drift,
224 &self.diffusion,
225 particle.clone(),
226 support_point,
227 covariates,
228 infusions,
229 ti,
230 tf,
231 );
232 });
233 Ok(())
234 }
235 fn nparticles(&self) -> usize {
236 self.nparticles
237 }
238
239 fn is_sde(&self) -> bool {
240 true
241 }
242 #[inline(always)]
243 fn process_observation(
244 &self,
245 support_point: &Vec<f64>,
246 observation: &crate::Observation,
247 error_models: Option<&ErrorModels>,
248 _time: f64,
249 covariates: &Covariates,
250 x: &mut Self::S,
251 likelihood: &mut Vec<f64>,
252 output: &mut Self::P,
253 ) -> Result<(), PharmsolError> {
254 let mut pred = vec![Prediction::default(); self.nparticles];
255 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
256 let mut y = V::zeros(self.get_nouteqs());
257 (self.out)(
258 &x[i],
259 &V::from_vec(support_point.clone()),
260 observation.time(),
261 covariates,
262 &mut y,
263 );
264 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
265 });
266 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
267 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
268 if let Some(em) = error_models {
271 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
272
273 pred.iter().for_each(|p| {
274 let lik = p.likelihood(em);
275 match lik {
276 Ok(l) => q.push(l),
277 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
278 }
279 });
280 let sum_q: f64 = q.iter().sum();
281 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
282 let i = sysresample(&w);
283 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
284 *x = a;
285 likelihood.push(sum_q / self.nparticles as f64);
286 }
289 Ok(())
290 }
291 #[inline(always)]
292 fn initial_state(
293 &self,
294 support_point: &Vec<f64>,
295 covariates: &Covariates,
296 occasion_index: usize,
297 ) -> Self::S {
298 let mut x = Vec::with_capacity(self.nparticles);
299 for _ in 0..self.nparticles {
300 let mut state = DVector::zeros(self.get_nstates());
301 if occasion_index == 0 {
302 (self.init)(
303 &V::from_vec(support_point.to_vec()),
304 0.0,
305 covariates,
306 &mut state,
307 );
308 }
309 x.push(state);
310 }
311 x
312 }
313}
314
315impl Equation for SDE {
316 fn estimate_likelihood(
329 &self,
330 subject: &Subject,
331 support_point: &Vec<f64>,
332 error_models: &ErrorModels,
333 cache: bool,
334 ) -> Result<f64, PharmsolError> {
335 if cache {
336 _estimate_likelihood(self, subject, support_point, error_models)
337 } else {
338 _estimate_likelihood_no_cache(self, subject, support_point, error_models)
339 }
340 }
341
342 fn kind() -> crate::EqnKind {
343 crate::EqnKind::SDE
344 }
345}
346
347fn spphash(spp: &[f64]) -> u64 {
357 spp.iter().fold(0, |acc, x| acc + x.to_bits())
358}
359
360#[inline(always)]
361#[cached(
362 ty = "UnboundCache<String, f64>",
363 create = "{ UnboundCache::with_capacity(100_000) }",
364 convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
365 result = "true"
366)]
367fn _estimate_likelihood(
368 sde: &SDE,
369 subject: &Subject,
370 support_point: &Vec<f64>,
371 error_models: &ErrorModels,
372) -> Result<f64, PharmsolError> {
373 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
374 Ok(ypred.1.unwrap())
375}
376
377fn sysresample(q: &[f64]) -> Vec<usize> {
387 let mut qc = vec![0.0; q.len()];
388 qc[0] = q[0];
389 for i in 1..q.len() {
390 qc[i] = qc[i - 1] + q[i];
391 }
392 let m = q.len();
393 let mut rng = rng();
394 let u: Vec<f64> = (0..m)
395 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
396 .collect();
397 let mut i = vec![0; m];
398 let mut k = 0;
399 for j in 0..m {
400 while qc[k] < u[j] {
401 k += 1;
402 }
403 i[j] = k;
404 }
405 i
406}