pmcore/structs/
theta.rs

1use std::fmt::Debug;
2
3use faer::Mat;
4
5/// [Theta] is a structure that holds the support points
6/// These represent the joint population parameter distribution
7///
8/// Each row represents a support points, and each column a parameter
9#[derive(Clone, PartialEq)]
10pub struct Theta {
11    matrix: Mat<f64>,
12    random: Vec<(String, f64, f64)>,
13    fixed: Vec<(String, f64)>,
14}
15
16impl Default for Theta {
17    fn default() -> Self {
18        Theta {
19            matrix: Mat::new(),
20            random: Vec::new(),
21            fixed: Vec::new(),
22        }
23    }
24}
25
26impl Theta {
27    pub fn new() -> Self {
28        Theta::default()
29    }
30
31    pub(crate) fn from_parts(
32        matrix: Mat<f64>,
33        random: Vec<(String, f64, f64)>,
34        fixed: Vec<(String, f64)>,
35    ) -> Self {
36        Theta {
37            matrix,
38            random,
39            fixed,
40        }
41    }
42
43    /// Get the matrix containing parameter values
44    ///
45    /// The matrix is a 2D array where each row represents a support point, and each column a parameter
46    pub fn matrix(&self) -> &Mat<f64> {
47        &self.matrix
48    }
49
50    /// Set the matrix containing parameter values
51    pub fn set_matrix(&mut self, matrix: Mat<f64>) {
52        self.matrix = matrix;
53    }
54
55    /// Get the number of support points, equal to the number of rows in the matrix
56    pub fn nspp(&self) -> usize {
57        self.matrix.nrows()
58    }
59
60    /// Get the parameter names
61    pub fn param_names(&self) -> Vec<String> {
62        self.random
63            .iter()
64            .map(|(name, _, _)| name.clone())
65            .collect()
66    }
67
68    /// Modify the [Theta::matrix] to only include the rows specified by `indices`
69    pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
70        let matrix = self.matrix.to_owned();
71
72        let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
73            *matrix.get(indices[r], c)
74        });
75
76        self.matrix = new;
77    }
78
79    /// Forcibly add a support point to the matrix
80    pub(crate) fn add_point(&mut self, spp: &[f64]) {
81        self.matrix
82            .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
83    }
84
85    /// Suggest a new support point to add to the matrix
86    /// The point is only added if it is at least `min_dist` away from all existing support points
87    /// and within the limits specified by `limits`
88    pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) {
89        if self.check_point(spp, min_dist, limits) {
90            self.add_point(spp);
91        }
92    }
93
94    /// Check if a point is at least `min_dist` away from all existing support points
95    pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) -> bool {
96        if self.matrix.nrows() == 0 {
97            return true;
98        }
99
100        for row_idx in 0..self.matrix.nrows() {
101            let mut squared_dist = 0.0;
102            for (i, val) in spp.iter().enumerate() {
103                // Normalized squared difference for this dimension
104                let normalized_diff =
105                    (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
106                squared_dist += normalized_diff * normalized_diff;
107            }
108            let dist = squared_dist.sqrt();
109            if dist <= min_dist {
110                return false; // This point is too close to an existing point
111            }
112        }
113        true // Point is sufficiently distant from all existing points
114    }
115
116    /// Write the matrix to a CSV file
117    pub fn write(&self, path: &str) {
118        let mut writer = csv::Writer::from_path(path).unwrap();
119        for row in self.matrix.row_iter() {
120            writer
121                .write_record(row.iter().map(|x| x.to_string()))
122                .unwrap();
123        }
124    }
125}
126
127impl Debug for Theta {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        // Write nspp and nsub
130        writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
131        // Write the matrix
132        self.matrix.row_iter().enumerate().for_each(|(index, row)| {
133            writeln!(f, "{index}\t{:?}", row).unwrap();
134        });
135        Ok(())
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use faer::mat;
143
144    #[test]
145    fn test_filter_indices() {
146        // Create a 4x2 matrix with recognizable values
147        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
148
149        let mut theta = Theta::from_parts(matrix, vec![], vec![]);
150
151        theta.filter_indices(&[0, 3]);
152
153        // Expected result is a 2x2 matrix with filtered rows
154        let expected = mat![[1.0, 2.0], [7.0, 8.0]];
155
156        assert_eq!(theta.matrix, expected);
157    }
158
159    #[test]
160    fn test_add_point() {
161        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
162
163        let mut theta = Theta::from_parts(matrix, vec![], vec![]);
164
165        theta.add_point(&[7.0, 8.0]);
166
167        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
168
169        assert_eq!(theta.matrix, expected);
170    }
171}