pmcore/routines/condensation/
mod.rs

1use crate::algorithms::npag::{burke, qr};
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use crate::structs::weights::Weights;
5use anyhow::Result;
6
7/// Apply lambda filtering and QR decomposition to condense support points
8///
9/// This implements the condensation step used in NPAG algorithms:
10/// 1. Filter support points by lambda (probability) threshold
11/// 2. Apply QR decomposition to remove linearly dependent points
12/// 3. Recalculate weights with Burke's IPM on filtered points
13///
14/// # Arguments
15///
16/// * `theta` - Support points matrix
17/// * `psi` - Likelihood matrix (subjects × support points)
18/// * `lambda` - Initial probability weights for support points
19/// * `lambda_threshold` - Minimum lambda value (relative to max) to keep a point
20/// * `qr_threshold` - QR decomposition threshold for linear independence (typically 1e-8)
21///
22/// # Returns
23///
24/// Returns filtered theta, psi, and recalculated weights, plus the objective function value
25pub fn condense_support_points(
26    theta: &Theta,
27    psi: &Psi,
28    lambda: &Weights,
29    lambda_threshold: f64,
30    qr_threshold: f64,
31) -> Result<(Theta, Psi, Weights, f64)> {
32    let mut filtered_theta = theta.clone();
33    let mut filtered_psi = psi.clone();
34
35    // Step 1: Lambda filtering
36    let max_lambda = lambda.iter().fold(f64::NEG_INFINITY, |acc, x| x.max(acc));
37
38    let threshold = max_lambda * lambda_threshold;
39
40    let keep_lambda: Vec<usize> = lambda
41        .iter()
42        .enumerate()
43        .filter(|(_, lam)| *lam > threshold)
44        .map(|(i, _)| i)
45        .collect();
46
47    let initial_count = theta.matrix().nrows();
48    let after_lambda = keep_lambda.len();
49
50    if initial_count != after_lambda {
51        tracing::debug!(
52            "Lambda filtering ({:.0e} × max): {} -> {} support points",
53            lambda_threshold,
54            initial_count,
55            after_lambda
56        );
57    }
58
59    filtered_theta.filter_indices(&keep_lambda);
60    filtered_psi.filter_column_indices(&keep_lambda);
61
62    // Step 2: QR decomposition filtering
63    let (r, perm) = qr::qrd(&filtered_psi)?;
64
65    let mut keep_qr = Vec::<usize>::new();
66
67    // The minimum between the number of subjects and the actual number of support points
68    let keep_n = filtered_psi
69        .matrix()
70        .ncols()
71        .min(filtered_psi.matrix().nrows());
72
73    for i in 0..keep_n {
74        let test = r.col(i).norm_l2();
75        let r_diag_val = r.get(i, i);
76        let ratio = r_diag_val / test;
77        if ratio.abs() >= qr_threshold {
78            keep_qr.push(*perm.get(i).unwrap());
79        }
80    }
81
82    let after_qr = keep_qr.len();
83
84    if after_lambda != after_qr {
85        tracing::debug!(
86            "QR decomposition (threshold {:.0e}): {} -> {} support points",
87            qr_threshold,
88            after_lambda,
89            after_qr
90        );
91    }
92
93    filtered_theta.filter_indices(&keep_qr);
94    filtered_psi.filter_column_indices(&keep_qr);
95
96    // Step 3: Recalculate weights with Burke's IPM
97    let (final_weights, objf) = burke(&filtered_psi)?;
98
99    tracing::debug!(
100        "Condensation complete: {} -> {} support points (objective: {:.4})",
101        initial_count,
102        filtered_theta.matrix().nrows(),
103        objf
104    );
105
106    Ok((filtered_theta, filtered_psi, final_weights, objf))
107}