pmcore/routines/evaluation/
ipm.rs1use crate::structs::psi::Psi;
2use anyhow::bail;
3use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
4use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
5use faer::{Col, Mat, Row};
6use rayon::prelude::*;
7pub fn burke(psi: &Psi) -> anyhow::Result<(Col<f64>, f64)> {
33 let mut psi = psi.matrix().to_owned();
34
35 psi.row_iter_mut().try_for_each(|row| {
37 row.iter_mut().try_for_each(|x| {
38 if !x.is_finite() {
39 bail!("Input matrix must have finite entries")
40 } else {
41 *x = x.abs();
43 Ok(())
44 }
45 })
46 })?;
47
48 let (n_sub, n_point) = psi.shape();
50
51 let ecol: Col<f64> = Col::from_fn(n_point, |_| 1.0);
55 let erow: Row<f64> = Row::from_fn(n_sub, |_| 1.0);
56
57 let mut plam: Col<f64> = &psi * &ecol;
59 let eps: f64 = 1e-8;
60 let mut sig: f64 = 0.0;
61
62 let mut lam = ecol.clone();
64
65 let mut w: Col<f64> = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i));
67
68 let mut ptw: Col<f64> = psi.transpose() * &w;
70
71 let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc));
73 let shrink = 2.0 * ptw_max;
74 lam *= shrink;
75 plam *= shrink;
76 w /= shrink;
77 ptw /= shrink;
78
79 let mut y: Col<f64> = &ecol - &ptw;
81 let mut r: Col<f64> = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
83 let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs()));
84
85 let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
87 let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
88 let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
89
90 let mut mu = lam.transpose() * &y / n_point as f64;
92
93 let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());
94
95 let n_threads = faer::get_global_parallelism().degree();
96
97 let rows = psi.nrows();
98
99 let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();
100
101 let mut h: Mat<f64> = Mat::zeros(rows, rows);
102
103 while mu > eps || norm_r > eps || gap > eps {
104 let smu = sig * mu;
105 let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i));
107 let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));
109
110 if psi.ncols() > n_threads * 128 {
113 psi_inner
114 .par_col_partition_mut(n_threads)
115 .zip(psi.par_col_partition(n_threads))
116 .zip(inner.par_partition(n_threads))
117 .zip(output.par_iter_mut())
118 .for_each(|(((mut psi_inner, psi), inner), output)| {
119 psi_inner
120 .as_mut()
121 .col_iter_mut()
122 .zip(psi.col_iter())
123 .zip(inner.iter())
124 .for_each(|((col, psi_col), inner_val)| {
125 col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
126 *x = psi_val * inner_val;
127 });
128 });
129 faer::linalg::matmul::triangular::matmul(
130 output.as_mut(),
131 faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
132 faer::Accum::Replace,
133 &psi_inner,
134 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
135 psi.transpose(),
136 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
137 1.0,
138 faer::Par::Seq,
139 );
140 });
141
142 let mut first_iter = true;
143 for output in &output {
144 if first_iter {
145 h.copy_from(output);
146 first_iter = false;
147 } else {
148 h += output;
149 }
150 }
151 } else {
152 psi_inner
153 .as_mut()
154 .col_iter_mut()
155 .zip(psi.col_iter())
156 .zip(inner.iter())
157 .for_each(|((col, psi_col), inner_val)| {
158 col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
159 *x = psi_val * inner_val;
160 });
161 });
162 faer::linalg::matmul::triangular::matmul(
163 h.as_mut(),
164 faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
165 faer::Accum::Replace,
166 &psi_inner,
167 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
168 psi.transpose(),
169 faer::linalg::matmul::triangular::BlockStructure::Rectangular,
170 1.0,
171 faer::Par::Seq,
172 );
173 }
174
175 for i in 0..h.nrows() {
176 h[(i, i)] += w_plam[i];
177 }
178
179 let uph = match h.llt(faer::Side::Lower) {
180 Ok(llt) => llt,
181 Err(_) => {
182 bail!("Error during Cholesky decomposition")
183 }
184 };
185 let uph = uph.L().transpose().to_owned();
186
187 let smuyinv: Col<f64> = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i]));
189
190 let psi_dot_muyinv: Col<f64> = &psi * &smuyinv;
193
194 let rhsdw: Row<f64> = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]);
195
196 let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i));
199
200 solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0));
207
208 solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0));
209
210 let dw = dw.col(0);
212
213 let dy = -(psi.transpose() * dw);
216
217 let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]);
218
219 let dlam: Row<f64> =
220 Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]);
221 let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]);
225 let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min);
227 let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5);
228 alfpri = (0.99995 * alfpri).min(1.0);
229
230 let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]);
232 let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min);
234 let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]);
235 let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min);
237 let mut alfdual = -1.0 / min_ratio_dy.min(-0.5);
238 alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5));
239 alfdual = (0.99995 * alfdual).min(1.0);
240
241 lam += alfpri * dlam.transpose();
243 w += alfdual * dw;
244 y += alfdual * &dy;
245
246 mu = lam.transpose() * &y / n_point as f64;
247 plam = &psi * &lam;
248
249 r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i));
252 ptw -= alfdual * dy;
253
254 norm_r = r.norm_max();
255 let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum();
256 let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum();
257 gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam);
258
259 if mu < eps && norm_r > eps {
261 sig = 1.0;
262 } else {
263 let candidate1 = (1.0 - alfpri).powi(2);
264 let candidate2 = (1.0 - alfdual).powi(2);
265 let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu);
266 sig = candidate1.max(candidate2).max(candidate3).min(0.3);
267 }
268 }
269 lam /= n_sub as f64;
271 let obj = (psi * &lam).iter().map(|x| x.ln()).sum();
273 let lam_sum: f64 = lam.iter().sum();
275 lam = &lam / lam_sum;
276
277 Ok((lam, obj))
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use approx::assert_relative_eq;
284 use faer::Mat;
285
286 #[test]
287 fn test_burke_identity() {
288 let n = 100;
291 let mat = Mat::identity(n, n);
292 let psi = Psi::from(mat);
293
294 let (lam, _) = burke(&psi).unwrap();
295
296 let expected = 1.0 / n as f64;
298 for i in 0..n {
299 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
300 }
301
302 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
304 }
305
306 #[test]
307 fn test_burke_uniform_square() {
308 let n_sub = 10;
311 let n_point = 10;
312 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
313 let psi = Psi::from(mat);
314
315 let (lam, _) = burke(&psi).unwrap();
316
317 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
319
320 let expected = 1.0 / n_point as f64;
322 for i in 0..n_point {
323 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
324 }
325 }
326
327 #[test]
328 fn test_burke_uniform_wide() {
329 let n_sub = 10;
332 let n_point = 100;
333 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
334 let psi = Psi::from(mat);
335
336 let (lam, _) = burke(&psi).unwrap();
337
338 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
340
341 let expected = 1.0 / n_point as f64;
343 for i in 0..n_point {
344 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
345 }
346 }
347
348 #[test]
349 fn test_burke_uniform_long() {
350 let n_sub = 100;
353 let n_point = 10;
354 let mat = Mat::from_fn(n_sub, n_point, |_, _| 1.0);
355 let psi = Psi::from(mat);
356
357 let (lam, _) = burke(&psi).unwrap();
358
359 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
361
362 let expected = 1.0 / n_point as f64;
364 for i in 0..n_point {
365 assert_relative_eq!(lam[i], expected, epsilon = 1e-10);
366 }
367 }
368
369 #[test]
370 fn test_burke_with_non_uniform_matrix() {
371 let n_sub = 3;
374 let n_point = 4;
375 let mat = Mat::from_fn(n_sub, n_point, |_, j| if j == 0 { 10.0 } else { 1.0 });
376 let psi = Psi::from(mat);
377
378 let (lam, _) = burke(&psi).unwrap();
379
380 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
382
383 assert!(lam[0] > lam[1]);
385 assert!(lam[0] > lam[2]);
386 assert!(lam[0] > lam[3]);
387 }
388
389 #[test]
390 fn test_burke_with_negative_values() {
391 let n_sub = 2;
393 let n_point = 3;
394 let mat = Mat::from_fn(
395 n_sub,
396 n_point,
397 |i, j| if i == 0 && j == 0 { -5.0 } else { 1.0 },
398 );
399 let psi = Psi::from(mat);
400
401 let result = burke(&psi);
402 assert!(result.is_ok());
403
404 let (lam, _) = result.unwrap();
405 assert_relative_eq!(lam.iter().sum::<f64>(), 1.0, epsilon = 1e-10);
407
408 assert!(lam[0] > lam[1]);
410 assert!(lam[0] > lam[2]);
411 }
412
413 #[test]
414 fn test_burke_with_non_finite_values() {
415 let n_sub = 10;
417 let n_point = 10;
418 let mat = Mat::from_fn(n_sub, n_point, |i, j| {
419 if i == 0 && j == 0 {
420 f64::NAN
421 } else {
422 1.0
423 }
424 });
425 let psi = Psi::from(mat);
426
427 let result = burke(&psi);
428 assert!(result.is_err());
429 }
430}