pmcore/routines/
settings.rs

1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorModel;
6use pharmsol::ErrorPoly;
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12/// Contains all settings for PMcore
13#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16    /// General configuration settings
17    pub(crate) config: Config,
18    /// Parameters to be estimated
19    pub(crate) parameters: Parameters,
20    /// Defines the error model and polynomial to be used
21    pub(crate) error: Error,
22    /// Configuration for predictions
23    pub(crate) predictions: Predictions,
24    /// Configuration for logging
25    pub(crate) log: Log,
26    /// Configuration for (optional) prior
27    pub(crate) prior: Prior,
28    /// Configuration for the output files
29    pub(crate) output: Output,
30    /// Configuration for the convergence criteria
31    pub(crate) convergence: Convergence,
32    /// Advanced options, mostly hyperparameters, for the algorithm(s)
33    pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37    /// Create a new [SettingsBuilder]
38    pub fn builder() -> SettingsBuilder<InitialState> {
39        SettingsBuilder::new()
40    }
41
42    /// Validate the settings
43    pub fn validate(&self) -> Result<()> {
44        self.error.validate()?;
45        self.predictions.validate()?;
46        Ok(())
47    }
48
49    /* Getters */
50    pub fn config(&self) -> &Config {
51        &self.config
52    }
53
54    pub fn parameters(&self) -> &Parameters {
55        &self.parameters
56    }
57
58    pub fn error(&self) -> &Error {
59        &self.error
60    }
61
62    pub fn predictions(&self) -> &Predictions {
63        &self.predictions
64    }
65
66    pub fn log(&self) -> &Log {
67        &self.log
68    }
69
70    pub fn prior(&self) -> &Prior {
71        &self.prior
72    }
73
74    pub fn output(&self) -> &Output {
75        &self.output
76    }
77    pub fn convergence(&self) -> &Convergence {
78        &self.convergence
79    }
80
81    pub fn advanced(&self) -> &Advanced {
82        &self.advanced
83    }
84
85    /* Setters */
86    pub fn set_cycles(&mut self, cycles: usize) {
87        self.config.cycles = cycles;
88    }
89
90    pub fn set_algorithm(&mut self, algorithm: Algorithm) {
91        self.config.algorithm = algorithm;
92    }
93
94    pub fn set_cache(&mut self, cache: bool) {
95        self.config.cache = cache;
96    }
97
98    pub fn set_idelta(&mut self, idelta: f64) {
99        self.predictions.idelta = idelta;
100    }
101
102    pub fn set_tad(&mut self, tad: f64) {
103        self.predictions.tad = tad;
104    }
105
106    pub fn set_prior(&mut self, prior: Prior) {
107        self.prior = prior;
108    }
109
110    pub fn disable_output(&mut self) {
111        self.output.write = false;
112    }
113
114    pub fn set_output_path(&mut self, path: impl Into<String>) {
115        self.output.path = parse_output_folder(path.into());
116    }
117
118    pub fn set_log_stdout(&mut self, stdout: bool) {
119        self.log.stdout = stdout;
120    }
121
122    pub fn set_write_logs(&mut self, write: bool) {
123        self.log.write = write;
124    }
125
126    pub fn set_log_level(&mut self, level: LogLevel) {
127        self.log.level = level;
128    }
129
130    pub fn set_progress(&mut self, progress: bool) {
131        self.config.progress = progress;
132    }
133
134    pub fn initialize_logs(&mut self) -> Result<()> {
135        crate::routines::logger::setup_log(self)
136    }
137
138    /// Writes a copy of the settings to file
139    /// The is written to output folder specified in the [Output] and is named `settings.json`.
140    pub fn write(&self) -> Result<()> {
141        let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
142
143        let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
144        let mut file = outputfile.file;
145        std::io::Write::write_all(&mut file, serialized.as_bytes())?;
146        Ok(())
147    }
148}
149
150/// General configuration settings
151#[derive(Debug, Deserialize, Clone, Serialize)]
152#[serde(deny_unknown_fields, default)]
153pub struct Config {
154    /// Maximum number of cycles to run
155    pub cycles: usize,
156    /// Denotes the algorithm to use
157    pub algorithm: Algorithm,
158    /// If true (default), cache predicted values
159    pub cache: bool,
160    /// Should a progress bar be displayed for the first cycle
161    ///
162    /// The progress bar is not written to logs, but is written to stdout. It incurs a minor performance penalty.
163    pub progress: bool,
164}
165
166impl Default for Config {
167    fn default() -> Self {
168        Config {
169            cycles: 100,
170            algorithm: Algorithm::NPAG,
171            cache: true,
172            progress: true,
173        }
174    }
175}
176
177/// Defines a parameter to be estimated
178///
179/// In non-parametric algorithms, parameters must be bounded. The lower and upper bounds are defined by the `lower` and `upper` fields, respectively.
180#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
181pub struct Parameter {
182    pub(crate) name: String,
183    pub(crate) lower: f64,
184    pub(crate) upper: f64,
185}
186
187impl Parameter {
188    /// Create a new parameter
189    pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
190        Self {
191            name: name.into(),
192            lower,
193            upper,
194        }
195    }
196}
197
198/// This structure contains information on all [Parameter]s to be estimated
199#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
200pub struct Parameters {
201    pub(crate) parameters: Vec<Parameter>,
202}
203
204impl Parameters {
205    pub fn new() -> Self {
206        Parameters {
207            parameters: Vec::new(),
208        }
209    }
210
211    pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
212        let parameter = Parameter::new(name, lower, upper);
213        self.parameters.push(parameter);
214        self
215    }
216
217    // Get a parameter by name
218    pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
219        let name = name.into();
220        self.parameters.iter().find(|p| p.name == name)
221    }
222
223    /// Get the names of the parameters
224    pub fn names(&self) -> Vec<String> {
225        self.parameters.iter().map(|p| p.name.clone()).collect()
226    }
227    /// Get the ranges of the parameters
228    ///
229    /// Returns a vector of tuples, where each tuple contains the lower and upper bounds of the parameter
230    pub fn ranges(&self) -> Vec<(f64, f64)> {
231        self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
232    }
233
234    /// Get the number of parameters
235    pub fn len(&self) -> usize {
236        self.parameters.len()
237    }
238
239    /// Check if the parameters are empty
240    pub fn is_empty(&self) -> bool {
241        self.parameters.is_empty()
242    }
243
244    /// Iterate over the parameters
245    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
246        self.parameters.iter()
247    }
248}
249
250impl IntoIterator for Parameters {
251    type Item = Parameter;
252    type IntoIter = std::vec::IntoIter<Parameter>;
253
254    fn into_iter(self) -> Self::IntoIter {
255        self.parameters.into_iter()
256    }
257}
258
259impl From<Vec<Parameter>> for Parameters {
260    fn from(parameters: Vec<Parameter>) -> Self {
261        Parameters { parameters }
262    }
263}
264
265#[derive(Debug, Deserialize, Clone, Serialize)]
266pub enum ErrorType {
267    Additive,
268    Proportional,
269}
270
271/// Defines the error model and polynomial to be used
272#[derive(Debug, Deserialize, Clone, Serialize)]
273#[serde(deny_unknown_fields, default)]
274pub struct Error {
275    /// The initial value of `gamma` or `lambda`
276    pub value: f64,
277    /// The error class, either `additive` or `proportional`
278    pub errortype: ErrorType,
279    /// The assay error polynomial
280    pub poly: (f64, f64, f64, f64),
281}
282
283impl Default for Error {
284    fn default() -> Self {
285        Error {
286            value: 0.0,
287            errortype: ErrorType::Additive,
288            poly: (0.0, 0.1, 0.0, 0.0),
289        }
290    }
291}
292
293impl Error {
294    fn new(value: f64, errortype: ErrorType, poly: (f64, f64, f64, f64)) -> Self {
295        Error {
296            value,
297            errortype,
298            poly,
299        }
300    }
301
302    fn validate(&self) -> Result<()> {
303        if self.value < 0.0 {
304            bail!("The initial value of gamma or lambda must be non-negative");
305        }
306        Ok(())
307    }
308}
309
310impl From<Error> for pharmsol::prelude::data::ErrorModel {
311    fn from(error: Error) -> Self {
312        match error.errortype {
313            ErrorType::Additive => ErrorModel::additive(
314                ErrorPoly::new(error.poly.0, error.poly.1, error.poly.2, error.poly.3),
315                error.value,
316            ),
317            ErrorType::Proportional => ErrorModel::proportional(
318                ErrorPoly::new(error.poly.0, error.poly.1, error.poly.2, error.poly.3),
319                error.value,
320            ),
321        }
322    }
323}
324
325/// This struct contains advanced options and hyperparameters
326#[derive(Debug, Deserialize, Clone, Serialize)]
327#[serde(deny_unknown_fields, default)]
328pub struct Advanced {
329    /// The minimum distance required between a candidate point and the existing grid (THETA_D)
330    ///
331    /// This is general for all non-parametric algorithms
332    pub min_distance: f64,
333    /// Maximum number of steps in Nelder-Mead optimization
334    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
335    pub nm_steps: usize,
336    /// Tolerance (in standard deviations) for the Nelder-Mead optimization
337    ///
338    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
339    pub tolerance: f64,
340}
341
342impl Default for Advanced {
343    fn default() -> Self {
344        Advanced {
345            min_distance: 1e-4,
346            nm_steps: 100,
347            tolerance: 1e-6,
348        }
349    }
350}
351
352#[derive(Debug, Deserialize, Clone, Serialize)]
353#[serde(deny_unknown_fields, default)]
354/// This struct contains the convergence criteria for the algorithm
355pub struct Convergence {
356    /// The objective function convergence criterion for the algorithm
357    ///
358    /// The objective function is the negative log likelihood
359    /// Previously referred to as THETA_G
360    pub likelihood: f64,
361    /// The PYL convergence criterion for the algorithm
362    ///
363    /// P(Y|L) represents the probability of the observation given its weighted support
364    /// Previously referred to as THETA_F
365    pub pyl: f64,
366    /// Precision convergence criterion for the algorithm
367    ///
368    /// The precision variable, sometimes referred to as `eps`, is the distance from existing points in the grid to the candidate point. A candidate point is suggested at a distance of `eps` times the range of the parameter.
369    /// For example, if the parameter `alpha` has a range of `[0.0, 1.0]`, and `eps` is `0.1`, then the candidate point will be at a distance of `0.1 * (1.0 - 0.0) = 0.1` from the existing grid point(s).
370    /// Previously referred to as THETA_E
371    pub eps: f64,
372}
373
374impl Default for Convergence {
375    fn default() -> Self {
376        Convergence {
377            likelihood: 1e-4,
378            pyl: 1e-2,
379            eps: 1e-2,
380        }
381    }
382}
383
384#[derive(Debug, Deserialize, Clone, Serialize)]
385#[serde(deny_unknown_fields, default)]
386pub struct Predictions {
387    /// The interval for which predictions are generated
388    pub idelta: f64,
389    /// The time after the last dose for which predictions are generated
390    ///
391    /// Predictions will always be generated until the last event (observation or dose) in the data.
392    /// This setting is used to generate predictions beyond the last event if the `tad` if sufficiently large.
393    /// This can be useful for generating predictions for a subject who only received a dose, but has no observations.
394    pub tad: f64,
395}
396
397impl Default for Predictions {
398    fn default() -> Self {
399        Predictions {
400            idelta: 0.12,
401            tad: 0.0,
402        }
403    }
404}
405
406impl Predictions {
407    /// Validate the prediction settings
408    pub fn validate(&self) -> Result<()> {
409        if self.idelta < 0.0 {
410            bail!("The interval for predictions must be non-negative");
411        }
412        if self.tad < 0.0 {
413            bail!("The time after dose for predictions must be non-negative");
414        }
415        Ok(())
416    }
417}
418
419/// The log level, which can be one of the following:
420/// - `TRACE`
421/// - `DEBUG`
422/// - `INFO` (Default)
423/// - `WARN`
424/// - `ERROR`
425#[derive(Debug, Deserialize, Clone, Serialize, Default)]
426pub enum LogLevel {
427    TRACE,
428    DEBUG,
429    #[default]
430    INFO,
431    WARN,
432    ERROR,
433}
434
435impl From<LogLevel> for tracing::Level {
436    fn from(log_level: LogLevel) -> tracing::Level {
437        match log_level {
438            LogLevel::TRACE => tracing::Level::TRACE,
439            LogLevel::DEBUG => tracing::Level::DEBUG,
440            LogLevel::INFO => tracing::Level::INFO,
441            LogLevel::WARN => tracing::Level::WARN,
442            LogLevel::ERROR => tracing::Level::ERROR,
443        }
444    }
445}
446
447impl AsRef<str> for LogLevel {
448    fn as_ref(&self) -> &str {
449        match self {
450            LogLevel::TRACE => "trace",
451            LogLevel::DEBUG => "debug",
452            LogLevel::INFO => "info",
453            LogLevel::WARN => "warn",
454            LogLevel::ERROR => "error",
455        }
456    }
457}
458
459impl Display for LogLevel {
460    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461        write!(f, "{}", self.as_ref())
462    }
463}
464
465#[derive(Debug, Deserialize, Clone, Serialize)]
466#[serde(deny_unknown_fields, default)]
467pub struct Log {
468    /// The maximum log level to display, as defined by [LogLevel]
469    ///
470    /// [LogLevel] is a thin wrapper around `tracing::Level`, but can be serialized
471    pub level: LogLevel,
472    /// Should the logs be written to a file
473    ///
474    /// If true, a file will be created in the output folder with the name `log.txt`, or, if [Output::write] is false, in the current directory.
475    pub write: bool,
476    /// Define if logs should be written to stdout
477    pub stdout: bool,
478}
479
480impl Default for Log {
481    fn default() -> Self {
482        Log {
483            level: LogLevel::INFO,
484            write: false,
485            stdout: true,
486        }
487    }
488}
489
490/// Configuration for the output files
491#[derive(Debug, Deserialize, Clone, Serialize)]
492#[serde(deny_unknown_fields, default)]
493pub struct Output {
494    /// Whether to write the output files
495    pub write: bool,
496    /// The (relative) path to write the output files to
497    pub path: String,
498}
499
500impl Default for Output {
501    fn default() -> Self {
502        let path = PathBuf::from("outputs/").to_string_lossy().to_string();
503
504        Output { write: true, path }
505    }
506}
507
508pub struct SettingsBuilder<State> {
509    config: Option<Config>,
510    parameters: Option<Parameters>,
511    error: Option<Error>,
512    predictions: Option<Predictions>,
513    log: Option<Log>,
514    prior: Option<Prior>,
515    output: Option<Output>,
516    convergence: Option<Convergence>,
517    advanced: Option<Advanced>,
518    _marker: std::marker::PhantomData<State>,
519}
520
521// Marker traits for builder states
522pub trait AlgorithmDefined {}
523pub trait ParametersDefined {}
524pub trait ErrorModelDefined {}
525
526// Implement marker traits for PhantomData states
527pub struct InitialState;
528pub struct AlgorithmSet;
529pub struct ParametersSet;
530pub struct ErrorSet;
531
532// Initial state: no algorithm set yet
533impl SettingsBuilder<InitialState> {
534    pub fn new() -> Self {
535        SettingsBuilder {
536            config: None,
537            parameters: None,
538            error: None,
539            predictions: None,
540            log: None,
541            prior: None,
542            output: None,
543            convergence: None,
544            advanced: None,
545            _marker: std::marker::PhantomData,
546        }
547    }
548
549    pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
550        SettingsBuilder {
551            config: Some(Config {
552                algorithm,
553                ..Config::default()
554            }),
555            parameters: self.parameters,
556            error: self.error,
557            predictions: self.predictions,
558            log: self.log,
559            prior: self.prior,
560            output: self.output,
561            convergence: self.convergence,
562            advanced: self.advanced,
563            _marker: std::marker::PhantomData,
564        }
565    }
566}
567
568impl Default for SettingsBuilder<InitialState> {
569    fn default() -> Self {
570        SettingsBuilder::new()
571    }
572}
573
574// Algorithm is set, move to defining parameters
575impl SettingsBuilder<AlgorithmSet> {
576    pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
577        SettingsBuilder {
578            config: self.config,
579            parameters: Some(parameters),
580            error: self.error,
581            predictions: self.predictions,
582            log: self.log,
583            prior: self.prior,
584            output: self.output,
585            convergence: self.convergence,
586            advanced: self.advanced,
587            _marker: std::marker::PhantomData,
588        }
589    }
590}
591
592// Parameters are set, move to defining error model
593impl SettingsBuilder<ParametersSet> {
594    pub fn set_error_model(
595        self,
596        errortype: ErrorType,
597        value: f64,
598        poly: (f64, f64, f64, f64),
599    ) -> SettingsBuilder<ErrorSet> {
600        let error = Error::new(value, errortype, poly);
601
602        SettingsBuilder {
603            config: self.config,
604            parameters: self.parameters,
605            error: Some(error),
606            predictions: self.predictions,
607            log: self.log,
608            prior: self.prior,
609            output: self.output,
610            convergence: self.convergence,
611            advanced: self.advanced,
612            _marker: std::marker::PhantomData,
613        }
614    }
615}
616
617// Error model is set, allow optional settings and final build
618impl SettingsBuilder<ErrorSet> {
619    pub fn build(self) -> Settings {
620        Settings {
621            config: self.config.unwrap(),
622            parameters: self.parameters.unwrap(),
623            error: self.error.unwrap(),
624            predictions: self.predictions.unwrap_or_default(),
625            log: self.log.unwrap_or_default(),
626            prior: self.prior.unwrap_or_default(),
627            output: self.output.unwrap_or_default(),
628            convergence: self.convergence.unwrap_or_default(),
629            advanced: self.advanced.unwrap_or_default(),
630        }
631    }
632}
633
634fn parse_output_folder(path: String) -> String {
635    // If the path doesn't contain a "#", just return it as is
636    if !path.contains("#") {
637        return path;
638    }
639
640    // If it does contain "#", perform the incrementation logic
641    let mut num = 1;
642    while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
643        num += 1;
644    }
645
646    path.replace("#", &num.to_string())
647}
648
649#[cfg(test)]
650
651mod tests {
652    use super::*;
653    use crate::algorithms::Algorithm;
654
655    #[test]
656    fn test_builder() {
657        let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
658
659        let mut settings = SettingsBuilder::new()
660            .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
661            .set_parameters(parameters) // Step 2: Define parameters
662            .set_error_model(ErrorType::Additive, 5.0, (0.0, 0.1, 0.0, 0.0))
663            .build();
664
665        settings.set_cycles(100);
666
667        assert_eq!(settings.config.algorithm, Algorithm::NPAG);
668        assert_eq!(settings.config.cycles, 100);
669        assert_eq!(settings.config.cache, true);
670        assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
671    }
672}