pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, Rng};
7use rayon::prelude::*;
8
9use cached::proc_macro::cached;
10use cached::UnboundCache;
11
12use crate::{
13    data::{Covariates, Infusion},
14    error_model::ErrorModels,
15    prelude::simulator::Prediction,
16    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
17    Subject,
18};
19
20use diffsol::VectorCommon;
21
22use crate::PharmsolError;
23
24use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
25
26/// Simulate a stochastic differential equation (SDE) event.
27///
28/// This function advances the SDE system from time `ti` to `tf` using
29/// the Euler-Maruyama method implemented in the `em` module.
30///
31/// # Arguments
32///
33/// * `drift` - Function defining the deterministic component of the SDE
34/// * `difussion` - Function defining the stochastic component of the SDE
35/// * `x` - Current state vector
36/// * `support_point` - Parameter vector for the model
37/// * `cov` - Covariates that may influence the system dynamics
38/// * `infusions` - Infusion events to be applied during simulation
39/// * `ti` - Starting time
40/// * `tf` - Ending time
41///
42/// # Returns
43///
44/// The state vector at time `tf` after simulation.
45#[inline(always)]
46pub(crate) fn simulate_sde_event(
47    drift: &Drift,
48    difussion: &Diffusion,
49    x: V,
50    support_point: &[f64],
51    cov: &Covariates,
52    infusions: &[Infusion],
53    ti: f64,
54    tf: f64,
55) -> V {
56    if ti == tf {
57        return x;
58    }
59
60    let mut sde = em::EM::new(
61        *drift,
62        *difussion,
63        DVector::from_column_slice(support_point),
64        x.inner().clone(),
65        cov.clone(),
66        infusions.to_vec(),
67        1e-2,
68        1e-2,
69    );
70    let (_time, solution) = sde.solve(ti, tf);
71    solution.last().unwrap().clone().into()
72}
73
74/// Stochastic Differential Equation solver for pharmacometric models.
75///
76/// This struct represents a stochastic differential equation system and provides
77/// methods to simulate particles and estimate likelihood for PKPD modeling.
78///
79/// SDE models introduce stochasticity into the system dynamics, allowing for more
80/// realistic modeling of biological variability and uncertainty.
81#[derive(Clone, Debug)]
82pub struct SDE {
83    drift: Drift,
84    diffusion: Diffusion,
85    lag: Lag,
86    fa: Fa,
87    init: Init,
88    out: Out,
89    neqs: Neqs,
90    nparticles: usize,
91}
92
93impl SDE {
94    /// Creates a new stochastic differential equation solver.
95    ///
96    /// # Arguments
97    ///
98    /// * `drift` - Function defining the deterministic component of the SDE
99    /// * `diffusion` - Function defining the stochastic component of the SDE
100    /// * `lag` - Function to compute absorption lag times
101    /// * `fa` - Function to compute bioavailability fractions
102    /// * `init` - Function to initialize the system state
103    /// * `out` - Function to compute output equations
104    /// * `neqs` - Tuple containing the number of state and output equations
105    /// * `nparticles` - Number of particles to use in the simulation
106    ///
107    /// # Returns
108    ///
109    /// A new SDE solver instance configured with the given components.
110    #[allow(clippy::too_many_arguments)]
111    pub fn new(
112        drift: Drift,
113        diffusion: Diffusion,
114        lag: Lag,
115        fa: Fa,
116        init: Init,
117        out: Out,
118        neqs: Neqs,
119        nparticles: usize,
120    ) -> Self {
121        Self {
122            drift,
123            diffusion,
124            lag,
125            fa,
126            init,
127            out,
128            neqs,
129            nparticles,
130        }
131    }
132}
133
134/// State trait implementation for particle-based SDE simulation.
135///
136/// This implementation allows adding bolus doses to all particles in the system.
137impl State for Vec<DVector<f64>> {
138    /// Adds a bolus dose to a specific input compartment across all particles.
139    ///
140    /// # Arguments
141    ///
142    /// * `input` - Index of the input compartment
143    /// * `amount` - Amount to add to the compartment
144    fn add_bolus(&mut self, input: usize, amount: f64) {
145        self.par_iter_mut().for_each(|particle| {
146            particle[input] += amount;
147        });
148    }
149}
150
151/// Predictions implementation for particle-based SDE simulation outputs.
152///
153/// This implementation manages and processes predictions from multiple particles.
154impl Predictions for Array2<Prediction> {
155    fn new(nparticles: usize) -> Self {
156        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
157    }
158    fn squared_error(&self) -> f64 {
159        unimplemented!();
160    }
161    fn get_predictions(&self) -> Vec<Prediction> {
162        // Make this return the mean prediction across all particles
163        if self.is_empty() || self.ncols() == 0 {
164            return Vec::new();
165        }
166
167        let mut result = Vec::with_capacity(self.ncols());
168
169        for col in 0..self.ncols() {
170            let column = self.column(col);
171
172            let mean_prediction: f64 = column
173                .iter()
174                .map(|pred: &Prediction| pred.prediction())
175                .sum::<f64>()
176                / self.nrows() as f64;
177
178            let mut prediction = column.first().unwrap().clone();
179            prediction.set_prediction(mean_prediction);
180            result.push(prediction);
181        }
182
183        result
184    }
185}
186
187impl EquationTypes for SDE {
188    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
189    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
190}
191
192impl EquationPriv for SDE {
193    // #[inline(always)]
194    // fn get_init(&self) -> &Init {
195    //     &self.init
196    // }
197
198    // #[inline(always)]
199    // fn get_out(&self) -> &Out {
200    //     &self.out
201    // }
202
203    // #[inline(always)]
204    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
205    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
206    // }
207
208    // #[inline(always)]
209    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
210    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
211    // }
212
213    #[inline(always)]
214    fn lag(&self) -> &Lag {
215        &self.lag
216    }
217
218    #[inline(always)]
219    fn fa(&self) -> &Fa {
220        &self.fa
221    }
222
223    #[inline(always)]
224    fn get_nstates(&self) -> usize {
225        self.neqs.0
226    }
227
228    #[inline(always)]
229    fn get_nouteqs(&self) -> usize {
230        self.neqs.1
231    }
232    #[inline(always)]
233    fn solve(
234        &self,
235        state: &mut Self::S,
236        support_point: &Vec<f64>,
237        covariates: &Covariates,
238        infusions: &Vec<Infusion>,
239        ti: f64,
240        tf: f64,
241    ) -> Result<(), PharmsolError> {
242        state.par_iter_mut().for_each(|particle| {
243            *particle = simulate_sde_event(
244                &self.drift,
245                &self.diffusion,
246                particle.clone().into(),
247                support_point,
248                covariates,
249                infusions,
250                ti,
251                tf,
252            )
253            .inner()
254            .clone();
255        });
256        Ok(())
257    }
258    fn nparticles(&self) -> usize {
259        self.nparticles
260    }
261
262    fn is_sde(&self) -> bool {
263        true
264    }
265    #[inline(always)]
266    fn process_observation(
267        &self,
268        support_point: &Vec<f64>,
269        observation: &crate::Observation,
270        error_models: Option<&ErrorModels>,
271        _time: f64,
272        covariates: &Covariates,
273        x: &mut Self::S,
274        likelihood: &mut Vec<f64>,
275        output: &mut Self::P,
276    ) -> Result<(), PharmsolError> {
277        let mut pred = vec![Prediction::default(); self.nparticles];
278        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
279            let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
280            (self.out)(
281                &x[i].clone().into(),
282                &V::from_vec(support_point.clone(), NalgebraContext),
283                observation.time(),
284                covariates,
285                &mut y,
286            );
287            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
288        });
289        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
290        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
291        //e = y[t] .- x[:,1]
292        // q = pdf.(Distributions.Normal(0, 0.5), e)
293        if let Some(em) = error_models {
294            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
295
296            pred.iter().for_each(|p| {
297                let lik = p.likelihood(em);
298                match lik {
299                    Ok(l) => q.push(l),
300                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
301                }
302            });
303            let sum_q: f64 = q.iter().sum();
304            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
305            let i = sysresample(&w);
306            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
307            *x = a;
308            likelihood.push(sum_q / self.nparticles as f64);
309            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
310            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
311        }
312        Ok(())
313    }
314    #[inline(always)]
315    fn initial_state(
316        &self,
317        support_point: &Vec<f64>,
318        covariates: &Covariates,
319        occasion_index: usize,
320    ) -> Self::S {
321        let mut x = Vec::with_capacity(self.nparticles);
322        for _ in 0..self.nparticles {
323            let mut state: V = DVector::zeros(self.get_nstates()).into();
324            if occasion_index == 0 {
325                (self.init)(
326                    &V::from_vec(support_point.to_vec(), NalgebraContext),
327                    0.0,
328                    covariates,
329                    &mut state,
330                );
331            }
332            x.push(state.inner().clone());
333        }
334        x
335    }
336}
337
338impl Equation for SDE {
339    /// Estimates the likelihood of observed data given a model and parameters.
340    ///
341    /// # Arguments
342    ///
343    /// * `subject` - Subject data containing observations
344    /// * `support_point` - Parameter vector for the model
345    /// * `error_model` - Error model to use for likelihood calculations
346    /// * `cache` - Whether to cache likelihood results for reuse
347    ///
348    /// # Returns
349    ///
350    /// The log-likelihood of the observed data given the model and parameters.
351    fn estimate_likelihood(
352        &self,
353        subject: &Subject,
354        support_point: &Vec<f64>,
355        error_models: &ErrorModels,
356        cache: bool,
357    ) -> Result<f64, PharmsolError> {
358        if cache {
359            _estimate_likelihood(self, subject, support_point, error_models)
360        } else {
361            _estimate_likelihood_no_cache(self, subject, support_point, error_models)
362        }
363    }
364
365    fn kind() -> crate::EqnKind {
366        crate::EqnKind::SDE
367    }
368}
369
370/// Computes a hash value for a parameter vector.
371///
372/// # Arguments
373///
374/// * `spp` - Parameter vector
375///
376/// # Returns
377///
378/// A u64 hash value representing the parameter vector.
379fn spphash(spp: &[f64]) -> u64 {
380    spp.iter().fold(0, |acc, x| acc + x.to_bits())
381}
382
383#[inline(always)]
384#[cached(
385    ty = "UnboundCache<String, f64>",
386    create = "{ UnboundCache::with_capacity(100_000) }",
387    convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
388    result = "true"
389)]
390fn _estimate_likelihood(
391    sde: &SDE,
392    subject: &Subject,
393    support_point: &Vec<f64>,
394    error_models: &ErrorModels,
395) -> Result<f64, PharmsolError> {
396    let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
397    Ok(ypred.1.unwrap())
398}
399
400/// Performs systematic resampling of particles based on weights.
401///
402/// # Arguments
403///
404/// * `q` - Vector of particle weights
405///
406/// # Returns
407///
408/// Vector of indices to use for resampling.
409fn sysresample(q: &[f64]) -> Vec<usize> {
410    let mut qc = vec![0.0; q.len()];
411    qc[0] = q[0];
412    for i in 1..q.len() {
413        qc[i] = qc[i - 1] + q[i];
414    }
415    let m = q.len();
416    let mut rng = rng();
417    let u: Vec<f64> = (0..m)
418        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
419        .collect();
420    let mut i = vec![0; m];
421    let mut k = 0;
422    for j in 0..m {
423        while qc[k] < u[j] {
424            k += 1;
425        }
426        i[j] = k;
427    }
428    i
429}