pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use nalgebra::DVector;
4use ndarray::{concatenate, Array2, Axis};
5use rand::{rng, Rng};
6use rayon::prelude::*;
7
8use cached::proc_macro::cached;
9use cached::UnboundCache;
10
11use crate::{
12    data::{Covariates, Infusion},
13    error_model::ErrorModels,
14    prelude::simulator::Prediction,
15    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
16    Subject,
17};
18
19use crate::PharmsolError;
20
21use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
22
23/// Simulate a stochastic differential equation (SDE) event.
24///
25/// This function advances the SDE system from time `ti` to `tf` using
26/// the Euler-Maruyama method implemented in the `em` module.
27///
28/// # Arguments
29///
30/// * `drift` - Function defining the deterministic component of the SDE
31/// * `difussion` - Function defining the stochastic component of the SDE
32/// * `x` - Current state vector
33/// * `support_point` - Parameter vector for the model
34/// * `cov` - Covariates that may influence the system dynamics
35/// * `infusions` - Infusion events to be applied during simulation
36/// * `ti` - Starting time
37/// * `tf` - Ending time
38///
39/// # Returns
40///
41/// The state vector at time `tf` after simulation.
42#[inline(always)]
43pub(crate) fn simulate_sde_event(
44    drift: &Drift,
45    difussion: &Diffusion,
46    x: V,
47    support_point: &[f64],
48    cov: &Covariates,
49    infusions: &[Infusion],
50    ti: f64,
51    tf: f64,
52) -> V {
53    if ti == tf {
54        return x;
55    }
56
57    let mut sde = em::EM::new(
58        *drift,
59        *difussion,
60        DVector::from_column_slice(support_point),
61        x,
62        cov.clone(),
63        infusions.to_vec(),
64        1e-2,
65        1e-2,
66    );
67    let (_time, solution) = sde.solve(ti, tf);
68    solution.last().unwrap().clone()
69}
70
71/// Stochastic Differential Equation solver for pharmacometric models.
72///
73/// This struct represents a stochastic differential equation system and provides
74/// methods to simulate particles and estimate likelihood for PKPD modeling.
75///
76/// SDE models introduce stochasticity into the system dynamics, allowing for more
77/// realistic modeling of biological variability and uncertainty.
78#[derive(Clone, Debug)]
79pub struct SDE {
80    drift: Drift,
81    diffusion: Diffusion,
82    lag: Lag,
83    fa: Fa,
84    init: Init,
85    out: Out,
86    neqs: Neqs,
87    nparticles: usize,
88}
89
90impl SDE {
91    /// Creates a new stochastic differential equation solver.
92    ///
93    /// # Arguments
94    ///
95    /// * `drift` - Function defining the deterministic component of the SDE
96    /// * `diffusion` - Function defining the stochastic component of the SDE
97    /// * `lag` - Function to compute absorption lag times
98    /// * `fa` - Function to compute bioavailability fractions
99    /// * `init` - Function to initialize the system state
100    /// * `out` - Function to compute output equations
101    /// * `neqs` - Tuple containing the number of state and output equations
102    /// * `nparticles` - Number of particles to use in the simulation
103    ///
104    /// # Returns
105    ///
106    /// A new SDE solver instance configured with the given components.
107    #[allow(clippy::too_many_arguments)]
108    pub fn new(
109        drift: Drift,
110        diffusion: Diffusion,
111        lag: Lag,
112        fa: Fa,
113        init: Init,
114        out: Out,
115        neqs: Neqs,
116        nparticles: usize,
117    ) -> Self {
118        Self {
119            drift,
120            diffusion,
121            lag,
122            fa,
123            init,
124            out,
125            neqs,
126            nparticles,
127        }
128    }
129}
130
131/// State trait implementation for particle-based SDE simulation.
132///
133/// This implementation allows adding bolus doses to all particles in the system.
134impl State for Vec<DVector<f64>> {
135    /// Adds a bolus dose to a specific input compartment across all particles.
136    ///
137    /// # Arguments
138    ///
139    /// * `input` - Index of the input compartment
140    /// * `amount` - Amount to add to the compartment
141    fn add_bolus(&mut self, input: usize, amount: f64) {
142        self.par_iter_mut().for_each(|particle| {
143            particle[input] += amount;
144        });
145    }
146}
147
148/// Predictions implementation for particle-based SDE simulation outputs.
149///
150/// This implementation manages and processes predictions from multiple particles.
151impl Predictions for Array2<Prediction> {
152    fn new(nparticles: usize) -> Self {
153        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
154    }
155    fn squared_error(&self) -> f64 {
156        unimplemented!();
157    }
158    fn get_predictions(&self) -> Vec<Prediction> {
159        //TODO: This is only returning the first particle, not the best, not the worst, THE FIRST
160        // CHANGE THIS
161        let row = self.row(0).to_vec();
162        row
163    }
164}
165
166impl EquationTypes for SDE {
167    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
168    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
169}
170
171impl EquationPriv for SDE {
172    // #[inline(always)]
173    // fn get_init(&self) -> &Init {
174    //     &self.init
175    // }
176
177    // #[inline(always)]
178    // fn get_out(&self) -> &Out {
179    //     &self.out
180    // }
181
182    // #[inline(always)]
183    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
184    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
185    // }
186
187    // #[inline(always)]
188    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
189    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
190    // }
191
192    #[inline(always)]
193    fn lag(&self) -> &Lag {
194        &self.lag
195    }
196
197    #[inline(always)]
198    fn fa(&self) -> &Fa {
199        &self.fa
200    }
201
202    #[inline(always)]
203    fn get_nstates(&self) -> usize {
204        self.neqs.0
205    }
206
207    #[inline(always)]
208    fn get_nouteqs(&self) -> usize {
209        self.neqs.1
210    }
211    #[inline(always)]
212    fn solve(
213        &self,
214        state: &mut Self::S,
215        support_point: &Vec<f64>,
216        covariates: &Covariates,
217        infusions: &Vec<Infusion>,
218        ti: f64,
219        tf: f64,
220    ) -> Result<(), PharmsolError> {
221        state.par_iter_mut().for_each(|particle| {
222            *particle = simulate_sde_event(
223                &self.drift,
224                &self.diffusion,
225                particle.clone(),
226                support_point,
227                covariates,
228                infusions,
229                ti,
230                tf,
231            );
232        });
233        Ok(())
234    }
235    fn nparticles(&self) -> usize {
236        self.nparticles
237    }
238
239    fn is_sde(&self) -> bool {
240        true
241    }
242    #[inline(always)]
243    fn process_observation(
244        &self,
245        support_point: &Vec<f64>,
246        observation: &crate::Observation,
247        error_models: Option<&ErrorModels>,
248        _time: f64,
249        covariates: &Covariates,
250        x: &mut Self::S,
251        likelihood: &mut Vec<f64>,
252        output: &mut Self::P,
253    ) -> Result<(), PharmsolError> {
254        let mut pred = vec![Prediction::default(); self.nparticles];
255        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
256            let mut y = V::zeros(self.get_nouteqs());
257            (self.out)(
258                &x[i],
259                &V::from_vec(support_point.clone()),
260                observation.time(),
261                covariates,
262                &mut y,
263            );
264            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
265        });
266        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
267        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
268        //e = y[t] .- x[:,1]
269        // q = pdf.(Distributions.Normal(0, 0.5), e)
270        if let Some(em) = error_models {
271            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
272
273            pred.iter().for_each(|p| {
274                let lik = p.likelihood(em);
275                match lik {
276                    Ok(l) => q.push(l),
277                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
278                }
279            });
280            let sum_q: f64 = q.iter().sum();
281            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
282            let i = sysresample(&w);
283            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
284            *x = a;
285            likelihood.push(sum_q / self.nparticles as f64);
286            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
287            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
288        }
289        Ok(())
290    }
291    #[inline(always)]
292    fn initial_state(
293        &self,
294        support_point: &Vec<f64>,
295        covariates: &Covariates,
296        occasion_index: usize,
297    ) -> Self::S {
298        let mut x = Vec::with_capacity(self.nparticles);
299        for _ in 0..self.nparticles {
300            let mut state = DVector::zeros(self.get_nstates());
301            if occasion_index == 0 {
302                (self.init)(
303                    &V::from_vec(support_point.to_vec()),
304                    0.0,
305                    covariates,
306                    &mut state,
307                );
308            }
309            x.push(state);
310        }
311        x
312    }
313}
314
315impl Equation for SDE {
316    /// Estimates the likelihood of observed data given a model and parameters.
317    ///
318    /// # Arguments
319    ///
320    /// * `subject` - Subject data containing observations
321    /// * `support_point` - Parameter vector for the model
322    /// * `error_model` - Error model to use for likelihood calculations
323    /// * `cache` - Whether to cache likelihood results for reuse
324    ///
325    /// # Returns
326    ///
327    /// The log-likelihood of the observed data given the model and parameters.
328    fn estimate_likelihood(
329        &self,
330        subject: &Subject,
331        support_point: &Vec<f64>,
332        error_models: &ErrorModels,
333        cache: bool,
334    ) -> Result<f64, PharmsolError> {
335        if cache {
336            _estimate_likelihood(self, subject, support_point, error_models)
337        } else {
338            _estimate_likelihood_no_cache(self, subject, support_point, error_models)
339        }
340    }
341
342    fn kind() -> crate::EqnKind {
343        crate::EqnKind::SDE
344    }
345}
346
347/// Computes a hash value for a parameter vector.
348///
349/// # Arguments
350///
351/// * `spp` - Parameter vector
352///
353/// # Returns
354///
355/// A u64 hash value representing the parameter vector.
356fn spphash(spp: &[f64]) -> u64 {
357    spp.iter().fold(0, |acc, x| acc + x.to_bits())
358}
359
360#[inline(always)]
361#[cached(
362    ty = "UnboundCache<String, f64>",
363    create = "{ UnboundCache::with_capacity(100_000) }",
364    convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
365    result = "true"
366)]
367fn _estimate_likelihood(
368    sde: &SDE,
369    subject: &Subject,
370    support_point: &Vec<f64>,
371    error_models: &ErrorModels,
372) -> Result<f64, PharmsolError> {
373    let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
374    Ok(ypred.1.unwrap())
375}
376
377/// Performs systematic resampling of particles based on weights.
378///
379/// # Arguments
380///
381/// * `q` - Vector of particle weights
382///
383/// # Returns
384///
385/// Vector of indices to use for resampling.
386fn sysresample(q: &[f64]) -> Vec<usize> {
387    let mut qc = vec![0.0; q.len()];
388    qc[0] = q[0];
389    for i in 1..q.len() {
390        qc[i] = qc[i - 1] + q[i];
391    }
392    let m = q.len();
393    let mut rng = rng();
394    let u: Vec<f64> = (0..m)
395        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
396        .collect();
397    let mut i = vec![0; m];
398    let mut k = 0;
399    for j in 0..m {
400        while qc[k] < u[j] {
401            k += 1;
402        }
403        i[j] = k;
404    }
405    i
406}