pmcore/routines/evaluation/
ipm.rs1use crate::structs::psi::Psi;
2use crate::structs::weights::Weights;
3use anyhow::bail;
4use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
5use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
6use faer::{Col, Mat, Row};
7use rayon::prelude::*;
8pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
34 let mut psi = psi.matrix().to_owned();
35
36 psi.row_iter_mut().try_for_each(|row| {
38 row.iter_mut().try_for_each(|x| {
39 if !x.is_finite() {
40 bail!("Input matrix must have finite entries")
41 } else {
42 *x = x.abs();
44 Ok(())
45 }
46 })
47 })?;
48
49 let (n_sub, n_point) = psi.shape();
51
52 let ecol: Col<f64> = Col::from_fn(n_point, |_| 1.0);
56 let erow: Row<f64> = Row::from_fn(n_sub, |_| 1.0);
57
58 let mut plam: Col<f64> = &psi * &ecol;
60 let eps: f64 = 1e-8;
61 let mut sig: f64 = 0.0;
62
63 let mut lam = ecol.clone();
65
66 let mut w: Col<f64> = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i));
68
69 let mut ptw: Col<f64> = psi.transpose() * &w;
71
72 let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
74 let shrink = 2.0 * ptw_max;
75 lam *= shrink;
76 plam *= shrink;
77 w /= shrink;
78 ptw /= shrink;
79
80 let mut y: Col<f64> = &ecol - &ptw;
82 let mut r: Col<f64> = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
84 let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs()));
85
86 let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
88 let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
89 let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
90
91 let mut mu = lam.transpose() * &y / n_point as f64;
93
94 let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());
95
96 let n_threads = faer::get_global_parallelism().degree();
97
98 let rows = psi.nrows();
99
100 let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();
101
102 let mut h: Mat<f64> = Mat::zeros(rows, rows);
103
104 while mu > eps || norm_r > eps || gap > eps {
105 let smu = sig * mu;
106 let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i));
108 let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));
110
111 if psi.ncols() > n_threads * 128 {
114 psi_inner
115 .par_col_partition_mut(n_threads)
116 .zip(psi.par_col_partition(n_threads))
117 .zip(inner.par_partition(n_threads))
118 .zip(output.par_iter_mut())
119 .for_each(|(((mut psi_inner, psi), inner), output)| {
120 psi_inner
121 .as_mut()
122 .col_iter_mut()
123 .zip(psi.col_iter())
124 .zip(inner.iter())
125 .for_each(|((col, psi_col), inner_val)| {
126 col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
127 *x = psi_val * inner_val;
128 });
129 });
130 faer::linalg::matmul::triangular::matmul(
131 output.as_mut(),
132 faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
133 faer::Accum::Replace,
134 &psi_inner,
135 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
136 psi.transpose(),
137 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
138 1.0,
139 faer::Par::Seq,
140 );
141 });
142
143 let mut first_iter = true;
144 for output in &output {
145 if first_iter {
146 h.copy_from(output);
147 first_iter = false;
148 } else {
149 h += output;
150 }
151 }
152 } else {
153 psi_inner
154 .as_mut()
155 .col_iter_mut()
156 .zip(psi.col_iter())
157 .zip(inner.iter())
158 .for_each(|((col, psi_col), inner_val)| {
159 col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
160 *x = psi_val * inner_val;
161 });
162 });
163 faer::linalg::matmul::triangular::matmul(
164 h.as_mut(),
165 faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
166 faer::Accum::Replace,
167 &psi_inner,
168 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
169 psi.transpose(),
170 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
171 1.0,
172 faer::Par::Seq,
173 );
174 }
175
176 for i in 0..h.nrows() {
177 h[(i, i)] += w_plam[i];
178 }
179
180 let uph = match h.llt(faer::Side::Lower) {
181 Ok(llt) => llt,
182 Err(_) => {
183 bail!("Error during Cholesky decomposition. The matrix might not be positive definite. This is usually due to model misspecification or numerical issues.")
184 }
185 };
186 let uph = uph.L().transpose().to_owned();
187
188 let smuyinv: Col<f64> = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i]));
190
191 let psi_dot_muyinv: Col<f64> = &psi * &smuyinv;
194
195 let rhsdw: Row<f64> = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]);
196
197 let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i));
200
201 solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0));
208
209 solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0));
210
211 let dw = dw.col(0);
213
214 let dy = -(psi.transpose() * dw);
217
218 let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]);
219
220 let dlam: Row<f64> =
221 Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]);
222 let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]);
226 let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min);
228 let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5);
229 alfpri = (0.99995 * alfpri).min(1.0);
230
231 let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]);
233 let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min);
235 let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]);
236 let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min);
238 let mut alfdual = -1.0 / min_ratio_dy.min(-0.5);
239 alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5));
240 alfdual = (0.99995 * alfdual).min(1.0);
241
242 lam += alfpri * dlam.transpose();
244 w += alfdual * dw;
245 y += alfdual * &dy;
246
247 mu = lam.transpose() * &y / n_point as f64;
248 plam = &psi * &lam;
249
250 r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
253 ptw -= alfdual * dy;
254
255 norm_r = r.norm_max();
256 let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
257 let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
258 gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
259
260 if mu < eps && norm_r > eps {
262 sig = 1.0;
263 } else {
264 let candidate1 = (1.0 - alfpri).powi(2);
265 let candidate2 = (1.0 - alfdual).powi(2);
266 let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu);
267 sig = candidate1.max(candidate2).max(candidate3).min(0.3);
268 }
269 }
270 lam /= n_sub as f64;
272 let obj = (psi * &lam).iter().map(|x| x.ln()).sum();
274 let lam_sum: f64 = lam.iter().sum();
276 lam = &lam / lam_sum;
277
278 Ok((lam.into(), obj))
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use approx::assert_relative_eq;
285 use faer::Mat;
286
287 #[test]
288 fn test_burke_identity() {
289 let n = 100;
292 let mat = Mat::identity(n, n);
293 let psi = Psi::from(mat);
294
295 let (lam, _) = burke(&psi).unwrap();
296
297 let expected = 1.0 / n as f64;
299 for i in 0..n {
300 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
301 }
302
303 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
305 }
306
307 #[test]
308 fn test_burke_uniform_square() {
309 let n_sub = 10;
312 let n_point = 10;
313 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
314 let psi = Psi::from(mat);
315
316 let (lam, _) = burke(&psi).unwrap();
317
318 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
320
321 let expected = 1.0 / n_point as f64;
323 for i in 0..n_point {
324 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
325 }
326 }
327
328 #[test]
329 fn test_burke_uniform_wide() {
330 let n_sub = 10;
333 let n_point = 100;
334 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
335 let psi = Psi::from(mat);
336
337 let (lam, _) = burke(&psi).unwrap();
338
339 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
341
342 let expected = 1.0 / n_point as f64;
344 for i in 0..n_point {
345 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
346 }
347 }
348
349 #[test]
350 fn test_burke_uniform_long() {
351 let n_sub = 100;
354 let n_point = 10;
355 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
356 let psi = Psi::from(mat);
357
358 let (lam, _) = burke(&psi).unwrap();
359
360 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
362
363 let expected = 1.0 / n_point as f64;
365 for i in 0..n_point {
366 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
367 }
368 }
369
370 #[test]
371 fn test_burke_with_non_uniform_matrix() {
372 let n_sub = 3;
375 let n_point = 4;
376 let mat = Mat::from_fn(n_sub, n_point, |_, j| if j == 0 { 10.0 } else { 1.0 });
377 let psi = Psi::from(mat);
378
379 let (lam, _) = burke(&psi).unwrap();
380
381 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
383
384 assert!(lam[0] > lam[1]);
386 assert!(lam[0] > lam[2]);
387 assert!(lam[0] > lam[3]);
388 }
389
390 #[test]
391 fn test_burke_with_negative_values() {
392 let n_sub = 2;
394 let n_point = 3;
395 let mat = Mat::from_fn(
396 n_sub,
397 n_point,
398 |i, j| if i == 0 && j == 0 { -5.0 } else { 1.0 },
399 );
400 let psi = Psi::from(mat);
401
402 let result = burke(&psi);
403 assert!(result.is_ok());
404
405 let (lam, _) = result.unwrap();
406 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
408
409 assert!(lam[0] > lam[1]);
411 assert!(lam[0] > lam[2]);
412 }
413
414 #[test]
415 fn test_burke_with_non_finite_values() {
416 let n_sub = 10;
418 let n_point = 10;
419 let mat = Mat::from_fn(n_sub, n_point, |i, j| {
420 if i == 0 && j == 0 {
421 f64::NAN
422 } else {
423 1.0
424 }
425 });
426 let psi = Psi::from(mat);
427
428 let result = burke(&psi);
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn test_burke_large_matrix_parallel_processing() {
434 let n_sub = 50;
437 let n_point = 10000;
438
439 let mat = Mat::from_fn(n_sub, n_point, |_i, _j| 1.0);
442 let psi = Psi::from(mat);
443
444 let result = burke(&psi);
445 assert!(
446 result.is_ok(),
447 "Burke algorithm should succeed with large matrix"
448 );
449
450 let (lam, obj) = result.unwrap();
451
452 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
454
455 for i in 0..n_point {
457 assert!(lam[i] >= 0.0, "Lambda values should be non-negative");
458 }
459
460 assert!(obj.is_finite(), "Objective function should be finite");
462
463 let max_weight = lam
470 .weights()
471 .iter()
472 .cloned()
473 .fold(f64::NEG_INFINITY, f64::max);
474 assert!(
475 max_weight < 0.1,
476 "No single weight should dominate in uniform matrix (max weight: {})",
477 max_weight
478 );
479 }
480
481 #[test]
482 fn test_burke_medium_matrix_sequential_processing() {
483 let n_sub = 50;
486 let n_point = 500; let mat = Mat::from_fn(n_sub, n_point, |i, j| {
490 if j % 100 == 0 {
491 5.0 + 0.1 * (i as f64)
492 } else {
493 1.0 + 0.01 * (i as f64) + 0.001 * (j as f64)
494 }
495 });
496 let psi = Psi::from(mat);
497
498 let result = burke(&psi);
499 assert!(
500 result.is_ok(),
501 "Burke algorithm should succeed with medium matrix"
502 );
503
504 let (lam, obj) = result.unwrap();
505
506 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
508
509 for i in 0..n_point {
511 assert!(lam[i] >= 0.0, "Lambda values should be non-negative");
512 }
513
514 assert!(obj.is_finite(), "Objective function should be finite");
516 }
517}