pharmsol/simulator/equation/sde/
mod.rs

1mod em;
2
3use diffsol::{NalgebraContext, Vector};
4use nalgebra::DVector;
5use ndarray::{concatenate, Array2, Axis};
6use rand::{rng, RngExt};
7use rayon::prelude::*;
8
9use crate::{
10    data::{Covariates, Infusion},
11    error_model::AssayErrorModels,
12    prelude::simulator::Prediction,
13    simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V},
14    Subject,
15};
16
17use super::id_hash;
18use super::spphash;
19use crate::simulator::cache::{cache_enabled, sde_cache_lock_read};
20
21use diffsol::VectorCommon;
22
23use crate::PharmsolError;
24
25use super::{Equation, EquationPriv, EquationTypes, Predictions, State};
26
27/// Simulate a stochastic differential equation (SDE) event.
28///
29/// This function advances the SDE system from time `ti` to `tf` using
30/// the Euler-Maruyama method implemented in the `em` module.
31///
32/// # Arguments
33///
34/// * `drift` - Function defining the deterministic component of the SDE
35/// * `difussion` - Function defining the stochastic component of the SDE
36/// * `x` - Current state vector
37/// * `support_point` - Parameter vector for the model
38/// * `cov` - Covariates that may influence the system dynamics
39/// * `infusions` - Infusion events to be applied during simulation
40/// * `ti` - Starting time
41/// * `tf` - Ending time
42///
43/// # Returns
44///
45/// The state vector at time `tf` after simulation.
46#[inline(always)]
47pub(crate) fn simulate_sde_event(
48    drift: &Drift,
49    difussion: &Diffusion,
50    x: V,
51    support_point: &[f64],
52    cov: &Covariates,
53    infusions: &[Infusion],
54    ndrugs: usize,
55    ti: f64,
56    tf: f64,
57) -> V {
58    if ti == tf {
59        return x;
60    }
61
62    let mut sde = em::EM::new(
63        *drift,
64        *difussion,
65        DVector::from_column_slice(support_point),
66        x.inner().clone(),
67        cov.clone(),
68        infusions.to_vec(),
69        1e-2,
70        1e-2,
71        ndrugs,
72    );
73    let (_time, solution) = sde.solve(ti, tf);
74    solution.last().unwrap().clone().into()
75}
76
77/// Stochastic Differential Equation solver for pharmacometric models.
78///
79/// This struct represents a stochastic differential equation system and provides
80/// methods to simulate particles and estimate likelihood for PKPD modeling.
81///
82/// SDE models introduce stochasticity into the system dynamics, allowing for more
83/// realistic modeling of biological variability and uncertainty.
84#[derive(Clone, Debug)]
85pub struct SDE {
86    drift: Drift,
87    diffusion: Diffusion,
88    lag: Lag,
89    fa: Fa,
90    init: Init,
91    out: Out,
92    neqs: Neqs,
93    nparticles: usize,
94}
95
96impl SDE {
97    /// Creates a new stochastic differential equation solver with default Neqs.
98    ///
99    /// Use builder methods to configure dimensions:
100    /// ```ignore
101    /// SDE::new(drift, diffusion, lag, fa, init, out, nparticles)
102    ///     .with_nstates(2)
103    ///     .with_ndrugs(1)
104    ///     .with_nout(1)
105    /// ```
106    pub fn new(
107        drift: Drift,
108        diffusion: Diffusion,
109        lag: Lag,
110        fa: Fa,
111        init: Init,
112        out: Out,
113        nparticles: usize,
114    ) -> Self {
115        Self {
116            drift,
117            diffusion,
118            lag,
119            fa,
120            init,
121            out,
122            neqs: Neqs::default(),
123            nparticles,
124        }
125    }
126
127    /// Set the number of state variables.
128    pub fn with_nstates(mut self, nstates: usize) -> Self {
129        self.neqs.nstates = nstates;
130        self
131    }
132
133    /// Set the number of drug input channels (size of bolus[] and rateiv[]).
134    pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
135        self.neqs.ndrugs = ndrugs;
136        self
137    }
138
139    /// Set the number of output equations.
140    pub fn with_nout(mut self, nout: usize) -> Self {
141        self.neqs.nout = nout;
142        self
143    }
144}
145
146/// State trait implementation for particle-based SDE simulation.
147///
148/// This implementation allows adding bolus doses to all particles in the system.
149impl State for Vec<DVector<f64>> {
150    /// Adds a bolus dose to a specific input compartment across all particles.
151    ///
152    /// # Arguments
153    ///
154    /// * `input` - Index of the input compartment
155    /// * `amount` - Amount to add to the compartment
156    fn add_bolus(&mut self, input: usize, amount: f64) {
157        self.par_iter_mut().for_each(|particle| {
158            particle[input] += amount;
159        });
160    }
161}
162
163/// Predictions implementation for particle-based SDE simulation outputs.
164///
165/// This implementation manages and processes predictions from multiple particles.
166impl Predictions for Array2<Prediction> {
167    fn new(nparticles: usize) -> Self {
168        Array2::from_shape_fn((nparticles, 0), |_| Prediction::default())
169    }
170    fn squared_error(&self) -> f64 {
171        unimplemented!();
172    }
173    fn get_predictions(&self) -> Vec<Prediction> {
174        // Make this return the mean prediction across all particles
175        if self.is_empty() || self.ncols() == 0 {
176            return Vec::new();
177        }
178
179        let mut result = Vec::with_capacity(self.ncols());
180
181        for col in 0..self.ncols() {
182            let column = self.column(col);
183
184            let mean_prediction: f64 = column
185                .iter()
186                .map(|pred: &Prediction| pred.prediction())
187                .sum::<f64>()
188                / self.nrows() as f64;
189
190            let mut prediction = column.first().unwrap().clone();
191            prediction.set_prediction(mean_prediction);
192            result.push(prediction);
193        }
194
195        result
196    }
197    fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result<f64, crate::PharmsolError> {
198        // For SDE, compute log-likelihood using mean predictions across particles
199        let predictions = self.get_predictions();
200        if predictions.is_empty() {
201            return Ok(0.0);
202        }
203
204        let log_liks: Result<Vec<f64>, _> = predictions
205            .iter()
206            .filter(|p| p.observation().is_some())
207            .map(|p| p.log_likelihood(error_models))
208            .collect();
209
210        log_liks.map(|lls| lls.iter().sum())
211    }
212}
213
214impl EquationTypes for SDE {
215    type S = Vec<DVector<f64>>; // Vec -> particles, DVector -> state
216    type P = Array2<Prediction>; // Rows -> particles, Columns -> time
217}
218
219impl EquationPriv for SDE {
220    // #[inline(always)]
221    // fn get_init(&self) -> &Init {
222    //     &self.init
223    // }
224
225    // #[inline(always)]
226    // fn get_out(&self) -> &Out {
227    //     &self.out
228    // }
229
230    // #[inline(always)]
231    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
232    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
233    // }
234
235    // #[inline(always)]
236    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
237    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
238    // }
239
240    #[inline(always)]
241    fn lag(&self) -> &Lag {
242        &self.lag
243    }
244
245    #[inline(always)]
246    fn fa(&self) -> &Fa {
247        &self.fa
248    }
249
250    #[inline(always)]
251    fn get_nstates(&self) -> usize {
252        self.neqs.nstates
253    }
254
255    #[inline(always)]
256    fn get_ndrugs(&self) -> usize {
257        self.neqs.ndrugs
258    }
259
260    #[inline(always)]
261    fn get_nouteqs(&self) -> usize {
262        self.neqs.nout
263    }
264    #[inline(always)]
265    fn solve(
266        &self,
267        state: &mut Self::S,
268        support_point: &Vec<f64>,
269        covariates: &Covariates,
270        infusions: &Vec<Infusion>,
271        ti: f64,
272        tf: f64,
273    ) -> Result<(), PharmsolError> {
274        let ndrugs = self.get_ndrugs();
275        state.par_iter_mut().for_each(|particle| {
276            *particle = simulate_sde_event(
277                &self.drift,
278                &self.diffusion,
279                particle.clone().into(),
280                support_point,
281                covariates,
282                infusions,
283                ndrugs,
284                ti,
285                tf,
286            )
287            .inner()
288            .clone();
289        });
290        Ok(())
291    }
292    fn nparticles(&self) -> usize {
293        self.nparticles
294    }
295
296    fn is_sde(&self) -> bool {
297        true
298    }
299    #[inline(always)]
300    fn process_observation(
301        &self,
302        support_point: &Vec<f64>,
303        observation: &crate::Observation,
304        error_models: Option<&AssayErrorModels>,
305        _time: f64,
306        covariates: &Covariates,
307        x: &mut Self::S,
308        likelihood: &mut Vec<f64>,
309        output: &mut Self::P,
310    ) -> Result<(), PharmsolError> {
311        let mut pred = vec![Prediction::default(); self.nparticles];
312
313        pred.par_iter_mut().enumerate().for_each(|(i, p)| {
314            let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
315            (self.out)(
316                &x[i].clone().into(),
317                &V::from_vec(support_point.clone(), NalgebraContext),
318                observation.time(),
319                covariates,
320                &mut y,
321            );
322            *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec());
323        });
324        let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?;
325        *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap();
326        //e = y[t] .- x[:,1]
327        // q = pdf.(Distributions.Normal(0, 0.5), e)
328        if let Some(em) = error_models {
329            let mut q: Vec<f64> = Vec::with_capacity(self.nparticles);
330
331            pred.iter().for_each(|p| {
332                let lik = p.log_likelihood(em).map(f64::exp);
333                match lik {
334                    Ok(l) => q.push(l),
335                    Err(e) => panic!("Error in likelihood calculation: {:?}", e),
336                }
337            });
338            let sum_q: f64 = q.iter().sum();
339            let w: Vec<f64> = q.iter().map(|qi| qi / sum_q).collect();
340            let i = sysresample(&w);
341            let a: Vec<DVector<f64>> = i.iter().map(|&i| x[i].clone()).collect();
342            *x = a;
343            likelihood.push(sum_q / self.nparticles as f64);
344            // let qq: Vec<f64> = i.iter().map(|&i| q[i]).collect();
345            // likelihood.push(qq.iter().sum::<f64>() / self.nparticles as f64);
346        }
347        Ok(())
348    }
349    #[inline(always)]
350    fn initial_state(
351        &self,
352        support_point: &Vec<f64>,
353        covariates: &Covariates,
354        occasion_index: usize,
355    ) -> Self::S {
356        let mut x = Vec::with_capacity(self.nparticles);
357        for _ in 0..self.nparticles {
358            let mut state: V = DVector::zeros(self.get_nstates()).into();
359            if occasion_index == 0 {
360                (self.init)(
361                    &V::from_vec(support_point.to_vec(), NalgebraContext),
362                    0.0,
363                    covariates,
364                    &mut state,
365                );
366            }
367            x.push(state.inner().clone());
368        }
369        x
370    }
371}
372
373impl Equation for SDE {
374    /// Estimates the likelihood of observed data given a model and parameters.
375    ///
376    /// # Arguments
377    ///
378    /// * `subject` - Subject data containing observations
379    /// * `support_point` - Parameter vector for the model
380    /// * `error_model` - Error model to use for likelihood calculations
381    ///
382    /// # Returns
383    ///
384    /// The log-likelihood of the observed data given the model and parameters.
385    fn estimate_likelihood(
386        &self,
387        subject: &Subject,
388        support_point: &Vec<f64>,
389        error_models: &AssayErrorModels,
390    ) -> Result<f64, PharmsolError> {
391        _estimate_likelihood(self, subject, support_point, error_models)
392    }
393
394    fn estimate_log_likelihood(
395        &self,
396        subject: &Subject,
397        support_point: &Vec<f64>,
398        error_models: &AssayErrorModels,
399    ) -> Result<f64, PharmsolError> {
400        // For SDE, the particle filter computes likelihood in regular space.
401        // We compute it directly and then take the log.
402        let lik = _estimate_likelihood(self, subject, support_point, error_models)?;
403
404        if lik > 0.0 {
405            Ok(lik.ln())
406        } else {
407            Ok(f64::NEG_INFINITY)
408        }
409    }
410
411    fn kind() -> crate::EqnKind {
412        crate::EqnKind::SDE
413    }
414}
415
416#[inline(always)]
417fn _estimate_likelihood(
418    sde: &SDE,
419    subject: &Subject,
420    support_point: &Vec<f64>,
421    error_models: &AssayErrorModels,
422) -> Result<f64, PharmsolError> {
423    if cache_enabled() {
424        let key = (
425            id_hash(subject.id()),
426            spphash(support_point),
427            error_models.hash(),
428        );
429        let cache_guard = sde_cache_lock_read()?;
430        if let Some(cached) = cache_guard.get(&key) {
431            return Ok(cached);
432        }
433        drop(cache_guard);
434
435        let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
436        let result = ypred.1.unwrap();
437        let cache_guard = sde_cache_lock_read()?;
438        cache_guard.insert(key, result);
439        Ok(result)
440    } else {
441        let ypred = sde.simulate_subject(subject, support_point, Some(error_models))?;
442        Ok(ypred.1.unwrap())
443    }
444}
445
446/// Performs systematic resampling of particles based on weights.
447///
448/// # Arguments
449///
450/// * `q` - Vector of particle weights
451///
452/// # Returns
453///
454/// Vector of indices to use for resampling.
455fn sysresample(q: &[f64]) -> Vec<usize> {
456    let mut qc = vec![0.0; q.len()];
457    qc[0] = q[0];
458    for i in 1..q.len() {
459        qc[i] = qc[i - 1] + q[i];
460    }
461    let m = q.len();
462    let mut rng = rng();
463    let u: Vec<f64> = (0..m)
464        .map(|i| (i as f64 + rng.random::<f64>()) / m as f64)
465        .collect();
466    let mut i = vec![0; m];
467    let mut k = 0;
468    for j in 0..m {
469        while qc[k] < u[j] {
470            k += 1;
471        }
472        i[j] = k;
473    }
474    i
475}