1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorModels;
6
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16 pub(crate) config: Config,
18 pub(crate) parameters: Parameters,
20 pub(crate) errormodels: ErrorModels,
22 pub(crate) predictions: Predictions,
24 pub(crate) log: Log,
26 pub(crate) prior: Prior,
28 pub(crate) output: Output,
30 pub(crate) convergence: Convergence,
32 pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37 pub fn builder() -> SettingsBuilder<InitialState> {
39 SettingsBuilder::new()
40 }
41
42 pub fn config(&self) -> &Config {
44 &self.config
45 }
46
47 pub fn parameters(&self) -> &Parameters {
48 &self.parameters
49 }
50
51 pub fn errormodels(&self) -> &ErrorModels {
52 &self.errormodels
53 }
54
55 pub fn predictions(&self) -> &Predictions {
56 &self.predictions
57 }
58
59 pub fn log(&self) -> &Log {
60 &self.log
61 }
62
63 pub fn prior(&self) -> &Prior {
64 &self.prior
65 }
66
67 pub fn output(&self) -> &Output {
68 &self.output
69 }
70 pub fn convergence(&self) -> &Convergence {
71 &self.convergence
72 }
73
74 pub fn advanced(&self) -> &Advanced {
75 &self.advanced
76 }
77
78 pub fn set_cycles(&mut self, cycles: usize) {
80 self.config.cycles = cycles;
81 }
82
83 pub fn set_algorithm(&mut self, algorithm: Algorithm) {
84 self.config.algorithm = algorithm;
85 }
86
87 pub fn set_cache(&mut self, cache: bool) {
88 self.config.cache = cache;
89 }
90
91 pub fn set_idelta(&mut self, idelta: f64) {
92 self.predictions.idelta = idelta;
93 }
94
95 pub fn set_tad(&mut self, tad: f64) {
96 self.predictions.tad = tad;
97 }
98
99 pub fn set_prior(&mut self, prior: Prior) {
100 self.prior = prior;
101 }
102
103 pub fn disable_output(&mut self) {
104 self.output.write = false;
105 }
106
107 pub fn set_output_path(&mut self, path: impl Into<String>) {
108 self.output.path = parse_output_folder(path.into());
109 }
110
111 pub fn set_log_stdout(&mut self, stdout: bool) {
112 self.log.stdout = stdout;
113 }
114
115 pub fn set_write_logs(&mut self, write: bool) {
116 self.log.write = write;
117 }
118
119 pub fn set_log_level(&mut self, level: LogLevel) {
120 self.log.level = level;
121 }
122
123 pub fn set_progress(&mut self, progress: bool) {
124 self.config.progress = progress;
125 }
126
127 pub fn initialize_logs(&mut self) -> Result<()> {
128 crate::routines::logger::setup_log(self)
129 }
130
131 pub fn write(&self) -> Result<()> {
134 let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
135
136 let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
137 let mut file = outputfile.file;
138 std::io::Write::write_all(&mut file, serialized.as_bytes())?;
139 Ok(())
140 }
141}
142
143#[derive(Debug, Deserialize, Clone, Serialize)]
145#[serde(deny_unknown_fields, default)]
146pub struct Config {
147 pub cycles: usize,
149 pub algorithm: Algorithm,
151 pub cache: bool,
153 pub progress: bool,
157}
158
159impl Default for Config {
160 fn default() -> Self {
161 Config {
162 cycles: 100,
163 algorithm: Algorithm::NPAG,
164 cache: true,
165 progress: true,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
174pub struct Parameter {
175 pub(crate) name: String,
176 pub(crate) lower: f64,
177 pub(crate) upper: f64,
178}
179
180impl Parameter {
181 pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
183 Self {
184 name: name.into(),
185 lower,
186 upper,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
193pub struct Parameters {
194 pub(crate) parameters: Vec<Parameter>,
195}
196
197impl Parameters {
198 pub fn new() -> Self {
199 Parameters {
200 parameters: Vec::new(),
201 }
202 }
203
204 pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
205 let parameter = Parameter::new(name, lower, upper);
206 self.parameters.push(parameter);
207 self
208 }
209
210 pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
212 let name = name.into();
213 self.parameters.iter().find(|p| p.name == name)
214 }
215
216 pub fn names(&self) -> Vec<String> {
218 self.parameters.iter().map(|p| p.name.clone()).collect()
219 }
220 pub fn ranges(&self) -> Vec<(f64, f64)> {
224 self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
225 }
226
227 pub fn len(&self) -> usize {
229 self.parameters.len()
230 }
231
232 pub fn is_empty(&self) -> bool {
234 self.parameters.is_empty()
235 }
236
237 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
239 self.parameters.iter()
240 }
241}
242
243impl IntoIterator for Parameters {
244 type Item = Parameter;
245 type IntoIter = std::vec::IntoIter<Parameter>;
246
247 fn into_iter(self) -> Self::IntoIter {
248 self.parameters.into_iter()
249 }
250}
251
252impl From<Vec<Parameter>> for Parameters {
253 fn from(parameters: Vec<Parameter>) -> Self {
254 Parameters { parameters }
255 }
256}
257
258#[derive(Debug, Deserialize, Clone, Serialize)]
260#[serde(deny_unknown_fields, default)]
261pub struct Advanced {
262 pub min_distance: f64,
266 pub nm_steps: usize,
269 pub tolerance: f64,
273}
274
275impl Default for Advanced {
276 fn default() -> Self {
277 Advanced {
278 min_distance: 1e-4,
279 nm_steps: 100,
280 tolerance: 1e-6,
281 }
282 }
283}
284
285#[derive(Debug, Deserialize, Clone, Serialize)]
286#[serde(deny_unknown_fields, default)]
287pub struct Convergence {
289 pub likelihood: f64,
294 pub pyl: f64,
299 pub eps: f64,
305}
306
307impl Default for Convergence {
308 fn default() -> Self {
309 Convergence {
310 likelihood: 1e-4,
311 pyl: 1e-2,
312 eps: 1e-2,
313 }
314 }
315}
316
317#[derive(Debug, Deserialize, Clone, Serialize)]
318#[serde(deny_unknown_fields, default)]
319pub struct Predictions {
320 pub idelta: f64,
322 pub tad: f64,
328}
329
330impl Default for Predictions {
331 fn default() -> Self {
332 Predictions {
333 idelta: 0.12,
334 tad: 0.0,
335 }
336 }
337}
338
339impl Predictions {
340 pub fn validate(&self) -> Result<()> {
342 if self.idelta < 0.0 {
343 bail!("The interval for predictions must be non-negative");
344 }
345 if self.tad < 0.0 {
346 bail!("The time after dose for predictions must be non-negative");
347 }
348 Ok(())
349 }
350}
351
352#[derive(Debug, Deserialize, Clone, Serialize, Default)]
359pub enum LogLevel {
360 TRACE,
361 DEBUG,
362 #[default]
363 INFO,
364 WARN,
365 ERROR,
366}
367
368impl From<LogLevel> for tracing::Level {
369 fn from(log_level: LogLevel) -> tracing::Level {
370 match log_level {
371 LogLevel::TRACE => tracing::Level::TRACE,
372 LogLevel::DEBUG => tracing::Level::DEBUG,
373 LogLevel::INFO => tracing::Level::INFO,
374 LogLevel::WARN => tracing::Level::WARN,
375 LogLevel::ERROR => tracing::Level::ERROR,
376 }
377 }
378}
379
380impl AsRef<str> for LogLevel {
381 fn as_ref(&self) -> &str {
382 match self {
383 LogLevel::TRACE => "trace",
384 LogLevel::DEBUG => "debug",
385 LogLevel::INFO => "info",
386 LogLevel::WARN => "warn",
387 LogLevel::ERROR => "error",
388 }
389 }
390}
391
392impl Display for LogLevel {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 write!(f, "{}", self.as_ref())
395 }
396}
397
398#[derive(Debug, Deserialize, Clone, Serialize)]
399#[serde(deny_unknown_fields, default)]
400pub struct Log {
401 pub level: LogLevel,
405 pub write: bool,
409 pub stdout: bool,
411}
412
413impl Default for Log {
414 fn default() -> Self {
415 Log {
416 level: LogLevel::INFO,
417 write: false,
418 stdout: true,
419 }
420 }
421}
422
423#[derive(Debug, Deserialize, Clone, Serialize)]
425#[serde(deny_unknown_fields, default)]
426pub struct Output {
427 pub write: bool,
429 pub path: String,
431}
432
433impl Default for Output {
434 fn default() -> Self {
435 let path = PathBuf::from("outputs/").to_string_lossy().to_string();
436
437 Output { write: true, path }
438 }
439}
440
441pub struct SettingsBuilder<State> {
442 config: Option<Config>,
443 parameters: Option<Parameters>,
444 errormodels: Option<ErrorModels>,
445 predictions: Option<Predictions>,
446 log: Option<Log>,
447 prior: Option<Prior>,
448 output: Option<Output>,
449 convergence: Option<Convergence>,
450 advanced: Option<Advanced>,
451 _marker: std::marker::PhantomData<State>,
452}
453
454pub trait AlgorithmDefined {}
456pub trait ParametersDefined {}
457pub trait ErrorModelDefined {}
458
459pub struct InitialState;
461pub struct AlgorithmSet;
462pub struct ParametersSet;
463pub struct ErrorSet;
464
465impl SettingsBuilder<InitialState> {
467 pub fn new() -> Self {
468 SettingsBuilder {
469 config: None,
470 parameters: None,
471 errormodels: None,
472 predictions: None,
473 log: None,
474 prior: None,
475 output: None,
476 convergence: None,
477 advanced: None,
478 _marker: std::marker::PhantomData,
479 }
480 }
481
482 pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
483 SettingsBuilder {
484 config: Some(Config {
485 algorithm,
486 ..Config::default()
487 }),
488 parameters: self.parameters,
489 errormodels: self.errormodels,
490 predictions: self.predictions,
491 log: self.log,
492 prior: self.prior,
493 output: self.output,
494 convergence: self.convergence,
495 advanced: self.advanced,
496 _marker: std::marker::PhantomData,
497 }
498 }
499}
500
501impl Default for SettingsBuilder<InitialState> {
502 fn default() -> Self {
503 SettingsBuilder::new()
504 }
505}
506
507impl SettingsBuilder<AlgorithmSet> {
509 pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
510 SettingsBuilder {
511 config: self.config,
512 parameters: Some(parameters),
513 errormodels: self.errormodels,
514 predictions: self.predictions,
515 log: self.log,
516 prior: self.prior,
517 output: self.output,
518 convergence: self.convergence,
519 advanced: self.advanced,
520 _marker: std::marker::PhantomData,
521 }
522 }
523}
524
525impl SettingsBuilder<ParametersSet> {
527 pub fn set_error_models(self, ems: ErrorModels) -> SettingsBuilder<ErrorSet> {
528 SettingsBuilder {
529 config: self.config,
530 parameters: self.parameters,
531 errormodels: Some(ems),
532 predictions: self.predictions,
533 log: self.log,
534 prior: self.prior,
535 output: self.output,
536 convergence: self.convergence,
537 advanced: self.advanced,
538 _marker: std::marker::PhantomData,
539 }
540 }
541}
542
543impl SettingsBuilder<ErrorSet> {
545 pub fn build(self) -> Settings {
546 Settings {
547 config: self.config.unwrap(),
548 parameters: self.parameters.unwrap(),
549 errormodels: self.errormodels.unwrap(),
550 predictions: self.predictions.unwrap_or_default(),
551 log: self.log.unwrap_or_default(),
552 prior: self.prior.unwrap_or_default(),
553 output: self.output.unwrap_or_default(),
554 convergence: self.convergence.unwrap_or_default(),
555 advanced: self.advanced.unwrap_or_default(),
556 }
557 }
558}
559
560fn parse_output_folder(path: String) -> String {
561 if !path.contains("#") {
563 return path;
564 }
565
566 let mut num = 1;
568 while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
569 num += 1;
570 }
571
572 path.replace("#", &num.to_string())
573}
574
575#[cfg(test)]
576
577mod tests {
578 use pharmsol::{ErrorModel, ErrorPoly};
579
580 use super::*;
581 use crate::algorithms::Algorithm;
582
583 #[test]
584 fn test_builder() {
585 let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
586
587 let ems = ErrorModels::new()
588 .add(
589 0,
590 ErrorModel::Proportional {
591 gamma: 5.0,
592 poly: ErrorPoly::new(0.0, 0.1, 0.0, 0.0),
593 },
594 )
595 .unwrap();
596 let mut settings = SettingsBuilder::new()
597 .set_algorithm(Algorithm::NPAG) .set_parameters(parameters) .set_error_models(ems)
600 .build();
601
602 settings.set_cycles(100);
603
604 assert_eq!(settings.config.algorithm, Algorithm::NPAG);
605 assert_eq!(settings.config.cycles, 100);
606 assert_eq!(settings.config.cache, true);
607 assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
608 }
609}