pharmsol/simulator/equation/analytical/
mod.rs

1pub mod one_compartment_models;
2pub mod three_compartment_models;
3pub mod two_compartment_models;
4
5use diffsol::{NalgebraContext, Vector, VectorHost};
6pub use one_compartment_models::*;
7pub use three_compartment_models::*;
8pub use two_compartment_models::*;
9
10use crate::PharmsolError;
11use crate::{
12    data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject,
13};
14use cached::proc_macro::cached;
15use cached::UnboundCache;
16
17/// Model equation using analytical solutions.
18///
19/// This implementation uses closed-form analytical solutions for the model
20/// equations rather than numerical integration.
21#[repr(C)]
22#[derive(Clone, Debug)]
23pub struct Analytical {
24    eq: AnalyticalEq,
25    seq_eq: SecEq,
26    lag: Lag,
27    fa: Fa,
28    init: Init,
29    out: Out,
30    neqs: Neqs,
31}
32
33impl Analytical {
34    /// Create a new Analytical equation model.
35    ///
36    /// # Parameters
37    /// - `eq`: The analytical equation function
38    /// - `seq_eq`: The secondary equation function
39    /// - `lag`: The lag time function
40    /// - `fa`: The fraction absorbed function
41    /// - `init`: The initial state function
42    /// - `out`: The output equation function
43    /// - `neqs`: The number of states and output equations
44    pub fn new(
45        eq: AnalyticalEq,
46        seq_eq: SecEq,
47        lag: Lag,
48        fa: Fa,
49        init: Init,
50        out: Out,
51        neqs: Neqs,
52    ) -> Self {
53        Self {
54            eq,
55            seq_eq,
56            lag,
57            fa,
58            init,
59            out,
60            neqs,
61        }
62    }
63}
64
65impl EquationTypes for Analytical {
66    type S = V;
67    type P = SubjectPredictions;
68}
69
70impl EquationPriv for Analytical {
71    // #[inline(always)]
72    // fn get_init(&self) -> &Init {
73    //     &self.init
74    // }
75
76    // #[inline(always)]
77    // fn get_out(&self) -> &Out {
78    //     &self.out
79    // }
80
81    // #[inline(always)]
82    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
83    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
84    // }
85
86    // #[inline(always)]
87    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
88    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
89    // }
90
91    #[inline(always)]
92    fn lag(&self) -> &Lag {
93        &self.lag
94    }
95
96    #[inline(always)]
97    fn fa(&self) -> &Fa {
98        &self.fa
99    }
100
101    #[inline(always)]
102    fn get_nstates(&self) -> usize {
103        self.neqs.0
104    }
105
106    #[inline(always)]
107    fn get_nouteqs(&self) -> usize {
108        self.neqs.1
109    }
110    #[inline(always)]
111    fn solve(
112        &self,
113        x: &mut Self::S,
114        support_point: &Vec<f64>,
115        covariates: &Covariates,
116        infusions: &Vec<Infusion>,
117        ti: f64,
118        tf: f64,
119    ) -> Result<(), PharmsolError> {
120        if ti == tf {
121            return Ok(());
122        }
123
124        // 1) Build and sort event times
125        let mut ts = Vec::new();
126        ts.push(ti);
127        ts.push(tf);
128        for inf in infusions {
129            let t0 = inf.time();
130            let t1 = t0 + inf.duration();
131            if t0 > ti && t0 < tf {
132                ts.push(t0)
133            }
134            if t1 > ti && t1 < tf {
135                ts.push(t1)
136            }
137        }
138        ts.sort_by(|a, b| a.partial_cmp(b).unwrap());
139        ts.dedup_by(|a, b| (*a - *b).abs() < 1e-12);
140
141        // 2) March over each sub-interval
142        let mut current_t = ts[0];
143        for &next_t in &ts[1..] {
144            // prepare support and infusion rate for [current_t .. next_t]
145            let mut sp = V::from_vec(support_point.to_owned(), NalgebraContext);
146            let mut rateiv = V::from_vec(vec![0.0; 3], NalgebraContext);
147            for inf in infusions {
148                let s = inf.time();
149                let e = s + inf.duration();
150                if current_t >= s && next_t <= e {
151                    rateiv[inf.input()] += inf.amount() / inf.duration();
152                }
153            }
154
155            // advance the support-point to next_t
156            (self.seq_eq)(&mut sp, next_t, covariates);
157
158            // advance state by dt
159            let dt = next_t - current_t;
160            *x = (self.eq)(x, &sp, dt, rateiv, covariates);
161
162            current_t = next_t;
163        }
164
165        Ok(())
166    }
167
168    #[inline(always)]
169    fn process_observation(
170        &self,
171        support_point: &Vec<f64>,
172        observation: &Observation,
173        error_models: Option<&ErrorModels>,
174        _time: f64,
175        covariates: &Covariates,
176        x: &mut Self::S,
177        likelihood: &mut Vec<f64>,
178        output: &mut Self::P,
179    ) -> Result<(), PharmsolError> {
180        let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
181        let out = &self.out;
182        (out)(
183            x,
184            &V::from_vec(support_point.clone(), NalgebraContext),
185            observation.time(),
186            covariates,
187            &mut y,
188        );
189        let pred = y[observation.outeq()];
190        let pred = observation.to_prediction(pred, x.as_slice().to_vec());
191        if let Some(error_models) = error_models {
192            likelihood.push(pred.likelihood(error_models)?);
193        }
194        output.add_prediction(pred);
195        Ok(())
196    }
197    #[inline(always)]
198    fn initial_state(&self, spp: &Vec<f64>, covariates: &Covariates, occasion_index: usize) -> V {
199        let init = &self.init;
200        let mut x = V::zeros(self.get_nstates(), NalgebraContext);
201        if occasion_index == 0 {
202            (init)(
203                &V::from_vec(spp.to_vec(), NalgebraContext),
204                0.0,
205                covariates,
206                &mut x,
207            );
208        }
209        x
210    }
211}
212
213impl Equation for Analytical {
214    fn estimate_likelihood(
215        &self,
216        subject: &Subject,
217        support_point: &Vec<f64>,
218        error_models: &ErrorModels,
219        cache: bool,
220    ) -> Result<f64, PharmsolError> {
221        _estimate_likelihood(self, subject, support_point, error_models, cache)
222    }
223    fn kind() -> crate::EqnKind {
224        crate::EqnKind::Analytical
225    }
226}
227fn spphash(spp: &[f64]) -> u64 {
228    spp.iter().fold(0, |acc, x| acc + x.to_bits())
229}
230
231#[inline(always)]
232#[cached(
233    ty = "UnboundCache<String, SubjectPredictions>",
234    create = "{ UnboundCache::with_capacity(100_000) }",
235    convert = r#"{ format!("{}{}", subject.id(), spphash(support_point)) }"#,
236    result = "true"
237)]
238fn _subject_predictions(
239    ode: &Analytical,
240    subject: &Subject,
241    support_point: &Vec<f64>,
242) -> Result<SubjectPredictions, PharmsolError> {
243    Ok(ode.simulate_subject(subject, support_point, None)?.0)
244}
245
246fn _estimate_likelihood(
247    ode: &Analytical,
248    subject: &Subject,
249    support_point: &Vec<f64>,
250    error_models: &ErrorModels,
251    cache: bool,
252) -> Result<f64, PharmsolError> {
253    let ypred = if cache {
254        _subject_predictions(ode, subject, support_point)
255    } else {
256        _subject_predictions_no_cache(ode, subject, support_point)
257    }?;
258    ypred.likelihood(error_models)
259}