pmcore/routines/optimization/
bestmo.rs

1use anyhow::Result;
2use argmin::{
3    core::{CostFunction, Executor, TerminationReason, TerminationStatus},
4    solver::neldermead::NelderMead,
5};
6
7#[derive(Debug, Clone)]
8struct BestM0 {
9    a: f64,
10    b: f64,
11    w: f64,
12    h1: f64,
13    h2: f64,
14    xx: f64,
15}
16
17/// We'll optimize over y = ln(xm), so Param = f64 (the log of xm)
18impl CostFunction for BestM0 {
19    type Param = f64; // this is ln(xm)
20    type Output = f64;
21
22    fn cost(&self, y: &Self::Param) -> Result<Self::Output> {
23        // compute xm from log-parameter
24        let xm = y.exp();
25
26        // guard: xm must be > 0 and finite
27        if !(xm.is_finite() && xm > 0.0) {
28            // return a very large cost instead of NaN
29            return Ok(1.0e100);
30        }
31
32        // guard a,b,w positive/negative combinations: powf with positive base fine
33        let t1 = if self.a == 0.0 {
34            0.0
35        } else {
36            self.a / xm.powf(self.h1)
37        };
38        let t2 = if self.b == 0.0 {
39            0.0
40        } else {
41            self.b / xm.powf(self.h2)
42        };
43        let t3 = if self.w == 0.0 {
44            0.0
45        } else {
46            self.w / xm.powf(self.xx)
47        };
48
49        // If any term is NaN or infinite, treat as bad point
50        if !t1.is_finite() || !t2.is_finite() || !t3.is_finite() {
51            return Ok(1.0e100);
52        }
53
54        let val = (1.0 - t1 - t2 - t3).powi(2);
55        if !val.is_finite() {
56            return Ok(1.0e100);
57        }
58
59        Ok(val)
60    }
61}
62
63impl BestM0 {
64    /// start and step are in log-space (ln(x))
65    fn get_best(&self, start_log: f64, step_log: f64) -> Result<(f64, f64, bool)> {
66        // Build a simplex with two log-parameters, both finite and distinct
67        let second = start_log + step_log;
68        // if step pushed us to invalid values, choose a small positive step
69        let initial_simplex = if !(second.is_finite()) || (second - start_log).abs() < 1e-12 {
70            vec![start_log, start_log + 0.1_f64] // 0.1 in log-space ~ 10% change
71        } else {
72            vec![start_log, second]
73        };
74
75        let solver = NelderMead::new(initial_simplex)
76            .with_sd_tolerance(1e-8)
77            .map_err(|e| anyhow::anyhow!("Failed creating NelderMead: {}", e))?;
78
79        let res = Executor::new(self.clone(), solver)
80            .configure(|state| state.max_iters(1000))
81            .run()
82            .map_err(|e| anyhow::anyhow!("Optimizer run failed: {}", e))?;
83
84        let converged = match &res.state.termination_status {
85            TerminationStatus::Terminated(reason) => {
86                matches!(reason, TerminationReason::SolverConverged)
87            }
88            _ => false,
89        };
90
91        // best_param is ln(xm). Convert back to xm
92        let best_log = res.state.best_param.unwrap();
93        let xm = best_log.exp();
94
95        Ok((xm, res.state.best_cost, converged))
96    }
97}
98
99/// find_m0 left largely as-is, but consider returning Result<f64>
100/// Keep in mind it expects a,b,h1,h2 in valid ranges
101fn find_m0(afinal: f64, b: f64, alpha: f64, h1: f64, h2: f64) -> f64 {
102    let noint = 1000;
103    let del_a = afinal / (noint as f64);
104    // initial guess; must be positive
105    let mut xm = if b > 0.0 { b.powf(1.0 / h2) } else { 1.0 };
106    let mut a = 0.0;
107    let hh = (h1 + h2) / 2.0;
108
109    for int in 1..=noint {
110        // safe guards: avoid dividing by zero
111        if xm <= 0.0 || xm.is_nan() || !xm.is_finite() {
112            return -1.0;
113        }
114
115        let top = 1.0 / xm.powf(h1) + alpha * b / xm.powf(hh);
116        let b1 = a * h1 / xm.powf(h1 + 1.0);
117        let b2 = b * h2 / xm.powf(h2 + 1.0);
118        let b3 = alpha * a * b * hh / xm.powf(hh + 1.0);
119
120        let denom = b1 + b2 + b3;
121        if denom == 0.0 || !denom.is_finite() {
122            return -1.0;
123        }
124
125        let xmp = top / denom;
126        xm = xm + xmp * del_a;
127
128        if !(xm.is_finite() && xm > 0.0) {
129            return -1.0;
130        }
131
132        a = del_a * (int as f64);
133    }
134
135    xm
136}
137
138pub fn get_xm0best(a: f64, b: f64, w: f64, h1: f64, h2: f64, alpha_s: f64) -> f64 {
139    // trivial cases
140    if a.abs() < 1.0e-12 && b.abs() < 1.0e-12 {
141        return 0.0;
142    }
143
144    // precompute
145    let xx = (h1 + h2) / 2.0;
146    let bm0 = BestM0 {
147        a,
148        b,
149        w,
150        h1,
151        h2,
152        xx,
153    };
154
155    // if one coefficient negative/zero, return simple closed-form estimate
156    if b <= 0.0 && a > 0.0 {
157        let xm0best = a.powf(1.0 / h1);
158        return xm0best / (xm0best + 1.0);
159    }
160    if a <= 0.0 && b > 0.0 {
161        let xm0best = b.powf(1.0 / h2);
162        return xm0best / (xm0best + 1.0);
163    }
164
165    // both positive: do optimization in log-space
166    // choose a safe initial guess > 0
167    let xm_guess = if b > 0.0 {
168        b.powf(1.0 / h2)
169    } else if a > 0.0 {
170        a.powf(1.0 / h1)
171    } else {
172        1.0
173    };
174    let start_log = xm_guess.max(1e-12).ln();
175    let step_log = 0.1_f64; // ~10% step in xm
176
177    // first optimization from small start
178    match bm0.get_best(start_log, step_log) {
179        Ok((xm0best1, valmin1, conv1)) => {
180            if !conv1 {
181                // we still keep the answer if cost is tiny
182                if valmin1 < 1e-10 {
183                    return xm0best1 / (xm0best1 + 1.0);
184                }
185                // fallback to iterative estimator
186                let xm0est = find_m0(a, b, alpha_s, h1, h2);
187                if xm0est < 0.0 {
188                    return xm0best1 / (xm0best1 + 1.0);
189                }
190                // refine from bg estimate:
191                let start_log2 = xm0est.ln();
192                if let Ok((xm0best2, valmin2, conv2)) = bm0.get_best(start_log2, 0.1) {
193                    if conv2 && valmin2 < valmin1 {
194                        return xm0best2 / (xm0best2 + 1.0);
195                    } else {
196                        return xm0best1 / (xm0best1 + 1.0);
197                    }
198                } else {
199                    return xm0best1 / (xm0best1 + 1.0);
200                }
201            } else {
202                return xm0best1 / (xm0best1 + 1.0);
203            }
204        }
205        Err(_) => {
206            // if optimizer failed, fallback to numerical estimator
207            let xm0est = find_m0(a, b, alpha_s, h1, h2);
208            if xm0est > 0.0 {
209                return xm0est / (xm0est + 1.0);
210            } else {
211                // last resort: simple closed form (if possible)
212                if a > 0.0 {
213                    let xm0best = a.powf(1.0 / h1);
214                    return xm0best / (xm0best + 1.0);
215                }
216                if b > 0.0 {
217                    let xm0best = b.powf(1.0 / h2);
218                    return xm0best / (xm0best + 1.0);
219                }
220                return 0.0;
221            }
222        }
223    }
224}