pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, Rng};
7use rayon::prelude::*;
8
9use cached::proc_macro::cached;
10use cached::UnboundCache;
11
12use crate::{
13    data::{Covariates, Infusion},
14    error_model::ErrorModels,
15    mapping::Mappings,
16    prelude::simulator::Prediction,
17    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
18    Subject,
19};
20
21use diffsol::VectorCommon;
22
23use crate::PharmsolError;
24
25use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
26
27/// Simulate a stochastic differential equation (SDE) event.
28///
29/// This function advances the SDE system from time `ti` to `tf` using
30/// the Euler-Maruyama method implemented in the `em` module.
31///
32/// # Arguments
33///
34/// * `drift` - Function defining the deterministic component of the SDE
35/// * `difussion` - Function defining the stochastic component of the SDE
36/// * `x` - Current state vector
37/// * `support_point` - Parameter vector for the model
38/// * `cov` - Covariates that may influence the system dynamics
39/// * `infusions` - Infusion events to be applied during simulation
40/// * `ti` - Starting time
41/// * `tf` - Ending time
42///
43/// # Returns
44///
45/// The state vector at time `tf` after simulation.
46#[inline(always)]
47pub(crate) fn simulate_sde_event(
48    drift: &Drift,
49    difussion: &Diffusion,
50    x: V,
51    support_point: &[f64],
52    cov: &Covariates,
53    infusions: &[Infusion],
54    ti: f64,
55    tf: f64,
56) -> V {
57    if ti == tf {
58        return x;
59    }
60
61    let mut sde = em::EM::new(
62        *drift,
63        *difussion,
64        DVector::from_column_slice(support_point),
65        x.inner().clone(),
66        cov.clone(),
67        infusions.to_vec(),
68        1e-2,
69        1e-2,
70    );
71    let (_time, solution) = sde.solve(ti, tf);
72    solution.last().unwrap().clone().into()
73}
74
75/// Stochastic Differential Equation solver for pharmacometric models.
76///
77/// This struct represents a stochastic differential equation system and provides
78/// methods to simulate particles and estimate likelihood for PKPD modeling.
79///
80/// SDE models introduce stochasticity into the system dynamics, allowing for more
81/// realistic modeling of biological variability and uncertainty.
82#[derive(Clone, Debug)]
83pub struct SDE {
84    drift: Drift,
85    diffusion: Diffusion,
86    lag: Lag,
87    fa: Fa,
88    init: Init,
89    out: Out,
90    neqs: Neqs,
91    nparticles: usize,
92    mappings: Mappings,
93}
94
95impl SDE {
96    /// Creates a new stochastic differential equation solver.
97    ///
98    /// # Arguments
99    ///
100    /// * `drift` - Function defining the deterministic component of the SDE
101    /// * `diffusion` - Function defining the stochastic component of the SDE
102    /// * `lag` - Function to compute absorption lag times
103    /// * `fa` - Function to compute bioavailability fractions
104    /// * `init` - Function to initialize the system state
105    /// * `out` - Function to compute output equations
106    /// * `neqs` - Tuple containing the number of state and output equations
107    /// * `nparticles` - Number of particles to use in the simulation
108    ///
109    /// # Returns
110    ///
111    /// A new SDE solver instance configured with the given components.
112    #[allow(clippy::too_many_arguments)]
113    pub fn new(
114        drift: Drift,
115        diffusion: Diffusion,
116        lag: Lag,
117        fa: Fa,
118        init: Init,
119        out: Out,
120        neqs: Neqs,
121        nparticles: usize,
122    ) -> Self {
123        Self {
124            drift,
125            diffusion,
126            lag,
127            fa,
128            init,
129            out,
130            neqs,
131            nparticles,
132            mappings: Mappings::new(),
133        }
134    }
135}
136
137/// State trait implementation for particle-based SDE simulation.
138///
139/// This implementation allows adding bolus doses to all particles in the system.
140impl State for Vec<DVector<f64>> {
141    /// Adds a bolus dose to a specific input compartment across all particles.
142    ///
143    /// # Arguments
144    ///
145    /// * `input` - Index of the input compartment
146    /// * `amount` - Amount to add to the compartment
147    fn add_bolus(&mut self, input: usize, amount: f64) {
148        self.par_iter_mut().for_each(|particle| {
149            particle[input] += amount;
150        });
151    }
152}
153
154/// Predictions implementation for particle-based SDE simulation outputs.
155///
156/// This implementation manages and processes predictions from multiple particles.
157impl Predictions for Array2<Prediction> {
158    fn new(nparticles: usize) -> Self {
159        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
160    }
161    fn squared_error(&self) -> f64 {
162        unimplemented!();
163    }
164    fn get_predictions(&self) -> Vec<Prediction> {
165        // Make this return the mean prediction across all particles
166        if self.is_empty() || self.ncols() == 0 {
167            return Vec::new();
168        }
169
170        let mut result = Vec::with_capacity(self.ncols());
171
172        for col in 0..self.ncols() {
173            let column = self.column(col);
174
175            let mean_prediction: f64 = column
176                .iter()
177                .map(|pred: &Prediction| pred.prediction())
178                .sum::<f64>()
179                / self.nrows() as f64;
180
181            let mut prediction = column.first().unwrap().clone();
182            prediction.set_prediction(mean_prediction);
183            result.push(prediction);
184        }
185
186        result
187    }
188}
189
190impl EquationTypes for SDE {
191    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
192    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
193}
194
195impl EquationPriv for SDE {
196    // #[inline(always)]
197    // fn get_init(&self) -> &Init {
198    //     &self.init
199    // }
200
201    // #[inline(always)]
202    // fn get_out(&self) -> &Out {
203    //     &self.out
204    // }
205
206    // #[inline(always)]
207    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
208    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
209    // }
210
211    // #[inline(always)]
212    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
213    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
214    // }
215
216    #[inline(always)]
217    fn lag(&self) -> &Lag {
218        &self.lag
219    }
220
221    #[inline(always)]
222    fn fa(&self) -> &Fa {
223        &self.fa
224    }
225
226    #[inline(always)]
227    fn get_nstates(&self) -> usize {
228        self.neqs.0
229    }
230
231    #[inline(always)]
232    fn get_nouteqs(&self) -> usize {
233        self.neqs.1
234    }
235    #[inline(always)]
236    fn solve(
237        &self,
238        state: &mut Self::S,
239        support_point: &Vec<f64>,
240        covariates: &Covariates,
241        infusions: &Vec<Infusion>,
242        ti: f64,
243        tf: f64,
244    ) -> Result<(), PharmsolError> {
245        state.par_iter_mut().for_each(|particle| {
246            *particle = simulate_sde_event(
247                &self.drift,
248                &self.diffusion,
249                particle.clone().into(),
250                support_point,
251                covariates,
252                infusions,
253                ti,
254                tf,
255            )
256            .inner()
257            .clone();
258        });
259        Ok(())
260    }
261    fn nparticles(&self) -> usize {
262        self.nparticles
263    }
264
265    fn is_sde(&self) -> bool {
266        true
267    }
268    #[inline(always)]
269    fn process_observation(
270        &self,
271        support_point: &Vec<f64>,
272        observation: &crate::Observation,
273        error_models: Option<&ErrorModels>,
274        _time: f64,
275        covariates: &Covariates,
276        x: &mut Self::S,
277        likelihood: &mut Vec<f64>,
278        output: &mut Self::P,
279    ) -> Result<(), PharmsolError> {
280        let mut pred = vec![Prediction::default(); self.nparticles];
281        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
282            let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
283            (self.out)(
284                &x[i].clone().into(),
285                &V::from_vec(support_point.clone(), NalgebraContext),
286                observation.time(),
287                covariates,
288                &mut y,
289            );
290            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
291        });
292        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
293        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
294        //e = y[t] .- x[:,1]
295        // q = pdf.(Distributions.Normal(0, 0.5), e)
296        if let Some(em) = error_models {
297            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
298
299            pred.iter().for_each(|p| {
300                let lik = p.likelihood(em);
301                match lik {
302                    Ok(l) => q.push(l),
303                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
304                }
305            });
306            let sum_q: f64 = q.iter().sum();
307            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
308            let i = sysresample(&w);
309            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
310            *x = a;
311            likelihood.push(sum_q / self.nparticles as f64);
312            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
313            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
314        }
315        Ok(())
316    }
317    #[inline(always)]
318    fn initial_state(
319        &self,
320        support_point: &Vec<f64>,
321        covariates: &Covariates,
322        occasion_index: usize,
323    ) -> Self::S {
324        let mut x = Vec::with_capacity(self.nparticles);
325        for _ in 0..self.nparticles {
326            let mut state: V = DVector::zeros(self.get_nstates()).into();
327            if occasion_index == 0 {
328                (self.init)(
329                    &V::from_vec(support_point.to_vec(), NalgebraContext),
330                    0.0,
331                    covariates,
332                    &mut state,
333                );
334            }
335            x.push(state.inner().clone());
336        }
337        x
338    }
339}
340
341impl Equation for SDE {
342    /// Estimates the likelihood of observed data given a model and parameters.
343    ///
344    /// # Arguments
345    ///
346    /// * `subject` - Subject data containing observations
347    /// * `support_point` - Parameter vector for the model
348    /// * `error_model` - Error model to use for likelihood calculations
349    /// * `cache` - Whether to cache likelihood results for reuse
350    ///
351    /// # Returns
352    ///
353    /// The log-likelihood of the observed data given the model and parameters.
354    fn estimate_likelihood(
355        &self,
356        subject: &Subject,
357        support_point: &Vec<f64>,
358        error_models: &ErrorModels,
359        cache: bool,
360    ) -> Result<f64, PharmsolError> {
361        if cache {
362            _estimate_likelihood(self, subject, support_point, error_models)
363        } else {
364            _estimate_likelihood_no_cache(self, subject, support_point, error_models)
365        }
366    }
367    fn mappings_ref(&self) -> &Mappings {
368        &self.mappings
369    }
370    fn mappings_mut(&mut self) -> &mut Mappings {
371        &mut self.mappings
372    }
373
374    fn kind() -> crate::EqnKind {
375        crate::EqnKind::SDE
376    }
377}
378
379/// Computes a hash value for a parameter vector.
380///
381/// # Arguments
382///
383/// * `spp` - Parameter vector
384///
385/// # Returns
386///
387/// A u64 hash value representing the parameter vector.
388fn spphash(spp: &[f64]) -> u64 {
389    spp.iter().fold(0, |acc, x| acc + x.to_bits())
390}
391
392#[inline(always)]
393#[cached(
394    ty = "UnboundCache<String, f64>",
395    create = "{ UnboundCache::with_capacity(100_000) }",
396    convert = r#"{ format!("{}{}{:#?}", subject.id(), spphash(support_point), error_models.hash()) }"#,
397    result = "true"
398)]
399fn _estimate_likelihood(
400    sde: &SDE,
401    subject: &Subject,
402    support_point: &Vec<f64>,
403    error_models: &ErrorModels,
404) -> Result<f64, PharmsolError> {
405    let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
406    Ok(ypred.1.unwrap())
407}
408
409/// Performs systematic resampling of particles based on weights.
410///
411/// # Arguments
412///
413/// * `q` - Vector of particle weights
414///
415/// # Returns
416///
417/// Vector of indices to use for resampling.
418fn sysresample(q: &[f64]) -> Vec<usize> {
419    let mut qc = vec![0.0; q.len()];
420    qc[0] = q[0];
421    for i in 1..q.len() {
422        qc[i] = qc[i - 1] + q[i];
423    }
424    let m = q.len();
425    let mut rng = rng();
426    let u: Vec<f64> = (0..m)
427        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
428        .collect();
429    let mut i = vec![0; m];
430    let mut k = 0;
431    for j in 0..m {
432        while qc[k] < u[j] {
433            k += 1;
434        }
435        i[j] = k;
436    }
437    i
438}