pmcore/routines/
settings.rs

1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorType;
6use serde::{Deserialize, Serialize};
7use serde_json;
8use std::fmt::Display;
9use std::path::PathBuf;
10
11/// Contains all settings for PMcore
12#[derive(Debug, Deserialize, Clone, Serialize)]
13#[serde(deny_unknown_fields)]
14pub struct Settings {
15    /// General configuration settings
16    pub(crate) config: Config,
17    /// Parameters to be estimated
18    pub(crate) parameters: Parameters,
19    /// Defines the error model and polynomial to be used
20    pub(crate) error: Error,
21    /// Configuration for predictions
22    pub(crate) predictions: Predictions,
23    /// Configuration for logging
24    pub(crate) log: Log,
25    /// Configuration for (optional) prior
26    pub(crate) prior: Prior,
27    /// Configuration for the output files
28    pub(crate) output: Output,
29    /// Configuration for the convergence criteria
30    pub(crate) convergence: Convergence,
31    /// Advanced options, mostly hyperparameters, for the algorithm(s)
32    pub(crate) advanced: Advanced,
33}
34
35impl Settings {
36    /// Create a new [SettingsBuilder]
37    pub fn builder() -> SettingsBuilder<InitialState> {
38        SettingsBuilder::new()
39    }
40
41    /// Validate the settings
42    pub fn validate(&self) -> Result<()> {
43        self.error.validate()?;
44        self.predictions.validate()?;
45        Ok(())
46    }
47
48    /* Getters */
49    pub fn config(&self) -> &Config {
50        &self.config
51    }
52
53    pub fn parameters(&self) -> &Parameters {
54        &self.parameters
55    }
56
57    pub fn error(&self) -> &Error {
58        &self.error
59    }
60
61    pub fn predictions(&self) -> &Predictions {
62        &self.predictions
63    }
64
65    pub fn log(&self) -> &Log {
66        &self.log
67    }
68
69    pub fn prior(&self) -> &Prior {
70        &self.prior
71    }
72
73    pub fn output(&self) -> &Output {
74        &self.output
75    }
76    pub fn convergence(&self) -> &Convergence {
77        &self.convergence
78    }
79
80    pub fn advanced(&self) -> &Advanced {
81        &self.advanced
82    }
83
84    /* Setters */
85    pub fn set_cycles(&mut self, cycles: usize) {
86        self.config.cycles = cycles;
87    }
88
89    pub fn set_algorithm(&mut self, algorithm: Algorithm) {
90        self.config.algorithm = algorithm;
91    }
92
93    pub fn set_cache(&mut self, cache: bool) {
94        self.config.cache = cache;
95    }
96
97    pub fn set_idelta(&mut self, idelta: f64) {
98        self.predictions.idelta = idelta;
99    }
100
101    pub fn set_tad(&mut self, tad: f64) {
102        self.predictions.tad = tad;
103    }
104
105    pub fn set_prior(&mut self, prior: Prior) {
106        self.prior = prior;
107    }
108
109    pub fn disable_output(&mut self) {
110        self.output.write = false;
111    }
112
113    pub fn set_output_path(&mut self, path: impl Into<String>) {
114        self.output.path = parse_output_folder(path.into());
115    }
116
117    pub fn set_log_stdout(&mut self, stdout: bool) {
118        self.log.stdout = stdout;
119    }
120
121    pub fn set_write_logs(&mut self, write: bool) {
122        self.log.write = write;
123    }
124
125    pub fn set_log_level(&mut self, level: LogLevel) {
126        self.log.level = level;
127    }
128
129    pub fn set_progress(&mut self, progress: bool) {
130        self.config.progress = progress;
131    }
132
133    pub fn initialize_logs(&mut self) -> Result<()> {
134        crate::routines::logger::setup_log(self)
135    }
136
137    /// Writes a copy of the settings to file
138    /// The is written to output folder specified in the [Output] and is named `settings.json`.
139    pub fn write(&self) -> Result<()> {
140        let serialized = serde_json::to_string_pretty(self)
141            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
142
143        let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
144        let mut file = outputfile.file;
145        std::io::Write::write_all(&mut file, serialized.as_bytes())?;
146        Ok(())
147    }
148}
149
150/// General configuration settings
151#[derive(Debug, Deserialize, Clone, Serialize)]
152#[serde(deny_unknown_fields, default)]
153pub struct Config {
154    /// Maximum number of cycles to run
155    pub cycles: usize,
156    /// Denotes the algorithm to use
157    pub algorithm: Algorithm,
158    /// If true (default), cache predicted values
159    pub cache: bool,
160    /// Should a progress bar be displayed for the first cycle
161    ///
162    /// The progress bar is not written to logs, but is written to stdout. It incurs a minor performance penalty.
163    pub progress: bool,
164}
165
166impl Default for Config {
167    fn default() -> Self {
168        Config {
169            cycles: 100,
170            algorithm: Algorithm::NPAG,
171            cache: true,
172            progress: true,
173        }
174    }
175}
176
177/// Defines a parameter to be estimated
178///
179/// In non-parametric algorithms, parameters must be bounded. The lower and upper bounds are defined by the `lower` and `upper` fields, respectively.
180/// Fixed parameters are unknown, but common among all subjects.
181#[derive(Debug, Clone, Deserialize, Serialize)]
182pub struct Parameter {
183    pub(crate) name: String,
184    pub(crate) lower: f64,
185    pub(crate) upper: f64,
186    pub(crate) fixed: bool,
187}
188
189impl Parameter {
190    /// Create a new parameter
191    pub fn new(name: impl Into<String>, lower: f64, upper: f64, fixed: bool) -> Self {
192        Self {
193            name: name.into(),
194            lower,
195            upper,
196            fixed,
197        }
198    }
199}
200
201/// This structure contains information on all [Parameter]s to be estimated
202#[derive(Debug, Clone, Deserialize, Serialize, Default)]
203pub struct Parameters {
204    pub(crate) parameters: Vec<Parameter>,
205}
206
207impl Parameters {
208    pub fn new() -> Self {
209        Parameters {
210            parameters: Vec::new(),
211        }
212    }
213
214    pub fn add(
215        mut self,
216        name: impl Into<String>,
217        lower: f64,
218        upper: f64,
219        fixed: bool,
220    ) -> Parameters {
221        let parameter = Parameter::new(name, lower, upper, fixed);
222        self.parameters.push(parameter);
223        self
224    }
225
226    // Get a parameter by name
227    pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
228        let name = name.into();
229        self.parameters.iter().find(|p| p.name == name)
230    }
231
232    /// Get the names of the parameters
233    pub fn names(&self) -> Vec<String> {
234        self.parameters.iter().map(|p| p.name.clone()).collect()
235    }
236    /// Get the ranges of the parameters
237    ///
238    /// Returns a vector of tuples, where each tuple contains the lower and upper bounds of the parameter
239    pub fn ranges(&self) -> Vec<(f64, f64)> {
240        self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
241    }
242    pub fn len(&self) -> usize {
243        self.parameters.len()
244    }
245
246    pub fn is_empty(&self) -> bool {
247        self.parameters.is_empty()
248    }
249
250    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
251        self.parameters.iter()
252    }
253}
254
255impl IntoIterator for Parameters {
256    type Item = Parameter;
257    type IntoIter = std::vec::IntoIter<Parameter>;
258
259    fn into_iter(self) -> Self::IntoIter {
260        self.parameters.into_iter()
261    }
262}
263
264impl From<Vec<Parameter>> for Parameters {
265    fn from(parameters: Vec<Parameter>) -> Self {
266        Parameters { parameters }
267    }
268}
269
270#[derive(Debug, Deserialize, Clone, Serialize)]
271pub enum ErrorModel {
272    Additive,
273    Proportional,
274}
275
276impl From<ErrorModel> for ErrorType {
277    fn from(error_model: ErrorModel) -> ErrorType {
278        match error_model {
279            ErrorModel::Additive => ErrorType::Add,
280            ErrorModel::Proportional => ErrorType::Prop,
281        }
282    }
283}
284
285/// Defines the error model and polynomial to be used
286#[derive(Debug, Deserialize, Clone, Serialize)]
287#[serde(deny_unknown_fields, default)]
288pub struct Error {
289    /// The initial value of `gamma` or `lambda`
290    pub value: f64,
291    /// The error class, either `additive` or `proportional`
292    pub model: ErrorModel,
293    /// The assay error polynomial
294    pub poly: (f64, f64, f64, f64),
295}
296
297impl Default for Error {
298    fn default() -> Self {
299        Error {
300            value: 0.0,
301            model: ErrorModel::Additive,
302            poly: (0.0, 0.1, 0.0, 0.0),
303        }
304    }
305}
306
307impl Error {
308    fn new(value: f64, model: ErrorModel, poly: (f64, f64, f64, f64)) -> Self {
309        Error { value, model, poly }
310    }
311
312    fn validate(&self) -> Result<()> {
313        if self.value < 0.0 {
314            bail!(format!(
315                "Error value must be non-negative, got {}",
316                self.value
317            ));
318        }
319        Ok(())
320    }
321
322    pub fn error_model(&self) -> ErrorModel {
323        self.model.clone()
324    }
325}
326
327/// This struct contains advanced options and hyperparameters
328#[derive(Debug, Deserialize, Clone, Serialize)]
329#[serde(deny_unknown_fields, default)]
330pub struct Advanced {
331    /// The minimum distance required between a candidate point and the existing grid (THETA_D)
332    ///
333    /// This is general for all non-parametric algorithms
334    pub min_distance: f64,
335    /// Maximum number of steps in Nelder-Mead optimization
336    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
337    pub nm_steps: usize,
338    /// Tolerance (in standard deviations) for the Nelder-Mead optimization
339    ///
340    /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer)
341    pub tolerance: f64,
342}
343
344impl Default for Advanced {
345    fn default() -> Self {
346        Advanced {
347            min_distance: 1e-4,
348            nm_steps: 100,
349            tolerance: 1e-6,
350        }
351    }
352}
353
354#[derive(Debug, Deserialize, Clone, Serialize)]
355#[serde(deny_unknown_fields, default)]
356/// This struct contains the convergence criteria for the algorithm
357pub struct Convergence {
358    /// The objective function convergence criterion for the algorithm
359    ///
360    /// The objective function is the negative log likelihood
361    /// Previously referred to as THETA_G
362    pub likelihood: f64,
363    /// The PYL convergence criterion for the algorithm
364    ///
365    /// P(Y|L) represents the probability of the observation given its weighted support
366    /// Previously referred to as THETA_F
367    pub pyl: f64,
368    /// Precision convergence criterion for the algorithm
369    ///
370    /// The precision variable, sometimes referred to as `eps`, is the distance from existing points in the grid to the candidate point. A candidate point is suggested at a distance of `eps` times the range of the parameter.
371    /// For example, if the parameter `alpha` has a range of `[0.0, 1.0]`, and `eps` is `0.1`, then the candidate point will be at a distance of `0.1 * (1.0 - 0.0) = 0.1` from the existing grid point(s).
372    /// Previously referred to as THETA_E
373    pub eps: f64,
374}
375
376impl Default for Convergence {
377    fn default() -> Self {
378        Convergence {
379            likelihood: 1e-4,
380            pyl: 1e-2,
381            eps: 1e-2,
382        }
383    }
384}
385
386#[derive(Debug, Deserialize, Clone, Serialize)]
387#[serde(deny_unknown_fields, default)]
388pub struct Predictions {
389    /// The interval for which predictions are generated
390    pub idelta: f64,
391    /// The time after the last dose for which predictions are generated
392    ///
393    /// Predictions will always be generated until the last event (observation or dose) in the data.
394    /// This setting is used to generate predictions beyond the last event if the `tad` if sufficiently large.
395    /// This can be useful for generating predictions for a subject who only received a dose, but has no observations.
396    pub tad: f64,
397}
398
399impl Default for Predictions {
400    fn default() -> Self {
401        Predictions {
402            idelta: 0.12,
403            tad: 0.0,
404        }
405    }
406}
407
408impl Predictions {
409    /// Validate the prediction settings
410    pub fn validate(&self) -> Result<()> {
411        if self.idelta < 0.0 {
412            bail!("The interval for predictions must be non-negative");
413        }
414        if self.tad < 0.0 {
415            bail!("The time after dose for predictions must be non-negative");
416        }
417        Ok(())
418    }
419}
420
421/// The log level, which can be one of the following:
422/// - `TRACE`
423/// - `DEBUG`
424/// - `INFO` (Default)
425/// - `WARN`
426/// - `ERROR`
427#[derive(Debug, Deserialize, Clone, Serialize, Default)]
428pub enum LogLevel {
429    TRACE,
430    DEBUG,
431    #[default]
432    INFO,
433    WARN,
434    ERROR,
435}
436
437impl From<LogLevel> for tracing::Level {
438    fn from(log_level: LogLevel) -> tracing::Level {
439        match log_level {
440            LogLevel::TRACE => tracing::Level::TRACE,
441            LogLevel::DEBUG => tracing::Level::DEBUG,
442            LogLevel::INFO => tracing::Level::INFO,
443            LogLevel::WARN => tracing::Level::WARN,
444            LogLevel::ERROR => tracing::Level::ERROR,
445        }
446    }
447}
448
449impl AsRef<str> for LogLevel {
450    fn as_ref(&self) -> &str {
451        match self {
452            LogLevel::TRACE => "trace",
453            LogLevel::DEBUG => "debug",
454            LogLevel::INFO => "info",
455            LogLevel::WARN => "warn",
456            LogLevel::ERROR => "error",
457        }
458    }
459}
460
461impl Display for LogLevel {
462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463        write!(f, "{}", self.as_ref())
464    }
465}
466
467#[derive(Debug, Deserialize, Clone, Serialize)]
468#[serde(deny_unknown_fields, default)]
469pub struct Log {
470    /// The maximum log level to display, as defined by [LogLevel]
471    ///
472    /// [LogLevel] is a thin wrapper around `tracing::Level`, but can be serialized
473    pub level: LogLevel,
474    /// Should the logs be written to a file
475    ///
476    /// If true, a file will be created in the output folder with the name `log.txt`, or, if [Output::write] is false, in the current directory.
477    pub write: bool,
478    /// Define if logs should be written to stdout
479    pub stdout: bool,
480}
481
482impl Default for Log {
483    fn default() -> Self {
484        Log {
485            level: LogLevel::INFO,
486            write: false,
487            stdout: true,
488        }
489    }
490}
491
492/// Configuration for the output files
493#[derive(Debug, Deserialize, Clone, Serialize)]
494#[serde(deny_unknown_fields, default)]
495pub struct Output {
496    /// Whether to write the output files
497    pub write: bool,
498    /// The (relative) path to write the output files to
499    pub path: String,
500}
501
502impl Default for Output {
503    fn default() -> Self {
504        let path = PathBuf::from("outputs/").to_string_lossy().to_string();
505
506        Output { write: true, path }
507    }
508}
509
510pub struct SettingsBuilder<State> {
511    config: Option<Config>,
512    parameters: Option<Parameters>,
513    error: Option<Error>,
514    predictions: Option<Predictions>,
515    log: Option<Log>,
516    prior: Option<Prior>,
517    output: Option<Output>,
518    convergence: Option<Convergence>,
519    advanced: Option<Advanced>,
520    _marker: std::marker::PhantomData<State>,
521}
522
523// Marker traits for builder states
524pub trait AlgorithmDefined {}
525pub trait ParametersDefined {}
526pub trait ErrorModelDefined {}
527
528// Implement marker traits for PhantomData states
529pub struct InitialState;
530pub struct AlgorithmSet;
531pub struct ParametersSet;
532pub struct ErrorSet;
533
534// Initial state: no algorithm set yet
535impl SettingsBuilder<InitialState> {
536    pub fn new() -> Self {
537        SettingsBuilder {
538            config: None,
539            parameters: None,
540            error: None,
541            predictions: None,
542            log: None,
543            prior: None,
544            output: None,
545            convergence: None,
546            advanced: None,
547            _marker: std::marker::PhantomData,
548        }
549    }
550
551    pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
552        SettingsBuilder {
553            config: Some(Config {
554                algorithm,
555                ..Config::default()
556            }),
557            parameters: self.parameters,
558            error: self.error,
559            predictions: self.predictions,
560            log: self.log,
561            prior: self.prior,
562            output: self.output,
563            convergence: self.convergence,
564            advanced: self.advanced,
565            _marker: std::marker::PhantomData,
566        }
567    }
568}
569
570impl Default for SettingsBuilder<InitialState> {
571    fn default() -> Self {
572        SettingsBuilder::new()
573    }
574}
575
576// Algorithm is set, move to defining parameters
577impl SettingsBuilder<AlgorithmSet> {
578    pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
579        SettingsBuilder {
580            config: self.config,
581            parameters: Some(parameters),
582            error: self.error,
583            predictions: self.predictions,
584            log: self.log,
585            prior: self.prior,
586            output: self.output,
587            convergence: self.convergence,
588            advanced: self.advanced,
589            _marker: std::marker::PhantomData,
590        }
591    }
592}
593
594// Parameters are set, move to defining error model
595impl SettingsBuilder<ParametersSet> {
596    pub fn set_error_model(
597        self,
598        model: ErrorModel,
599        value: f64,
600        poly: (f64, f64, f64, f64),
601    ) -> SettingsBuilder<ErrorSet> {
602        let error = Error::new(value, model, poly);
603
604        SettingsBuilder {
605            config: self.config,
606            parameters: self.parameters,
607            error: Some(error),
608            predictions: self.predictions,
609            log: self.log,
610            prior: self.prior,
611            output: self.output,
612            convergence: self.convergence,
613            advanced: self.advanced,
614            _marker: std::marker::PhantomData,
615        }
616    }
617}
618
619// Error model is set, allow optional settings and final build
620impl SettingsBuilder<ErrorSet> {
621    pub fn build(self) -> Settings {
622        Settings {
623            config: self.config.unwrap(),
624            parameters: self.parameters.unwrap(),
625            error: self.error.unwrap(),
626            predictions: self.predictions.unwrap_or_default(),
627            log: self.log.unwrap_or_default(),
628            prior: self.prior.unwrap_or_default(),
629            output: self.output.unwrap_or_default(),
630            convergence: self.convergence.unwrap_or_default(),
631            advanced: self.advanced.unwrap_or_default(),
632        }
633    }
634}
635
636fn parse_output_folder(path: String) -> String {
637    // If the path doesn't contain a "#", just return it as is
638    if !path.contains("#") {
639        return path;
640    }
641
642    // If it does contain "#", perform the incrementation logic
643    let mut num = 1;
644    while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
645        num += 1;
646    }
647
648    let result = path.replace("#", &num.to_string());
649    result
650}
651
652#[cfg(test)]
653
654mod tests {
655    use super::*;
656    use crate::algorithms::Algorithm;
657
658    #[test]
659    fn test_builder() {
660        let parameters = Parameters::new()
661            .add("Ke", 0.0, 5.0, false)
662            .add("V", 10.0, 200.0, true);
663
664        let mut settings = SettingsBuilder::new()
665            .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
666            .set_parameters(parameters) // Step 2: Define parameters
667            .set_error_model(ErrorModel::Additive, 5.0, (0.0, 0.1, 0.0, 0.0))
668            .build();
669
670        settings.set_cycles(100);
671
672        assert_eq!(settings.config.algorithm, Algorithm::NPAG);
673        assert_eq!(settings.config.cycles, 100);
674        assert_eq!(settings.config.cache, true);
675        assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
676    }
677}