pharmsol/simulator/equation/sde/
mod.rs1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, Rng};
7use rayon::prelude::*;
8
9use cached::proc_macro::cached;
10use cached::UnboundCache;
11
12use crate::{
13 data::{Covariates, Infusion},
14 error_model::ErrorModels,
15 mapping::Mappings,
16 prelude::simulator::Prediction,
17 simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
18 Subject,
19};
20
21use diffsol::VectorCommon;
22
23use crate::PharmsolError;
24
25use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
26
27#[inline(always)]
47pub(crate) fn simulate_sde_event(
48 drift: &Drift,
49 difussion: &Diffusion,
50 x: V,
51 support_point: &[f64],
52 cov: &Covariates,
53 infusions: &[Infusion],
54 ti: f64,
55 tf: f64,
56) -> V {
57 if ti == tf {
58 return x;
59 }
60
61 let mut sde = em::EM::new(
62 *drift,
63 *difussion,
64 DVector::from_column_slice(support_point),
65 x.inner().clone(),
66 cov.clone(),
67 infusions.to_vec(),
68 1e-2,
69 1e-2,
70 );
71 let (_time, solution) = sde.solve(ti, tf);
72 solution.last().unwrap().clone().into()
73}
74
75#[derive(Clone, Debug)]
83pub struct SDE {
84 drift: Drift,
85 diffusion: Diffusion,
86 lag: Lag,
87 fa: Fa,
88 init: Init,
89 out: Out,
90 neqs: Neqs,
91 nparticles: usize,
92 mappings: Mappings,
93}
94
95impl SDE {
96 #[allow(clippy::too_many_arguments)]
113 pub fn new(
114 drift: Drift,
115 diffusion: Diffusion,
116 lag: Lag,
117 fa: Fa,
118 init: Init,
119 out: Out,
120 neqs: Neqs,
121 nparticles: usize,
122 ) -> Self {
123 Self {
124 drift,
125 diffusion,
126 lag,
127 fa,
128 init,
129 out,
130 neqs,
131 nparticles,
132 mappings: Mappings::new(),
133 }
134 }
135}
136
137impl State for Vec<DVector<f64>> {
141 fn add_bolus(&mut self, input: usize, amount: f64) {
148 self.par_iter_mut().for_each(|particle| {
149 particle[input] += amount;
150 });
151 }
152}
153
154impl Predictions for Array2<Prediction> {
158 fn new(nparticles: usize) -> Self {
159 Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
160 }
161 fn squared_error(&self) -> f64 {
162 unimplemented!();
163 }
164 fn get_predictions(&self) -> Vec<Prediction> {
165 if self.is_empty() || self.ncols() == 0 {
167 return Vec::new();
168 }
169
170 let mut result = Vec::with_capacity(self.ncols());
171
172 for col in 0..self.ncols() {
173 let column = self.column(col);
174
175 let mean_prediction: f64 = column
176 .iter()
177 .map(|pred: &Prediction| pred.prediction())
178 .sum::<f64>()
179 / self.nrows() as f64;
180
181 let mut prediction = column.first().unwrap().clone();
182 prediction.set_prediction(mean_prediction);
183 result.push(prediction);
184 }
185
186 result
187 }
188}
189
190impl EquationTypes for SDE {
191 type S = Vec<DVector<f64>>; type P = Array2<Prediction>; }
194
195impl EquationPriv for SDE {
196 #[inline(always)]
217 fn lag(&self) -> &Lag {
218 &self.lag
219 }
220
221 #[inline(always)]
222 fn fa(&self) -> &Fa {
223 &self.fa
224 }
225
226 #[inline(always)]
227 fn get_nstates(&self) -> usize {
228 self.neqs.0
229 }
230
231 #[inline(always)]
232 fn get_nouteqs(&self) -> usize {
233 self.neqs.1
234 }
235 #[inline(always)]
236 fn solve(
237 &self,
238 state: &mut Self::S,
239 support_point: &Vec<f64>,
240 covariates: &Covariates,
241 infusions: &Vec<Infusion>,
242 ti: f64,
243 tf: f64,
244 ) -> Result<(), PharmsolError> {
245 state.par_iter_mut().for_each(|particle| {
246 *particle = simulate_sde_event(
247 &self.drift,
248 &self.diffusion,
249 particle.clone().into(),
250 support_point,
251 covariates,
252 infusions,
253 ti,
254 tf,
255 )
256 .inner()
257 .clone();
258 });
259 Ok(())
260 }
261 fn nparticles(&self) -> usize {
262 self.nparticles
263 }
264
265 fn is_sde(&self) -> bool {
266 true
267 }
268 #[inline(always)]
269 fn process_observation(
270 &self,
271 support_point: &Vec<f64>,
272 observation: &crate::Observation,
273 error_models: Option<&ErrorModels>,
274 _time: f64,
275 covariates: &Covariates,
276 x: &mut Self::S,
277 likelihood: &mut Vec<f64>,
278 output: &mut Self::P,
279 ) -> Result<(), PharmsolError> {
280 let mut pred = vec![Prediction::default(); self.nparticles];
281 pred.par_iter_mut().enumerate().for_each(|(i, p)| {
282 let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
283 (self.out)(
284 &x[i].clone().into(),
285 &V::from_vec(support_point.clone(), NalgebraContext),
286 observation.time(),
287 covariates,
288 &mut y,
289 );
290 *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
291 });
292 let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
293 *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
294 if let Some(em) = error_models {
297 let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
298
299 pred.iter().for_each(|p| {
300 let lik = p.likelihood(em);
301 match lik {
302 Ok(l) => q.push(l),
303 Err(e) => panic!("Error in likelihood calculation: {:?}", e),
304 }
305 });
306 let sum_q: f64 = q.iter().sum();
307 let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
308 let i = sysresample(&w);
309 let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
310 *x = a;
311 likelihood.push(sum_q / self.nparticles as f64);
312 }
315 Ok(())
316 }
317 #[inline(always)]
318 fn initial_state(
319 &self,
320 support_point: &Vec<f64>,
321 covariates: &Covariates,
322 occasion_index: usize,
323 ) -> Self::S {
324 let mut x = Vec::with_capacity(self.nparticles);
325 for _ in 0..self.nparticles {
326 let mut state: V = DVector::zeros(self.get_nstates()).into();
327 if occasion_index == 0 {
328 (self.init)(
329 &V::from_vec(support_point.to_vec(), NalgebraContext),
330 0.0,
331 covariates,
332 &mut state,
333 );
334 }
335 x.push(state.inner().clone());
336 }
337 x
338 }
339}
340
341impl Equation for SDE {
342 fn estimate_likelihood(
355 &self,
356 subject: &Subject,
357 support_point: &Vec<f64>,
358 error_models: &ErrorModels,
359 cache: bool,
360 ) -> Result<f64, PharmsolError> {
361 if cache {
362 _estimate_likelihood(self, subject, support_point, error_models)
363 } else {
364 _estimate_likelihood_no_cache(self, subject, support_point, error_models)
365 }
366 }
367 fn mappings_ref(&self) -> &Mappings {
368 &self.mappings
369 }
370 fn mappings_mut(&mut self) -> &mut Mappings {
371 &mut self.mappings
372 }
373
374 fn kind() -> crate::EqnKind {
375 crate::EqnKind::SDE
376 }
377}
378
379fn spphash(spp: &[f64]) -> u64 {
389 spp.iter().fold(0, |acc, x| acc + x.to_bits())
390}
391
392#[inline(always)]
393#[cached(
394 ty = "UnboundCache<String, f64>",
395 create = "{ UnboundCache::with_capacity(100_000) }",
396 convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
397 result = "true"
398)]
399fn _estimate_likelihood(
400 sde: &SDE,
401 subject: &Subject,
402 support_point: &Vec<f64>,
403 error_models: &ErrorModels,
404) -> Result<f64, PharmsolError> {
405 let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
406 Ok(ypred.1.unwrap())
407}
408
409fn sysresample(q: &[f64]) -> Vec<usize> {
419 let mut qc = vec![0.0; q.len()];
420 qc[0] = q[0];
421 for i in 1..q.len() {
422 qc[i] = qc[i - 1] + q[i];
423 }
424 let m = q.len();
425 let mut rng = rng();
426 let u: Vec<f64> = (0..m)
427 .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
428 .collect();
429 let mut i = vec![0; m];
430 let mut k = 0;
431 for j in 0..m {
432 while qc[k] < u[j] {
433 k += 1;
434 }
435 i[j] = k;
436 }
437 i
438}