pharmsol/simulator/equation/analytical/
mod.rs

1pub mod one_compartment_cl_models;
2pub mod one_compartment_models;
3pub mod three_compartment_cl_models;
4pub mod three_compartment_models;
5pub mod two_compartment_cl_models;
6pub mod two_compartment_models;
7
8use diffsol::{NalgebraContext, Vector, VectorHost};
9pub use one_compartment_cl_models::*;
10pub use one_compartment_models::*;
11pub use three_compartment_cl_models::*;
12pub use three_compartment_models::*;
13pub use two_compartment_cl_models::*;
14pub use two_compartment_models::*;
15
16use super::id_hash;
17use super::spphash;
18
19use crate::data::error_model::AssayErrorModels;
20use crate::simulator::cache::{ana_cache_lock_read, cache_enabled};
21use crate::PharmsolError;
22use crate::{
23    data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject,
24};
25
26/// Model equation using analytical solutions.
27///
28/// This implementation uses closed-form analytical solutions for the model
29/// equations rather than numerical integration.
30#[repr(C)]
31#[derive(Clone, Debug)]
32pub struct Analytical {
33    eq: AnalyticalEq,
34    seq_eq: SecEq,
35    lag: Lag,
36    fa: Fa,
37    init: Init,
38    out: Out,
39    neqs: Neqs,
40}
41
42impl Analytical {
43    /// Create a new Analytical equation model with default Neqs (all sizes = 5).
44    ///
45    /// Use builder methods to configure dimensions:
46    /// ```ignore
47    /// Analytical::new(eq, seq_eq, lag, fa, init, out)
48    ///     .with_nstates(2)
49    ///     .with_ndrugs(1)
50    ///     .with_nout(1)
51    /// ```
52    pub fn new(eq: AnalyticalEq, seq_eq: SecEq, lag: Lag, fa: Fa, init: Init, out: Out) -> Self {
53        Self {
54            eq,
55            seq_eq,
56            lag,
57            fa,
58            init,
59            out,
60            neqs: Neqs::default(),
61        }
62    }
63
64    /// Set the number of state variables.
65    pub fn with_nstates(mut self, nstates: usize) -> Self {
66        self.neqs.nstates = nstates;
67        self
68    }
69
70    /// Set the number of drug input channels (size of bolus[] and rateiv[]).
71    pub fn with_ndrugs(mut self, ndrugs: usize) -> Self {
72        self.neqs.ndrugs = ndrugs;
73        self
74    }
75
76    /// Set the number of output equations.
77    pub fn with_nout(mut self, nout: usize) -> Self {
78        self.neqs.nout = nout;
79        self
80    }
81}
82
83impl EquationTypes for Analytical {
84    type S = V;
85    type P = SubjectPredictions;
86}
87
88impl EquationPriv for Analytical {
89    // #[inline(always)]
90    // fn get_init(&self) -> &Init {
91    //     &self.init
92    // }
93
94    // #[inline(always)]
95    // fn get_out(&self) -> &Out {
96    //     &self.out
97    // }
98
99    // #[inline(always)]
100    // fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
101    //     Some((self.lag)(&V::from_vec(spp.to_owned())))
102    // }
103
104    // #[inline(always)]
105    // fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
106    //     Some((self.fa)(&V::from_vec(spp.to_owned())))
107    // }
108
109    #[inline(always)]
110    fn lag(&self) -> &Lag {
111        &self.lag
112    }
113
114    #[inline(always)]
115    fn fa(&self) -> &Fa {
116        &self.fa
117    }
118
119    #[inline(always)]
120    fn get_nstates(&self) -> usize {
121        self.neqs.nstates
122    }
123
124    #[inline(always)]
125    fn get_ndrugs(&self) -> usize {
126        self.neqs.ndrugs
127    }
128
129    #[inline(always)]
130    fn get_nouteqs(&self) -> usize {
131        self.neqs.nout
132    }
133    #[inline(always)]
134    fn solve(
135        &self,
136        x: &mut Self::S,
137        support_point: &Vec<f64>,
138        covariates: &Covariates,
139        infusions: &Vec<Infusion>,
140        ti: f64,
141        tf: f64,
142    ) -> Result<(), PharmsolError> {
143        if ti == tf {
144            return Ok(());
145        }
146
147        // 1) Build and sort event times
148        let mut ts = Vec::new();
149        ts.push(ti);
150        ts.push(tf);
151        for inf in infusions {
152            let t0 = inf.time();
153            let t1 = t0 + inf.duration();
154            if t0 > ti && t0 < tf {
155                ts.push(t0)
156            }
157            if t1 > ti && t1 < tf {
158                ts.push(t1)
159            }
160        }
161        ts.sort_by(|a, b| a.partial_cmp(b).unwrap());
162        ts.dedup_by(|a, b| (*a - *b).abs() < 1e-12);
163
164        // 2) March over each sub-interval
165        let mut current_t = ts[0];
166        let mut sp = V::from_vec(support_point.to_owned(), NalgebraContext);
167        let mut rateiv = V::zeros(self.get_ndrugs(), NalgebraContext);
168
169        for &next_t in &ts[1..] {
170            // prepare support and infusion rate for [current_t .. next_t]
171            rateiv.fill(0.0);
172            for inf in infusions {
173                let s = inf.time();
174                let e = s + inf.duration();
175                if current_t >= s && next_t <= e {
176                    if inf.input() >= self.get_ndrugs() {
177                        return Err(PharmsolError::InputOutOfRange {
178                            input: inf.input(),
179                            ndrugs: self.get_ndrugs(),
180                        });
181                    }
182                    rateiv[inf.input()] += inf.amount() / inf.duration();
183                }
184            }
185
186            // advance the support-point to next_t
187            (self.seq_eq)(&mut sp, next_t, covariates);
188
189            // advance state by dt
190            let dt = next_t - current_t;
191            *x = (self.eq)(x, &sp, dt, &rateiv, covariates);
192
193            current_t = next_t;
194        }
195
196        Ok(())
197    }
198
199    #[inline(always)]
200    fn process_observation(
201        &self,
202        support_point: &Vec<f64>,
203        observation: &Observation,
204        error_models: Option<&AssayErrorModels>,
205        _time: f64,
206        covariates: &Covariates,
207        x: &mut Self::S,
208        likelihood: &mut Vec<f64>,
209        output: &mut Self::P,
210    ) -> Result<(), PharmsolError> {
211        let mut y = V::zeros(self.get_nouteqs(), NalgebraContext);
212        let out = &self.out;
213        (out)(
214            x,
215            &V::from_vec(support_point.clone(), NalgebraContext),
216            observation.time(),
217            covariates,
218            &mut y,
219        );
220        let pred = y[observation.outeq()];
221        let pred = observation.to_prediction(pred, x.as_slice().to_vec());
222        if let Some(error_models) = error_models {
223            likelihood.push(pred.log_likelihood(error_models)?.exp());
224        }
225        output.add_prediction(pred);
226        Ok(())
227    }
228    #[inline(always)]
229    fn initial_state(&self, spp: &Vec<f64>, covariates: &Covariates, occasion_index: usize) -> V {
230        let init = &self.init;
231        let mut x = V::zeros(self.get_nstates(), NalgebraContext);
232        if occasion_index == 0 {
233            (init)(
234                &V::from_vec(spp.to_vec(), NalgebraContext),
235                0.0,
236                covariates,
237                &mut x,
238            );
239        }
240        x
241    }
242}
243
244#[cfg(test)]
245pub(crate) mod tests {
246    use super::*;
247    use crate::SubjectBuilderExt;
248    use std::collections::HashMap;
249
250    pub(crate) enum SubjectInfo {
251        InfusionDosing,
252        OralInfusionDosage,
253    }
254    impl SubjectInfo {
255        pub(crate) fn get_subject(&self) -> Subject {
256            match self {
257                SubjectInfo::InfusionDosing => Subject::builder("id1")
258                    .bolus(0.0, 100.0, 0)
259                    .infusion(24.0, 150.0, 0, 3.0)
260                    .missing_observation(0.0, 0)
261                    .missing_observation(1.0, 0)
262                    .missing_observation(2.0, 0)
263                    .missing_observation(4.0, 0)
264                    .missing_observation(8.0, 0)
265                    .missing_observation(12.0, 0)
266                    .missing_observation(24.0, 0)
267                    .missing_observation(25.0, 0)
268                    .missing_observation(26.0, 0)
269                    .missing_observation(27.0, 0)
270                    .missing_observation(28.0, 0)
271                    .missing_observation(32.0, 0)
272                    .missing_observation(36.0, 0)
273                    .build(),
274
275                SubjectInfo::OralInfusionDosage => Subject::builder("id1")
276                    .bolus(0.0, 100.0, 1)
277                    .infusion(24.0, 150.0, 0, 3.0)
278                    .bolus(48.0, 100.0, 0)
279                    .missing_observation(0.0, 0)
280                    .missing_observation(1.0, 0)
281                    .missing_observation(2.0, 0)
282                    .missing_observation(4.0, 0)
283                    .missing_observation(8.0, 0)
284                    .missing_observation(12.0, 0)
285                    .missing_observation(24.0, 0)
286                    .missing_observation(25.0, 0)
287                    .missing_observation(26.0, 0)
288                    .missing_observation(27.0, 0)
289                    .missing_observation(28.0, 0)
290                    .missing_observation(32.0, 0)
291                    .missing_observation(36.0, 0)
292                    .missing_observation(48.0, 0)
293                    .missing_observation(49.0, 0)
294                    .missing_observation(50.0, 0)
295                    .missing_observation(52.0, 0)
296                    .missing_observation(56.0, 0)
297                    .missing_observation(60.0, 0)
298                    .build(),
299            }
300        }
301    }
302
303    #[test]
304    fn secondary_equations_accumulate_within_single_solve() {
305        let eq = |x: &V, p: &V, dt: f64, _rateiv: &V, _cov: &Covariates| {
306            let mut next = x.clone();
307            next[0] += p[0] * dt;
308            next
309        };
310        let seq_eq = |params: &mut V, _t: f64, _cov: &Covariates| {
311            params[0] += 1.0;
312        };
313        let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
314        let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
315        let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| {
316            x.fill(0.0);
317        };
318        let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| {
319            y[0] = x[0];
320        };
321
322        let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out)
323            .with_nstates(1)
324            .with_ndrugs(1)
325            .with_nout(1);
326        let subject = Subject::builder("seq")
327            .bolus(0.0, 0.0, 0)
328            .infusion(0.25, 1.0, 0, 0.25)
329            .observation(1.0, 0.0, 0)
330            .build();
331
332        let predictions = analytical
333            .estimate_predictions(&subject, &vec![1.0])
334            .unwrap();
335
336        let value = predictions.predictions()[0].prediction();
337        assert!((value - 2.5).abs() < 1e-12);
338    }
339
340    #[test]
341    fn infusion_inputs_match_state_dimension() {
342        let eq = |x: &V, _p: &V, dt: f64, rateiv: &V, _cov: &Covariates| {
343            let mut next = x.clone();
344            next[0] += rateiv[3] * dt;
345            next
346        };
347        let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {};
348        let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
349        let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new();
350        let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| {
351            x.fill(0.0);
352        };
353        let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| {
354            y[0] = x[0];
355        };
356
357        let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out)
358            .with_nstates(4)
359            .with_ndrugs(4)
360            .with_nout(1);
361        let subject = Subject::builder("inf")
362            .infusion(0.0, 4.0, 3, 1.0)
363            .observation(1.0, 0.0, 0)
364            .build();
365
366        let predictions = analytical
367            .estimate_predictions(&subject, &vec![0.0])
368            .unwrap();
369
370        assert_eq!(predictions.predictions()[0].prediction(), 4.0);
371    }
372}
373impl Equation for Analytical {
374    fn estimate_likelihood(
375        &self,
376        subject: &Subject,
377        support_point: &Vec<f64>,
378        error_models: &AssayErrorModels,
379    ) -> Result<f64, PharmsolError> {
380        _estimate_likelihood(self, subject, support_point, error_models)
381    }
382
383    fn estimate_log_likelihood(
384        &self,
385        subject: &Subject,
386        support_point: &Vec<f64>,
387        error_models: &AssayErrorModels,
388    ) -> Result<f64, PharmsolError> {
389        let ypred = _subject_predictions(self, subject, support_point)?;
390        ypred.log_likelihood(error_models)
391    }
392
393    fn kind() -> crate::EqnKind {
394        crate::EqnKind::Analytical
395    }
396}
397
398#[inline(always)]
399fn _subject_predictions(
400    ode: &Analytical,
401    subject: &Subject,
402    support_point: &Vec<f64>,
403) -> Result<SubjectPredictions, PharmsolError> {
404    if cache_enabled() {
405        let key = (id_hash(subject.id()), spphash(support_point));
406        let cache_guard = ana_cache_lock_read()?;
407        if let Some(cached) = cache_guard.get(&key) {
408            return Ok(cached);
409        }
410        drop(cache_guard);
411
412        let result = ode.simulate_subject(subject, support_point, None)?.0;
413        let cache_guard = ana_cache_lock_read()?;
414        cache_guard.insert(key, result.clone());
415        Ok(result)
416    } else {
417        Ok(ode.simulate_subject(subject, support_point, None)?.0)
418    }
419}
420
421fn _estimate_likelihood(
422    ode: &Analytical,
423    subject: &Subject,
424    support_point: &Vec<f64>,
425    error_models: &AssayErrorModels,
426) -> Result<f64, PharmsolError> {
427    let ypred = _subject_predictions(ode, subject, support_point)?;
428    Ok(ypred.log_likelihood(error_models)?.exp())
429}