pmcore/routines/optimization/
bestmo.rs1use anyhow::Result;
2use argmin::{
3 core::{CostFunction, Executor, TerminationReason, TerminationStatus},
4 solver::neldermead::NelderMead,
5};
6
7#[derive(Debug, Clone)]
8struct BestM0 {
9 a: f64,
10 b: f64,
11 w: f64,
12 h1: f64,
13 h2: f64,
14 xx: f64,
15}
16
17impl CostFunction for BestM0 {
19 type Param = f64; type Output = f64;
21
22 fn cost(&self, y: &Self::Param) -> Result<Self::Output> {
23 let xm = y.exp();
25
26 if !(xm.is_finite() && xm > 0.0) {
28 return Ok(1.0e100);
30 }
31
32 let t1 = if self.a == 0.0 {
34 0.0
35 } else {
36 self.a / xm.powf(self.h1)
37 };
38 let t2 = if self.b == 0.0 {
39 0.0
40 } else {
41 self.b / xm.powf(self.h2)
42 };
43 let t3 = if self.w == 0.0 {
44 0.0
45 } else {
46 self.w / xm.powf(self.xx)
47 };
48
49 if !t1.is_finite() || !t2.is_finite() || !t3.is_finite() {
51 return Ok(1.0e100);
52 }
53
54 let val = (1.0 - t1 - t2 - t3).powi(2);
55 if !val.is_finite() {
56 return Ok(1.0e100);
57 }
58
59 Ok(val)
60 }
61}
62
63impl BestM0 {
64 fn get_best(&self, start_log: f64, step_log: f64) -> Result<(f64, f64, bool)> {
66 let second = start_log + step_log;
68 let initial_simplex = if !(second.is_finite()) || (second - start_log).abs() < 1e-12 {
70 vec![start_log, start_log + 0.1_f64] } else {
72 vec![start_log, second]
73 };
74
75 let solver = NelderMead::new(initial_simplex)
76 .with_sd_tolerance(1e-8)
77 .map_err(|e| anyhow::anyhow!("Failed creating NelderMead: {}", e))?;
78
79 let res = Executor::new(self.clone(), solver)
80 .configure(|state| state.max_iters(1000))
81 .run()
82 .map_err(|e| anyhow::anyhow!("Optimizer run failed: {}", e))?;
83
84 let converged = match &res.state.termination_status {
85 TerminationStatus::Terminated(reason) => {
86 matches!(reason, TerminationReason::SolverConverged)
87 }
88 _ => false,
89 };
90
91 let best_log = res.state.best_param.unwrap();
93 let xm = best_log.exp();
94
95 Ok((xm, res.state.best_cost, converged))
96 }
97}
98
99fn find_m0(afinal: f64, b: f64, alpha: f64, h1: f64, h2: f64) -> f64 {
102 let noint = 1000;
103 let del_a = afinal / (noint as f64);
104 let mut xm = if b > 0.0 { b.powf(1.0 / h2) } else { 1.0 };
106 let mut a = 0.0;
107 let hh = (h1 + h2) / 2.0;
108
109 for int in 1..=noint {
110 if xm <= 0.0 || xm.is_nan() || !xm.is_finite() {
112 return -1.0;
113 }
114
115 let top = 1.0 / xm.powf(h1) + alpha * b / xm.powf(hh);
116 let b1 = a * h1 / xm.powf(h1 + 1.0);
117 let b2 = b * h2 / xm.powf(h2 + 1.0);
118 let b3 = alpha * a * b * hh / xm.powf(hh + 1.0);
119
120 let denom = b1 + b2 + b3;
121 if denom == 0.0 || !denom.is_finite() {
122 return -1.0;
123 }
124
125 let xmp = top / denom;
126 xm = xm + xmp * del_a;
127
128 if !(xm.is_finite() && xm > 0.0) {
129 return -1.0;
130 }
131
132 a = del_a * (int as f64);
133 }
134
135 xm
136}
137
138pub fn get_xm0best(a: f64, b: f64, w: f64, h1: f64, h2: f64, alpha_s: f64) -> f64 {
139 if a.abs() < 1.0e-12 && b.abs() < 1.0e-12 {
141 return 0.0;
142 }
143
144 let xx = (h1 + h2) / 2.0;
146 let bm0 = BestM0 {
147 a,
148 b,
149 w,
150 h1,
151 h2,
152 xx,
153 };
154
155 if b <= 0.0 && a > 0.0 {
157 let xm0best = a.powf(1.0 / h1);
158 return xm0best / (xm0best + 1.0);
159 }
160 if a <= 0.0 && b > 0.0 {
161 let xm0best = b.powf(1.0 / h2);
162 return xm0best / (xm0best + 1.0);
163 }
164
165 let xm_guess = if b > 0.0 {
168 b.powf(1.0 / h2)
169 } else if a > 0.0 {
170 a.powf(1.0 / h1)
171 } else {
172 1.0
173 };
174 let start_log = xm_guess.max(1e-12).ln();
175 let step_log = 0.1_f64; match bm0.get_best(start_log, step_log) {
179 Ok((xm0best1, valmin1, conv1)) => {
180 if !conv1 {
181 if valmin1 < 1e-10 {
183 return xm0best1 / (xm0best1 + 1.0);
184 }
185 let xm0est = find_m0(a, b, alpha_s, h1, h2);
187 if xm0est < 0.0 {
188 return xm0best1 / (xm0best1 + 1.0);
189 }
190 let start_log2 = xm0est.ln();
192 if let Ok((xm0best2, valmin2, conv2)) = bm0.get_best(start_log2, 0.1) {
193 if conv2 && valmin2 < valmin1 {
194 return xm0best2 / (xm0best2 + 1.0);
195 } else {
196 return xm0best1 / (xm0best1 + 1.0);
197 }
198 } else {
199 return xm0best1 / (xm0best1 + 1.0);
200 }
201 } else {
202 return xm0best1 / (xm0best1 + 1.0);
203 }
204 }
205 Err(_) => {
206 let xm0est = find_m0(a, b, alpha_s, h1, h2);
208 if xm0est > 0.0 {
209 return xm0est / (xm0est + 1.0);
210 } else {
211 if a > 0.0 {
213 let xm0best = a.powf(1.0 / h1);
214 return xm0best / (xm0best + 1.0);
215 }
216 if b > 0.0 {
217 let xm0best = b.powf(1.0 / h2);
218 return xm0best / (xm0best + 1.0);
219 }
220 return 0.0;
221 }
222 }
223 }
224}