Skip to main content

pmcore/routines/
settings.rs

1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::AssayErrorModels;
6
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12/// Contains all settings for PMcore
13#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16    /// General configuration settings
17    pub(crate) config: Config,
18    /// Parameters to be estimated
19    pub(crate) parameters: Parameters,
20    /// Defines the error models and polynomials to be used
21    pub(crate) errormodels: AssayErrorModels,
22    /// Configuration for predictions
23    pub(crate) predictions: Predictions,
24    /// Configuration for logging
25    pub(crate) log: Log,
26    /// Configuration for (optional) prior
27    pub(crate) prior: Prior,
28    /// Configuration for the output files
29    pub(crate) output: Output,
30    /// Configuration for the convergence criteria
31    pub(crate) convergence: Convergence,
32    /// Advanced options, mostly hyperparameters, for the algorithm(s)
33    pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37    /// Create a new [SettingsBuilder]
38    pub fn builder() -> SettingsBuilder<InitialState> {
39        SettingsBuilder::new()
40    }
41
42    /* Getters */
43    pub fn config(&self) -> &Config {
44        &self.config
45    }
46
47    pub fn parameters(&self) -> &Parameters {
48        &self.parameters
49    }
50
51    pub fn errormodels(&self) -> &AssayErrorModels {
52        &self.errormodels
53    }
54
55    pub fn predictions(&self) -> &Predictions {
56        &self.predictions
57    }
58
59    pub fn log(&self) -> &Log {
60        &self.log
61    }
62
63    pub fn prior(&self) -> &Prior {
64        &self.prior
65    }
66
67    pub fn output(&self) -> &Output {
68        &self.output
69    }
70    pub fn convergence(&self) -> &Convergence {
71        &self.convergence
72    }
73
74    pub fn advanced(&self) -> &Advanced {
75        &self.advanced
76    }
77
78    /* Setters */
79    pub fn set_cycles(&mut self, cycles: usize) {
80        self.config.cycles = cycles;
81    }
82
83    pub fn set_algorithm(&mut self, algorithm: Algorithm) {
84        self.config.algorithm = algorithm;
85    }
86
87    pub fn set_idelta(&mut self, idelta: f64) {
88        self.predictions.idelta = idelta;
89    }
90
91    pub fn set_tad(&mut self, tad: f64) {
92        self.predictions.tad = tad;
93    }
94
95    pub fn set_prior(&mut self, prior: Prior) {
96        self.prior = prior;
97    }
98
99    pub fn disable_output(&mut self) {
100        self.output.write = false;
101    }
102
103    pub fn set_output_path(&mut self, path: impl Into<String>) {
104        self.output.path = parse_output_folder(path.into());
105    }
106
107    pub fn set_log_stdout(&mut self, stdout: bool) {
108        self.log.stdout = stdout;
109    }
110
111    pub fn set_write_logs(&mut self, write: bool) {
112        self.log.write = write;
113    }
114
115    pub fn set_log_level(&mut self, level: LogLevel) {
116        self.log.level = level;
117    }
118
119    pub fn set_progress(&mut self, progress: bool) {
120        self.config.progress = progress;
121    }
122
123    pub fn initialize_logs(&mut self) -> Result<()> {
124        crate::routines::logger::setup_log(self)
125    }
126
127    /// Writes a copy of the settings to file
128    /// The is written to output folder specified in the [Output] and is named `settings.json`.
129    pub fn write(&self) -> Result<()> {
130        let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
131
132        let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
133        let mut file = outputfile.file_owned();
134        std::io::Write::write_all(&mut file, serialized.as_bytes())?;
135        Ok(())
136    }
137}
138
139/// General configuration settings
140#[derive(Debug, Deserialize, Clone, Serialize)]
141#[serde(deny_unknown_fields, default)]
142pub struct Config {
143    /// Maximum number of cycles to run
144    pub cycles: usize,
145    /// Denotes the algorithm to use
146    pub algorithm: Algorithm,
147    /// Should a progress bar be displayed for the first cycle
148    ///
149    /// The progress bar is not written to logs, but is written to stdout. It incurs a minor performance penalty.
150    pub progress: bool,
151}
152
153impl Default for Config {
154    fn default() -> Self {
155        Config {
156            cycles: 100,
157            algorithm: Algorithm::NPAG,
158            progress: true,
159        }
160    }
161}
162
163/// Defines a parameter to be estimated
164///
165/// In non-parametric algorithms, parameters must be bounded. The lower and upper bounds are defined by the `lower` and `upper` fields, respectively.
166#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
167pub struct Parameter {
168    pub(crate) name: String,
169    pub(crate) lower: f64,
170    pub(crate) upper: f64,
171}
172
173impl Parameter {
174    /// Create a new parameter
175    pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
176        Self {
177            name: name.into(),
178            lower,
179            upper,
180        }
181    }
182}
183
184/// This structure contains information on all [Parameter]s to be estimated
185#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
186pub struct Parameters {
187    pub(crate) parameters: Vec<Parameter>,
188}
189
190impl Parameters {
191    pub fn new() -> Self {
192        Parameters {
193            parameters: Vec::new(),
194        }
195    }
196
197    pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
198        let parameter = Parameter::new(name, lower, upper);
199        self.parameters.push(parameter);
200        self
201    }
202
203    // Get a parameter by name
204    pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
205        let name = name.into();
206        self.parameters.iter().find(|p| p.name == name)
207    }
208
209    /// Get the names of the parameters
210    pub fn names(&self) -> Vec<String> {
211        self.parameters.iter().map(|p| p.name.clone()).collect()
212    }
213    /// Get the ranges of the parameters
214    ///
215    /// Returns a vector of tuples, where each tuple contains the lower and upper bounds of the parameter
216    pub fn ranges(&self) -> Vec<(f64, f64)> {
217        self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
218    }
219
220    /// Get the number of parameters
221    pub fn len(&self) -> usize {
222        self.parameters.len()
223    }
224
225    /// Check if the parameters are empty
226    pub fn is_empty(&self) -> bool {
227        self.parameters.is_empty()
228    }
229
230    /// Iterate over the parameters
231    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
232        self.parameters.iter()
233    }
234}
235
236impl IntoIterator for Parameters {
237    type Item = Parameter;
238    type IntoIter = std::vec::IntoIter<Parameter>;
239
240    fn into_iter(self) -> Self::IntoIter {
241        self.parameters.into_iter()
242    }
243}
244
245impl From<Vec<Parameter>> for Parameters {
246    fn from(parameters: Vec<Parameter>) -> Self {
247        Parameters { parameters }
248    }
249}
250
251/// This struct contains advanced options and hyperparameters
252#[derive(Debug, Deserialize, Clone, Serialize)]
253#[serde(deny_unknown_fields, default)]
254pub struct Advanced {
255    /// The minimum distance required between a candidate point and the existing grid (THETA_D)
256    ///
257    /// This is general for all non-parametric algorithms
258    pub min_distance: f64,
259    /// Maximum number of steps in Nelder-Mead optimization
260    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
261    pub nm_steps: usize,
262    /// Tolerance (in standard deviations) for the Nelder-Mead optimization
263    ///
264    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
265    pub tolerance: f64,
266}
267
268impl Default for Advanced {
269    fn default() -> Self {
270        Advanced {
271            min_distance: 1e-4,
272            nm_steps: 100,
273            tolerance: 1e-6,
274        }
275    }
276}
277
278#[derive(Debug, Deserialize, Clone, Serialize)]
279#[serde(deny_unknown_fields, default)]
280/// This struct contains the convergence criteria for the algorithm
281pub struct Convergence {
282    /// The objective function convergence criterion for the algorithm
283    ///
284    /// The objective function is the negative log likelihood
285    /// Previously referred to as THETA_G
286    pub likelihood: f64,
287    /// The PYL convergence criterion for the algorithm
288    ///
289    /// P(Y|L) represents the probability of the observation given its weighted support
290    /// Previously referred to as THETA_F
291    pub pyl: f64,
292    /// Precision convergence criterion for the algorithm
293    ///
294    /// The precision variable, sometimes referred to as `eps`, is the distance from existing points in the grid to the candidate point. A candidate point is suggested at a distance of `eps` times the range of the parameter.
295    /// For example, if the parameter `alpha` has a range of `[0.0, 1.0]`, and `eps` is `0.1`, then the candidate point will be at a distance of `0.1 * (1.0 - 0.0) = 0.1` from the existing grid point(s).
296    /// Previously referred to as THETA_E
297    pub eps: f64,
298}
299
300impl Default for Convergence {
301    fn default() -> Self {
302        Convergence {
303            likelihood: 1e-4,
304            pyl: 1e-2,
305            eps: 1e-2,
306        }
307    }
308}
309
310#[derive(Debug, Deserialize, Clone, Serialize)]
311#[serde(deny_unknown_fields, default)]
312pub struct Predictions {
313    /// The interval for which predictions are generated
314    pub idelta: f64,
315    /// The time after the last dose for which predictions are generated
316    ///
317    /// Predictions will always be generated until the last event (observation or dose) in the data.
318    /// This setting is used to generate predictions beyond the last event if the `tad` if sufficiently large.
319    /// This can be useful for generating predictions for a subject who only received a dose, but has no observations.
320    pub tad: f64,
321}
322
323impl Default for Predictions {
324    fn default() -> Self {
325        Predictions {
326            idelta: 0.12,
327            tad: 0.0,
328        }
329    }
330}
331
332impl Predictions {
333    /// Validate the prediction settings
334    pub fn validate(&self) -> Result<()> {
335        if self.idelta < 0.0 {
336            bail!("The interval for predictions must be non-negative");
337        }
338        if self.tad < 0.0 {
339            bail!("The time after dose for predictions must be non-negative");
340        }
341        Ok(())
342    }
343}
344
345/// The log level, which can be one of the following:
346/// - `TRACE`
347/// - `DEBUG`
348/// - `INFO` (Default)
349/// - `WARN`
350/// - `ERROR`
351#[derive(Debug, Deserialize, Clone, Serialize, Default)]
352pub enum LogLevel {
353    TRACE,
354    DEBUG,
355    #[default]
356    INFO,
357    WARN,
358    ERROR,
359}
360
361impl From<LogLevel> for tracing::Level {
362    fn from(log_level: LogLevel) -> tracing::Level {
363        match log_level {
364            LogLevel::TRACE => tracing::Level::TRACE,
365            LogLevel::DEBUG => tracing::Level::DEBUG,
366            LogLevel::INFO => tracing::Level::INFO,
367            LogLevel::WARN => tracing::Level::WARN,
368            LogLevel::ERROR => tracing::Level::ERROR,
369        }
370    }
371}
372
373impl AsRef<str> for LogLevel {
374    fn as_ref(&self) -> &str {
375        match self {
376            LogLevel::TRACE => "trace",
377            LogLevel::DEBUG => "debug",
378            LogLevel::INFO => "info",
379            LogLevel::WARN => "warn",
380            LogLevel::ERROR => "error",
381        }
382    }
383}
384
385impl Display for LogLevel {
386    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        write!(f, "{}", self.as_ref())
388    }
389}
390
391#[derive(Debug, Deserialize, Clone, Serialize)]
392#[serde(deny_unknown_fields, default)]
393pub struct Log {
394    /// The maximum log level to display, as defined by [LogLevel]
395    ///
396    /// [LogLevel] is a thin wrapper around `tracing::Level`, but can be serialized
397    pub level: LogLevel,
398    /// Should the logs be written to a file
399    ///
400    /// If true, a file will be created in the output folder with the name `log.txt`, or, if [Output::write] is false, in the current directory.
401    pub write: bool,
402    /// Define if logs should be written to stdout
403    pub stdout: bool,
404}
405
406impl Default for Log {
407    fn default() -> Self {
408        Log {
409            level: LogLevel::INFO,
410            write: false,
411            stdout: true,
412        }
413    }
414}
415
416/// Configuration for the output files
417#[derive(Debug, Deserialize, Clone, Serialize)]
418#[serde(deny_unknown_fields, default)]
419pub struct Output {
420    /// Whether to write the output files
421    pub write: bool,
422    /// The (relative) path to write the output files to
423    pub path: String,
424}
425
426impl Default for Output {
427    fn default() -> Self {
428        let path = PathBuf::from("outputs/").to_string_lossy().to_string();
429
430        Output { write: true, path }
431    }
432}
433
434pub struct SettingsBuilder<State> {
435    config: Option<Config>,
436    parameters: Option<Parameters>,
437    errormodels: Option<AssayErrorModels>,
438    predictions: Option<Predictions>,
439    log: Option<Log>,
440    prior: Option<Prior>,
441    output: Option<Output>,
442    convergence: Option<Convergence>,
443    advanced: Option<Advanced>,
444    _marker: std::marker::PhantomData<State>,
445}
446
447// Marker traits for builder states
448pub trait AlgorithmDefined {}
449pub trait ParametersDefined {}
450pub trait ErrorModelDefined {}
451
452// Implement marker traits for PhantomData states
453pub struct InitialState;
454pub struct AlgorithmSet;
455pub struct ParametersSet;
456pub struct ErrorSet;
457
458// Initial state: no algorithm set yet
459impl SettingsBuilder<InitialState> {
460    pub fn new() -> Self {
461        SettingsBuilder {
462            config: None,
463            parameters: None,
464            errormodels: None,
465            predictions: None,
466            log: None,
467            prior: None,
468            output: None,
469            convergence: None,
470            advanced: None,
471            _marker: std::marker::PhantomData,
472        }
473    }
474
475    pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
476        SettingsBuilder {
477            config: Some(Config {
478                algorithm,
479                ..Config::default()
480            }),
481            parameters: self.parameters,
482            errormodels: self.errormodels,
483            predictions: self.predictions,
484            log: self.log,
485            prior: self.prior,
486            output: self.output,
487            convergence: self.convergence,
488            advanced: self.advanced,
489            _marker: std::marker::PhantomData,
490        }
491    }
492}
493
494impl Default for SettingsBuilder<InitialState> {
495    fn default() -> Self {
496        SettingsBuilder::new()
497    }
498}
499
500// Algorithm is set, move to defining parameters
501impl SettingsBuilder<AlgorithmSet> {
502    pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
503        SettingsBuilder {
504            config: self.config,
505            parameters: Some(parameters),
506            errormodels: self.errormodels,
507            predictions: self.predictions,
508            log: self.log,
509            prior: self.prior,
510            output: self.output,
511            convergence: self.convergence,
512            advanced: self.advanced,
513            _marker: std::marker::PhantomData,
514        }
515    }
516}
517
518// Parameters are set, move to defining error model
519impl SettingsBuilder<ParametersSet> {
520    pub fn set_error_models(self, ems: AssayErrorModels) -> SettingsBuilder<ErrorSet> {
521        SettingsBuilder {
522            config: self.config,
523            parameters: self.parameters,
524            errormodels: Some(ems),
525            predictions: self.predictions,
526            log: self.log,
527            prior: self.prior,
528            output: self.output,
529            convergence: self.convergence,
530            advanced: self.advanced,
531            _marker: std::marker::PhantomData,
532        }
533    }
534}
535
536// Error model is set, allow optional settings and final build
537impl SettingsBuilder<ErrorSet> {
538    pub fn build(self) -> Settings {
539        Settings {
540            config: self.config.unwrap(),
541            parameters: self.parameters.unwrap(),
542            errormodels: self.errormodels.unwrap(),
543            predictions: self.predictions.unwrap_or_default(),
544            log: self.log.unwrap_or_default(),
545            prior: self.prior.unwrap_or_default(),
546            output: self.output.unwrap_or_default(),
547            convergence: self.convergence.unwrap_or_default(),
548            advanced: self.advanced.unwrap_or_default(),
549        }
550    }
551}
552
553fn parse_output_folder(path: String) -> String {
554    // If the path doesn't contain a "#", just return it as is
555    if !path.contains("#") {
556        return path;
557    }
558
559    // If it does contain "#", perform the incrementation logic
560    let mut num = 1;
561    while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
562        num += 1;
563    }
564
565    path.replace("#", &num.to_string())
566}
567
568#[cfg(test)]
569
570mod tests {
571    use pharmsol::{AssayErrorModel, AssayErrorModels, ErrorPoly};
572
573    use super::*;
574    use crate::algorithms::Algorithm;
575
576    #[test]
577    fn test_builder() {
578        let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
579
580        let ems = AssayErrorModels::new()
581            .add(
582                0,
583                AssayErrorModel::Proportional {
584                    gamma: pharmsol::Factor::Variable(5.0),
585                    poly: ErrorPoly::new(0.0, 0.1, 0.0, 0.0),
586                },
587            )
588            .unwrap();
589        let mut settings = SettingsBuilder::new()
590            .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
591            .set_parameters(parameters) // Step 2: Define parameters
592            .set_error_models(ems)
593            .build();
594
595        settings.set_cycles(100);
596
597        assert_eq!(settings.config.algorithm, Algorithm::NPAG);
598        assert_eq!(settings.config.cycles, 100);
599        assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
600    }
601}