pmcore/routines/evaluation/
ipm.rs

1use crate::structs::psi::Psi;
2use anyhow::bail;
3use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
4use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
5use faer::{Col, Mat, Row};
6use rayon::prelude::*;
7/// Applies Burke's Interior Point Method (IPM) to solve a convex optimization problem.
8///
9/// The objective function to maximize is:
10///     f(x) = Σ(log(Σ(ψ_ij * x_j)))   for i = 1 to n_sub
11///
12/// subject to:
13///     1. x_j ≥ 0 for all j = 1 to n_point,
14///     2. Σ(x_j) = 1,
15///
16/// where ψ is an n_sub×n_point matrix with non-negative entries and x is a probability vector.
17///
18/// # Arguments
19///
20/// * `psi` - A reference to a Psi structure containing the input matrix.
21///
22/// # Returns
23///
24/// On success, returns a tuple `(lam, obj)` where:
25///   - `lam` is a faer::Col<f64> containing the computed probability vector,
26///   - `obj` is the value of the objective function at the solution.
27///
28/// # Errors
29///
30/// This function returns an error if any step in the optimization (e.g. Cholesky factorization)
31/// fails.
32pub fn burke(psi: &Psi) -> anyhow::Result<(Col<f64>, f64)> {
33    let mut psi = psi.matrix().to_owned();
34
35    // Ensure all entries are finite and make them non-negative.
36    psi.row_iter_mut().try_for_each(|row| {
37        row.iter_mut().try_for_each(|x| {
38            if !x.is_finite() {
39                bail!("Input matrix must have finite entries")
40            } else {
41                // Coerce negatives to non-negative (could alternatively return an error)
42                *x = x.abs();
43                Ok(())
44            }
45        })
46    })?;
47
48    // Let psi be of shape (n_sub, n_point)
49    let (n_sub, n_point) = psi.shape();
50
51    // Create unit vectors:
52    // ecol: ones vector of length n_point (used for sums over points)
53    // erow: ones row of length n_sub (used for sums over subproblems)
54    let ecol: Col<f64> = Col::from_fn(n_point, |_| 1.0);
55    let erow: Row<f64> = Row::from_fn(n_sub, |_| 1.0);
56
57    // Compute plam = psi · ecol. This gives a column vector of length n_sub.
58    let mut plam: Col<f64> = &psi * &ecol;
59    let eps: f64 = 1e-8;
60    let mut sig: f64 = 0.0;
61
62    // Initialize lam (the variable we optimize) as a column vector of ones (length n_point).
63    let mut lam = ecol.clone();
64
65    // w = 1 ./ plam, elementwise.
66    let mut w: Col<f64> = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i));
67
68    // ptw = ψᵀ · w, which will be a vector of length n_point.
69    let mut ptw: Col<f64> = psi.transpose() * &w;
70
71    // Use the maximum entry in ptw for scaling (the "shrink" factor).
72    let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
73    let shrink = 2.0 * ptw_max;
74    lam *= shrink;
75    plam *= shrink;
76    w /= shrink;
77    ptw /= shrink;
78
79    // y = ecol - ptw (a vector of length n_point).
80    let mut y: Col<f64> = &ecol - &ptw;
81    // r = erow - (w .* plam) (elementwise product; r has length n_sub).
82    let mut r: Col<f64> = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
83    let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs()));
84
85    // Compute the duality gap.
86    let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
87    let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
88    let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
89
90    // Compute the duality measure mu.
91    let mut mu = lam.transpose() * &y / n_point as f64;
92
93    let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());
94
95    let n_threads = faer::get_global_parallelism().degree();
96
97    let rows = psi.nrows();
98
99    let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();
100
101    let mut h: Mat<f64> = Mat::zeros(rows, rows);
102
103    while mu > eps || norm_r > eps || gap > eps {
104        let smu = sig * mu;
105        // inner = lam ./ y, elementwise.
106        let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i));
107        // w_plam = plam ./ w, elementwise (length n_sub).
108        let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));
109
110        // Scale each column of psi by the corresponding element of 'inner'
111
112        if psi.ncols() > n_threads * 128 {
113            psi_inner
114                .par_col_partition_mut(n_threads)
115                .zip(psi.par_col_partition(n_threads))
116                .zip(inner.par_partition(n_threads))
117                .zip(output.par_iter_mut())
118                .for_each(|(((mut psi_inner, psi), inner), output)| {
119                    psi_inner
120                        .as_mut()
121                        .col_iter_mut()
122                        .zip(psi.col_iter())
123                        .zip(inner.iter())
124                        .for_each(|((col, psi_col), inner_val)| {
125                            col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
126                                *x = psi_val * inner_val;
127                            });
128                        });
129                    faer::linalg::matmul::triangular::matmul(
130                        output.as_mut(),
131                        faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
132                        faer::Accum::Replace,
133                        &psi_inner,
134                        faer::linalg::matmul::triangular::BlockStructure::Rectangular,
135                        psi.transpose(),
136                        faer::linalg::matmul::triangular::BlockStructure::Rectangular,
137                        1.0,
138                        faer::Par::Seq,
139                    );
140                });
141
142            let mut first_iter = true;
143            for output in &output {
144                if first_iter {
145                    h.copy_from(output);
146                    first_iter = false;
147                } else {
148                    h += output;
149                }
150            }
151        } else {
152            psi_inner
153                .as_mut()
154                .col_iter_mut()
155                .zip(psi.col_iter())
156                .zip(inner.iter())
157                .for_each(|((col, psi_col), inner_val)| {
158                    col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
159                        *x = psi_val * inner_val;
160                    });
161                });
162            faer::linalg::matmul::triangular::matmul(
163                h.as_mut(),
164                faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
165                faer::Accum::Replace,
166                &psi_inner,
167                faer::linalg::matmul::triangular::BlockStructure::Rectangular,
168                psi.transpose(),
169                faer::linalg::matmul::triangular::BlockStructure::Rectangular,
170                1.0,
171                faer::Par::Seq,
172            );
173        }
174
175        for i in 0..h.nrows() {
176            h[(i, i)] += w_plam[i];
177        }
178
179        let uph = match h.llt(faer::Side::Lower) {
180            Ok(llt) => llt,
181            Err(_) => {
182                bail!("Error during Cholesky decomposition")
183            }
184        };
185        let uph = uph.L().transpose().to_owned();
186
187        // smuyinv = smu * (ecol ./ y)
188        let smuyinv: Col<f64> = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i]));
189
190        // let smuyinv = smu * (&ecol / &y);
191        // rhsdw = (erow ./ w) - (psi · smuyinv)
192        let psi_dot_muyinv: Col<f64> = &psi * &smuyinv;
193
194        let rhsdw: Row<f64> = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]);
195
196        //let rhsdw = (&erow / &w) - psi * &smuyinv;
197        // Reshape rhsdw into a column vector.
198        let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i));
199
200        // let a = rhsdw
201        //     .into_shape((n_sub, 1))
202        //     .context("Failed to reshape rhsdw").unwrap();
203
204        // Solve the triangular systems:
205
206        solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0));
207
208        solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0));
209
210        // Extract dw (a column vector) from the solution.
211        let dw = dw.col(0);
212
213        // let dw = dw_aux.column(0);
214        // Compute dy = - (ψᵀ · dw)
215        let dy = -(psi.transpose() * dw);
216
217        let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]);
218
219        let dlam: Row<f64> =
220            Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]);
221        // let dlam = &smuyinv - &lam - inner.transpose() * &dy;
222
223        // Compute the primal step length alfpri.
224        let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]);
225        //let ratio_dlam_lam = &dlam / &lam;
226        let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min);
227        let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5);
228        alfpri = (0.99995 * alfpri).min(1.0);
229
230        // Compute the dual step length alfdual.
231        let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]);
232        // let ratio_dy_y = &dy / &y;
233        let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min);
234        let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]);
235        //let ratio_dw_w = &dw / &w;
236        let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min);
237        let mut alfdual = -1.0 / min_ratio_dy.min(-0.5);
238        alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5));
239        alfdual = (0.99995 * alfdual).min(1.0);
240
241        // Update the iterates.
242        lam += alfpri * dlam.transpose();
243        w += alfdual * dw;
244        y += alfdual * &dy;
245
246        mu = lam.transpose() * &y / n_point as f64;
247        plam = &psi * &lam;
248
249        // mu = lam.dot(&y) / n_point as f64;
250        // plam = psi.dot(&lam);
251        r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
252        ptw -= alfdual * dy;
253
254        norm_r = r.norm_max();
255        let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
256        let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
257        gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
258
259        // Adjust sigma.
260        if mu < eps && norm_r > eps {
261            sig = 1.0;
262        } else {
263            let candidate1 = (1.0 - alfpri).powi(2);
264            let candidate2 = (1.0 - alfdual).powi(2);
265            let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu);
266            sig = candidate1.max(candidate2).max(candidate3).min(0.3);
267        }
268    }
269    // Scale lam.
270    lam /= n_sub as f64;
271    // Compute the objective function value: sum(ln(psi·lam)).
272    let obj = (psi * &lam).iter().map(|x| x.ln()).sum();
273    // Normalize lam to sum to 1.
274    let lam_sum: f64 = lam.iter().sum();
275    lam = &lam / lam_sum;
276
277    Ok((lam, obj))
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use approx::assert_relative_eq;
284    use faer::Mat;
285
286    #[test]
287    fn test_burke_identity() {
288        // Test with a small identity matrix
289        // For an identity matrix, each support point should have equal weight
290        let n = 100;
291        let mat = Mat::identity(n, n);
292        let psi = Psi::from(mat);
293
294        let (lam, _) = burke(&psi).unwrap();
295
296        // For identity matrix, all lambda values should be equal
297        let expected = 1.0 / n as f64;
298        for i in 0..n {
299            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
300        }
301
302        // Check that lambda sums to 1
303        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
304    }
305
306    #[test]
307    fn test_burke_uniform_square() {
308        // Test with a matrix of all ones
309        // This should also result in uniform weights
310        let n_sub = 10;
311        let n_point = 10;
312        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
313        let psi = Psi::from(mat);
314
315        let (lam, _) = burke(&psi).unwrap();
316
317        // Check that lambda sums to 1
318        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
319
320        // For uniform matrix, all lambda values should be equal
321        let expected = 1.0 / n_point as f64;
322        for i in 0..n_point {
323            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
324        }
325    }
326
327    #[test]
328    fn test_burke_uniform_wide() {
329        // Test with a matrix of all ones
330        // This should also result in uniform weights
331        let n_sub = 10;
332        let n_point = 100;
333        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
334        let psi = Psi::from(mat);
335
336        let (lam, _) = burke(&psi).unwrap();
337
338        // Check that lambda sums to 1
339        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
340
341        // For uniform matrix, all lambda values should be equal
342        let expected = 1.0 / n_point as f64;
343        for i in 0..n_point {
344            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
345        }
346    }
347
348    #[test]
349    fn test_burke_uniform_long() {
350        // Test with a matrix of all ones
351        // This should also result in uniform weights
352        let n_sub = 100;
353        let n_point = 10;
354        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
355        let psi = Psi::from(mat);
356
357        let (lam, _) = burke(&psi).unwrap();
358
359        // Check that lambda sums to 1
360        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
361
362        // For uniform matrix, all lambda values should be equal
363        let expected = 1.0 / n_point as f64;
364        for i in 0..n_point {
365            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
366        }
367    }
368
369    #[test]
370    fn test_burke_with_non_uniform_matrix() {
371        // Test with a non-uniform matrix
372        // Create a matrix where one column is clearly better
373        let n_sub = 3;
374        let n_point = 4;
375        let mat = Mat::from_fn(n_sub, n_point, |_, j| if j == 0 { 10.0 } else { 1.0 });
376        let psi = Psi::from(mat);
377
378        let (lam, _) = burke(&psi).unwrap();
379
380        // Check that lambda sums to 1
381        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
382
383        // First support point should have highest weight
384        assert!(lam[0] > lam[1]);
385        assert!(lam[0] > lam[2]);
386        assert!(lam[0] > lam[3]);
387    }
388
389    #[test]
390    fn test_burke_with_negative_values() {
391        // The algorithm should handle negative values by taking their absolute value
392        let n_sub = 2;
393        let n_point = 3;
394        let mat = Mat::from_fn(
395            n_sub,
396            n_point,
397            |i, j| if i == 0 && j == 0 { -5.0 } else { 1.0 },
398        );
399        let psi = Psi::from(mat);
400
401        let result = burke(&psi);
402        assert!(result.is_ok());
403
404        let (lam, _) = result.unwrap();
405        // Check that lambda sums to 1
406        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
407
408        // First support point should have highest weight due to the high absolute value
409        assert!(lam[0] > lam[1]);
410        assert!(lam[0] > lam[2]);
411    }
412
413    #[test]
414    fn test_burke_with_non_finite_values() {
415        // The algorithm should return an error for non-finite values
416        let n_sub = 10;
417        let n_point = 10;
418        let mat = Mat::from_fn(n_sub, n_point, |i, j| {
419            if i == 0 && j == 0 {
420                f64::NAN
421            } else {
422                1.0
423            }
424        });
425        let psi = Psi::from(mat);
426
427        let result = burke(&psi);
428        assert!(result.is_err());
429    }
430}