pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use std::collections::HashMap;
4
5use nalgebra::DVector;
6use ndarray::{concatenate, Array2, Axis};
7use rand::{rng, Rng};
8use rayon::prelude::*;
9
10use cached::proc_macro::cached;
11use cached::UnboundCache;
12
13use crate::{
14    data::{Covariates, Infusion},
15    error_model::ErrorModel,
16    prelude::simulator::Prediction,
17    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
18    Subject,
19};
20
21use crate::PharmsolError;
22
23use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
24
25/// Simulate a stochastic differential equation (SDE) event.
26///
27/// This function advances the SDE system from time `ti` to `tf` using
28/// the Euler-Maruyama method implemented in the `em` module.
29///
30/// # Arguments
31///
32/// * `drift` - Function defining the deterministic component of the SDE
33/// * `difussion` - Function defining the stochastic component of the SDE
34/// * `x` - Current state vector
35/// * `support_point` - Parameter vector for the model
36/// * `cov` - Covariates that may influence the system dynamics
37/// * `infusions` - Infusion events to be applied during simulation
38/// * `ti` - Starting time
39/// * `tf` - Ending time
40///
41/// # Returns
42///
43/// The state vector at time `tf` after simulation.
44#[inline(always)]
45pub(crate) fn simulate_sde_event(
46    drift: &Drift,
47    difussion: &Diffusion,
48    x: V,
49    support_point: &[f64],
50    cov: &Covariates,
51    infusions: &[Infusion],
52    ti: f64,
53    tf: f64,
54) -> V {
55    if ti == tf {
56        return x;
57    }
58
59    let mut sde = em::EM::new(
60        *drift,
61        *difussion,
62        DVector::from_column_slice(support_point),
63        x,
64        cov.clone(),
65        infusions.to_vec(),
66        1e-2,
67        1e-2,
68    );
69    let (_time, solution) = sde.solve(ti, tf);
70    solution.last().unwrap().clone()
71}
72
73/// Stochastic Differential Equation solver for pharmacometric models.
74///
75/// This struct represents a stochastic differential equation system and provides
76/// methods to simulate particles and estimate likelihood for PKPD modeling.
77///
78/// SDE models introduce stochasticity into the system dynamics, allowing for more
79/// realistic modeling of biological variability and uncertainty.
80#[derive(Clone, Debug)]
81pub struct SDE {
82    drift: Drift,
83    diffusion: Diffusion,
84    lag: Lag,
85    fa: Fa,
86    init: Init,
87    out: Out,
88    neqs: Neqs,
89    nparticles: usize,
90}
91
92impl SDE {
93    /// Creates a new stochastic differential equation solver.
94    ///
95    /// # Arguments
96    ///
97    /// * `drift` - Function defining the deterministic component of the SDE
98    /// * `diffusion` - Function defining the stochastic component of the SDE
99    /// * `lag` - Function to compute absorption lag times
100    /// * `fa` - Function to compute bioavailability fractions
101    /// * `init` - Function to initialize the system state
102    /// * `out` - Function to compute output equations
103    /// * `neqs` - Tuple containing the number of state and output equations
104    /// * `nparticles` - Number of particles to use in the simulation
105    ///
106    /// # Returns
107    ///
108    /// A new SDE solver instance configured with the given components.
109    #[allow(clippy::too_many_arguments)]
110    pub fn new(
111        drift: Drift,
112        diffusion: Diffusion,
113        lag: Lag,
114        fa: Fa,
115        init: Init,
116        out: Out,
117        neqs: Neqs,
118        nparticles: usize,
119    ) -> Self {
120        Self {
121            drift,
122            diffusion,
123            lag,
124            fa,
125            init,
126            out,
127            neqs,
128            nparticles,
129        }
130    }
131}
132
133/// State trait implementation for particle-based SDE simulation.
134///
135/// This implementation allows adding bolus doses to all particles in the system.
136impl State for Vec<DVector<f64>> {
137    /// Adds a bolus dose to a specific input compartment across all particles.
138    ///
139    /// # Arguments
140    ///
141    /// * `input` - Index of the input compartment
142    /// * `amount` - Amount to add to the compartment
143    fn add_bolus(&mut self, input: usize, amount: f64) {
144        self.par_iter_mut().for_each(|particle| {
145            particle[input] += amount;
146        });
147    }
148}
149
150/// Predictions implementation for particle-based SDE simulation outputs.
151///
152/// This implementation manages and processes predictions from multiple particles.
153impl Predictions for Array2<Prediction> {
154    fn new(nparticles: usize) -> Self {
155        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
156    }
157    fn squared_error(&self) -> f64 {
158        unimplemented!();
159    }
160    fn get_predictions(&self) -> Vec<Prediction> {
161        //TODO: This is only returning the first particle, not the best, not the worst, THE FIRST
162        // CHANGE THIS
163        let row = self.row(0).to_vec();
164        row
165    }
166}
167
168impl EquationTypes for SDE {
169    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
170    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
171}
172
173impl EquationPriv for SDE {
174    // #[inline(always)]
175    // fn get_init(&self) -> &Init {
176    //     &self.init
177    // }
178
179    // #[inline(always)]
180    // fn get_out(&self) -> &Out {
181    //     &self.out
182    // }
183
184    #[inline(always)]
185    fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
186        Some((self.lag)(&V::from_vec(spp.to_owned())))
187    }
188
189    #[inline(always)]
190    fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
191        Some((self.fa)(&V::from_vec(spp.to_owned())))
192    }
193
194    #[inline(always)]
195    fn get_nstates(&self) -> usize {
196        self.neqs.0
197    }
198
199    #[inline(always)]
200    fn get_nouteqs(&self) -> usize {
201        self.neqs.1
202    }
203    #[inline(always)]
204    fn solve(
205        &self,
206        state: &mut Self::S,
207        support_point: &Vec<f64>,
208        covariates: &Covariates,
209        infusions: &Vec<Infusion>,
210        ti: f64,
211        tf: f64,
212    ) -> Result<(), PharmsolError> {
213        state.par_iter_mut().for_each(|particle| {
214            *particle = simulate_sde_event(
215                &self.drift,
216                &self.diffusion,
217                particle.clone(),
218                support_point,
219                covariates,
220                infusions,
221                ti,
222                tf,
223            );
224        });
225        Ok(())
226    }
227    fn nparticles(&self) -> usize {
228        self.nparticles
229    }
230
231    fn is_sde(&self) -> bool {
232        true
233    }
234    #[inline(always)]
235    fn process_observation(
236        &self,
237        support_point: &Vec<f64>,
238        observation: &crate::Observation,
239        error_model: Option<&ErrorModel>,
240        _time: f64,
241        covariates: &Covariates,
242        x: &mut Self::S,
243        likelihood: &mut Vec<f64>,
244        output: &mut Self::P,
245    ) -> Result<(), PharmsolError> {
246        let mut pred = vec![Prediction::default(); self.nparticles];
247        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
248            let mut y = V::zeros(self.get_nouteqs());
249            (self.out)(
250                &x[i],
251                &V::from_vec(support_point.clone()),
252                observation.time(),
253                covariates,
254                &mut y,
255            );
256            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
257        });
258        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
259        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
260        //e = y[t] .- x[:,1]
261        // q = pdf.(Distributions.Normal(0, 0.5), e)
262        if let Some(em) = error_model {
263            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
264
265            pred.iter().for_each(|p| {
266                let lik = p.likelihood(em);
267                match lik {
268                    Ok(l) => q.push(l),
269                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
270                }
271            });
272            let sum_q: f64 = q.iter().sum();
273            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
274            let i = sysresample(&w);
275            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
276            *x = a;
277            likelihood.push(sum_q / self.nparticles as f64);
278            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
279            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
280        }
281        Ok(())
282    }
283    #[inline(always)]
284    fn initial_state(
285        &self,
286        support_point: &Vec<f64>,
287        covariates: &Covariates,
288        occasion_index: usize,
289    ) -> Self::S {
290        let mut x = Vec::with_capacity(self.nparticles);
291        for _ in 0..self.nparticles {
292            let mut state = DVector::zeros(self.get_nstates());
293            if occasion_index == 0 {
294                (self.init)(
295                    &V::from_vec(support_point.to_vec()),
296                    0.0,
297                    covariates,
298                    &mut state,
299                );
300            }
301            x.push(state);
302        }
303        x
304    }
305}
306
307impl Equation for SDE {
308    /// Estimates the likelihood of observed data given a model and parameters.
309    ///
310    /// # Arguments
311    ///
312    /// * `subject` - Subject data containing observations
313    /// * `support_point` - Parameter vector for the model
314    /// * `error_model` - Error model to use for likelihood calculations
315    /// * `cache` - Whether to cache likelihood results for reuse
316    ///
317    /// # Returns
318    ///
319    /// The log-likelihood of the observed data given the model and parameters.
320    fn estimate_likelihood(
321        &self,
322        subject: &Subject,
323        support_point: &Vec<f64>,
324        error_model: &ErrorModel,
325        cache: bool,
326    ) -> Result<f64, PharmsolError> {
327        if cache {
328            _estimate_likelihood(self, subject, support_point, error_model)
329        } else {
330            _estimate_likelihood_no_cache(self, subject, support_point, error_model)
331        }
332    }
333}
334
335/// Computes a hash value for a parameter vector.
336///
337/// # Arguments
338///
339/// * `spp` - Parameter vector
340///
341/// # Returns
342///
343/// A u64 hash value representing the parameter vector.
344fn spphash(spp: &[f64]) -> u64 {
345    spp.iter().fold(0, |acc, x| acc + x.to_bits())
346}
347
348#[inline(always)]
349#[cached(
350    ty = "UnboundCache<String, f64>",
351    create = "{ UnboundCache::with_capacity(100_000) }",
352    convert = r#"{ format!("{}{}{}", subject.id(), spphash(support_point), error_model.scalar()) }"#,
353    result = "true"
354)]
355fn _estimate_likelihood(
356    sde: &SDE,
357    subject: &Subject,
358    support_point: &Vec<f64>,
359    error_model: &ErrorModel,
360) -> Result<f64, PharmsolError> {
361    let ypred = sde.simulate_subject(subject, support_point, Some(error_model))?;
362    Ok(ypred.1.unwrap())
363}
364
365/// Performs systematic resampling of particles based on weights.
366///
367/// # Arguments
368///
369/// * `q` - Vector of particle weights
370///
371/// # Returns
372///
373/// Vector of indices to use for resampling.
374fn sysresample(q: &[f64]) -> Vec<usize> {
375    let mut qc = vec![0.0; q.len()];
376    qc[0] = q[0];
377    for i in 1..q.len() {
378        qc[i] = qc[i - 1] + q[i];
379    }
380    let m = q.len();
381    let mut rng = rng();
382    let u: Vec<f64> = (0..m)
383        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
384        .collect();
385    let mut i = vec![0; m];
386    let mut k = 0;
387    for j in 0..m {
388        while qc[k] < u[j] {
389            k += 1;
390        }
391        i[j] = k;
392    }
393    i
394}