1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::ErrorModel;
6use pharmsol::ErrorPoly;
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16 pub(crate) config: Config,
18 pub(crate) parameters: Parameters,
20 pub(crate) error: Error,
22 pub(crate) predictions: Predictions,
24 pub(crate) log: Log,
26 pub(crate) prior: Prior,
28 pub(crate) output: Output,
30 pub(crate) convergence: Convergence,
32 pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37 pub fn builder() -> SettingsBuilder<InitialState> {
39 SettingsBuilder::new()
40 }
41
42 pub fn validate(&self) -> Result<()> {
44 self.error.validate()?;
45 self.predictions.validate()?;
46 Ok(())
47 }
48
49 pub fn config(&self) -> &Config {
51 &self.config
52 }
53
54 pub fn parameters(&self) -> &Parameters {
55 &self.parameters
56 }
57
58 pub fn error(&self) -> &Error {
59 &self.error
60 }
61
62 pub fn predictions(&self) -> &Predictions {
63 &self.predictions
64 }
65
66 pub fn log(&self) -> &Log {
67 &self.log
68 }
69
70 pub fn prior(&self) -> &Prior {
71 &self.prior
72 }
73
74 pub fn output(&self) -> &Output {
75 &self.output
76 }
77 pub fn convergence(&self) -> &Convergence {
78 &self.convergence
79 }
80
81 pub fn advanced(&self) -> &Advanced {
82 &self.advanced
83 }
84
85 pub fn set_cycles(&mut self, cycles: usize) {
87 self.config.cycles = cycles;
88 }
89
90 pub fn set_algorithm(&mut self, algorithm: Algorithm) {
91 self.config.algorithm = algorithm;
92 }
93
94 pub fn set_cache(&mut self, cache: bool) {
95 self.config.cache = cache;
96 }
97
98 pub fn set_idelta(&mut self, idelta: f64) {
99 self.predictions.idelta = idelta;
100 }
101
102 pub fn set_tad(&mut self, tad: f64) {
103 self.predictions.tad = tad;
104 }
105
106 pub fn set_prior(&mut self, prior: Prior) {
107 self.prior = prior;
108 }
109
110 pub fn disable_output(&mut self) {
111 self.output.write = false;
112 }
113
114 pub fn set_output_path(&mut self, path: impl Into<String>) {
115 self.output.path = parse_output_folder(path.into());
116 }
117
118 pub fn set_log_stdout(&mut self, stdout: bool) {
119 self.log.stdout = stdout;
120 }
121
122 pub fn set_write_logs(&mut self, write: bool) {
123 self.log.write = write;
124 }
125
126 pub fn set_log_level(&mut self, level: LogLevel) {
127 self.log.level = level;
128 }
129
130 pub fn set_progress(&mut self, progress: bool) {
131 self.config.progress = progress;
132 }
133
134 pub fn initialize_logs(&mut self) -> Result<()> {
135 crate::routines::logger::setup_log(self)
136 }
137
138 pub fn write(&self) -> Result<()> {
141 let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
142
143 let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
144 let mut file = outputfile.file;
145 std::io::Write::write_all(&mut file, serialized.as_bytes())?;
146 Ok(())
147 }
148}
149
150#[derive(Debug, Deserialize, Clone, Serialize)]
152#[serde(deny_unknown_fields, default)]
153pub struct Config {
154 pub cycles: usize,
156 pub algorithm: Algorithm,
158 pub cache: bool,
160 pub progress: bool,
164}
165
166impl Default for Config {
167 fn default() -> Self {
168 Config {
169 cycles: 100,
170 algorithm: Algorithm::NPAG,
171 cache: true,
172 progress: true,
173 }
174 }
175}
176
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
181pub struct Parameter {
182 pub(crate) name: String,
183 pub(crate) lower: f64,
184 pub(crate) upper: f64,
185}
186
187impl Parameter {
188 pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
190 Self {
191 name: name.into(),
192 lower,
193 upper,
194 }
195 }
196}
197
198#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
200pub struct Parameters {
201 pub(crate) parameters: Vec<Parameter>,
202}
203
204impl Parameters {
205 pub fn new() -> Self {
206 Parameters {
207 parameters: Vec::new(),
208 }
209 }
210
211 pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
212 let parameter = Parameter::new(name, lower, upper);
213 self.parameters.push(parameter);
214 self
215 }
216
217 pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
219 let name = name.into();
220 self.parameters.iter().find(|p| p.name == name)
221 }
222
223 pub fn names(&self) -> Vec<String> {
225 self.parameters.iter().map(|p| p.name.clone()).collect()
226 }
227 pub fn ranges(&self) -> Vec<(f64, f64)> {
231 self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
232 }
233
234 pub fn len(&self) -> usize {
236 self.parameters.len()
237 }
238
239 pub fn is_empty(&self) -> bool {
241 self.parameters.is_empty()
242 }
243
244 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
246 self.parameters.iter()
247 }
248}
249
250impl IntoIterator for Parameters {
251 type Item = Parameter;
252 type IntoIter = std::vec::IntoIter<Parameter>;
253
254 fn into_iter(self) -> Self::IntoIter {
255 self.parameters.into_iter()
256 }
257}
258
259impl From<Vec<Parameter>> for Parameters {
260 fn from(parameters: Vec<Parameter>) -> Self {
261 Parameters { parameters }
262 }
263}
264
265#[derive(Debug, Deserialize, Clone, Serialize)]
266pub enum ErrorType {
267 Additive,
268 Proportional,
269}
270
271#[derive(Debug, Deserialize, Clone, Serialize)]
273#[serde(deny_unknown_fields, default)]
274pub struct Error {
275 pub value: f64,
277 pub errortype: ErrorType,
279 pub poly: (f64, f64, f64, f64),
281}
282
283impl Default for Error {
284 fn default() -> Self {
285 Error {
286 value: 0.0,
287 errortype: ErrorType::Additive,
288 poly: (0.0, 0.1, 0.0, 0.0),
289 }
290 }
291}
292
293impl Error {
294 fn new(value: f64, errortype: ErrorType, poly: (f64, f64, f64, f64)) -> Self {
295 Error {
296 value,
297 errortype,
298 poly,
299 }
300 }
301
302 fn validate(&self) -> Result<()> {
303 if self.value < 0.0 {
304 bail!("The initial value of gamma or lambda must be non-negative");
305 }
306 Ok(())
307 }
308}
309
310impl From<Error> for pharmsol::prelude::data::ErrorModel {
311 fn from(error: Error) -> Self {
312 match error.errortype {
313 ErrorType::Additive => ErrorModel::additive(
314 ErrorPoly::new(error.poly.0, error.poly.1, error.poly.2, error.poly.3),
315 error.value,
316 ),
317 ErrorType::Proportional => ErrorModel::proportional(
318 ErrorPoly::new(error.poly.0, error.poly.1, error.poly.2, error.poly.3),
319 error.value,
320 ),
321 }
322 }
323}
324
325#[derive(Debug, Deserialize, Clone, Serialize)]
327#[serde(deny_unknown_fields, default)]
328pub struct Advanced {
329 pub min_distance: f64,
333 pub nm_steps: usize,
336 pub tolerance: f64,
340}
341
342impl Default for Advanced {
343 fn default() -> Self {
344 Advanced {
345 min_distance: 1e-4,
346 nm_steps: 100,
347 tolerance: 1e-6,
348 }
349 }
350}
351
352#[derive(Debug, Deserialize, Clone, Serialize)]
353#[serde(deny_unknown_fields, default)]
354pub struct Convergence {
356 pub likelihood: f64,
361 pub pyl: f64,
366 pub eps: f64,
372}
373
374impl Default for Convergence {
375 fn default() -> Self {
376 Convergence {
377 likelihood: 1e-4,
378 pyl: 1e-2,
379 eps: 1e-2,
380 }
381 }
382}
383
384#[derive(Debug, Deserialize, Clone, Serialize)]
385#[serde(deny_unknown_fields, default)]
386pub struct Predictions {
387 pub idelta: f64,
389 pub tad: f64,
395}
396
397impl Default for Predictions {
398 fn default() -> Self {
399 Predictions {
400 idelta: 0.12,
401 tad: 0.0,
402 }
403 }
404}
405
406impl Predictions {
407 pub fn validate(&self) -> Result<()> {
409 if self.idelta < 0.0 {
410 bail!("The interval for predictions must be non-negative");
411 }
412 if self.tad < 0.0 {
413 bail!("The time after dose for predictions must be non-negative");
414 }
415 Ok(())
416 }
417}
418
419#[derive(Debug, Deserialize, Clone, Serialize, Default)]
426pub enum LogLevel {
427 TRACE,
428 DEBUG,
429 #[default]
430 INFO,
431 WARN,
432 ERROR,
433}
434
435impl From<LogLevel> for tracing::Level {
436 fn from(log_level: LogLevel) -> tracing::Level {
437 match log_level {
438 LogLevel::TRACE => tracing::Level::TRACE,
439 LogLevel::DEBUG => tracing::Level::DEBUG,
440 LogLevel::INFO => tracing::Level::INFO,
441 LogLevel::WARN => tracing::Level::WARN,
442 LogLevel::ERROR => tracing::Level::ERROR,
443 }
444 }
445}
446
447impl AsRef<str> for LogLevel {
448 fn as_ref(&self) -> &str {
449 match self {
450 LogLevel::TRACE => "trace",
451 LogLevel::DEBUG => "debug",
452 LogLevel::INFO => "info",
453 LogLevel::WARN => "warn",
454 LogLevel::ERROR => "error",
455 }
456 }
457}
458
459impl Display for LogLevel {
460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461 write!(f, "{}", self.as_ref())
462 }
463}
464
465#[derive(Debug, Deserialize, Clone, Serialize)]
466#[serde(deny_unknown_fields, default)]
467pub struct Log {
468 pub level: LogLevel,
472 pub write: bool,
476 pub stdout: bool,
478}
479
480impl Default for Log {
481 fn default() -> Self {
482 Log {
483 level: LogLevel::INFO,
484 write: false,
485 stdout: true,
486 }
487 }
488}
489
490#[derive(Debug, Deserialize, Clone, Serialize)]
492#[serde(deny_unknown_fields, default)]
493pub struct Output {
494 pub write: bool,
496 pub path: String,
498}
499
500impl Default for Output {
501 fn default() -> Self {
502 let path = PathBuf::from("outputs/").to_string_lossy().to_string();
503
504 Output { write: true, path }
505 }
506}
507
508pub struct SettingsBuilder<State> {
509 config: Option<Config>,
510 parameters: Option<Parameters>,
511 error: Option<Error>,
512 predictions: Option<Predictions>,
513 log: Option<Log>,
514 prior: Option<Prior>,
515 output: Option<Output>,
516 convergence: Option<Convergence>,
517 advanced: Option<Advanced>,
518 _marker: std::marker::PhantomData<State>,
519}
520
521pub trait AlgorithmDefined {}
523pub trait ParametersDefined {}
524pub trait ErrorModelDefined {}
525
526pub struct InitialState;
528pub struct AlgorithmSet;
529pub struct ParametersSet;
530pub struct ErrorSet;
531
532impl SettingsBuilder<InitialState> {
534 pub fn new() -> Self {
535 SettingsBuilder {
536 config: None,
537 parameters: None,
538 error: None,
539 predictions: None,
540 log: None,
541 prior: None,
542 output: None,
543 convergence: None,
544 advanced: None,
545 _marker: std::marker::PhantomData,
546 }
547 }
548
549 pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
550 SettingsBuilder {
551 config: Some(Config {
552 algorithm,
553 ..Config::default()
554 }),
555 parameters: self.parameters,
556 error: self.error,
557 predictions: self.predictions,
558 log: self.log,
559 prior: self.prior,
560 output: self.output,
561 convergence: self.convergence,
562 advanced: self.advanced,
563 _marker: std::marker::PhantomData,
564 }
565 }
566}
567
568impl Default for SettingsBuilder<InitialState> {
569 fn default() -> Self {
570 SettingsBuilder::new()
571 }
572}
573
574impl SettingsBuilder<AlgorithmSet> {
576 pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
577 SettingsBuilder {
578 config: self.config,
579 parameters: Some(parameters),
580 error: self.error,
581 predictions: self.predictions,
582 log: self.log,
583 prior: self.prior,
584 output: self.output,
585 convergence: self.convergence,
586 advanced: self.advanced,
587 _marker: std::marker::PhantomData,
588 }
589 }
590}
591
592impl SettingsBuilder<ParametersSet> {
594 pub fn set_error_model(
595 self,
596 errortype: ErrorType,
597 value: f64,
598 poly: (f64, f64, f64, f64),
599 ) -> SettingsBuilder<ErrorSet> {
600 let error = Error::new(value, errortype, poly);
601
602 SettingsBuilder {
603 config: self.config,
604 parameters: self.parameters,
605 error: Some(error),
606 predictions: self.predictions,
607 log: self.log,
608 prior: self.prior,
609 output: self.output,
610 convergence: self.convergence,
611 advanced: self.advanced,
612 _marker: std::marker::PhantomData,
613 }
614 }
615}
616
617impl SettingsBuilder<ErrorSet> {
619 pub fn build(self) -> Settings {
620 Settings {
621 config: self.config.unwrap(),
622 parameters: self.parameters.unwrap(),
623 error: self.error.unwrap(),
624 predictions: self.predictions.unwrap_or_default(),
625 log: self.log.unwrap_or_default(),
626 prior: self.prior.unwrap_or_default(),
627 output: self.output.unwrap_or_default(),
628 convergence: self.convergence.unwrap_or_default(),
629 advanced: self.advanced.unwrap_or_default(),
630 }
631 }
632}
633
634fn parse_output_folder(path: String) -> String {
635 if !path.contains("#") {
637 return path;
638 }
639
640 let mut num = 1;
642 while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
643 num += 1;
644 }
645
646 path.replace("#", &num.to_string())
647}
648
649#[cfg(test)]
650
651mod tests {
652 use super::*;
653 use crate::algorithms::Algorithm;
654
655 #[test]
656 fn test_builder() {
657 let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
658
659 let mut settings = SettingsBuilder::new()
660 .set_algorithm(Algorithm::NPAG) .set_parameters(parameters) .set_error_model(ErrorType::Additive, 5.0, (0.0, 0.1, 0.0, 0.0))
663 .build();
664
665 settings.set_cycles(100);
666
667 assert_eq!(settings.config.algorithm, Algorithm::NPAG);
668 assert_eq!(settings.config.cycles, 100);
669 assert_eq!(settings.config.cache, true);
670 assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
671 }
672}