Skip to main content

pharmsol/simulator/equation/analytical/
mod.rs

1pub mod one_compartment_cl_models;
2pub mod one_compartment_models;
3pub mod three_compartment_cl_models;
4pub mod three_compartment_models;
5pub mod two_compartment_cl_models;
6pub mod two_compartment_models;
7
8use diffsol::{NalgebraContext, Vector, VectorHost};
9pub use one_compartment_cl_models::*;
10pub use one_compartment_models::*;
11pub use three_compartment_cl_models::*;
12pub use three_compartment_models::*;
13pub use two_compartment_cl_models::*;
14pub use two_compartment_models::*;
15
16use super::spphash;
17
18use crate::data::error_model::AssayErrorModels;
19use crate::simulator::cache::{PredictionCache, DEFAULT_CACHE_SIZE};
20use crate::PharmsolError;
21use crate::{
22    data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject,
23};
24
25/// Model equation using analytical solutions.
26///
27/// This implementation uses closed-form analytical solutions for the model
28/// equations rather than numerical integration.
29#[derive(Clone, Debug)]
30pub struct Analytical {
31    eq: AnalyticalEq,
32    seq_eq: SecEq,
33    lag: Lag,
34    fa: Fa,
35    init: Init,
36    out: Out,
37    neqs: Neqs,
38    cache: Option<PredictionCache>,
39}
40
41#[inline(always)]
42pub(crate) fn compact_public_vector(vector: &V) -> V {
43    V::from_vec(
44        vector.as_slice().get(1..).unwrap_or(&[]).to_vec(),
45        NalgebraContext,
46    )
47}
48
49#[inline(always)]
50pub(crate) fn pad_public_vector(vector: &V) -> V {
51    let mut padded = Vec::with_capacity(vector.len() + 1);
52    padded.push(0.0);
53    padded.extend(vector.as_slice().iter().copied());
54    V::from_vec(padded, NalgebraContext)
55}
56
57#[inline(always)]
58pub(crate) fn wrap_pmetrics_analytical(
59    x: &V,
60    p: &V,
61    t: T,
62    rateiv: &V,
63    cov: &Covariates,
64    native: AnalyticalEq,
65) -> V {
66    let compact_x = compact_public_vector(x);
67    let compact_rateiv = compact_public_vector(rateiv);
68    let compact_output = native(&compact_x, p, t, &compact_rateiv, cov);
69    pad_public_vector(&compact_output)
70}
71
72impl Analytical {
73    /// Create a new Analytical equation model with default Neqs (all sizes = 5).
74    ///
75    /// Use builder methods to configure dimensions:
76    /// ```ignore
77    /// Analytical::new(eq, seq_eq, lag, fa, init, out)
78    ///     .with_nstates(2)
79    ///     .with_ndrugs(1)
80    ///     .with_nout(1)
81    /// ```
82    pub fn new(eq: AnalyticalEq, seq_eq: SecEq, lag: Lag, fa: Fa, init: Init, out: Out) -> Self {
83        Self {
84            eq,
85            seq_eq,
86            lag,
87            fa,
88            init,
89            out,
90            neqs: Neqs::default(),
91            cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)),
92        }
93    }
94
95    /// Set the number of state variables.
96    pub fn with_nstates(mut self, nstates: usize) -> Self {
97        self.neqs.nstates = nstates;
98        self
99    }
100
101    /// Set the number of drug input channels (size of bolus[] and rateiv[]).
102    pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
103        self.neqs.ndrugs = ndrugs;
104        self
105    }
106
107    /// Set the number of output equations.
108    pub fn with_nout(mut self, nout: usize) -> Self {
109        self.neqs.nout = nout;
110        self
111    }
112}
113
114impl super::Cache for Analytical {
115    fn with_cache_capacity(mut self, size: u64) -> Self {
116        self.cache = Some(PredictionCache::new(size));
117        self
118    }
119
120    fn enable_cache(mut self) -> Self {
121        self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE));
122        self
123    }
124
125    fn clear_cache(&self) {
126        if let Some(cache) = &self.cache {
127            cache.invalidate_all();
128        }
129    }
130
131    fn disable_cache(mut self) -> Self {
132        self.cache = None;
133        self
134    }
135}
136
137impl EquationTypes for Analytical {
138    type S = V;
139    type P = SubjectPredictions;
140}
141
142impl EquationPriv for Analytical {
143    // #[inline(always)]
144    // fn get_init(&self) -> &Init {
145    //     &self.init
146    // }
147
148    // #[inline(always)]
149    // fn get_out(&self) -> &Out {
150    //     &self.out
151    // }
152
153    // #[inline(always)]
154    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
155    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
156    // }
157
158    // #[inline(always)]
159    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
160    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
161    // }
162
163    #[inline(always)]
164    fn lag(&self) -> &Lag {
165        &self.lag
166    }
167
168    #[inline(always)]
169    fn fa(&self) -> &Fa {
170        &self.fa
171    }
172
173    #[inline(always)]
174    fn get_nstates(&self) -> usize {
175        self.neqs.nstates
176    }
177
178    #[inline(always)]
179    fn get_ndrugs(&self) -> usize {
180        self.neqs.ndrugs
181    }
182
183    #[inline(always)]
184    fn get_nouteqs(&self) -> usize {
185        self.neqs.nout
186    }
187    #[inline(always)]
188    fn solve(
189        &self,
190        x: &mut Self::S,
191        support_point: &[f64],
192        covariates: &Covariates,
193        infusions: &[Infusion],
194        ti: f64,
195        tf: f64,
196    ) -> Result<(), PharmsolError> {
197        if ti == tf {
198            return Ok(());
199        }
200
201        // 1) Build and sort event times
202        let mut ts = Vec::new();
203        ts.push(ti);
204        ts.push(tf);
205        for inf in infusions {
206            let t0 = inf.time();
207            let t1 = t0 + inf.duration();
208            if t0 > ti && t0 < tf {
209                ts.push(t0)
210            }
211            if t1 > ti && t1 < tf {
212                ts.push(t1)
213            }
214        }
215        ts.sort_by(|a, b| a.partial_cmp(b).unwrap());
216        ts.dedup_by(|a, b| (*a - *b).abs() < 1e-12);
217
218        // 2) March over each sub-interval
219        let mut current_t = ts[0];
220        let mut sp = V::from_vec(support_point.to_vec(), NalgebraContext);
221        let mut rateiv = V::zeros(self.get_ndrugs(), NalgebraContext);
222
223        for &next_t in &ts[1..] {
224            // prepare support and infusion rate for [current_t .. next_t]
225            rateiv.fill(0.0);
226            for inf in infusions {
227                let s = inf.time();
228                let e = s + inf.duration();
229                if current_t >= s && next_t <= e {
230                    if inf.input() >= self.get_ndrugs() {
231                        return Err(PharmsolError::InputOutOfRange {
232                            input: inf.input(),
233                            ndrugs: self.get_ndrugs(),
234                        });
235                    }
236                    rateiv[inf.input()] += inf.amount() / inf.duration();
237                }
238            }
239
240            // advance the support-point to next_t
241            (self.seq_eq)(&mut sp, next_t, covariates);
242
243            // advance state by dt
244            let dt = next_t - current_t;
245            *x = (self.eq)(x, &sp, dt, &rateiv, covariates);
246
247            current_t = next_t;
248        }
249
250        Ok(())
251    }
252
253    #[inline(always)]
254    fn process_observation(
255        &self,
256        support_point: &[f64],
257        observation: &Observation,
258        error_models: Option<&AssayErrorModels>,
259        _time: f64,
260        covariates: &Covariates,
261        x: &mut Self::S,
262        likelihood: &mut Vec<f64>,
263        output: &mut Self::P,
264    ) -> Result<(), PharmsolError> {
265        let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
266        let out = &self.out;
267        (out)(
268            x,
269            &V::from_vec(support_point.to_vec(), NalgebraContext),
270            observation.time(),
271            covariates,
272            &mut y,
273        );
274        let pred = y[observation.outeq()];
275        let pred = observation.to_prediction(pred, x.as_slice().to_vec());
276        if let Some(error_models) = error_models {
277            likelihood.push(pred.log_likelihood(error_models)?.exp());
278        }
279        output.add_prediction(pred);
280        Ok(())
281    }
282    #[inline(always)]
283    fn initial_state(&self, spp: &[f64], covariates: &Covariates, occasion_index: usize) -> V {
284        let init = &self.init;
285        let mut x = V::zeros(self.get_nstates(), NalgebraContext);
286        if occasion_index == 0 {
287            (init)(
288                &V::from_vec(spp.to_vec(), NalgebraContext),
289                0.0,
290                covariates,
291                &mut x,
292            );
293        }
294        x
295    }
296}
297
298#[allow(clippy::items_after_test_module)]
299#[cfg(test)]
300pub(crate) mod tests {
301    use super::*;
302    use crate::SubjectBuilderExt;
303    use approx::assert_relative_eq;
304    use diffsol::Vector;
305    use std::collections::HashMap;
306
307    pub(crate) enum SubjectInfo {
308        InfusionDosing,
309        OralInfusionDosage,
310    }
311    impl SubjectInfo {
312        pub(crate) fn get_subject(&self) -> Subject {
313            match self {
314                SubjectInfo::InfusionDosing => Subject::builder("id1")
315                    .bolus(0.0, 100.0, 0)
316                    .infusion(24.0, 150.0, 0, 3.0)
317                    .missing_observation(0.0, 0)
318                    .missing_observation(1.0, 0)
319                    .missing_observation(2.0, 0)
320                    .missing_observation(4.0, 0)
321                    .missing_observation(8.0, 0)
322                    .missing_observation(12.0, 0)
323                    .missing_observation(24.0, 0)
324                    .missing_observation(25.0, 0)
325                    .missing_observation(26.0, 0)
326                    .missing_observation(27.0, 0)
327                    .missing_observation(28.0, 0)
328                    .missing_observation(32.0, 0)
329                    .missing_observation(36.0, 0)
330                    .build(),
331
332                SubjectInfo::OralInfusionDosage => Subject::builder("id1")
333                    .bolus(0.0, 100.0, 1)
334                    .infusion(24.0, 150.0, 0, 3.0)
335                    .bolus(48.0, 100.0, 0)
336                    .missing_observation(0.0, 0)
337                    .missing_observation(1.0, 0)
338                    .missing_observation(2.0, 0)
339                    .missing_observation(4.0, 0)
340                    .missing_observation(8.0, 0)
341                    .missing_observation(12.0, 0)
342                    .missing_observation(24.0, 0)
343                    .missing_observation(25.0, 0)
344                    .missing_observation(26.0, 0)
345                    .missing_observation(27.0, 0)
346                    .missing_observation(28.0, 0)
347                    .missing_observation(32.0, 0)
348                    .missing_observation(36.0, 0)
349                    .missing_observation(48.0, 0)
350                    .missing_observation(49.0, 0)
351                    .missing_observation(50.0, 0)
352                    .missing_observation(52.0, 0)
353                    .missing_observation(56.0, 0)
354                    .missing_observation(60.0, 0)
355                    .build(),
356            }
357        }
358    }
359
360    #[test]
361    fn secondary_equations_accumulate_within_single_solve() {
362        let eq = |x: &V, p: &V, dt: f64, _rateiv: &V, _cov: &Covariates| {
363            let mut next = x.clone();
364            next[0] += p[0] * dt;
365            next
366        };
367        let seq_eq = |params: &mut V, _t: f64, _cov: &Covariates| {
368            params[0] += 1.0;
369        };
370        let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
371        let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
372        let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| {
373            x.fill(0.0);
374        };
375        let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| {
376            y[0] = x[0];
377        };
378
379        let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out)
380            .with_nstates(1)
381            .with_ndrugs(1)
382            .with_nout(1);
383        let subject = Subject::builder("seq")
384            .bolus(0.0, 0.0, 0)
385            .infusion(0.25, 1.0, 0, 0.25)
386            .observation(1.0, 0.0, 0)
387            .build();
388
389        let predictions = analytical.estimate_predictions(&subject, &[1.0]).unwrap();
390
391        let value = predictions.predictions()[0].prediction();
392        assert!((value - 2.5).abs() < 1e-12);
393    }
394
395    #[test]
396    fn infusion_inputs_match_state_dimension() {
397        let eq = |x: &V, _p: &V, dt: f64, rateiv: &V, _cov: &Covariates| {
398            let mut next = x.clone();
399            next[0] += rateiv[3] * dt;
400            next
401        };
402        let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {};
403        let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
404        let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
405        let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| {
406            x.fill(0.0);
407        };
408        let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| {
409            y[0] = x[0];
410        };
411
412        let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out)
413            .with_nstates(4)
414            .with_ndrugs(4)
415            .with_nout(1);
416        let subject = Subject::builder("inf")
417            .infusion(0.0, 4.0, 3, 1.0)
418            .observation(1.0, 0.0, 0)
419            .build();
420
421        let predictions = analytical.estimate_predictions(&subject, &[0.0]).unwrap();
422
423        assert_eq!(predictions.predictions()[0].prediction(), 4.0);
424    }
425
426    fn assert_pm_wrapper_matches_native(
427        native: AnalyticalEq,
428        wrapper: AnalyticalEq,
429        compact_x: Vec<f64>,
430        params: Vec<f64>,
431        compact_rateiv: Vec<f64>,
432    ) {
433        let covariates = Covariates::new();
434        let compact_x = V::from_vec(compact_x, NalgebraContext);
435        let params = V::from_vec(params, NalgebraContext);
436        let compact_rateiv = V::from_vec(compact_rateiv, NalgebraContext);
437
438        let mut padded_x = vec![1234.0];
439        padded_x.extend(compact_x.as_slice().iter().copied());
440        let padded_x = V::from_vec(padded_x, NalgebraContext);
441
442        let mut padded_rateiv = vec![5678.0];
443        padded_rateiv.extend(compact_rateiv.as_slice().iter().copied());
444        let padded_rateiv = V::from_vec(padded_rateiv, NalgebraContext);
445
446        let native_output = native(&compact_x, &params, 1.5, &compact_rateiv, &covariates);
447        let wrapped_output = wrapper(&padded_x, &params, 1.5, &padded_rateiv, &covariates);
448
449        assert_eq!(wrapped_output[0], 0.0);
450        assert_eq!(wrapped_output.len(), native_output.len() + 1);
451
452        for (wrapped, native) in wrapped_output
453            .as_slice()
454            .iter()
455            .skip(1)
456            .zip(native_output.as_slice().iter())
457        {
458            assert_relative_eq!(*wrapped, *native, max_relative = 1e-10, epsilon = 1e-10);
459        }
460    }
461
462    #[test]
463    fn pmetrics_wrappers_match_native_helpers() {
464        assert_pm_wrapper_matches_native(
465            one_compartment,
466            pm_one_compartment,
467            vec![100.0],
468            vec![0.2],
469            vec![5.0],
470        );
471        assert_pm_wrapper_matches_native(
472            one_compartment_cl,
473            pm_one_compartment_cl,
474            vec![100.0],
475            vec![0.2, 2.0],
476            vec![5.0],
477        );
478        assert_pm_wrapper_matches_native(
479            one_compartment_with_absorption,
480            pm_one_compartment_with_absorption,
481            vec![10.0, 20.0],
482            vec![1.1, 0.2],
483            vec![5.0],
484        );
485        assert_pm_wrapper_matches_native(
486            one_compartment_cl_with_absorption,
487            pm_one_compartment_cl_with_absorption,
488            vec![10.0, 20.0],
489            vec![1.1, 0.2, 2.0],
490            vec![5.0],
491        );
492        assert_pm_wrapper_matches_native(
493            two_compartments,
494            pm_two_compartments,
495            vec![100.0, 40.0],
496            vec![0.1, 0.3, 0.2],
497            vec![3.0],
498        );
499        assert_pm_wrapper_matches_native(
500            two_compartments_cl,
501            pm_two_compartments_cl,
502            vec![100.0, 40.0],
503            vec![0.1, 0.3, 1.0, 2.0],
504            vec![3.0],
505        );
506        assert_pm_wrapper_matches_native(
507            two_compartments_with_absorption,
508            pm_two_compartments_with_absorption,
509            vec![10.0, 100.0, 40.0],
510            vec![0.1, 1.0, 0.3, 0.2],
511            vec![3.0],
512        );
513        assert_pm_wrapper_matches_native(
514            two_compartments_cl_with_absorption,
515            pm_two_compartments_cl_with_absorption,
516            vec![10.0, 100.0, 40.0],
517            vec![1.0, 0.1, 0.3, 1.0, 2.0],
518            vec![3.0],
519        );
520        assert_pm_wrapper_matches_native(
521            three_compartments,
522            pm_three_compartments,
523            vec![100.0, 40.0, 20.0],
524            vec![0.1, 3.0, 2.0, 1.0, 0.5],
525            vec![2.0],
526        );
527        assert_pm_wrapper_matches_native(
528            three_compartments_cl,
529            pm_three_compartments_cl,
530            vec![100.0, 40.0, 20.0],
531            vec![0.1, 3.0, 2.0, 1.0, 3.0, 4.0],
532            vec![2.0],
533        );
534        assert_pm_wrapper_matches_native(
535            three_compartments_with_absorption,
536            pm_three_compartments_with_absorption,
537            vec![10.0, 100.0, 40.0, 20.0],
538            vec![1.0, 0.1, 3.0, 2.0, 1.0, 0.5],
539            vec![2.0],
540        );
541        assert_pm_wrapper_matches_native(
542            three_compartments_cl_with_absorption,
543            pm_three_compartments_cl_with_absorption,
544            vec![10.0, 100.0, 40.0, 20.0],
545            vec![1.0, 0.1, 3.0, 2.0, 1.0, 3.0, 4.0],
546            vec![2.0],
547        );
548    }
549}
550impl Equation for Analytical {
551    fn estimate_likelihood(
552        &self,
553        subject: &Subject,
554        support_point: &[f64],
555        error_models: &AssayErrorModels,
556    ) -> Result<f64, PharmsolError> {
557        _estimate_likelihood(self, subject, support_point, error_models)
558    }
559
560    fn estimate_log_likelihood(
561        &self,
562        subject: &Subject,
563        support_point: &[f64],
564        error_models: &AssayErrorModels,
565    ) -> Result<f64, PharmsolError> {
566        let ypred = _subject_predictions(self, subject, support_point)?;
567        ypred.log_likelihood(error_models)
568    }
569
570    fn kind() -> crate::EqnKind {
571        crate::EqnKind::Analytical
572    }
573}
574
575#[inline(always)]
576fn _subject_predictions(
577    analytical: &Analytical,
578    subject: &Subject,
579    support_point: &[f64],
580) -> Result<SubjectPredictions, PharmsolError> {
581    if let Some(cache) = &analytical.cache {
582        let key = (subject.hash(), spphash(support_point));
583        if let Some(cached) = cache.get(&key) {
584            return Ok(cached);
585        }
586
587        let result = analytical.simulate_subject(subject, support_point, None)?.0;
588        cache.insert(key, result.clone());
589        Ok(result)
590    } else {
591        Ok(analytical.simulate_subject(subject, support_point, None)?.0)
592    }
593}
594
595fn _estimate_likelihood(
596    ode: &Analytical,
597    subject: &Subject,
598    support_point: &[f64],
599    error_models: &AssayErrorModels,
600) -> Result<f64, PharmsolError> {
601    let ypred = _subject_predictions(ode, subject, support_point)?;
602    Ok(ypred.log_likelihood(error_models)?.exp())
603}