pmcore/routines/output/
cycles.rs

1use anyhow::Result;
2use csv::WriterBuilder;
3use pharmsol::{ErrorModel, ErrorModels};
4use serde::Serialize;
5
6use crate::{
7    algorithms::{Status, StopReason},
8    prelude::Settings,
9    routines::output::{median, OutputFile},
10    structs::theta::Theta,
11};
12
13/// An [NPCycle] object contains the summary of a cycle
14/// It holds the following information:
15/// - `cycle`: The cycle number
16/// - `objf`: The objective function value
17/// - `gamlam`: The assay noise parameter, either gamma or lambda
18/// - `theta`: The support points and their associated probabilities
19/// - `nspp`: The number of support points
20/// - `delta_objf`: The change in objective function value from last cycle
21/// - `converged`: Whether the algorithm has reached convergence
22#[derive(Debug, Clone, Serialize)]
23pub struct NPCycle {
24    cycle: usize,
25    objf: f64,
26    error_models: ErrorModels,
27    theta: Theta,
28    nspp: usize,
29    delta_objf: f64,
30    status: Status,
31}
32
33impl NPCycle {
34    pub fn new(
35        cycle: usize,
36        objf: f64,
37        error_models: ErrorModels,
38        theta: Theta,
39        nspp: usize,
40        delta_objf: f64,
41        status: Status,
42    ) -> Self {
43        Self {
44            cycle,
45            objf,
46            error_models,
47            theta,
48            nspp,
49            delta_objf,
50            status,
51        }
52    }
53
54    pub fn cycle(&self) -> usize {
55        self.cycle
56    }
57    pub fn objf(&self) -> f64 {
58        self.objf
59    }
60    pub fn error_models(&self) -> &ErrorModels {
61        &self.error_models
62    }
63    pub fn theta(&self) -> &Theta {
64        &self.theta
65    }
66    pub fn nspp(&self) -> usize {
67        self.nspp
68    }
69    pub fn delta_objf(&self) -> f64 {
70        self.delta_objf
71    }
72    pub fn status(&self) -> &Status {
73        &self.status
74    }
75
76    pub fn placeholder() -> Self {
77        Self {
78            cycle: 0,
79            objf: 0.0,
80            error_models: ErrorModels::default(),
81            theta: Theta::new(),
82            nspp: 0,
83            delta_objf: 0.0,
84            status: Status::Continue,
85        }
86    }
87}
88
89/// This holdes a vector of [NPCycle] objects to provide a more detailed log
90#[derive(Debug, Clone, Serialize)]
91pub struct CycleLog {
92    cycles: Vec<NPCycle>,
93}
94
95impl CycleLog {
96    pub fn new() -> Self {
97        Self { cycles: Vec::new() }
98    }
99
100    pub fn cycles(&self) -> &[NPCycle] {
101        &self.cycles
102    }
103
104    pub fn push(&mut self, cycle: NPCycle) {
105        self.cycles.push(cycle);
106    }
107
108    pub fn write(&self, settings: &Settings) -> Result<()> {
109        tracing::debug!("Writing cycles...");
110        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
111        let mut writer = WriterBuilder::new()
112            .has_headers(false)
113            .from_writer(&outputfile.file);
114
115        // Write headers
116        writer.write_field("cycle")?;
117        writer.write_field("converged")?;
118        writer.write_field("status")?;
119        writer.write_field("neg2ll")?;
120        writer.write_field("nspp")?;
121        if let Some(first_cycle) = self.cycles.first() {
122            first_cycle.error_models.iter().try_for_each(
123                |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
124                    match errmod {
125                        ErrorModel::Additive { .. } => {
126                            writer.write_field(format!("gamlam.{}", outeq))?;
127                        }
128                        ErrorModel::Proportional { .. } => {
129                            writer.write_field(format!("gamlam.{}", outeq))?;
130                        }
131                        ErrorModel::None => {}
132                    }
133                    Ok(())
134                },
135            )?;
136        }
137
138        let parameter_names = settings.parameters().names();
139        for param_name in &parameter_names {
140            writer.write_field(format!("{}.mean", param_name))?;
141            writer.write_field(format!("{}.median", param_name))?;
142            writer.write_field(format!("{}.sd", param_name))?;
143        }
144
145        writer.write_record(None::<&[u8]>)?;
146
147        for cycle in &self.cycles {
148            writer.write_field(format!("{}", cycle.cycle))?;
149            writer.write_field(format!(
150                "{}",
151                cycle.status == Status::Stop(StopReason::Converged)
152            ))?;
153            writer.write_field(format!("{}", cycle.status))?;
154            writer.write_field(format!("{}", cycle.objf))?;
155            writer
156                .write_field(format!("{}", cycle.theta.nspp()))
157                .unwrap();
158
159            // Write the error models
160            cycle.error_models.iter().try_for_each(
161                |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
162                    match errmod {
163                        ErrorModel::Additive { lambda: _, poly: _ } => {
164                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
165                        }
166                        ErrorModel::Proportional { gamma: _, poly: _ } => {
167                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
168                        }
169                        ErrorModel::None => {}
170                    }
171                    Ok(())
172                },
173            )?;
174
175            for param in cycle.theta.matrix().col_iter() {
176                let param_values: Vec<f64> = param.iter().cloned().collect();
177
178                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
179                let median = median(&param_values);
180                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
181                    / (param_values.len() as f64 - 1.0);
182
183                writer.write_field(format!("{}", mean))?;
184                writer.write_field(format!("{}", median))?;
185                writer.write_field(format!("{}", std))?;
186            }
187            writer.write_record(None::<&[u8]>)?;
188        }
189        writer.flush()?;
190        tracing::debug!("Cycles written to {:?}", &outputfile.relative_path());
191        Ok(())
192    }
193}
194
195impl Default for CycleLog {
196    fn default() -> Self {
197        Self::new()
198    }
199}