pmcore/routines/optimization/
mod.rs

1use argmin::{
2    core::{CostFunction, Error, Executor},
3    solver::neldermead::NelderMead,
4};
5use ndarray::{Array1, Axis};
6
7use pharmsol::prelude::{
8    data::{Data, ErrorModel},
9    simulator::{psi, Equation},
10};
11
12pub struct SppOptimizer<'a, E: Equation> {
13    equation: &'a E,
14    data: &'a Data,
15    sig: &'a ErrorModel<'a>,
16    pyl: &'a Array1<f64>,
17}
18
19impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
20    type Param = Vec<f64>;
21    type Output = f64;
22    fn cost(&self, spp: &Self::Param) -> Result<Self::Output, Error> {
23        let theta = Array1::from(spp.clone()).insert_axis(Axis(0));
24
25        let psi = psi(self.equation, self.data, &theta, self.sig, false, false);
26
27        if psi.ncols() > 1 {
28            tracing::error!("Psi in SppOptimizer has more than one column");
29        }
30        if psi.nrows() != self.pyl.len() {
31            tracing::error!(
32                "Psi in SppOptimizer has {} rows, but spp has {}",
33                psi.nrows(),
34                self.pyl.len()
35            );
36        }
37        let nsub = psi.nrows() as f64;
38        let mut sum = -nsub;
39        for (p_i, pyl_i) in psi.iter().zip(self.pyl.iter()) {
40            sum += p_i / pyl_i;
41        }
42        Ok(-sum)
43    }
44}
45
46impl<'a, E: Equation> SppOptimizer<'a, E> {
47    pub fn new(equation: &'a E, data: &'a Data, sig: &'a ErrorModel, pyl: &'a Array1<f64>) -> Self {
48        Self {
49            equation,
50            data,
51            sig,
52            pyl,
53        }
54    }
55    pub fn optimize_point(self, spp: Array1<f64>) -> Result<Array1<f64>, Error> {
56        let simplex = create_initial_simplex(&spp.to_vec());
57        let solver = NelderMead::new(simplex).with_sd_tolerance(1e-2)?;
58        let res = Executor::new(self, solver)
59            .configure(|state| state.max_iters(5))
60            // .add_observer(SlogLogger::term(), ObserverMode::Always)
61            .run()?;
62        Ok(Array1::from(res.state.best_param.unwrap()))
63    }
64}
65
66fn create_initial_simplex(initial_point: &[f64]) -> Vec<Vec<f64>> {
67    let num_dimensions = initial_point.len();
68    let perturbation_percentage = 0.008;
69
70    // Initialize a Vec to store the vertices of the simplex
71    let mut vertices = Vec::new();
72
73    // Add the initial point to the vertices
74    vertices.push(initial_point.to_vec());
75
76    // Calculate perturbation values for each component
77    for i in 0..num_dimensions {
78        let perturbation = if initial_point[i] == 0.0 {
79            0.00025 // Special case for components equal to 0
80        } else {
81            perturbation_percentage * initial_point[i]
82        };
83
84        let mut perturbed_point = initial_point.to_owned();
85        perturbed_point[i] += perturbation;
86        vertices.push(perturbed_point);
87    }
88
89    vertices
90}