pmcore/routines/optimization/
mod.rs

1use argmin::{
2    core::{CostFunction, Error, Executor},
3    solver::neldermead::NelderMead,
4};
5
6use ndarray::{Array1, Axis};
7
8use pharmsol::prelude::{
9    data::{Data, ErrorModels},
10    simulator::{psi, Equation},
11};
12
13pub struct SppOptimizer<'a, E: Equation> {
14    equation: &'a E,
15    data: &'a Data,
16    sig: &'a ErrorModels,
17    pyl: &'a Array1<f64>,
18}
19
20impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
21    type Param = Vec<f64>;
22    type Output = f64;
23    fn cost(&self, spp: &Self::Param) -> Result<Self::Output, Error> {
24        let theta = Array1::from(spp.clone()).insert_axis(Axis(0));
25
26        let psi = psi(self.equation, self.data, &theta, self.sig, false, false)?;
27
28        if psi.ncols() > 1 {
29            tracing::error!("Psi in SppOptimizer has more than one column");
30        }
31        if psi.nrows() != self.pyl.len() {
32            tracing::error!(
33                "Psi in SppOptimizer has {} rows, but spp has {}",
34                psi.nrows(),
35                self.pyl.len()
36            );
37        }
38        let nsub = psi.nrows() as f64;
39        let mut sum = -nsub;
40        for (p_i, pyl_i) in psi.iter().zip(self.pyl.iter()) {
41            sum += p_i / pyl_i;
42        }
43        Ok(-sum)
44    }
45}
46
47impl<'a, E: Equation> SppOptimizer<'a, E> {
48    pub fn new(
49        equation: &'a E,
50        data: &'a Data,
51        sig: &'a ErrorModels,
52        pyl: &'a Array1<f64>,
53    ) -> Self {
54        Self {
55            equation,
56            data,
57            sig,
58            pyl,
59        }
60    }
61    pub fn optimize_point(self, spp: Array1<f64>) -> Result<Array1<f64>, Error> {
62        let simplex = create_initial_simplex(&spp.to_vec());
63        let solver: NelderMead<Vec<f64>, f64> = NelderMead::new(simplex).with_sd_tolerance(1e-2)?;
64        let res = Executor::new(self, solver)
65            .configure(|state| state.max_iters(5))
66            // .add_observer(SlogLogger::term(), ObserverMode::Always)
67            .run()?;
68        Ok(Array1::from(res.state.best_param.unwrap()))
69    }
70}
71
72fn create_initial_simplex(initial_point: &[f64]) -> Vec<Vec<f64>> {
73    let num_dimensions = initial_point.len();
74    let perturbation_percentage = 0.008;
75
76    // Initialize a Vec to store the vertices of the simplex
77    let mut vertices = Vec::new();
78
79    // Add the initial point to the vertices
80    vertices.push(initial_point.to_vec());
81
82    // Calculate perturbation values for each component
83    for i in 0..num_dimensions {
84        let perturbation = if initial_point[i] == 0.0 {
85            0.00025 // Special case for components equal to 0
86        } else {
87            perturbation_percentage * initial_point[i]
88        };
89
90        let mut perturbed_point = initial_point.to_owned();
91        perturbed_point[i] += perturbation;
92        vertices.push(perturbed_point);
93    }
94
95    vertices
96}