use crate::algorithms::Algorithm;
use crate::routines::initialization::Prior;
use crate::routines::output::OutputFile;
use anyhow::{bail, Result};
use pharmsol::prelude::data::ErrorType;
use serde::{Deserialize, Serialize};
use serde_json;
use std::fmt::Display;
use std::path::PathBuf;
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields)]
pub struct Settings {
pub(crate) config: Config,
pub(crate) parameters: Parameters,
pub(crate) error: Error,
pub(crate) predictions: Predictions,
pub(crate) log: Log,
pub(crate) prior: Prior,
pub(crate) output: Output,
pub(crate) convergence: Convergence,
pub(crate) advanced: Advanced,
}
impl Settings {
pub fn builder() -> SettingsBuilder<InitialState> {
SettingsBuilder::new()
}
pub fn validate(&self) -> Result<()> {
self.error.validate()?;
self.predictions.validate()?;
Ok(())
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn parameters(&self) -> &Parameters {
&self.parameters
}
pub fn error(&self) -> &Error {
&self.error
}
pub fn predictions(&self) -> &Predictions {
&self.predictions
}
pub fn log(&self) -> &Log {
&self.log
}
pub fn prior(&self) -> &Prior {
&self.prior
}
pub fn output(&self) -> &Output {
&self.output
}
pub fn convergence(&self) -> &Convergence {
&self.convergence
}
pub fn advanced(&self) -> &Advanced {
&self.advanced
}
pub fn set_cycles(&mut self, cycles: usize) {
self.config.cycles = cycles;
}
pub fn set_algorithm(&mut self, algorithm: Algorithm) {
self.config.algorithm = algorithm;
}
pub fn set_cache(&mut self, cache: bool) {
self.config.cache = cache;
}
pub fn set_idelta(&mut self, idelta: f64) {
self.predictions.idelta = idelta;
}
pub fn set_tad(&mut self, tad: f64) {
self.predictions.tad = tad;
}
pub fn set_prior(&mut self, prior: Prior) {
self.prior = prior;
}
pub fn disable_output(&mut self) {
self.output.write = false;
}
pub fn set_output_path(&mut self, path: impl Into<String>) {
self.output.path = parse_output_folder(path.into());
}
pub fn set_log_stdout(&mut self, stdout: bool) {
self.log.stdout = stdout;
}
pub fn set_write_logs(&mut self, write: bool) {
self.log.write = write;
}
pub fn set_log_level(&mut self, level: LogLevel) {
self.log.level = level;
}
pub fn set_progress(&mut self, progress: bool) {
self.config.progress = progress;
}
pub fn initialize_logs(&mut self) -> Result<()> {
crate::routines::logger::setup_log(self)
}
pub fn write(&self) -> Result<()> {
let serialized = serde_json::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?;
let mut file = outputfile.file;
std::io::Write::write_all(&mut file, serialized.as_bytes())?;
Ok(())
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Config {
pub cycles: usize,
pub algorithm: Algorithm,
pub cache: bool,
pub progress: bool,
}
impl Default for Config {
fn default() -> Self {
Config {
cycles: 100,
algorithm: Algorithm::NPAG,
cache: true,
progress: true,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Parameter {
pub(crate) name: String,
pub(crate) lower: f64,
pub(crate) upper: f64,
pub(crate) fixed: bool,
}
impl Parameter {
pub fn new(name: impl Into<String>, lower: f64, upper: f64, fixed: bool) -> Self {
Self {
name: name.into(),
lower,
upper,
fixed,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Parameters {
pub(crate) parameters: Vec<Parameter>,
}
impl Parameters {
pub fn new() -> Self {
Parameters {
parameters: Vec::new(),
}
}
pub fn add(
mut self,
name: impl Into<String>,
lower: f64,
upper: f64,
fixed: bool,
) -> Parameters {
let parameter = Parameter::new(name, lower, upper, fixed);
self.parameters.push(parameter);
self
}
pub fn get(&self, name: impl Into<String>) -> Option<&Parameter> {
let name = name.into();
self.parameters.iter().find(|p| p.name == name)
}
pub fn names(&self) -> Vec<String> {
self.parameters.iter().map(|p| p.name.clone()).collect()
}
pub fn ranges(&self) -> Vec<(f64, f64)> {
self.parameters.iter().map(|p| (p.lower, p.upper)).collect()
}
pub fn len(&self) -> usize {
self.parameters.len()
}
pub fn is_empty(&self) -> bool {
self.parameters.is_empty()
}
pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
self.parameters.iter()
}
}
impl IntoIterator for Parameters {
type Item = Parameter;
type IntoIter = std::vec::IntoIter<Parameter>;
fn into_iter(self) -> Self::IntoIter {
self.parameters.into_iter()
}
}
impl From<Vec<Parameter>> for Parameters {
fn from(parameters: Vec<Parameter>) -> Self {
Parameters { parameters }
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
pub enum ErrorModel {
Additive,
Proportional,
}
impl From<ErrorModel> for ErrorType {
fn from(error_model: ErrorModel) -> ErrorType {
match error_model {
ErrorModel::Additive => ErrorType::Add,
ErrorModel::Proportional => ErrorType::Prop,
}
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Error {
pub value: f64,
pub model: ErrorModel,
pub poly: (f64, f64, f64, f64),
}
impl Default for Error {
fn default() -> Self {
Error {
value: 0.0,
model: ErrorModel::Additive,
poly: (0.0, 0.1, 0.0, 0.0),
}
}
}
impl Error {
fn new(value: f64, model: ErrorModel, poly: (f64, f64, f64, f64)) -> Self {
Error { value, model, poly }
}
fn validate(&self) -> Result<()> {
if self.value < 0.0 {
bail!(format!(
"Error value must be non-negative, got {}",
self.value
));
}
Ok(())
}
pub fn error_model(&self) -> ErrorModel {
self.model.clone()
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Advanced {
pub min_distance: f64,
pub nm_steps: usize,
pub tolerance: f64,
}
impl Default for Advanced {
fn default() -> Self {
Advanced {
min_distance: 1e-4,
nm_steps: 100,
tolerance: 1e-6,
}
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Convergence {
pub likelihood: f64,
pub pyl: f64,
pub eps: f64,
}
impl Default for Convergence {
fn default() -> Self {
Convergence {
likelihood: 1e-4,
pyl: 1e-2,
eps: 1e-2,
}
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Predictions {
pub idelta: f64,
pub tad: f64,
}
impl Default for Predictions {
fn default() -> Self {
Predictions {
idelta: 0.12,
tad: 0.0,
}
}
}
impl Predictions {
pub fn validate(&self) -> Result<()> {
if self.idelta < 0.0 {
bail!("The interval for predictions must be non-negative");
}
if self.tad < 0.0 {
bail!("The time after dose for predictions must be non-negative");
}
Ok(())
}
}
#[derive(Debug, Deserialize, Clone, Serialize, Default)]
pub enum LogLevel {
TRACE,
DEBUG,
#[default]
INFO,
WARN,
ERROR,
}
impl From<LogLevel> for tracing::Level {
fn from(log_level: LogLevel) -> tracing::Level {
match log_level {
LogLevel::TRACE => tracing::Level::TRACE,
LogLevel::DEBUG => tracing::Level::DEBUG,
LogLevel::INFO => tracing::Level::INFO,
LogLevel::WARN => tracing::Level::WARN,
LogLevel::ERROR => tracing::Level::ERROR,
}
}
}
impl AsRef<str> for LogLevel {
fn as_ref(&self) -> &str {
match self {
LogLevel::TRACE => "trace",
LogLevel::DEBUG => "debug",
LogLevel::INFO => "info",
LogLevel::WARN => "warn",
LogLevel::ERROR => "error",
}
}
}
impl Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Log {
pub level: LogLevel,
pub write: bool,
pub stdout: bool,
}
impl Default for Log {
fn default() -> Self {
Log {
level: LogLevel::INFO,
write: false,
stdout: true,
}
}
}
#[derive(Debug, Deserialize, Clone, Serialize)]
#[serde(deny_unknown_fields, default)]
pub struct Output {
pub write: bool,
pub path: String,
}
impl Default for Output {
fn default() -> Self {
let path = PathBuf::from("outputs/").to_string_lossy().to_string();
Output { write: true, path }
}
}
pub struct SettingsBuilder<State> {
config: Option<Config>,
parameters: Option<Parameters>,
error: Option<Error>,
predictions: Option<Predictions>,
log: Option<Log>,
prior: Option<Prior>,
output: Option<Output>,
convergence: Option<Convergence>,
advanced: Option<Advanced>,
_marker: std::marker::PhantomData<State>,
}
pub trait AlgorithmDefined {}
pub trait ParametersDefined {}
pub trait ErrorModelDefined {}
pub struct InitialState;
pub struct AlgorithmSet;
pub struct ParametersSet;
pub struct ErrorSet;
impl SettingsBuilder<InitialState> {
pub fn new() -> Self {
SettingsBuilder {
config: None,
parameters: None,
error: None,
predictions: None,
log: None,
prior: None,
output: None,
convergence: None,
advanced: None,
_marker: std::marker::PhantomData,
}
}
pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder<AlgorithmSet> {
SettingsBuilder {
config: Some(Config {
algorithm,
..Config::default()
}),
parameters: self.parameters,
error: self.error,
predictions: self.predictions,
log: self.log,
prior: self.prior,
output: self.output,
convergence: self.convergence,
advanced: self.advanced,
_marker: std::marker::PhantomData,
}
}
}
impl Default for SettingsBuilder<InitialState> {
fn default() -> Self {
SettingsBuilder::new()
}
}
impl SettingsBuilder<AlgorithmSet> {
pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder<ParametersSet> {
SettingsBuilder {
config: self.config,
parameters: Some(parameters),
error: self.error,
predictions: self.predictions,
log: self.log,
prior: self.prior,
output: self.output,
convergence: self.convergence,
advanced: self.advanced,
_marker: std::marker::PhantomData,
}
}
}
impl SettingsBuilder<ParametersSet> {
pub fn set_error_model(
self,
model: ErrorModel,
value: f64,
poly: (f64, f64, f64, f64),
) -> SettingsBuilder<ErrorSet> {
let error = Error::new(value, model, poly);
SettingsBuilder {
config: self.config,
parameters: self.parameters,
error: Some(error),
predictions: self.predictions,
log: self.log,
prior: self.prior,
output: self.output,
convergence: self.convergence,
advanced: self.advanced,
_marker: std::marker::PhantomData,
}
}
}
impl SettingsBuilder<ErrorSet> {
pub fn build(self) -> Settings {
Settings {
config: self.config.unwrap(),
parameters: self.parameters.unwrap(),
error: self.error.unwrap(),
predictions: self.predictions.unwrap_or_default(),
log: self.log.unwrap_or_default(),
prior: self.prior.unwrap_or_default(),
output: self.output.unwrap_or_default(),
convergence: self.convergence.unwrap_or_default(),
advanced: self.advanced.unwrap_or_default(),
}
}
}
fn parse_output_folder(path: String) -> String {
if !path.contains("#") {
return path;
}
let mut num = 1;
while std::path::Path::new(&path.replace("#", &num.to_string())).exists() {
num += 1;
}
let result = path.replace("#", &num.to_string());
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algorithms::Algorithm;
#[test]
fn test_builder() {
let parameters = Parameters::new()
.add("Ke", 0.0, 5.0, false)
.add("V", 10.0, 200.0, true);
let mut settings = SettingsBuilder::new()
.set_algorithm(Algorithm::NPAG) .set_parameters(parameters) .set_error_model(ErrorModel::Additive, 5.0, (0.0, 0.1, 0.0, 0.0))
.build();
settings.set_cycles(100);
assert_eq!(settings.config.algorithm, Algorithm::NPAG);
assert_eq!(settings.config.cycles, 100);
assert_eq!(settings.config.cache, true);
assert_eq!(settings.parameters().names(), vec!["Ke", "V"]);
}
}