pharmsol/simulator/equation/sde/
mod.rs1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, Rng};
7use rayon::prelude::*;
8
9use cached::proc_macro::cached;
10use cached::UnboundCache;
11
12use crate::{
13 data::{Covariates, Infusion},
14 error_model::ErrorModels,
15 prelude::simulator::Prediction,
16 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
17 Subject,
18};
19
20use diffsol::VectorCommon;
21
22use crate::PharmsolError;
23
24use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
25
26#[inline(always)]
46pub(crate) fn simulate_sde_event(
47 drift: &Drift,
48 difussion: &Diffusion,
49 x: V,
50 support_point: &[f64],
51 cov: &Covariates,
52 infusions: &[Infusion],
53 ti: f64,
54 tf: f64,
55) -> V {
56 if ti == tf {
57 return x;
58 }
59
60 let mut sde = em::EM::new(
61 *drift,
62 *difussion,
63 DVector::from_column_slice(support_point),
64 x.inner().clone(),
65 cov.clone(),
66 infusions.to_vec(),
67 1e-2,
68 1e-2,
69 );
70 let (_time, solution) = sde.solve(ti, tf);
71 solution.last().unwrap().clone().into()
72}
73
74#[derive(Clone, Debug)]
82pub struct SDE {
83 drift: Drift,
84 diffusion: Diffusion,
85 lag: Lag,
86 fa: Fa,
87 init: Init,
88 out: Out,
89 neqs: Neqs,
90 nparticles: usize,
91}
92
93impl SDE {
94 #[allow(clippy::too_many_arguments)]
111 pub fn new(
112 drift: Drift,
113 diffusion: Diffusion,
114 lag: Lag,
115 fa: Fa,
116 init: Init,
117 out: Out,
118 neqs: Neqs,
119 nparticles: usize,
120 ) -> Self {
121 Self {
122 drift,
123 diffusion,
124 lag,
125 fa,
126 init,
127 out,
128 neqs,
129 nparticles,
130 }
131 }
132}
133
134impl State for Vec<DVector<f64>> {
138 fn add_bolus(&mut self, input: usize, amount: f64) {
145 self.par_iter_mut().for_each(|particle| {
146 particle[input] += amount;
147 });
148 }
149}
150
151impl Predictions for Array2<Prediction> {
155 fn new(nparticles: usize) -> Self {
156 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
157 }
158 fn squared_error(&self) -> f64 {
159 unimplemented!();
160 }
161 fn get_predictions(&self) -> Vec<Prediction> {
162 if self.is_empty() || self.ncols() == 0 {
164 return Vec::new();
165 }
166
167 let mut result = Vec::with_capacity(self.ncols());
168
169 for col in 0..self.ncols() {
170 let column = self.column(col);
171
172 let mean_prediction: f64 = column
173 .iter()
174 .map(|pred: &Prediction| pred.prediction())
175 .sum::<f64>()
176 / self.nrows() as f64;
177
178 let mut prediction = column.first().unwrap().clone();
179 prediction.set_prediction(mean_prediction);
180 result.push(prediction);
181 }
182
183 result
184 }
185}
186
187impl EquationTypes for SDE {
188 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
191
192impl EquationPriv for SDE {
193 #[inline(always)]
214 fn lag(&self) -> &Lag {
215 &self.lag
216 }
217
218 #[inline(always)]
219 fn fa(&self) -> &Fa {
220 &self.fa
221 }
222
223 #[inline(always)]
224 fn get_nstates(&self) -> usize {
225 self.neqs.0
226 }
227
228 #[inline(always)]
229 fn get_nouteqs(&self) -> usize {
230 self.neqs.1
231 }
232 #[inline(always)]
233 fn solve(
234 &self,
235 state: &mut Self::S,
236 support_point: &Vec<f64>,
237 covariates: &Covariates,
238 infusions: &Vec<Infusion>,
239 ti: f64,
240 tf: f64,
241 ) -> Result<(), PharmsolError> {
242 state.par_iter_mut().for_each(|particle| {
243 *particle = simulate_sde_event(
244 &self.drift,
245 &self.diffusion,
246 particle.clone().into(),
247 support_point,
248 covariates,
249 infusions,
250 ti,
251 tf,
252 )
253 .inner()
254 .clone();
255 });
256 Ok(())
257 }
258 fn nparticles(&self) -> usize {
259 self.nparticles
260 }
261
262 fn is_sde(&self) -> bool {
263 true
264 }
265 #[inline(always)]
266 fn process_observation(
267 &self,
268 support_point: &Vec<f64>,
269 observation: &crate::Observation,
270 error_models: Option<&ErrorModels>,
271 _time: f64,
272 covariates: &Covariates,
273 x: &mut Self::S,
274 likelihood: &mut Vec<f64>,
275 output: &mut Self::P,
276 ) -> Result<(), PharmsolError> {
277 let mut pred = vec![Prediction::default(); self.nparticles];
278 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
279 let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
280 (self.out)(
281 &x[i].clone().into(),
282 &V::from_vec(support_point.clone(), NalgebraContext),
283 observation.time(),
284 covariates,
285 &mut y,
286 );
287 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
288 });
289 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
290 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
291 if let Some(em) = error_models {
294 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
295
296 pred.iter().for_each(|p| {
297 let lik = p.likelihood(em);
298 match lik {
299 Ok(l) => q.push(l),
300 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
301 }
302 });
303 let sum_q: f64 = q.iter().sum();
304 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
305 let i = sysresample(&w);
306 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
307 *x = a;
308 likelihood.push(sum_q / self.nparticles as f64);
309 }
312 Ok(())
313 }
314 #[inline(always)]
315 fn initial_state(
316 &self,
317 support_point: &Vec<f64>,
318 covariates: &Covariates,
319 occasion_index: usize,
320 ) -> Self::S {
321 let mut x = Vec::with_capacity(self.nparticles);
322 for _ in 0..self.nparticles {
323 let mut state: V = DVector::zeros(self.get_nstates()).into();
324 if occasion_index == 0 {
325 (self.init)(
326 &V::from_vec(support_point.to_vec(), NalgebraContext),
327 0.0,
328 covariates,
329 &mut state,
330 );
331 }
332 x.push(state.inner().clone());
333 }
334 x
335 }
336}
337
338impl Equation for SDE {
339 fn estimate_likelihood(
352 &self,
353 subject: &Subject,
354 support_point: &Vec<f64>,
355 error_models: &ErrorModels,
356 cache: bool,
357 ) -> Result<f64, PharmsolError> {
358 if cache {
359 _estimate_likelihood(self, subject, support_point, error_models)
360 } else {
361 _estimate_likelihood_no_cache(self, subject, support_point, error_models)
362 }
363 }
364
365 fn kind() -> crate::EqnKind {
366 crate::EqnKind::SDE
367 }
368}
369
370fn spphash(spp: &[f64]) -> u64 {
380 spp.iter().fold(0, |acc, x| acc + x.to_bits())
381}
382
383#[inline(always)]
384#[cached(
385 ty = "UnboundCache<String, f64>",
386 create = "{ UnboundCache::with_capacity(100_000) }",
387 convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
388 result = "true"
389)]
390fn _estimate_likelihood(
391 sde: &SDE,
392 subject: &Subject,
393 support_point: &Vec<f64>,
394 error_models: &ErrorModels,
395) -> Result<f64, PharmsolError> {
396 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
397 Ok(ypred.1.unwrap())
398}
399
400fn sysresample(q: &[f64]) -> Vec<usize> {
410 let mut qc = vec![0.0; q.len()];
411 qc[0] = q[0];
412 for i in 1..q.len() {
413 qc[i] = qc[i - 1] + q[i];
414 }
415 let m = q.len();
416 let mut rng = rng();
417 let u: Vec<f64> = (0..m)
418 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
419 .collect();
420 let mut i = vec![0; m];
421 let mut k = 0;
422 for j in 0..m {
423 while qc[k] < u[j] {
424 k += 1;
425 }
426 i[j] = k;
427 }
428 i
429}