pmcore/routines/evaluation/
ipm.rs

1use crate::structs::psi::Psi;
2use crate::structs::weights::Weights;
3use anyhow::bail;
4use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
5use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
6use faer::{Col, Mat, Row};
7use rayon::prelude::*;
8/// Applies Burke's Interior Point Method (IPM) to solve a convex optimization problem.
9///
10/// The objective function to maximize is:
11///     f(x) = Σ(log(Σ(ψ_ij * x_j)))   for i = 1 to n_sub
12///
13/// subject to:
14///     1. x_j ≥ 0 for all j = 1 to n_point,
15///     2. Σ(x_j) = 1,
16///
17/// where ψ is an n_sub×n_point matrix with non-negative entries and x is a probability vector.
18///
19/// # Arguments
20///
21/// * `psi` - A reference to a Psi structure containing the input matrix.
22///
23/// # Returns
24///
25/// On success, returns a tuple `(weights, obj)` where:
26///   - [Weights] contains the optimized weights (probabilities) for each support point.
27///   - `obj` is the value of the objective function at the solution.
28///
29/// # Errors
30///
31/// This function returns an error if any step in the optimization (e.g. Cholesky factorization)
32/// fails.
33pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
34    let mut psi = psi.matrix().to_owned();
35
36    // Ensure all entries are finite and make them non-negative.
37    psi.row_iter_mut().try_for_each(|row| {
38        row.iter_mut().try_for_each(|x| {
39            if !x.is_finite() {
40                bail!("Input matrix must have finite entries")
41            } else {
42                // Coerce negatives to non-negative (could alternatively return an error)
43                *x = x.abs();
44                Ok(())
45            }
46        })
47    })?;
48
49    // Let psi be of shape (n_sub, n_point)
50    let (n_sub, n_point) = psi.shape();
51
52    // Create unit vectors:
53    // ecol: ones vector of length n_point (used for sums over points)
54    // erow: ones row of length n_sub (used for sums over subproblems)
55    let ecol: Col<f64> = Col::from_fn(n_point, |_| 1.0);
56    let erow: Row<f64> = Row::from_fn(n_sub, |_| 1.0);
57
58    // Compute plam = psi · ecol. This gives a column vector of length n_sub.
59    let mut plam: Col<f64> = &psi * &ecol;
60    let eps: f64 = 1e-8;
61    let mut sig: f64 = 0.0;
62
63    // Initialize lam (the variable we optimize) as a column vector of ones (length n_point).
64    let mut lam = ecol.clone();
65
66    // w = 1 ./ plam, elementwise.
67    let mut w: Col<f64> = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i));
68
69    // ptw = ψᵀ · w, which will be a vector of length n_point.
70    let mut ptw: Col<f64> = psi.transpose() * &w;
71
72    // Use the maximum entry in ptw for scaling (the "shrink" factor).
73    let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
74    let shrink = 2.0 * ptw_max;
75    lam *= shrink;
76    plam *= shrink;
77    w /= shrink;
78    ptw /= shrink;
79
80    // y = ecol - ptw (a vector of length n_point).
81    let mut y: Col<f64> = &ecol - &ptw;
82    // r = erow - (w .* plam) (elementwise product; r has length n_sub).
83    let mut r: Col<f64> = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
84    let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs()));
85
86    // Compute the duality gap.
87    let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
88    let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
89    let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
90
91    // Compute the duality measure mu.
92    let mut mu = lam.transpose() * &y / n_point as f64;
93
94    let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());
95
96    let n_threads = faer::get_global_parallelism().degree();
97
98    let rows = psi.nrows();
99
100    let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();
101
102    let mut h: Mat<f64> = Mat::zeros(rows, rows);
103
104    while mu > eps || norm_r > eps || gap > eps {
105        let smu = sig * mu;
106        // inner = lam ./ y, elementwise.
107        let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i));
108        // w_plam = plam ./ w, elementwise (length n_sub).
109        let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));
110
111        // Scale each column of psi by the corresponding element of 'inner'
112
113        if psi.ncols() > n_threads * 128 {
114            psi_inner
115                .par_col_partition_mut(n_threads)
116                .zip(psi.par_col_partition(n_threads))
117                .zip(inner.par_partition(n_threads))
118                .zip(output.par_iter_mut())
119                .for_each(|(((mut psi_inner, psi), inner), output)| {
120                    psi_inner
121                        .as_mut()
122                        .col_iter_mut()
123                        .zip(psi.col_iter())
124                        .zip(inner.iter())
125                        .for_each(|((col, psi_col), inner_val)| {
126                            col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
127                                *x = psi_val * inner_val;
128                            });
129                        });
130                    faer::linalg::matmul::triangular::matmul(
131                        output.as_mut(),
132                        faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
133                        faer::Accum::Replace,
134                        &psi_inner,
135                        faer::linalg::matmul::triangular::BlockStructure::Rectangular,
136                        psi.transpose(),
137                        faer::linalg::matmul::triangular::BlockStructure::Rectangular,
138                        1.0,
139                        faer::Par::Seq,
140                    );
141                });
142
143            let mut first_iter = true;
144            for output in &output {
145                if first_iter {
146                    h.copy_from(output);
147                    first_iter = false;
148                } else {
149                    h += output;
150                }
151            }
152        } else {
153            psi_inner
154                .as_mut()
155                .col_iter_mut()
156                .zip(psi.col_iter())
157                .zip(inner.iter())
158                .for_each(|((col, psi_col), inner_val)| {
159                    col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
160                        *x = psi_val * inner_val;
161                    });
162                });
163            faer::linalg::matmul::triangular::matmul(
164                h.as_mut(),
165                faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
166                faer::Accum::Replace,
167                &psi_inner,
168                faer::linalg::matmul::triangular::BlockStructure::Rectangular,
169                psi.transpose(),
170                faer::linalg::matmul::triangular::BlockStructure::Rectangular,
171                1.0,
172                faer::Par::Seq,
173            );
174        }
175
176        for i in 0..h.nrows() {
177            h[(i, i)] += w_plam[i];
178        }
179
180        let uph = match h.llt(faer::Side::Lower) {
181            Ok(llt) => llt,
182            Err(_) => {
183                bail!("Error during Cholesky decomposition. The matrix might not be positive definite. This is usually due to model misspecification or numerical issues.")
184            }
185        };
186        let uph = uph.L().transpose().to_owned();
187
188        // smuyinv = smu * (ecol ./ y)
189        let smuyinv: Col<f64> = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i]));
190
191        // let smuyinv = smu * (&ecol / &y);
192        // rhsdw = (erow ./ w) - (psi · smuyinv)
193        let psi_dot_muyinv: Col<f64> = &psi * &smuyinv;
194
195        let rhsdw: Row<f64> = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]);
196
197        //let rhsdw = (&erow / &w) - psi * &smuyinv;
198        // Reshape rhsdw into a column vector.
199        let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i));
200
201        // let a = rhsdw
202        //     .into_shape((n_sub, 1))
203        //     .context("Failed to reshape rhsdw").unwrap();
204
205        // Solve the triangular systems:
206
207        solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0));
208
209        solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0));
210
211        // Extract dw (a column vector) from the solution.
212        let dw = dw.col(0);
213
214        // let dw = dw_aux.column(0);
215        // Compute dy = - (ψᵀ · dw)
216        let dy = -(psi.transpose() * dw);
217
218        let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]);
219
220        let dlam: Row<f64> =
221            Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]);
222        // let dlam = &smuyinv - &lam - inner.transpose() * &dy;
223
224        // Compute the primal step length alfpri.
225        let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]);
226        //let ratio_dlam_lam = &dlam / &lam;
227        let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min);
228        let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5);
229        alfpri = (0.99995 * alfpri).min(1.0);
230
231        // Compute the dual step length alfdual.
232        let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]);
233        // let ratio_dy_y = &dy / &y;
234        let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min);
235        let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]);
236        //let ratio_dw_w = &dw / &w;
237        let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min);
238        let mut alfdual = -1.0 / min_ratio_dy.min(-0.5);
239        alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5));
240        alfdual = (0.99995 * alfdual).min(1.0);
241
242        // Update the iterates.
243        lam += alfpri * dlam.transpose();
244        w += alfdual * dw;
245        y += alfdual * &dy;
246
247        mu = lam.transpose() * &y / n_point as f64;
248        plam = &psi * &lam;
249
250        // mu = lam.dot(&y) / n_point as f64;
251        // plam = psi.dot(&lam);
252        r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
253        ptw -= alfdual * dy;
254
255        norm_r = r.norm_max();
256        let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
257        let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
258        gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
259
260        // Adjust sigma.
261        if mu < eps && norm_r > eps {
262            sig = 1.0;
263        } else {
264            let candidate1 = (1.0 - alfpri).powi(2);
265            let candidate2 = (1.0 - alfdual).powi(2);
266            let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu);
267            sig = candidate1.max(candidate2).max(candidate3).min(0.3);
268        }
269    }
270    // Scale lam.
271    lam /= n_sub as f64;
272    // Compute the objective function value: sum(ln(psi·lam)).
273    let obj = (psi * &lam).iter().map(|x| x.ln()).sum();
274    // Normalize lam to sum to 1.
275    let lam_sum: f64 = lam.iter().sum();
276    lam = &lam / lam_sum;
277
278    Ok((lam.into(), obj))
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use approx::assert_relative_eq;
285    use faer::Mat;
286
287    #[test]
288    fn test_burke_identity() {
289        // Test with a small identity matrix
290        // For an identity matrix, each support point should have equal weight
291        let n = 100;
292        let mat = Mat::identity(n, n);
293        let psi = Psi::from(mat);
294
295        let (lam, _) = burke(&psi).unwrap();
296
297        // For identity matrix, all lambda values should be equal
298        let expected = 1.0 / n as f64;
299        for i in 0..n {
300            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
301        }
302
303        // Check that lambda sums to 1
304        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
305    }
306
307    #[test]
308    fn test_burke_uniform_square() {
309        // Test with a matrix of all ones
310        // This should also result in uniform weights
311        let n_sub = 10;
312        let n_point = 10;
313        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
314        let psi = Psi::from(mat);
315
316        let (lam, _) = burke(&psi).unwrap();
317
318        // Check that lambda sums to 1
319        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
320
321        // For uniform matrix, all lambda values should be equal
322        let expected = 1.0 / n_point as f64;
323        for i in 0..n_point {
324            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
325        }
326    }
327
328    #[test]
329    fn test_burke_uniform_wide() {
330        // Test with a matrix of all ones
331        // This should also result in uniform weights
332        let n_sub = 10;
333        let n_point = 100;
334        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
335        let psi = Psi::from(mat);
336
337        let (lam, _) = burke(&psi).unwrap();
338
339        // Check that lambda sums to 1
340        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
341
342        // For uniform matrix, all lambda values should be equal
343        let expected = 1.0 / n_point as f64;
344        for i in 0..n_point {
345            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
346        }
347    }
348
349    #[test]
350    fn test_burke_uniform_long() {
351        // Test with a matrix of all ones
352        // This should also result in uniform weights
353        let n_sub = 100;
354        let n_point = 10;
355        let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
356        let psi = Psi::from(mat);
357
358        let (lam, _) = burke(&psi).unwrap();
359
360        // Check that lambda sums to 1
361        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
362
363        // For uniform matrix, all lambda values should be equal
364        let expected = 1.0 / n_point as f64;
365        for i in 0..n_point {
366            assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
367        }
368    }
369
370    #[test]
371    fn test_burke_with_non_uniform_matrix() {
372        // Test with a non-uniform matrix
373        // Create a matrix where one column is clearly better
374        let n_sub = 3;
375        let n_point = 4;
376        let mat = Mat::from_fn(n_sub, n_point, |_, j| if j == 0 { 10.0 } else { 1.0 });
377        let psi = Psi::from(mat);
378
379        let (lam, _) = burke(&psi).unwrap();
380
381        // Check that lambda sums to 1
382        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
383
384        // First support point should have highest weight
385        assert!(lam[0] > lam[1]);
386        assert!(lam[0] > lam[2]);
387        assert!(lam[0] > lam[3]);
388    }
389
390    #[test]
391    fn test_burke_with_negative_values() {
392        // The algorithm should handle negative values by taking their absolute value
393        let n_sub = 2;
394        let n_point = 3;
395        let mat = Mat::from_fn(
396            n_sub,
397            n_point,
398            |i, j| if i == 0 && j == 0 { -5.0 } else { 1.0 },
399        );
400        let psi = Psi::from(mat);
401
402        let result = burke(&psi);
403        assert!(result.is_ok());
404
405        let (lam, _) = result.unwrap();
406        // Check that lambda sums to 1
407        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
408
409        // First support point should have highest weight due to the high absolute value
410        assert!(lam[0] > lam[1]);
411        assert!(lam[0] > lam[2]);
412    }
413
414    #[test]
415    fn test_burke_with_non_finite_values() {
416        // The algorithm should return an error for non-finite values
417        let n_sub = 10;
418        let n_point = 10;
419        let mat = Mat::from_fn(n_sub, n_point, |i, j| {
420            if i == 0 && j == 0 {
421                f64::NAN
422            } else {
423                1.0
424            }
425        });
426        let psi = Psi::from(mat);
427
428        let result = burke(&psi);
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn test_burke_large_matrix_parallel_processing() {
434        // Test with a large matrix to trigger the parallel processing code path
435        // This should exceed n_threads * 128 threshold
436        let n_sub = 50;
437        let n_point = 10000;
438
439        // Create a simple uniform matrix
440        // The main goal is to test that parallel processing works correctly
441        let mat = Mat::from_fn(n_sub, n_point, |_i, _j| 1.0);
442        let psi = Psi::from(mat);
443
444        let result = burke(&psi);
445        assert!(
446            result.is_ok(),
447            "Burke algorithm should succeed with large matrix"
448        );
449
450        let (lam, obj) = result.unwrap();
451
452        // Verify basic mathematical properties of the solution
453        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
454
455        // All lambda values should be non-negative
456        for i in 0..n_point {
457            assert!(lam[i] >= 0.0, "Lambda values should be non-negative");
458        }
459
460        // The objective function should be finite
461        assert!(obj.is_finite(), "Objective function should be finite");
462
463        // The main test: verify that the parallel processing path was executed
464        // and produced a valid probability distribution
465        // For a uniform matrix, we expect roughly uniform weights, but the exact
466        // distribution depends on the optimization algorithm's convergence
467
468        // Just verify that no single weight dominates excessively (basic sanity check)
469        let max_weight = lam
470            .weights()
471            .iter()
472            .cloned()
473            .fold(f64::NEG_INFINITY, f64::max);
474        assert!(
475            max_weight < 0.1,
476            "No single weight should dominate in uniform matrix (max weight: {})",
477            max_weight
478        );
479    }
480
481    #[test]
482    fn test_burke_medium_matrix_sequential_processing() {
483        // Test with a medium-sized matrix that should NOT trigger parallel processing
484        // This serves as a comparison to ensure both code paths produce similar results
485        let n_sub = 50;
486        let n_point = 500; // This should be < n_threads * 128 threshold
487
488        // Use the same pattern as the large matrix test
489        let mat = Mat::from_fn(n_sub, n_point, |i, j| {
490            if j % 100 == 0 {
491                5.0 + 0.1 * (i as f64)
492            } else {
493                1.0 + 0.01 * (i as f64) + 0.001 * (j as f64)
494            }
495        });
496        let psi = Psi::from(mat);
497
498        let result = burke(&psi);
499        assert!(
500            result.is_ok(),
501            "Burke algorithm should succeed with medium matrix"
502        );
503
504        let (lam, obj) = result.unwrap();
505
506        // Verify basic properties of the solution
507        assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
508
509        // All lambda values should be non-negative
510        for i in 0..n_point {
511            assert!(lam[i] >= 0.0, "Lambda values should be non-negative");
512        }
513
514        // The objective function should be finite
515        assert!(obj.is_finite(), "Objective function should be finite");
516    }
517}