Skip to main content

pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, RngExt};
7use rayon::prelude::*;
8
9use crate::{
10    data::{Covariates, Infusion},
11    error_model::AssayErrorModels,
12    prelude::simulator::Prediction,
13    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
14    Subject,
15};
16
17use super::spphash;
18use crate::simulator::cache::{SdeLikelihoodCache, DEFAULT_CACHE_SIZE};
19
20use diffsol::VectorCommon;
21
22use crate::PharmsolError;
23
24use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
25
26/// Simulate a stochastic differential equation (SDE) event.
27///
28/// This function advances the SDE system from time `ti` to `tf` using
29/// the Euler-Maruyama method implemented in the `em` module.
30///
31/// # Arguments
32///
33/// * `drift` - Function defining the deterministic component of the SDE
34/// * `difussion` - Function defining the stochastic component of the SDE
35/// * `x` - Current state vector
36/// * `support_point` - Parameter vector for the model
37/// * `cov` - Covariates that may influence the system dynamics
38/// * `infusions` - Infusion events to be applied during simulation
39/// * `ti` - Starting time
40/// * `tf` - Ending time
41///
42/// # Returns
43///
44/// The state vector at time `tf` after simulation.
45#[inline(always)]
46#[allow(clippy::too_many_arguments)]
47pub(crate) fn simulate_sde_event(
48    drift: &Drift,
49    difussion: &Diffusion,
50    x: V,
51    support_point: &[f64],
52    cov: &Covariates,
53    infusions: &[Infusion],
54    ndrugs: usize,
55    ti: f64,
56    tf: f64,
57) -> V {
58    if ti == tf {
59        return x;
60    }
61
62    let params_v = V::from_vec(support_point.to_vec(), NalgebraContext);
63    let covariates = cov.clone();
64    let infusion_events = infusions.to_vec();
65    let drift_fn = *drift;
66    let diffusion_fn = *difussion;
67
68    let params_for_drift = params_v.clone();
69    let drift_closure = move |time: f64, state: &DVector<f64>, out: &mut DVector<f64>| {
70        let mut rateiv = V::zeros(ndrugs, NalgebraContext);
71        for infusion in &infusion_events {
72            if time >= infusion.time() && time <= infusion.duration() + infusion.time() {
73                rateiv[infusion.input()] += infusion.amount() / infusion.duration();
74            }
75        }
76
77        let state_v: V = state.clone().into();
78        let mut out_v = V::zeros(state.len(), NalgebraContext);
79        drift_fn(
80            &state_v,
81            &params_for_drift,
82            time,
83            &mut out_v,
84            &rateiv,
85            &covariates,
86        );
87        out.copy_from(out_v.inner());
88    };
89
90    let diffusion_closure = move |_time: f64, _state: &DVector<f64>, out: &mut DVector<f64>| {
91        let mut out_v = V::zeros(out.len(), NalgebraContext);
92        diffusion_fn(&params_v, &mut out_v);
93        out.copy_from(out_v.inner());
94    };
95
96    simulate_sde_event_with(drift_closure, diffusion_closure, x.inner().clone(), ti, tf).into()
97}
98
99pub(crate) fn simulate_sde_event_with<D, G>(
100    drift: D,
101    diffusion: G,
102    initial_state: DVector<f64>,
103    ti: f64,
104    tf: f64,
105) -> DVector<f64>
106where
107    D: Fn(f64, &DVector<f64>, &mut DVector<f64>),
108    G: Fn(f64, &DVector<f64>, &mut DVector<f64>),
109{
110    if ti == tf {
111        return initial_state;
112    }
113
114    let mut sde = em::EM::new(drift, diffusion, initial_state, 1e-2, 1e-2);
115    let (_time, solution) = sde.solve(ti, tf);
116    solution.last().unwrap().clone()
117}
118
119/// Stochastic Differential Equation solver for pharmacometric models.
120///
121/// This struct represents a stochastic differential equation system and provides
122/// methods to simulate particles and estimate likelihood for PKPD modeling.
123///
124/// SDE models introduce stochasticity into the system dynamics, allowing for more
125/// realistic modeling of biological variability and uncertainty.
126#[derive(Clone, Debug)]
127pub struct SDE {
128    drift: Drift,
129    diffusion: Diffusion,
130    lag: Lag,
131    fa: Fa,
132    init: Init,
133    out: Out,
134    neqs: Neqs,
135    nparticles: usize,
136    cache: Option<SdeLikelihoodCache>,
137}
138
139impl SDE {
140    /// Creates a new stochastic differential equation solver with default Neqs.
141    ///
142    /// Use builder methods to configure dimensions:
143    /// ```ignore
144    /// SDE::new(drift, diffusion, lag, fa, init, out, nparticles)
145    ///     .with_nstates(2)
146    ///     .with_ndrugs(1)
147    ///     .with_nout(1)
148    /// ```
149    pub fn new(
150        drift: Drift,
151        diffusion: Diffusion,
152        lag: Lag,
153        fa: Fa,
154        init: Init,
155        out: Out,
156        nparticles: usize,
157    ) -> Self {
158        Self {
159            drift,
160            diffusion,
161            lag,
162            fa,
163            init,
164            out,
165            neqs: Neqs::default(),
166            nparticles,
167            cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)),
168        }
169    }
170
171    /// Set the number of state variables.
172    pub fn with_nstates(mut self, nstates: usize) -> Self {
173        self.neqs.nstates = nstates;
174        self
175    }
176
177    /// Set the number of drug input channels (size of bolus[] and rateiv[]).
178    pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
179        self.neqs.ndrugs = ndrugs;
180        self
181    }
182
183    /// Set the number of output equations.
184    pub fn with_nout(mut self, nout: usize) -> Self {
185        self.neqs.nout = nout;
186        self
187    }
188}
189
190impl super::Cache for SDE {
191    fn with_cache_capacity(mut self, size: u64) -> Self {
192        self.cache = Some(SdeLikelihoodCache::new(size));
193        self
194    }
195
196    fn enable_cache(mut self) -> Self {
197        self.cache = Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE));
198        self
199    }
200
201    fn clear_cache(&self) {
202        if let Some(cache) = &self.cache {
203            cache.invalidate_all();
204        }
205    }
206
207    fn disable_cache(mut self) -> Self {
208        self.cache = None;
209        self
210    }
211}
212
213/// State trait implementation for particle-based SDE simulation.
214///
215/// This implementation allows adding bolus doses to all particles in the system.
216impl State for Vec<DVector<f64>> {
217    /// Adds a bolus dose to a specific input compartment across all particles.
218    ///
219    /// # Arguments
220    ///
221    /// * `input` - Index of the input compartment
222    /// * `amount` - Amount to add to the compartment
223    fn add_bolus(&mut self, input: usize, amount: f64) {
224        self.par_iter_mut().for_each(|particle| {
225            particle[input] += amount;
226        });
227    }
228}
229
230/// Predictions implementation for particle-based SDE simulation outputs.
231///
232/// This implementation manages and processes predictions from multiple particles.
233impl Predictions for Array2<Prediction> {
234    fn new(nparticles: usize) -> Self {
235        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
236    }
237    fn squared_error(&self) -> f64 {
238        unimplemented!();
239    }
240    fn get_predictions(&self) -> Vec<Prediction> {
241        // Make this return the mean prediction across all particles
242        if self.is_empty() || self.ncols() == 0 {
243            return Vec::new();
244        }
245
246        let mut result = Vec::with_capacity(self.ncols());
247
248        for col in 0..self.ncols() {
249            let column = self.column(col);
250
251            let mean_prediction: f64 = column
252                .iter()
253                .map(|pred: &Prediction| pred.prediction())
254                .sum::<f64>()
255                / self.nrows() as f64;
256
257            let mut prediction = column.first().unwrap().clone();
258            prediction.set_prediction(mean_prediction);
259            result.push(prediction);
260        }
261
262        result
263    }
264    fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result<f64, crate::PharmsolError> {
265        // For SDE, compute log-likelihood using mean predictions across particles
266        let predictions = self.get_predictions();
267        if predictions.is_empty() {
268            return Ok(0.0);
269        }
270
271        let log_liks: Result<Vec<f64>, _> = predictions
272            .iter()
273            .filter(|p| p.observation().is_some())
274            .map(|p| p.log_likelihood(error_models))
275            .collect();
276
277        log_liks.map(|lls| lls.iter().sum())
278    }
279}
280
281impl EquationTypes for SDE {
282    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
283    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
284}
285
286impl EquationPriv for SDE {
287    // #[inline(always)]
288    // fn get_init(&self) -> &Init {
289    //     &self.init
290    // }
291
292    // #[inline(always)]
293    // fn get_out(&self) -> &Out {
294    //     &self.out
295    // }
296
297    // #[inline(always)]
298    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
299    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
300    // }
301
302    // #[inline(always)]
303    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
304    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
305    // }
306
307    #[inline(always)]
308    fn lag(&self) -> &Lag {
309        &self.lag
310    }
311
312    #[inline(always)]
313    fn fa(&self) -> &Fa {
314        &self.fa
315    }
316
317    #[inline(always)]
318    fn get_nstates(&self) -> usize {
319        self.neqs.nstates
320    }
321
322    #[inline(always)]
323    fn get_ndrugs(&self) -> usize {
324        self.neqs.ndrugs
325    }
326
327    #[inline(always)]
328    fn get_nouteqs(&self) -> usize {
329        self.neqs.nout
330    }
331    #[inline(always)]
332    fn solve(
333        &self,
334        state: &mut Self::S,
335        support_point: &[f64],
336        covariates: &Covariates,
337        infusions: &[Infusion],
338        ti: f64,
339        tf: f64,
340    ) -> Result<(), PharmsolError> {
341        let ndrugs = self.get_ndrugs();
342        state.par_iter_mut().for_each(|particle| {
343            *particle = simulate_sde_event(
344                &self.drift,
345                &self.diffusion,
346                particle.clone().into(),
347                support_point,
348                covariates,
349                infusions,
350                ndrugs,
351                ti,
352                tf,
353            )
354            .inner()
355            .clone();
356        });
357        Ok(())
358    }
359    fn nparticles(&self) -> usize {
360        self.nparticles
361    }
362
363    fn is_sde(&self) -> bool {
364        true
365    }
366    #[inline(always)]
367    fn process_observation(
368        &self,
369        support_point: &[f64],
370        observation: &crate::Observation,
371        error_models: Option<&AssayErrorModels>,
372        _time: f64,
373        covariates: &Covariates,
374        x: &mut Self::S,
375        likelihood: &mut Vec<f64>,
376        output: &mut Self::P,
377    ) -> Result<(), PharmsolError> {
378        let mut pred = vec![Prediction::default(); self.nparticles];
379
380        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
381            let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
382            (self.out)(
383                &x[i].clone().into(),
384                &V::from_vec(support_point.to_vec(), NalgebraContext),
385                observation.time(),
386                covariates,
387                &mut y,
388            );
389            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
390        });
391        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
392        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
393        //e = y[t] .- x[:,1]
394        // q = pdf.(Distributions.Normal(0, 0.5), e)
395        if let Some(em) = error_models {
396            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
397
398            pred.iter().for_each(|p| {
399                let lik = p.log_likelihood(em).map(f64::exp);
400                match lik {
401                    Ok(l) => q.push(l),
402                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
403                }
404            });
405            let sum_q: f64 = q.iter().sum();
406            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
407            let i = sysresample(&w);
408            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
409            *x = a;
410            likelihood.push(sum_q / self.nparticles as f64);
411            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
412            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
413        }
414        Ok(())
415    }
416    #[inline(always)]
417    fn initial_state(
418        &self,
419        support_point: &[f64],
420        covariates: &Covariates,
421        occasion_index: usize,
422    ) -> Self::S {
423        let mut x = Vec::with_capacity(self.nparticles);
424        for _ in 0..self.nparticles {
425            let mut state: V = DVector::zeros(self.get_nstates()).into();
426            if occasion_index == 0 {
427                (self.init)(
428                    &V::from_vec(support_point.to_vec(), NalgebraContext),
429                    0.0,
430                    covariates,
431                    &mut state,
432                );
433            }
434            x.push(state.inner().clone());
435        }
436        x
437    }
438}
439
440impl Equation for SDE {
441    /// Estimates the likelihood of observed data given a model and parameters.
442    ///
443    /// # Arguments
444    ///
445    /// * `subject` - Subject data containing observations
446    /// * `support_point` - Parameter vector for the model
447    /// * `error_model` - Error model to use for likelihood calculations
448    ///
449    /// # Returns
450    ///
451    /// The log-likelihood of the observed data given the model and parameters.
452    fn estimate_likelihood(
453        &self,
454        subject: &Subject,
455        support_point: &[f64],
456        error_models: &AssayErrorModels,
457    ) -> Result<f64, PharmsolError> {
458        _estimate_likelihood(self, subject, support_point, error_models)
459    }
460
461    fn estimate_log_likelihood(
462        &self,
463        subject: &Subject,
464        support_point: &[f64],
465        error_models: &AssayErrorModels,
466    ) -> Result<f64, PharmsolError> {
467        // For SDE, the particle filter computes likelihood in regular space.
468        // We compute it directly and then take the log.
469        let lik = _estimate_likelihood(self, subject, support_point, error_models)?;
470
471        if lik > 0.0 {
472            Ok(lik.ln())
473        } else {
474            Ok(f64::NEG_INFINITY)
475        }
476    }
477
478    fn kind() -> crate::EqnKind {
479        crate::EqnKind::SDE
480    }
481}
482
483#[inline(always)]
484fn _estimate_likelihood(
485    sde: &SDE,
486    subject: &Subject,
487    support_point: &[f64],
488    error_models: &AssayErrorModels,
489) -> Result<f64, PharmsolError> {
490    if let Some(cache) = &sde.cache {
491        let key = (subject.hash(), spphash(support_point), error_models.hash());
492        if let Some(cached) = cache.get(&key) {
493            return Ok(cached);
494        }
495
496        let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
497        let result = ypred.1.unwrap();
498        cache.insert(key, result);
499        Ok(result)
500    } else {
501        let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
502        Ok(ypred.1.unwrap())
503    }
504}
505
506/// Performs systematic resampling of particles based on weights.
507///
508/// # Arguments
509///
510/// * `q` - Vector of particle weights
511///
512/// # Returns
513///
514/// Vector of indices to use for resampling.
515fn sysresample(q: &[f64]) -> Vec<usize> {
516    let mut qc = vec![0.0; q.len()];
517    qc[0] = q[0];
518    for i in 1..q.len() {
519        qc[i] = qc[i - 1] + q[i];
520    }
521    let m = q.len();
522    let mut rng = rng();
523    let u: Vec<f64> = (0..m)
524        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
525        .collect();
526    let mut i = vec![0; m];
527    let mut k = 0;
528    for j in 0..m {
529        while qc[k] < u[j] {
530            k += 1;
531        }
532        i[j] = k;
533    }
534    i
535}