pmcore/routines/evaluation/
qr.rs

1use crate::structs::psi::Psi;
2use anyhow::{bail, Result};
3use faer::linalg::solvers::ColPivQr;
4use faer::Mat;
5
6/// Perform a QR decomposition on the Psi matrix
7///
8/// Normalizes each row of the matrix to sum to 1 before decomposition.
9/// Returns the R matrix from QR decomposition and the column permutation vector.
10///
11/// # Arguments
12/// * `psi` - The Psi matrix to decompose
13///
14/// # Returns
15/// * Tuple containing the R matrix (as [faer::Mat<f64>]) and permutation vector (as [Vec<usize>])
16/// * Error if any row in the matrix sums to zero
17pub fn qrd(psi: &Psi) -> Result<(Mat<f64>, Vec<usize>)> {
18    let mut mat = psi.matrix().to_owned();
19
20    // Normalize the rows to sum to 1
21    for (index, row) in mat.row_iter_mut().enumerate() {
22        let row_sum: f64 = row.as_ref().iter().sum();
23
24        // Check if the row sum is zero
25        if row_sum.abs() == 0.0 {
26            bail!("In psi, the row with index {} sums to zero", index);
27        }
28        row.iter_mut().for_each(|x| *x /= row_sum);
29    }
30
31    // Perform column pivoted QR decomposition
32    let qr: ColPivQr<f64> = mat.col_piv_qr();
33
34    // Extract the R matrix
35    let r_mat: faer::Mat<f64> = qr.R().to_owned();
36
37    // Get the permutation information
38    let perm = qr.P().arrays().0.to_vec();
39    Ok((r_mat, perm))
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    #[test]
47    fn test_identity() {
48        // Create a 2x2 identity matrix
49        let mat: Mat<f64> = Mat::identity(10, 10);
50        let psi = Psi::from(mat);
51
52        // Perform the QR decomposition
53        let (r_mat, perm) = qrd(&psi).unwrap();
54
55        // Check that R is an identity matrix
56        let expected_r_mat: Mat<f64> = Mat::identity(10, 10);
57        assert_eq!(r_mat, expected_r_mat);
58
59        // Check that the permutation is the identity
60        assert_eq!(perm, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
61    }
62
63    #[test]
64    fn test_with_zero_row_sum() {
65        // Create a test matrix with a row that sums to zero
66        let mat = Mat::from_fn(2, 2, |i, j| {
67            match (i, j) {
68                (0, 0) => 1.0,
69                (0, 1) => 2.0,
70                (1, 0) => 0.0, // Row that sums to zero
71                (1, 1) => 0.0,
72                _ => 0.0,
73            }
74        });
75        let psi = Psi::from(mat);
76
77        // Perform the QR decomposition
78        let result = qrd(&psi);
79
80        // Confirm that the function returns an error
81        assert!(result.is_err(), "Expected an error due to zero row sum");
82    }
83
84    #[test]
85    fn test_empty_matrix() {
86        // Create an empty Psi
87        let mat = Mat::<f64>::new();
88        let psi = Psi::from(mat);
89
90        // Should not panic
91        let (r_mat, perm) = qrd(&psi).unwrap();
92
93        // Empty matrix should produce empty results
94        assert_eq!(r_mat.nrows(), 0);
95        assert_eq!(r_mat.ncols(), 0);
96        assert_eq!(perm.len(), 0);
97    }
98}