pmcore/routines/output/
cycles.rs

1use anyhow::Result;
2use csv::WriterBuilder;
3use pharmsol::{ErrorModel, ErrorModels};
4
5use crate::{
6    algorithms::Status,
7    prelude::Settings,
8    routines::output::{median, OutputFile},
9    structs::theta::Theta,
10};
11
12/// An [NPCycle] object contains the summary of a cycle
13/// It holds the following information:
14/// - `cycle`: The cycle number
15/// - `objf`: The objective function value
16/// - `gamlam`: The assay noise parameter, either gamma or lambda
17/// - `theta`: The support points and their associated probabilities
18/// - `nspp`: The number of support points
19/// - `delta_objf`: The change in objective function value from last cycle
20/// - `converged`: Whether the algorithm has reached convergence
21#[derive(Debug, Clone)]
22pub struct NPCycle {
23    cycle: usize,
24    objf: f64,
25    error_models: ErrorModels,
26    theta: Theta,
27    nspp: usize,
28    delta_objf: f64,
29    status: Status,
30}
31
32impl NPCycle {
33    pub fn new(
34        cycle: usize,
35        objf: f64,
36        error_models: ErrorModels,
37        theta: Theta,
38        nspp: usize,
39        delta_objf: f64,
40        status: Status,
41    ) -> Self {
42        Self {
43            cycle,
44            objf,
45            error_models,
46            theta,
47            nspp,
48            delta_objf,
49            status,
50        }
51    }
52
53    pub fn cycle(&self) -> usize {
54        self.cycle
55    }
56    pub fn objf(&self) -> f64 {
57        self.objf
58    }
59    pub fn error_models(&self) -> &ErrorModels {
60        &self.error_models
61    }
62    pub fn theta(&self) -> &Theta {
63        &self.theta
64    }
65    pub fn nspp(&self) -> usize {
66        self.nspp
67    }
68    pub fn delta_objf(&self) -> f64 {
69        self.delta_objf
70    }
71    pub fn status(&self) -> &Status {
72        &self.status
73    }
74
75    pub fn placeholder() -> Self {
76        Self {
77            cycle: 0,
78            objf: 0.0,
79            error_models: ErrorModels::default(),
80            theta: Theta::new(),
81            nspp: 0,
82            delta_objf: 0.0,
83            status: Status::Starting,
84        }
85    }
86}
87
88/// This holdes a vector of [NPCycle] objects to provide a more detailed log
89#[derive(Debug, Clone)]
90pub struct CycleLog {
91    cycles: Vec<NPCycle>,
92}
93
94impl CycleLog {
95    pub fn new() -> Self {
96        Self { cycles: Vec::new() }
97    }
98
99    pub fn cycles(&self) -> &[NPCycle] {
100        &self.cycles
101    }
102
103    pub fn push(&mut self, cycle: NPCycle) {
104        self.cycles.push(cycle);
105    }
106
107    pub fn write(&self, settings: &Settings) -> Result<()> {
108        tracing::debug!("Writing cycles...");
109        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
110        let mut writer = WriterBuilder::new()
111            .has_headers(false)
112            .from_writer(&outputfile.file);
113
114        // Write headers
115        writer.write_field("cycle")?;
116        writer.write_field("converged")?;
117        writer.write_field("status")?;
118        writer.write_field("neg2ll")?;
119        writer.write_field("nspp")?;
120        if let Some(first_cycle) = self.cycles.first() {
121            first_cycle.error_models.iter().try_for_each(
122                |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
123                    match errmod {
124                        ErrorModel::Additive { .. } => {
125                            writer.write_field(format!("gamlam.{}", outeq))?;
126                        }
127                        ErrorModel::Proportional { .. } => {
128                            writer.write_field(format!("gamlam.{}", outeq))?;
129                        }
130                        ErrorModel::None => {}
131                    }
132                    Ok(())
133                },
134            )?;
135        }
136
137        let parameter_names = settings.parameters().names();
138        for param_name in &parameter_names {
139            writer.write_field(format!("{}.mean", param_name))?;
140            writer.write_field(format!("{}.median", param_name))?;
141            writer.write_field(format!("{}.sd", param_name))?;
142        }
143
144        writer.write_record(None::<&[u8]>)?;
145
146        for cycle in &self.cycles {
147            writer.write_field(format!("{}", cycle.cycle))?;
148            writer.write_field(format!("{}", cycle.status == Status::Converged))?;
149            writer.write_field(format!("{}", cycle.status))?;
150            writer.write_field(format!("{}", cycle.objf))?;
151            writer
152                .write_field(format!("{}", cycle.theta.nspp()))
153                .unwrap();
154
155            // Write the error models
156            cycle.error_models.iter().try_for_each(
157                |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
158                    match errmod {
159                        ErrorModel::Additive {
160                            lambda: _,
161                            poly: _,
162                            lloq: _,
163                        } => {
164                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
165                        }
166                        ErrorModel::Proportional {
167                            gamma: _,
168                            poly: _,
169                            lloq: _,
170                        } => {
171                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
172                        }
173                        ErrorModel::None => {}
174                    }
175                    Ok(())
176                },
177            )?;
178
179            for param in cycle.theta.matrix().col_iter() {
180                let param_values: Vec<f64> = param.iter().cloned().collect();
181
182                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
183                let median = median(&param_values);
184                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
185                    / (param_values.len() as f64 - 1.0);
186
187                writer.write_field(format!("{}", mean))?;
188                writer.write_field(format!("{}", median))?;
189                writer.write_field(format!("{}", std))?;
190            }
191            writer.write_record(None::<&[u8]>)?;
192        }
193        writer.flush()?;
194        tracing::debug!("Cycles written to {:?}", &outputfile.relative_path());
195        Ok(())
196    }
197}
198
199impl Default for CycleLog {
200    fn default() -> Self {
201        Self::new()
202    }
203}