pmcore/algorithms/
postprob.rs

1use crate::{
2    algorithms::{Status, StopReason},
3    prelude::algorithms::Algorithms,
4    structs::{
5        psi::{calculate_psi, Psi},
6        theta::Theta,
7        weights::Weights,
8    },
9};
10use anyhow::{Context, Result};
11
12use pharmsol::prelude::{
13    data::{Data, ErrorModels},
14    simulator::Equation,
15};
16
17use crate::routines::estimation::ipm::burke;
18use crate::routines::initialization;
19use crate::routines::output::{cycles::CycleLog, NPResult};
20use crate::routines::settings::Settings;
21
22/// Posterior probability algorithm
23/// Reweights the prior probabilities to the observed data and error model
24pub struct POSTPROB<E: Equation + Send + 'static> {
25    equation: E,
26    psi: Psi,
27    theta: Theta,
28    w: Weights,
29    objf: f64,
30    cycle: usize,
31    status: Status,
32    data: Data,
33    settings: Settings,
34    cyclelog: CycleLog,
35    error_models: ErrorModels,
36}
37
38impl<E: Equation + Send + 'static> Algorithms<E> for POSTPROB<E> {
39    fn new(settings: Settings, equation: E, data: Data) -> Result<Box<Self>, anyhow::Error> {
40        Ok(Box::new(Self {
41            equation,
42            psi: Psi::new(),
43            theta: Theta::new(),
44            w: Weights::default(),
45            objf: f64::INFINITY,
46            cycle: 0,
47            status: Status::Continue,
48            error_models: settings.errormodels().clone(),
49            settings,
50            data,
51            cyclelog: CycleLog::new(),
52        }))
53    }
54    fn into_npresult(&self) -> NPResult<E> {
55        NPResult::new(
56            self.equation.clone(),
57            self.data.clone(),
58            self.theta.clone(),
59            self.psi.clone(),
60            self.w.clone(),
61            self.objf,
62            self.cycle,
63            self.status.clone(),
64            self.settings.clone(),
65            self.cyclelog.clone(),
66        )
67    }
68    fn settings(&self) -> &Settings {
69        &self.settings
70    }
71
72    fn equation(&self) -> &E {
73        &self.equation
74    }
75
76    fn data(&self) -> &Data {
77        &self.data
78    }
79
80    fn get_prior(&self) -> Theta {
81        initialization::sample_space(&self.settings).unwrap()
82    }
83
84    fn likelihood(&self) -> f64 {
85        self.objf
86    }
87
88    fn increment_cycle(&mut self) -> usize {
89        0
90    }
91
92    fn cycle(&self) -> usize {
93        0
94    }
95
96    fn set_theta(&mut self, theta: Theta) {
97        self.theta = theta;
98    }
99
100    fn theta(&self) -> &Theta {
101        &self.theta
102    }
103
104    fn psi(&self) -> &Psi {
105        &self.psi
106    }
107
108    fn set_status(&mut self, status: Status) {
109        self.status = status;
110    }
111
112    fn status(&self) -> &Status {
113        &self.status
114    }
115
116    fn evaluation(&mut self) -> Result<Status> {
117        self.status = Status::Stop(StopReason::Converged);
118        Ok(self.status.clone())
119    }
120
121    fn estimation(&mut self) -> Result<()> {
122        self.psi = calculate_psi(
123            &self.equation,
124            &self.data,
125            &self.theta,
126            &self.error_models,
127            false,
128            false,
129        )?;
130        (self.w, self.objf) = burke(&self.psi).context("Error in IPM")?;
131        Ok(())
132    }
133
134    fn condensation(&mut self) -> Result<()> {
135        Ok(())
136    }
137    fn optimizations(&mut self) -> Result<()> {
138        Ok(())
139    }
140
141    fn expansion(&mut self) -> Result<()> {
142        Ok(())
143    }
144
145    fn log_cycle_state(&mut self) {
146        // Postprob doesn't track last_objf, so we use 0.0 as the delta
147        let state = crate::routines::output::cycles::NPCycle::new(
148            self.cycle,
149            self.objf,
150            self.error_models.clone(),
151            self.theta.clone(),
152            self.theta.nspp(),
153            0.0,
154            self.status.clone(),
155        );
156        self.cyclelog.push(state);
157    }
158}