pmcore/structs/
theta.rs

1use std::fmt::Debug;
2
3use faer::Mat;
4
5use crate::prelude::Parameters;
6
7/// [Theta] is a structure that holds the support points
8/// These represent the joint population parameter distribution
9///
10/// Each row represents a support points, and each column a parameter
11#[derive(Clone, PartialEq)]
12pub struct Theta {
13    matrix: Mat<f64>,
14    parameters: Parameters,
15}
16
17impl Default for Theta {
18    fn default() -> Self {
19        Theta {
20            matrix: Mat::new(),
21            parameters: Parameters::new(),
22        }
23    }
24}
25
26impl Theta {
27    pub fn new() -> Self {
28        Theta::default()
29    }
30
31    pub(crate) fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Self {
32        Theta { matrix, parameters }
33    }
34
35    /// Get the matrix containing parameter values
36    ///
37    /// The matrix is a 2D array where each row represents a support point, and each column a parameter
38    pub fn matrix(&self) -> &Mat<f64> {
39        &self.matrix
40    }
41
42    /// Get the number of support points, equal to the number of rows in the matrix
43    pub fn nspp(&self) -> usize {
44        self.matrix.nrows()
45    }
46
47    /// Get the parameter names
48    pub fn param_names(&self) -> Vec<String> {
49        self.parameters.names()
50    }
51
52    /// Modify the [Theta::matrix] to only include the rows specified by `indices`
53    pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
54        let matrix = self.matrix.to_owned();
55
56        let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
57            *matrix.get(indices[r], c)
58        });
59
60        self.matrix = new;
61    }
62
63    /// Forcibly add a support point to the matrix
64    pub(crate) fn add_point(&mut self, spp: &[f64]) {
65        self.matrix
66            .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
67    }
68
69    /// Suggest a new support point to add to the matrix
70    /// The point is only added if it is at least `min_dist` away from all existing support points
71    /// and within the limits specified by `limits`
72    pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) {
73        if self.check_point(spp, min_dist) {
74            self.add_point(spp);
75        }
76    }
77
78    /// Check if a point is at least `min_dist` away from all existing support points
79    pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
80        if self.matrix.nrows() == 0 {
81            return true;
82        }
83
84        let limits = self.parameters.ranges();
85
86        for row_idx in 0..self.matrix.nrows() {
87            let mut squared_dist = 0.0;
88            for (i, val) in spp.iter().enumerate() {
89                // Normalized squared difference for this dimension
90                let normalized_diff =
91                    (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
92                squared_dist += normalized_diff * normalized_diff;
93            }
94            let dist = squared_dist.sqrt();
95            if dist <= min_dist {
96                return false; // This point is too close to an existing point
97            }
98        }
99        true // Point is sufficiently distant from all existing points
100    }
101
102    /// Write the matrix to a CSV file
103    pub fn write(&self, path: &str) {
104        let mut writer = csv::Writer::from_path(path).unwrap();
105        for row in self.matrix.row_iter() {
106            writer
107                .write_record(row.iter().map(|x| x.to_string()))
108                .unwrap();
109        }
110    }
111}
112
113impl Debug for Theta {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        // Write nspp and nsub
116        writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
117
118        // Write the parameter names
119        for name in self.parameters.names().iter() {
120            write!(f, "\t{}", name)?;
121        }
122        writeln!(f)?;
123        // Write the matrix
124        self.matrix.row_iter().enumerate().for_each(|(index, row)| {
125            write!(f, "{}", index).unwrap();
126            for val in row.iter() {
127                write!(f, "\t{:.2}", val).unwrap();
128            }
129            writeln!(f).unwrap();
130        });
131        Ok(())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use faer::mat;
139
140    #[test]
141    fn test_filter_indices() {
142        // Create a 4x2 matrix with recognizable values
143        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
144
145        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
146
147        let mut theta = Theta::from_parts(matrix, parameters);
148
149        theta.filter_indices(&[0, 3]);
150
151        // Expected result is a 2x2 matrix with filtered rows
152        let expected = mat![[1.0, 2.0], [7.0, 8.0]];
153
154        assert_eq!(theta.matrix, expected);
155    }
156
157    #[test]
158    fn test_add_point() {
159        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
160
161        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
162
163        let mut theta = Theta::from_parts(matrix, parameters);
164
165        theta.add_point(&[7.0, 8.0]);
166
167        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
168
169        assert_eq!(theta.matrix, expected);
170    }
171}