pmcore/routines/condensation/
mod.rs1use crate::algorithms::npag::{burke, qr};
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use crate::structs::weights::Weights;
5use anyhow::Result;
6
7pub fn condense_support_points(
26 theta: &Theta,
27 psi: &Psi,
28 lambda: &Weights,
29 lambda_threshold: f64,
30 qr_threshold: f64,
31) -> Result<(Theta, Psi, Weights, f64)> {
32 let mut filtered_theta = theta.clone();
33 let mut filtered_psi = psi.clone();
34
35 let max_lambda = lambda.iter().fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
37
38 let threshold = max_lambda * lambda_threshold;
39
40 let keep_lambda: Vec<usize> = lambda
41 .iter()
42 .enumerate()
43 .filter(|(_, lam)| *lam > threshold)
44 .map(|(i, _)| i)
45 .collect();
46
47 let initial_count = theta.matrix().nrows();
48 let after_lambda = keep_lambda.len();
49
50 if initial_count != after_lambda {
51 tracing::debug!(
52 "Lambda filtering ({:.0e} × max): {} -> {} support points",
53 lambda_threshold,
54 initial_count,
55 after_lambda
56 );
57 }
58
59 filtered_theta.filter_indices(&keep_lambda);
60 filtered_psi.filter_column_indices(&keep_lambda);
61
62 let (r, perm) = qr::qrd(&filtered_psi)?;
64
65 let mut keep_qr = Vec::<usize>::new();
66
67 let keep_n = filtered_psi
69 .matrix()
70 .ncols()
71 .min(filtered_psi.matrix().nrows());
72
73 for i in 0..keep_n {
74 let test = r.col(i).norm_l2();
75 let r_diag_val = r.get(i, i);
76 let ratio = r_diag_val / test;
77 if ratio.abs() >= qr_threshold {
78 keep_qr.push(*perm.get(i).unwrap());
79 }
80 }
81
82 let after_qr = keep_qr.len();
83
84 if after_lambda != after_qr {
85 tracing::debug!(
86 "QR decomposition (threshold {:.0e}): {} -> {} support points",
87 qr_threshold,
88 after_lambda,
89 after_qr
90 );
91 }
92
93 filtered_theta.filter_indices(&keep_qr);
94 filtered_psi.filter_column_indices(&keep_qr);
95
96 let (final_weights, objf) = burke(&filtered_psi)?;
98
99 tracing::debug!(
100 "Condensation complete: {} -> {} support points (objective: {:.4})",
101 initial_count,
102 filtered_theta.matrix().nrows(),
103 objf
104 );
105
106 Ok((filtered_theta, filtered_psi, final_weights, objf))
107}