1use crate::algorithms::Algorithm;
2use crate::routines::initialization::Prior;
3use crate::routines::output::OutputFile;
4use anyhow::{bail, Result};
5use pharmsol::prelude::data::AssayErrorModels;
6
7use serde::{Deserialize, Serialize};
8use serde_json;
9use std::fmt::Display;
10use std::path::PathBuf;
11
12#[derive(Debug, Deserialize, Clone, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Settings {
16 pub(crate) config: Config,
18 pub(crate) parameters: Parameters,
20 pub(crate) errormodels: AssayErrorModels,
22 pub(crate) predictions: Predictions,
24 pub(crate) log: Log,
26 pub(crate) prior: Prior,
28 pub(crate) output: Output,
30 pub(crate) convergence: Convergence,
32 pub(crate) advanced: Advanced,
34}
35
36impl Settings {
37 pub fn builder() -> SettingsBuilder<InitialState> {
39 SettingsBuilder::new()
40 }
41
42 pub fn config(&self) -> &Config {
44 &self.config
45 }
46
47 pub fn parameters(&self) -> &Parameters {
48 &self.parameters
49 }
50
51 pub fn errormodels(&self) -> &AssayErrorModels {
52 &self.errormodels
53 }
54
55 pub fn predictions(&self) -> &Predictions {
56 &self.predictions
57 }
58
59 pub fn log(&self) -> &Log {
60 &self.log
61 }
62
63 pub fn prior(&self) -> &Prior {
64 &self.prior
65 }
66
67 pub fn output(&self) -> &Output {
68 &self.output
69 }
70 pub fn convergence(&self) -> &Convergence {
71 &self.convergence
72 }
73
74 pub fn advanced(&self) -> &Advanced {
75 &self.advanced
76 }
77
78 pub fn set_cycles(&mut self, cycles: usize) {
80 self.config.cycles = cycles;
81 }
82
83 pub fn set_algorithm(&mut self, algorithm: Algorithm) {
84 self.config.algorithm = algorithm;
85 }
86
87 pub fn set_idelta(&mut self, idelta: f64) {
88 self.predictions.idelta = idelta;
89 }
90
91 pub fn set_tad(&mut self, tad: f64) {
92 self.predictions.tad = tad;
93 }
94
95 pub fn set_prior(&mut self, prior: Prior) {
96 self.prior = prior;
97 }
98
99 pub fn disable_output(&mut self) {
100 self.output.write = false;
101 }
102
103 pub fn set_output_path(&mut self, path: impl Into<String>) {
104 self.output.path = parse_output_folder(path.into());
105 }
106
107 pub fn set_log_stdout(&mut self, stdout: bool) {
108 self.log.stdout = stdout;
109 }
110
111 pub fn set_write_logs(&mut self, write: bool) {
112 self.log.write = write;
113 }
114
115 pub fn set_log_level(&mut self, level: LogLevel) {
116 self.log.level = level;
117 }
118
119 pub fn set_progress(&mut self, progress: bool) {
120 self.config.progress = progress;
121 }
122
123 pub fn initialize_logs(&mut self) -> Result<()> {
124 crate::routines::logger::setup_log(self)
125 }
126
127 pub fn write(&self) -> Result<()> {
130 let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
131
132 let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
133 let mut file = outputfile.file_owned();
134 std::io::Write::write_all(&mut file, serialized.as_bytes())?;
135 Ok(())
136 }
137}
138
139#[derive(Debug, Deserialize, Clone, Serialize)]
141#[serde(deny_unknown_fields, default)]
142pub struct Config {
143 pub cycles: usize,
145 pub algorithm: Algorithm,
147 pub progress: bool,
151}
152
153impl Default for Config {
154 fn default() -> Self {
155 Config {
156 cycles: 100,
157 algorithm: Algorithm::NPAG,
158 progress: true,
159 }
160 }
161}
162
163#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
167pub struct Parameter {
168 pub(crate) name: String,
169 pub(crate) lower: f64,
170 pub(crate) upper: f64,
171}
172
173impl Parameter {
174 pub fn new(name: impl Into<String>, lower: f64, upper: f64) -> Self {
176 Self {
177 name: name.into(),
178 lower,
179 upper,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
186pub struct Parameters {
187 pub(crate) parameters: Vec<Parameter>,
188}
189
190impl Parameters {
191 pub fn new() -> Self {
192 Parameters {
193 parameters: Vec::new(),
194 }
195 }
196
197 pub fn add(mut self, name: impl Into<String>, lower: f64, upper: f64) -> Parameters {
198 let parameter = Parameter::new(name, lower, upper);
199 self.parameters.push(parameter);
200 self
201 }
202
203 pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
205 let name = name.into();
206 self.parameters.iter().find(|p| p.name == name)
207 }
208
209 pub fn names(&self) -> Vec<String> {
211 self.parameters.iter().map(|p| p.name.clone()).collect()
212 }
213 pub fn ranges(&self) -> Vec<(f64, f64)> {
217 self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
218 }
219
220 pub fn len(&self) -> usize {
222 self.parameters.len()
223 }
224
225 pub fn is_empty(&self) -> bool {
227 self.parameters.is_empty()
228 }
229
230 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
232 self.parameters.iter()
233 }
234}
235
236impl IntoIterator for Parameters {
237 type Item = Parameter;
238 type IntoIter = std::vec::IntoIter<Parameter>;
239
240 fn into_iter(self) -> Self::IntoIter {
241 self.parameters.into_iter()
242 }
243}
244
245impl From<Vec<Parameter>> for Parameters {
246 fn from(parameters: Vec<Parameter>) -> Self {
247 Parameters { parameters }
248 }
249}
250
251#[derive(Debug, Deserialize, Clone, Serialize)]
253#[serde(deny_unknown_fields, default)]
254pub struct Advanced {
255 pub min_distance: f64,
259 pub nm_steps: usize,
262 pub tolerance: f64,
266}
267
268impl Default for Advanced {
269 fn default() -> Self {
270 Advanced {
271 min_distance: 1e-4,
272 nm_steps: 100,
273 tolerance: 1e-6,
274 }
275 }
276}
277
278#[derive(Debug, Deserialize, Clone, Serialize)]
279#[serde(deny_unknown_fields, default)]
280pub struct Convergence {
282 pub likelihood: f64,
287 pub pyl: f64,
292 pub eps: f64,
298}
299
300impl Default for Convergence {
301 fn default() -> Self {
302 Convergence {
303 likelihood: 1e-4,
304 pyl: 1e-2,
305 eps: 1e-2,
306 }
307 }
308}
309
310#[derive(Debug, Deserialize, Clone, Serialize)]
311#[serde(deny_unknown_fields, default)]
312pub struct Predictions {
313 pub idelta: f64,
315 pub tad: f64,
321}
322
323impl Default for Predictions {
324 fn default() -> Self {
325 Predictions {
326 idelta: 0.12,
327 tad: 0.0,
328 }
329 }
330}
331
332impl Predictions {
333 pub fn validate(&self) -> Result<()> {
335 if self.idelta < 0.0 {
336 bail!("The interval for predictions must be non-negative");
337 }
338 if self.tad < 0.0 {
339 bail!("The time after dose for predictions must be non-negative");
340 }
341 Ok(())
342 }
343}
344
345#[derive(Debug, Deserialize, Clone, Serialize, Default)]
352pub enum LogLevel {
353 TRACE,
354 DEBUG,
355 #[default]
356 INFO,
357 WARN,
358 ERROR,
359}
360
361impl From<LogLevel> for tracing::Level {
362 fn from(log_level: LogLevel) -> tracing::Level {
363 match log_level {
364 LogLevel::TRACE => tracing::Level::TRACE,
365 LogLevel::DEBUG => tracing::Level::DEBUG,
366 LogLevel::INFO => tracing::Level::INFO,
367 LogLevel::WARN => tracing::Level::WARN,
368 LogLevel::ERROR => tracing::Level::ERROR,
369 }
370 }
371}
372
373impl AsRef<str> for LogLevel {
374 fn as_ref(&self) -> &str {
375 match self {
376 LogLevel::TRACE => "trace",
377 LogLevel::DEBUG => "debug",
378 LogLevel::INFO => "info",
379 LogLevel::WARN => "warn",
380 LogLevel::ERROR => "error",
381 }
382 }
383}
384
385impl Display for LogLevel {
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 write!(f, "{}", self.as_ref())
388 }
389}
390
391#[derive(Debug, Deserialize, Clone, Serialize)]
392#[serde(deny_unknown_fields, default)]
393pub struct Log {
394 pub level: LogLevel,
398 pub write: bool,
402 pub stdout: bool,
404}
405
406impl Default for Log {
407 fn default() -> Self {
408 Log {
409 level: LogLevel::INFO,
410 write: false,
411 stdout: true,
412 }
413 }
414}
415
416#[derive(Debug, Deserialize, Clone, Serialize)]
418#[serde(deny_unknown_fields, default)]
419pub struct Output {
420 pub write: bool,
422 pub path: String,
424}
425
426impl Default for Output {
427 fn default() -> Self {
428 let path = PathBuf::from("outputs/").to_string_lossy().to_string();
429
430 Output { write: true, path }
431 }
432}
433
434pub struct SettingsBuilder<State> {
435 config: Option<Config>,
436 parameters: Option<Parameters>,
437 errormodels: Option<AssayErrorModels>,
438 predictions: Option<Predictions>,
439 log: Option<Log>,
440 prior: Option<Prior>,
441 output: Option<Output>,
442 convergence: Option<Convergence>,
443 advanced: Option<Advanced>,
444 _marker: std::marker::PhantomData<State>,
445}
446
447pub trait AlgorithmDefined {}
449pub trait ParametersDefined {}
450pub trait ErrorModelDefined {}
451
452pub struct InitialState;
454pub struct AlgorithmSet;
455pub struct ParametersSet;
456pub struct ErrorSet;
457
458impl SettingsBuilder<InitialState> {
460 pub fn new() -> Self {
461 SettingsBuilder {
462 config: None,
463 parameters: None,
464 errormodels: None,
465 predictions: None,
466 log: None,
467 prior: None,
468 output: None,
469 convergence: None,
470 advanced: None,
471 _marker: std::marker::PhantomData,
472 }
473 }
474
475 pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
476 SettingsBuilder {
477 config: Some(Config {
478 algorithm,
479 ..Config::default()
480 }),
481 parameters: self.parameters,
482 errormodels: self.errormodels,
483 predictions: self.predictions,
484 log: self.log,
485 prior: self.prior,
486 output: self.output,
487 convergence: self.convergence,
488 advanced: self.advanced,
489 _marker: std::marker::PhantomData,
490 }
491 }
492}
493
494impl Default for SettingsBuilder<InitialState> {
495 fn default() -> Self {
496 SettingsBuilder::new()
497 }
498}
499
500impl SettingsBuilder<AlgorithmSet> {
502 pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
503 SettingsBuilder {
504 config: self.config,
505 parameters: Some(parameters),
506 errormodels: self.errormodels,
507 predictions: self.predictions,
508 log: self.log,
509 prior: self.prior,
510 output: self.output,
511 convergence: self.convergence,
512 advanced: self.advanced,
513 _marker: std::marker::PhantomData,
514 }
515 }
516}
517
518impl SettingsBuilder<ParametersSet> {
520 pub fn set_error_models(self, ems: AssayErrorModels) -> SettingsBuilder<ErrorSet> {
521 SettingsBuilder {
522 config: self.config,
523 parameters: self.parameters,
524 errormodels: Some(ems),
525 predictions: self.predictions,
526 log: self.log,
527 prior: self.prior,
528 output: self.output,
529 convergence: self.convergence,
530 advanced: self.advanced,
531 _marker: std::marker::PhantomData,
532 }
533 }
534}
535
536impl SettingsBuilder<ErrorSet> {
538 pub fn build(self) -> Settings {
539 Settings {
540 config: self.config.unwrap(),
541 parameters: self.parameters.unwrap(),
542 errormodels: self.errormodels.unwrap(),
543 predictions: self.predictions.unwrap_or_default(),
544 log: self.log.unwrap_or_default(),
545 prior: self.prior.unwrap_or_default(),
546 output: self.output.unwrap_or_default(),
547 convergence: self.convergence.unwrap_or_default(),
548 advanced: self.advanced.unwrap_or_default(),
549 }
550 }
551}
552
553fn parse_output_folder(path: String) -> String {
554 if !path.contains("#") {
556 return path;
557 }
558
559 let mut num = 1;
561 while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
562 num += 1;
563 }
564
565 path.replace("#", &num.to_string())
566}
567
568#[cfg(test)]
569
570mod tests {
571 use pharmsol::{AssayErrorModel, AssayErrorModels, ErrorPoly};
572
573 use super::*;
574 use crate::algorithms::Algorithm;
575
576 #[test]
577 fn test_builder() {
578 let parameters = Parameters::new().add("Ke", 0.0, 5.0).add("V", 10.0, 200.0);
579
580 let ems = AssayErrorModels::new()
581 .add(
582 0,
583 AssayErrorModel::Proportional {
584 gamma: pharmsol::Factor::Variable(5.0),
585 poly: ErrorPoly::new(0.0, 0.1, 0.0, 0.0),
586 },
587 )
588 .unwrap();
589 let mut settings = SettingsBuilder::new()
590 .set_algorithm(Algorithm::NPAG) .set_parameters(parameters) .set_error_models(ems)
593 .build();
594
595 settings.set_cycles(100);
596
597 assert_eq!(settings.config.algorithm, Algorithm::NPAG);
598 assert_eq!(settings.config.cycles, 100);
599 assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
600 }
601}