pmcore/routines/
settings.rs

1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorModels;
6
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12/// Contains all settings for PMcore
13#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16    /// General configuration settings
17    pub(crate) config: Config,
18    /// Parameters to be estimated
19    pub(crate) parameters: Parameters,
20    /// Defines the error models and polynomials to be used
21    pub(crate) errormodels: ErrorModels,
22    /// Configuration for predictions
23    pub(crate) predictions: Predictions,
24    /// Configuration for logging
25    pub(crate) log: Log,
26    /// Configuration for (optional) prior
27    pub(crate) prior: Prior,
28    /// Configuration for the output files
29    pub(crate) output: Output,
30    /// Configuration for the convergence criteria
31    pub(crate) convergence: Convergence,
32    /// Advanced options, mostly hyperparameters, for the algorithm(s)
33    pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37    /// Create a new [SettingsBuilder]
38    pub fn builder() -> SettingsBuilder<InitialState> {
39        SettingsBuilder::new()
40    }
41
42    /* Getters */
43    pub fn config(&self) -> &Config {
44        &self.config
45    }
46
47    pub fn parameters(&self) -> &Parameters {
48        &self.parameters
49    }
50
51    pub fn errormodels(&self) -> &ErrorModels {
52        &self.errormodels
53    }
54
55    pub fn predictions(&self) -> &Predictions {
56        &self.predictions
57    }
58
59    pub fn log(&self) -> &Log {
60        &self.log
61    }
62
63    pub fn prior(&self) -> &Prior {
64        &self.prior
65    }
66
67    pub fn output(&self) -> &Output {
68        &self.output
69    }
70    pub fn convergence(&self) -> &Convergence {
71        &self.convergence
72    }
73
74    pub fn advanced(&self) -> &Advanced {
75        &self.advanced
76    }
77
78    /* Setters */
79    pub fn set_cycles(&mut self, cycles: usize) {
80        self.config.cycles = cycles;
81    }
82
83    pub fn set_algorithm(&mut self, algorithm: Algorithm) {
84        self.config.algorithm = algorithm;
85    }
86
87    pub fn set_cache(&mut self, cache: bool) {
88        self.config.cache = cache;
89    }
90
91    pub fn set_idelta(&mut self, idelta: f64) {
92        self.predictions.idelta = idelta;
93    }
94
95    pub fn set_tad(&mut self, tad: f64) {
96        self.predictions.tad = tad;
97    }
98
99    pub fn set_prior(&mut self, prior: Prior) {
100        self.prior = prior;
101    }
102
103    pub fn disable_output(&mut self) {
104        self.output.write = false;
105    }
106
107    pub fn set_output_path(&mut self, path: impl Into<String>) {
108        self.output.path = parse_output_folder(path.into());
109    }
110
111    pub fn set_log_stdout(&mut self, stdout: bool) {
112        self.log.stdout = stdout;
113    }
114
115    pub fn set_write_logs(&mut self, write: bool) {
116        self.log.write = write;
117    }
118
119    pub fn set_log_level(&mut self, level: LogLevel) {
120        self.log.level = level;
121    }
122
123    pub fn set_progress(&mut self, progress: bool) {
124        self.config.progress = progress;
125    }
126
127    pub fn initialize_logs(&mut self) -> Result<()> {
128        crate::routines::logger::setup_log(self)
129    }
130
131    /// Writes a copy of the settings to file
132    /// The is written to output folder specified in the [Output] and is named `settings.json`.
133    pub fn write(&self) -> Result<()> {
134        let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
135
136        let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
137        let mut file = outputfile.file;
138        std::io::Write::write_all(&mut file, serialized.as_bytes())?;
139        Ok(())
140    }
141}
142
143/// General configuration settings
144#[derive(Debug, Deserialize, Clone, Serialize)]
145#[serde(deny_unknown_fields, default)]
146pub struct Config {
147    /// Maximum number of cycles to run
148    pub cycles: usize,
149    /// Denotes the algorithm to use
150    pub algorithm: Algorithm,
151    /// If true (default), cache predicted values
152    pub cache: bool,
153    /// Should a progress bar be displayed for the first cycle
154    ///
155    /// The progress bar is not written to logs, but is written to stdout. It incurs a minor performance penalty.
156    pub progress: bool,
157}
158
159impl Default for Config {
160    fn default() -> Self {
161        Config {
162            cycles: 100,
163            algorithm: Algorithm::NPAG,
164            cache: true,
165            progress: true,
166        }
167    }
168}
169
170/// Defines a parameter to be estimated
171///
172/// In non-parametric algorithms, parameters must be bounded. The lower and upper bounds are defined by the `lower` and `upper` fields, respectively.
173#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
174pub struct Parameter {
175    pub(crate) name: String,
176    pub(crate) lower: f64,
177    pub(crate) upper: f64,
178}
179
180impl Parameter {
181    /// Create a new parameter
182    pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
183        Self {
184            name: name.into(),
185            lower,
186            upper,
187        }
188    }
189}
190
191/// This structure contains information on all [Parameter]s to be estimated
192#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
193pub struct Parameters {
194    pub(crate) parameters: Vec<Parameter>,
195}
196
197impl Parameters {
198    pub fn new() -> Self {
199        Parameters {
200            parameters: Vec::new(),
201        }
202    }
203
204    pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
205        let parameter = Parameter::new(name, lower, upper);
206        self.parameters.push(parameter);
207        self
208    }
209
210    // Get a parameter by name
211    pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
212        let name = name.into();
213        self.parameters.iter().find(|p| p.name == name)
214    }
215
216    /// Get the names of the parameters
217    pub fn names(&self) -> Vec<String> {
218        self.parameters.iter().map(|p| p.name.clone()).collect()
219    }
220    /// Get the ranges of the parameters
221    ///
222    /// Returns a vector of tuples, where each tuple contains the lower and upper bounds of the parameter
223    pub fn ranges(&self) -> Vec<(f64, f64)> {
224        self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
225    }
226
227    /// Get the number of parameters
228    pub fn len(&self) -> usize {
229        self.parameters.len()
230    }
231
232    /// Check if the parameters are empty
233    pub fn is_empty(&self) -> bool {
234        self.parameters.is_empty()
235    }
236
237    /// Iterate over the parameters
238    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
239        self.parameters.iter()
240    }
241}
242
243impl IntoIterator for Parameters {
244    type Item = Parameter;
245    type IntoIter = std::vec::IntoIter<Parameter>;
246
247    fn into_iter(self) -> Self::IntoIter {
248        self.parameters.into_iter()
249    }
250}
251
252impl From<Vec<Parameter>> for Parameters {
253    fn from(parameters: Vec<Parameter>) -> Self {
254        Parameters { parameters }
255    }
256}
257
258/// This struct contains advanced options and hyperparameters
259#[derive(Debug, Deserialize, Clone, Serialize)]
260#[serde(deny_unknown_fields, default)]
261pub struct Advanced {
262    /// The minimum distance required between a candidate point and the existing grid (THETA_D)
263    ///
264    /// This is general for all non-parametric algorithms
265    pub min_distance: f64,
266    /// Maximum number of steps in Nelder-Mead optimization
267    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
268    pub nm_steps: usize,
269    /// Tolerance (in standard deviations) for the Nelder-Mead optimization
270    ///
271    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
272    pub tolerance: f64,
273}
274
275impl Default for Advanced {
276    fn default() -> Self {
277        Advanced {
278            min_distance: 1e-4,
279            nm_steps: 100,
280            tolerance: 1e-6,
281        }
282    }
283}
284
285#[derive(Debug, Deserialize, Clone, Serialize)]
286#[serde(deny_unknown_fields, default)]
287/// This struct contains the convergence criteria for the algorithm
288pub struct Convergence {
289    /// The objective function convergence criterion for the algorithm
290    ///
291    /// The objective function is the negative log likelihood
292    /// Previously referred to as THETA_G
293    pub likelihood: f64,
294    /// The PYL convergence criterion for the algorithm
295    ///
296    /// P(Y|L) represents the probability of the observation given its weighted support
297    /// Previously referred to as THETA_F
298    pub pyl: f64,
299    /// Precision convergence criterion for the algorithm
300    ///
301    /// The precision variable, sometimes referred to as `eps`, is the distance from existing points in the grid to the candidate point. A candidate point is suggested at a distance of `eps` times the range of the parameter.
302    /// For example, if the parameter `alpha` has a range of `[0.0, 1.0]`, and `eps` is `0.1`, then the candidate point will be at a distance of `0.1 * (1.0 - 0.0) = 0.1` from the existing grid point(s).
303    /// Previously referred to as THETA_E
304    pub eps: f64,
305}
306
307impl Default for Convergence {
308    fn default() -> Self {
309        Convergence {
310            likelihood: 1e-4,
311            pyl: 1e-2,
312            eps: 1e-2,
313        }
314    }
315}
316
317#[derive(Debug, Deserialize, Clone, Serialize)]
318#[serde(deny_unknown_fields, default)]
319pub struct Predictions {
320    /// The interval for which predictions are generated
321    pub idelta: f64,
322    /// The time after the last dose for which predictions are generated
323    ///
324    /// Predictions will always be generated until the last event (observation or dose) in the data.
325    /// This setting is used to generate predictions beyond the last event if the `tad` if sufficiently large.
326    /// This can be useful for generating predictions for a subject who only received a dose, but has no observations.
327    pub tad: f64,
328}
329
330impl Default for Predictions {
331    fn default() -> Self {
332        Predictions {
333            idelta: 0.12,
334            tad: 0.0,
335        }
336    }
337}
338
339impl Predictions {
340    /// Validate the prediction settings
341    pub fn validate(&self) -> Result<()> {
342        if self.idelta < 0.0 {
343            bail!("The interval for predictions must be non-negative");
344        }
345        if self.tad < 0.0 {
346            bail!("The time after dose for predictions must be non-negative");
347        }
348        Ok(())
349    }
350}
351
352/// The log level, which can be one of the following:
353/// - `TRACE`
354/// - `DEBUG`
355/// - `INFO` (Default)
356/// - `WARN`
357/// - `ERROR`
358#[derive(Debug, Deserialize, Clone, Serialize, Default)]
359pub enum LogLevel {
360    TRACE,
361    DEBUG,
362    #[default]
363    INFO,
364    WARN,
365    ERROR,
366}
367
368impl From<LogLevel> for tracing::Level {
369    fn from(log_level: LogLevel) -> tracing::Level {
370        match log_level {
371            LogLevel::TRACE => tracing::Level::TRACE,
372            LogLevel::DEBUG => tracing::Level::DEBUG,
373            LogLevel::INFO => tracing::Level::INFO,
374            LogLevel::WARN => tracing::Level::WARN,
375            LogLevel::ERROR => tracing::Level::ERROR,
376        }
377    }
378}
379
380impl AsRef<str> for LogLevel {
381    fn as_ref(&self) -> &str {
382        match self {
383            LogLevel::TRACE => "trace",
384            LogLevel::DEBUG => "debug",
385            LogLevel::INFO => "info",
386            LogLevel::WARN => "warn",
387            LogLevel::ERROR => "error",
388        }
389    }
390}
391
392impl Display for LogLevel {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        write!(f, "{}", self.as_ref())
395    }
396}
397
398#[derive(Debug, Deserialize, Clone, Serialize)]
399#[serde(deny_unknown_fields, default)]
400pub struct Log {
401    /// The maximum log level to display, as defined by [LogLevel]
402    ///
403    /// [LogLevel] is a thin wrapper around `tracing::Level`, but can be serialized
404    pub level: LogLevel,
405    /// Should the logs be written to a file
406    ///
407    /// If true, a file will be created in the output folder with the name `log.txt`, or, if [Output::write] is false, in the current directory.
408    pub write: bool,
409    /// Define if logs should be written to stdout
410    pub stdout: bool,
411}
412
413impl Default for Log {
414    fn default() -> Self {
415        Log {
416            level: LogLevel::INFO,
417            write: false,
418            stdout: true,
419        }
420    }
421}
422
423/// Configuration for the output files
424#[derive(Debug, Deserialize, Clone, Serialize)]
425#[serde(deny_unknown_fields, default)]
426pub struct Output {
427    /// Whether to write the output files
428    pub write: bool,
429    /// The (relative) path to write the output files to
430    pub path: String,
431}
432
433impl Default for Output {
434    fn default() -> Self {
435        let path = PathBuf::from("outputs/").to_string_lossy().to_string();
436
437        Output { write: true, path }
438    }
439}
440
441pub struct SettingsBuilder<State> {
442    config: Option<Config>,
443    parameters: Option<Parameters>,
444    errormodels: Option<ErrorModels>,
445    predictions: Option<Predictions>,
446    log: Option<Log>,
447    prior: Option<Prior>,
448    output: Option<Output>,
449    convergence: Option<Convergence>,
450    advanced: Option<Advanced>,
451    _marker: std::marker::PhantomData<State>,
452}
453
454// Marker traits for builder states
455pub trait AlgorithmDefined {}
456pub trait ParametersDefined {}
457pub trait ErrorModelDefined {}
458
459// Implement marker traits for PhantomData states
460pub struct InitialState;
461pub struct AlgorithmSet;
462pub struct ParametersSet;
463pub struct ErrorSet;
464
465// Initial state: no algorithm set yet
466impl SettingsBuilder<InitialState> {
467    pub fn new() -> Self {
468        SettingsBuilder {
469            config: None,
470            parameters: None,
471            errormodels: None,
472            predictions: None,
473            log: None,
474            prior: None,
475            output: None,
476            convergence: None,
477            advanced: None,
478            _marker: std::marker::PhantomData,
479        }
480    }
481
482    pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
483        SettingsBuilder {
484            config: Some(Config {
485                algorithm,
486                ..Config::default()
487            }),
488            parameters: self.parameters,
489            errormodels: self.errormodels,
490            predictions: self.predictions,
491            log: self.log,
492            prior: self.prior,
493            output: self.output,
494            convergence: self.convergence,
495            advanced: self.advanced,
496            _marker: std::marker::PhantomData,
497        }
498    }
499}
500
501impl Default for SettingsBuilder<InitialState> {
502    fn default() -> Self {
503        SettingsBuilder::new()
504    }
505}
506
507// Algorithm is set, move to defining parameters
508impl SettingsBuilder<AlgorithmSet> {
509    pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
510        SettingsBuilder {
511            config: self.config,
512            parameters: Some(parameters),
513            errormodels: self.errormodels,
514            predictions: self.predictions,
515            log: self.log,
516            prior: self.prior,
517            output: self.output,
518            convergence: self.convergence,
519            advanced: self.advanced,
520            _marker: std::marker::PhantomData,
521        }
522    }
523}
524
525// Parameters are set, move to defining error model
526impl SettingsBuilder<ParametersSet> {
527    pub fn set_error_models(self, ems: ErrorModels) -> SettingsBuilder<ErrorSet> {
528        SettingsBuilder {
529            config: self.config,
530            parameters: self.parameters,
531            errormodels: Some(ems),
532            predictions: self.predictions,
533            log: self.log,
534            prior: self.prior,
535            output: self.output,
536            convergence: self.convergence,
537            advanced: self.advanced,
538            _marker: std::marker::PhantomData,
539        }
540    }
541}
542
543// Error model is set, allow optional settings and final build
544impl SettingsBuilder<ErrorSet> {
545    pub fn build(self) -> Settings {
546        Settings {
547            config: self.config.unwrap(),
548            parameters: self.parameters.unwrap(),
549            errormodels: self.errormodels.unwrap(),
550            predictions: self.predictions.unwrap_or_default(),
551            log: self.log.unwrap_or_default(),
552            prior: self.prior.unwrap_or_default(),
553            output: self.output.unwrap_or_default(),
554            convergence: self.convergence.unwrap_or_default(),
555            advanced: self.advanced.unwrap_or_default(),
556        }
557    }
558}
559
560fn parse_output_folder(path: String) -> String {
561    // If the path doesn't contain a "#", just return it as is
562    if !path.contains("#") {
563        return path;
564    }
565
566    // If it does contain "#", perform the incrementation logic
567    let mut num = 1;
568    while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
569        num += 1;
570    }
571
572    path.replace("#", &num.to_string())
573}
574
575#[cfg(test)]
576
577mod tests {
578    use pharmsol::{ErrorModel, ErrorPoly};
579
580    use super::*;
581    use crate::algorithms::Algorithm;
582
583    #[test]
584    fn test_builder() {
585        let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
586
587        let ems = ErrorModels::new()
588            .add(
589                0,
590                ErrorModel::Proportional {
591                    gamma: 5.0,
592                    poly: ErrorPoly::new(0.0, 0.1, 0.0, 0.0),
593                },
594            )
595            .unwrap();
596        let mut settings = SettingsBuilder::new()
597            .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
598            .set_parameters(parameters) // Step 2: Define parameters
599            .set_error_models(ems)
600            .build();
601
602        settings.set_cycles(100);
603
604        assert_eq!(settings.config.algorithm, Algorithm::NPAG);
605        assert_eq!(settings.config.cycles, 100);
606        assert_eq!(settings.config.cache, true);
607        assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
608    }
609}