pharmsol/simulator/equation/sde/
mod.rs1mod em;
2
3use std::collections::HashMap;
4
5use nalgebra::DVector;
6use ndarray::{concatenate, Array2, Axis};
7use rand::{rng, Rng};
8use rayon::prelude::*;
9
10use cached::proc_macro::cached;
11use cached::UnboundCache;
12
13use crate::{
14 data::{Covariates, Infusion},
15 error_model::ErrorModel,
16 prelude::simulator::Prediction,
17 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
18 Subject,
19};
20
21use crate::PharmsolError;
22
23use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
24
25#[inline(always)]
45pub(crate) fn simulate_sde_event(
46 drift: &Drift,
47 difussion: &Diffusion,
48 x: V,
49 support_point: &[f64],
50 cov: &Covariates,
51 infusions: &[Infusion],
52 ti: f64,
53 tf: f64,
54) -> V {
55 if ti == tf {
56 return x;
57 }
58
59 let mut sde = em::EM::new(
60 *drift,
61 *difussion,
62 DVector::from_column_slice(support_point),
63 x,
64 cov.clone(),
65 infusions.to_vec(),
66 1e-2,
67 1e-2,
68 );
69 let (_time, solution) = sde.solve(ti, tf);
70 solution.last().unwrap().clone()
71}
72
73#[derive(Clone, Debug)]
81pub struct SDE {
82 drift: Drift,
83 diffusion: Diffusion,
84 lag: Lag,
85 fa: Fa,
86 init: Init,
87 out: Out,
88 neqs: Neqs,
89 nparticles: usize,
90}
91
92impl SDE {
93 #[allow(clippy::too_many_arguments)]
110 pub fn new(
111 drift: Drift,
112 diffusion: Diffusion,
113 lag: Lag,
114 fa: Fa,
115 init: Init,
116 out: Out,
117 neqs: Neqs,
118 nparticles: usize,
119 ) -> Self {
120 Self {
121 drift,
122 diffusion,
123 lag,
124 fa,
125 init,
126 out,
127 neqs,
128 nparticles,
129 }
130 }
131}
132
133impl State for Vec<DVector<f64>> {
137 fn add_bolus(&mut self, input: usize, amount: f64) {
144 self.par_iter_mut().for_each(|particle| {
145 particle[input] += amount;
146 });
147 }
148}
149
150impl Predictions for Array2<Prediction> {
154 fn new(nparticles: usize) -> Self {
155 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
156 }
157 fn squared_error(&self) -> f64 {
158 unimplemented!();
159 }
160 fn get_predictions(&self) -> Vec<Prediction> {
161 let row = self.row(0).to_vec();
164 row
165 }
166}
167
168impl EquationTypes for SDE {
169 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
172
173impl EquationPriv for SDE {
174 #[inline(always)]
185 fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
186 Some((self.lag)(&V::from_vec(spp.to_owned())))
187 }
188
189 #[inline(always)]
190 fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
191 Some((self.fa)(&V::from_vec(spp.to_owned())))
192 }
193
194 #[inline(always)]
195 fn get_nstates(&self) -> usize {
196 self.neqs.0
197 }
198
199 #[inline(always)]
200 fn get_nouteqs(&self) -> usize {
201 self.neqs.1
202 }
203 #[inline(always)]
204 fn solve(
205 &self,
206 state: &mut Self::S,
207 support_point: &Vec<f64>,
208 covariates: &Covariates,
209 infusions: &Vec<Infusion>,
210 ti: f64,
211 tf: f64,
212 ) -> Result<(), PharmsolError> {
213 state.par_iter_mut().for_each(|particle| {
214 *particle = simulate_sde_event(
215 &self.drift,
216 &self.diffusion,
217 particle.clone(),
218 support_point,
219 covariates,
220 infusions,
221 ti,
222 tf,
223 );
224 });
225 Ok(())
226 }
227 fn nparticles(&self) -> usize {
228 self.nparticles
229 }
230
231 fn is_sde(&self) -> bool {
232 true
233 }
234 #[inline(always)]
235 fn process_observation(
236 &self,
237 support_point: &Vec<f64>,
238 observation: &crate::Observation,
239 error_model: Option<&ErrorModel>,
240 _time: f64,
241 covariates: &Covariates,
242 x: &mut Self::S,
243 likelihood: &mut Vec<f64>,
244 output: &mut Self::P,
245 ) -> Result<(), PharmsolError> {
246 let mut pred = vec![Prediction::default(); self.nparticles];
247 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
248 let mut y = V::zeros(self.get_nouteqs());
249 (self.out)(
250 &x[i],
251 &V::from_vec(support_point.clone()),
252 observation.time(),
253 covariates,
254 &mut y,
255 );
256 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
257 });
258 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
259 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
260 if let Some(em) = error_model {
263 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
264
265 pred.iter().for_each(|p| {
266 let lik = p.likelihood(em);
267 match lik {
268 Ok(l) => q.push(l),
269 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
270 }
271 });
272 let sum_q: f64 = q.iter().sum();
273 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
274 let i = sysresample(&w);
275 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
276 *x = a;
277 likelihood.push(sum_q / self.nparticles as f64);
278 }
281 Ok(())
282 }
283 #[inline(always)]
284 fn initial_state(
285 &self,
286 support_point: &Vec<f64>,
287 covariates: &Covariates,
288 occasion_index: usize,
289 ) -> Self::S {
290 let mut x = Vec::with_capacity(self.nparticles);
291 for _ in 0..self.nparticles {
292 let mut state = DVector::zeros(self.get_nstates());
293 if occasion_index == 0 {
294 (self.init)(
295 &V::from_vec(support_point.to_vec()),
296 0.0,
297 covariates,
298 &mut state,
299 );
300 }
301 x.push(state);
302 }
303 x
304 }
305}
306
307impl Equation for SDE {
308 fn estimate_likelihood(
321 &self,
322 subject: &Subject,
323 support_point: &Vec<f64>,
324 error_model: &ErrorModel,
325 cache: bool,
326 ) -> Result<f64, PharmsolError> {
327 if cache {
328 _estimate_likelihood(self, subject, support_point, error_model)
329 } else {
330 _estimate_likelihood_no_cache(self, subject, support_point, error_model)
331 }
332 }
333}
334
335fn spphash(spp: &[f64]) -> u64 {
345 spp.iter().fold(0, |acc, x| acc + x.to_bits())
346}
347
348#[inline(always)]
349#[cached(
350 ty = "UnboundCache<String, f64>",
351 create = "{ UnboundCache::with_capacity(100_000) }",
352 convert = r#"{ format!("{}{}{}", subject.id(), spphash(support_point), error_model.scalar()) }"#,
353 result = "true"
354)]
355fn _estimate_likelihood(
356 sde: &SDE,
357 subject: &Subject,
358 support_point: &Vec<f64>,
359 error_model: &ErrorModel,
360) -> Result<f64, PharmsolError> {
361 let ypred = sde.simulate_subject(subject, support_point, Some(error_model))?;
362 Ok(ypred.1.unwrap())
363}
364
365fn sysresample(q: &[f64]) -> Vec<usize> {
375 let mut qc = vec![0.0; q.len()];
376 qc[0] = q[0];
377 for i in 1..q.len() {
378 qc[i] = qc[i - 1] + q[i];
379 }
380 let m = q.len();
381 let mut rng = rng();
382 let u: Vec<f64> = (0..m)
383 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
384 .collect();
385 let mut i = vec![0; m];
386 let mut k = 0;
387 for j in 0..m {
388 while qc[k] < u[j] {
389 k += 1;
390 }
391 i[j] = k;
392 }
393 i
394}