1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorType;
6use serde::{Deserialize, Serialize};
7use serde_json;
8use std::fmt::Display;
9use std::path::PathBuf;
10
11#[derive(Debug, Deserialize, Clone, Serialize)]
13#[serde(deny_unknown_fields)]
14pub struct Settings {
15 pub(crate) config: Config,
17 pub(crate) parameters: Parameters,
19 pub(crate) error: Error,
21 pub(crate) predictions: Predictions,
23 pub(crate) log: Log,
25 pub(crate) prior: Prior,
27 pub(crate) output: Output,
29 pub(crate) convergence: Convergence,
31 pub(crate) advanced: Advanced,
33}
34
35impl Settings {
36 pub fn builder() -> SettingsBuilder<InitialState> {
38 SettingsBuilder::new()
39 }
40
41 pub fn validate(&self) -> Result<()> {
43 self.error.validate()?;
44 self.predictions.validate()?;
45 Ok(())
46 }
47
48 pub fn config(&self) -> &Config {
50 &self.config
51 }
52
53 pub fn parameters(&self) -> &Parameters {
54 &self.parameters
55 }
56
57 pub fn error(&self) -> &Error {
58 &self.error
59 }
60
61 pub fn predictions(&self) -> &Predictions {
62 &self.predictions
63 }
64
65 pub fn log(&self) -> &Log {
66 &self.log
67 }
68
69 pub fn prior(&self) -> &Prior {
70 &self.prior
71 }
72
73 pub fn output(&self) -> &Output {
74 &self.output
75 }
76 pub fn convergence(&self) -> &Convergence {
77 &self.convergence
78 }
79
80 pub fn advanced(&self) -> &Advanced {
81 &self.advanced
82 }
83
84 pub fn set_cycles(&mut self, cycles: usize) {
86 self.config.cycles = cycles;
87 }
88
89 pub fn set_algorithm(&mut self, algorithm: Algorithm) {
90 self.config.algorithm = algorithm;
91 }
92
93 pub fn set_cache(&mut self, cache: bool) {
94 self.config.cache = cache;
95 }
96
97 pub fn set_idelta(&mut self, idelta: f64) {
98 self.predictions.idelta = idelta;
99 }
100
101 pub fn set_tad(&mut self, tad: f64) {
102 self.predictions.tad = tad;
103 }
104
105 pub fn set_prior(&mut self, prior: Prior) {
106 self.prior = prior;
107 }
108
109 pub fn disable_output(&mut self) {
110 self.output.write = false;
111 }
112
113 pub fn set_output_path(&mut self, path: impl Into<String>) {
114 self.output.path = parse_output_folder(path.into());
115 }
116
117 pub fn set_log_stdout(&mut self, stdout: bool) {
118 self.log.stdout = stdout;
119 }
120
121 pub fn set_write_logs(&mut self, write: bool) {
122 self.log.write = write;
123 }
124
125 pub fn set_log_level(&mut self, level: LogLevel) {
126 self.log.level = level;
127 }
128
129 pub fn set_progress(&mut self, progress: bool) {
130 self.config.progress = progress;
131 }
132
133 pub fn initialize_logs(&mut self) -> Result<()> {
134 crate::routines::logger::setup_log(self)
135 }
136
137 pub fn write(&self) -> Result<()> {
140 let serialized = serde_json::to_string_pretty(self)
141 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
142
143 let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
144 let mut file = outputfile.file;
145 std::io::Write::write_all(&mut file, serialized.as_bytes())?;
146 Ok(())
147 }
148}
149
150#[derive(Debug, Deserialize, Clone, Serialize)]
152#[serde(deny_unknown_fields, default)]
153pub struct Config {
154 pub cycles: usize,
156 pub algorithm: Algorithm,
158 pub cache: bool,
160 pub progress: bool,
164}
165
166impl Default for Config {
167 fn default() -> Self {
168 Config {
169 cycles: 100,
170 algorithm: Algorithm::NPAG,
171 cache: true,
172 progress: true,
173 }
174 }
175}
176
177#[derive(Debug, Clone, Deserialize, Serialize)]
182pub struct Parameter {
183 pub(crate) name: String,
184 pub(crate) lower: f64,
185 pub(crate) upper: f64,
186 pub(crate) fixed: bool,
187}
188
189impl Parameter {
190 pub fn new(name: impl Into<String>, lower: f64, upper: f64, fixed: bool) -> Self {
192 Self {
193 name: name.into(),
194 lower,
195 upper,
196 fixed,
197 }
198 }
199}
200
201#[derive(Debug, Clone, Deserialize, Serialize, Default)]
203pub struct Parameters {
204 pub(crate) parameters: Vec<Parameter>,
205}
206
207impl Parameters {
208 pub fn new() -> Self {
209 Parameters {
210 parameters: Vec::new(),
211 }
212 }
213
214 pub fn add(
215 mut self,
216 name: impl Into<String>,
217 lower: f64,
218 upper: f64,
219 fixed: bool,
220 ) -> Parameters {
221 let parameter = Parameter::new(name, lower, upper, fixed);
222 self.parameters.push(parameter);
223 self
224 }
225
226 pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
228 let name = name.into();
229 self.parameters.iter().find(|p| p.name == name)
230 }
231
232 pub fn names(&self) -> Vec<String> {
234 self.parameters.iter().map(|p| p.name.clone()).collect()
235 }
236 pub fn ranges(&self) -> Vec<(f64, f64)> {
240 self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
241 }
242 pub fn len(&self) -> usize {
243 self.parameters.len()
244 }
245
246 pub fn is_empty(&self) -> bool {
247 self.parameters.is_empty()
248 }
249
250 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
251 self.parameters.iter()
252 }
253}
254
255impl IntoIterator for Parameters {
256 type Item = Parameter;
257 type IntoIter = std::vec::IntoIter<Parameter>;
258
259 fn into_iter(self) -> Self::IntoIter {
260 self.parameters.into_iter()
261 }
262}
263
264impl From<Vec<Parameter>> for Parameters {
265 fn from(parameters: Vec<Parameter>) -> Self {
266 Parameters { parameters }
267 }
268}
269
270#[derive(Debug, Deserialize, Clone, Serialize)]
271pub enum ErrorModel {
272 Additive,
273 Proportional,
274}
275
276impl From<ErrorModel> for ErrorType {
277 fn from(error_model: ErrorModel) -> ErrorType {
278 match error_model {
279 ErrorModel::Additive => ErrorType::Add,
280 ErrorModel::Proportional => ErrorType::Prop,
281 }
282 }
283}
284
285#[derive(Debug, Deserialize, Clone, Serialize)]
287#[serde(deny_unknown_fields, default)]
288pub struct Error {
289 pub value: f64,
291 pub model: ErrorModel,
293 pub poly: (f64, f64, f64, f64),
295}
296
297impl Default for Error {
298 fn default() -> Self {
299 Error {
300 value: 0.0,
301 model: ErrorModel::Additive,
302 poly: (0.0, 0.1, 0.0, 0.0),
303 }
304 }
305}
306
307impl Error {
308 fn new(value: f64, model: ErrorModel, poly: (f64, f64, f64, f64)) -> Self {
309 Error { value, model, poly }
310 }
311
312 fn validate(&self) -> Result<()> {
313 if self.value < 0.0 {
314 bail!(format!(
315 "Error value must be non-negative, got {}",
316 self.value
317 ));
318 }
319 Ok(())
320 }
321
322 pub fn error_model(&self) -> ErrorModel {
323 self.model.clone()
324 }
325}
326
327#[derive(Debug, Deserialize, Clone, Serialize)]
329#[serde(deny_unknown_fields, default)]
330pub struct Advanced {
331 pub min_distance: f64,
335 pub nm_steps: usize,
338 pub tolerance: f64,
342}
343
344impl Default for Advanced {
345 fn default() -> Self {
346 Advanced {
347 min_distance: 1e-4,
348 nm_steps: 100,
349 tolerance: 1e-6,
350 }
351 }
352}
353
354#[derive(Debug, Deserialize, Clone, Serialize)]
355#[serde(deny_unknown_fields, default)]
356pub struct Convergence {
358 pub likelihood: f64,
363 pub pyl: f64,
368 pub eps: f64,
374}
375
376impl Default for Convergence {
377 fn default() -> Self {
378 Convergence {
379 likelihood: 1e-4,
380 pyl: 1e-2,
381 eps: 1e-2,
382 }
383 }
384}
385
386#[derive(Debug, Deserialize, Clone, Serialize)]
387#[serde(deny_unknown_fields, default)]
388pub struct Predictions {
389 pub idelta: f64,
391 pub tad: f64,
397}
398
399impl Default for Predictions {
400 fn default() -> Self {
401 Predictions {
402 idelta: 0.12,
403 tad: 0.0,
404 }
405 }
406}
407
408impl Predictions {
409 pub fn validate(&self) -> Result<()> {
411 if self.idelta < 0.0 {
412 bail!("The interval for predictions must be non-negative");
413 }
414 if self.tad < 0.0 {
415 bail!("The time after dose for predictions must be non-negative");
416 }
417 Ok(())
418 }
419}
420
421#[derive(Debug, Deserialize, Clone, Serialize, Default)]
428pub enum LogLevel {
429 TRACE,
430 DEBUG,
431 #[default]
432 INFO,
433 WARN,
434 ERROR,
435}
436
437impl From<LogLevel> for tracing::Level {
438 fn from(log_level: LogLevel) -> tracing::Level {
439 match log_level {
440 LogLevel::TRACE => tracing::Level::TRACE,
441 LogLevel::DEBUG => tracing::Level::DEBUG,
442 LogLevel::INFO => tracing::Level::INFO,
443 LogLevel::WARN => tracing::Level::WARN,
444 LogLevel::ERROR => tracing::Level::ERROR,
445 }
446 }
447}
448
449impl AsRef<str> for LogLevel {
450 fn as_ref(&self) -> &str {
451 match self {
452 LogLevel::TRACE => "trace",
453 LogLevel::DEBUG => "debug",
454 LogLevel::INFO => "info",
455 LogLevel::WARN => "warn",
456 LogLevel::ERROR => "error",
457 }
458 }
459}
460
461impl Display for LogLevel {
462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463 write!(f, "{}", self.as_ref())
464 }
465}
466
467#[derive(Debug, Deserialize, Clone, Serialize)]
468#[serde(deny_unknown_fields, default)]
469pub struct Log {
470 pub level: LogLevel,
474 pub write: bool,
478 pub stdout: bool,
480}
481
482impl Default for Log {
483 fn default() -> Self {
484 Log {
485 level: LogLevel::INFO,
486 write: false,
487 stdout: true,
488 }
489 }
490}
491
492#[derive(Debug, Deserialize, Clone, Serialize)]
494#[serde(deny_unknown_fields, default)]
495pub struct Output {
496 pub write: bool,
498 pub path: String,
500}
501
502impl Default for Output {
503 fn default() -> Self {
504 let path = PathBuf::from("outputs/").to_string_lossy().to_string();
505
506 Output { write: true, path }
507 }
508}
509
510pub struct SettingsBuilder<State> {
511 config: Option<Config>,
512 parameters: Option<Parameters>,
513 error: Option<Error>,
514 predictions: Option<Predictions>,
515 log: Option<Log>,
516 prior: Option<Prior>,
517 output: Option<Output>,
518 convergence: Option<Convergence>,
519 advanced: Option<Advanced>,
520 _marker: std::marker::PhantomData<State>,
521}
522
523pub trait AlgorithmDefined {}
525pub trait ParametersDefined {}
526pub trait ErrorModelDefined {}
527
528pub struct InitialState;
530pub struct AlgorithmSet;
531pub struct ParametersSet;
532pub struct ErrorSet;
533
534impl SettingsBuilder<InitialState> {
536 pub fn new() -> Self {
537 SettingsBuilder {
538 config: None,
539 parameters: None,
540 error: None,
541 predictions: None,
542 log: None,
543 prior: None,
544 output: None,
545 convergence: None,
546 advanced: None,
547 _marker: std::marker::PhantomData,
548 }
549 }
550
551 pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
552 SettingsBuilder {
553 config: Some(Config {
554 algorithm,
555 ..Config::default()
556 }),
557 parameters: self.parameters,
558 error: self.error,
559 predictions: self.predictions,
560 log: self.log,
561 prior: self.prior,
562 output: self.output,
563 convergence: self.convergence,
564 advanced: self.advanced,
565 _marker: std::marker::PhantomData,
566 }
567 }
568}
569
570impl Default for SettingsBuilder<InitialState> {
571 fn default() -> Self {
572 SettingsBuilder::new()
573 }
574}
575
576impl SettingsBuilder<AlgorithmSet> {
578 pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
579 SettingsBuilder {
580 config: self.config,
581 parameters: Some(parameters),
582 error: self.error,
583 predictions: self.predictions,
584 log: self.log,
585 prior: self.prior,
586 output: self.output,
587 convergence: self.convergence,
588 advanced: self.advanced,
589 _marker: std::marker::PhantomData,
590 }
591 }
592}
593
594impl SettingsBuilder<ParametersSet> {
596 pub fn set_error_model(
597 self,
598 model: ErrorModel,
599 value: f64,
600 poly: (f64, f64, f64, f64),
601 ) -> SettingsBuilder<ErrorSet> {
602 let error = Error::new(value, model, poly);
603
604 SettingsBuilder {
605 config: self.config,
606 parameters: self.parameters,
607 error: Some(error),
608 predictions: self.predictions,
609 log: self.log,
610 prior: self.prior,
611 output: self.output,
612 convergence: self.convergence,
613 advanced: self.advanced,
614 _marker: std::marker::PhantomData,
615 }
616 }
617}
618
619impl SettingsBuilder<ErrorSet> {
621 pub fn build(self) -> Settings {
622 Settings {
623 config: self.config.unwrap(),
624 parameters: self.parameters.unwrap(),
625 error: self.error.unwrap(),
626 predictions: self.predictions.unwrap_or_default(),
627 log: self.log.unwrap_or_default(),
628 prior: self.prior.unwrap_or_default(),
629 output: self.output.unwrap_or_default(),
630 convergence: self.convergence.unwrap_or_default(),
631 advanced: self.advanced.unwrap_or_default(),
632 }
633 }
634}
635
636fn parse_output_folder(path: String) -> String {
637 if !path.contains("#") {
639 return path;
640 }
641
642 let mut num = 1;
644 while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
645 num += 1;
646 }
647
648 let result = path.replace("#", &num.to_string());
649 result
650}
651
652#[cfg(test)]
653
654mod tests {
655 use super::*;
656 use crate::algorithms::Algorithm;
657
658 #[test]
659 fn test_builder() {
660 let parameters = Parameters::new()
661 .add("Ke", 0.0, 5.0, false)
662 .add("V", 10.0, 200.0, true);
663
664 let mut settings = SettingsBuilder::new()
665 .set_algorithm(Algorithm::NPAG) .set_parameters(parameters) .set_error_model(ErrorModel::Additive, 5.0, (0.0, 0.1, 0.0, 0.0))
668 .build();
669
670 settings.set_cycles(100);
671
672 assert_eq!(settings.config.algorithm, Algorithm::NPAG);
673 assert_eq!(settings.config.cycles, 100);
674 assert_eq!(settings.config.cache, true);
675 assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
676 }
677}