pmcore/routines/optimization/
mod.rs1use argmin::{
2 core::{CostFunction, Error, Executor},
3 solver::neldermead::NelderMead,
4};
5use ndarray::{Array1, Axis};
6
7use pharmsol::prelude::{
8 data::{Data, ErrorModel},
9 simulator::{psi, Equation},
10};
11
12pub struct SppOptimizer<'a, E: Equation> {
13 equation: &'a E,
14 data: &'a Data,
15 sig: &'a ErrorModel<'a>,
16 pyl: &'a Array1<f64>,
17}
18
19impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
20 type Param = Vec<f64>;
21 type Output = f64;
22 fn cost(&self, spp: &Self::Param) -> Result<Self::Output, Error> {
23 let theta = Array1::from(spp.clone()).insert_axis(Axis(0));
24
25 let psi = psi(self.equation, self.data, &theta, self.sig, false, false);
26
27 if psi.ncols() > 1 {
28 tracing::error!("Psi in SppOptimizer has more than one column");
29 }
30 if psi.nrows() != self.pyl.len() {
31 tracing::error!(
32 "Psi in SppOptimizer has {} rows, but spp has {}",
33 psi.nrows(),
34 self.pyl.len()
35 );
36 }
37 let nsub = psi.nrows() as f64;
38 let mut sum = -nsub;
39 for (p_i, pyl_i) in psi.iter().zip(self.pyl.iter()) {
40 sum += p_i / pyl_i;
41 }
42 Ok(-sum)
43 }
44}
45
46impl<'a, E: Equation> SppOptimizer<'a, E> {
47 pub fn new(equation: &'a E, data: &'a Data, sig: &'a ErrorModel, pyl: &'a Array1<f64>) -> Self {
48 Self {
49 equation,
50 data,
51 sig,
52 pyl,
53 }
54 }
55 pub fn optimize_point(self, spp: Array1<f64>) -> Result<Array1<f64>, Error> {
56 let simplex = create_initial_simplex(&spp.to_vec());
57 let solver = NelderMead::new(simplex).with_sd_tolerance(1e-2)?;
58 let res = Executor::new(self, solver)
59 .configure(|state| state.max_iters(5))
60 .run()?;
62 Ok(Array1::from(res.state.best_param.unwrap()))
63 }
64}
65
66fn create_initial_simplex(initial_point: &[f64]) -> Vec<Vec<f64>> {
67 let num_dimensions = initial_point.len();
68 let perturbation_percentage = 0.008;
69
70 let mut vertices = Vec::new();
72
73 vertices.push(initial_point.to_vec());
75
76 for i in 0..num_dimensions {
78 let perturbation = if initial_point[i] == 0.0 {
79 0.00025 } else {
81 perturbation_percentage * initial_point[i]
82 };
83
84 let mut perturbed_point = initial_point.to_owned();
85 perturbed_point[i] += perturbation;
86 vertices.push(perturbed_point);
87 }
88
89 vertices
90}