pmcore/routines/optimization/
mod.rs1use argmin::{
2 core::{CostFunction, Error, Executor},
3 solver::neldermead::NelderMead,
4};
5
6use ndarray::{Array1, Axis};
7
8use pharmsol::prelude::{
9 data::{Data, ErrorModels},
10 simulator::{psi, Equation},
11};
12
13pub struct SppOptimizer<'a, E: Equation> {
14 equation: &'a E,
15 data: &'a Data,
16 sig: &'a ErrorModels,
17 pyl: &'a Array1<f64>,
18}
19
20impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
21 type Param = Vec<f64>;
22 type Output = f64;
23 fn cost(&self, spp: &Self::Param) -> Result<Self::Output, Error> {
24 let theta = Array1::from(spp.clone()).insert_axis(Axis(0));
25
26 let psi = psi(self.equation, self.data, &theta, self.sig, false, false)?;
27
28 if psi.ncols() > 1 {
29 tracing::error!("Psi in SppOptimizer has more than one column");
30 }
31 if psi.nrows() != self.pyl.len() {
32 tracing::error!(
33 "Psi in SppOptimizer has {} rows, but spp has {}",
34 psi.nrows(),
35 self.pyl.len()
36 );
37 }
38 let nsub = psi.nrows() as f64;
39 let mut sum = -nsub;
40 for (p_i, pyl_i) in psi.iter().zip(self.pyl.iter()) {
41 sum += p_i / pyl_i;
42 }
43 Ok(-sum)
44 }
45}
46
47impl<'a, E: Equation> SppOptimizer<'a, E> {
48 pub fn new(
49 equation: &'a E,
50 data: &'a Data,
51 sig: &'a ErrorModels,
52 pyl: &'a Array1<f64>,
53 ) -> Self {
54 Self {
55 equation,
56 data,
57 sig,
58 pyl,
59 }
60 }
61 pub fn optimize_point(self, spp: Array1<f64>) -> Result<Array1<f64>, Error> {
62 let simplex = create_initial_simplex(&spp.to_vec());
63 let solver: NelderMead<Vec<f64>, f64> = NelderMead::new(simplex).with_sd_tolerance(1e-2)?;
64 let res = Executor::new(self, solver)
65 .configure(|state| state.max_iters(5))
66 .run()?;
68 Ok(Array1::from(res.state.best_param.unwrap()))
69 }
70}
71
72fn create_initial_simplex(initial_point: &[f64]) -> Vec<Vec<f64>> {
73 let num_dimensions = initial_point.len();
74 let perturbation_percentage = 0.008;
75
76 let mut vertices = Vec::new();
78
79 vertices.push(initial_point.to_vec());
81
82 for i in 0..num_dimensions {
84 let perturbation = if initial_point[i] == 0.0 {
85 0.00025 } else {
87 perturbation_percentage * initial_point[i]
88 };
89
90 let mut perturbed_point = initial_point.to_owned();
91 perturbed_point[i] += perturbation;
92 vertices.push(perturbed_point);
93 }
94
95 vertices
96}