pmcore/structs/
theta.rs

1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::prelude::Parameters;
8
9/// [Theta] is a structure that holds the support points
10/// These represent the joint population parameter distribution
11///
12/// Each row represents a support points, and each column a parameter
13#[derive(Clone, PartialEq)]
14pub struct Theta {
15    matrix: Mat<f64>,
16    parameters: Parameters,
17}
18
19impl Default for Theta {
20    fn default() -> Self {
21        Theta {
22            matrix: Mat::new(),
23            parameters: Parameters::new(),
24        }
25    }
26}
27
28impl Theta {
29    pub fn new() -> Self {
30        Theta::default()
31    }
32
33    pub(crate) fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Self {
34        Theta { matrix, parameters }
35    }
36
37    /// Get the matrix containing parameter values
38    ///
39    /// The matrix is a 2D array where each row represents a support point, and each column a parameter
40    pub fn matrix(&self) -> &Mat<f64> {
41        &self.matrix
42    }
43
44    /// Get the number of support points, equal to the number of rows in the matrix
45    pub fn nspp(&self) -> usize {
46        self.matrix.nrows()
47    }
48
49    /// Get the parameter names
50    pub fn param_names(&self) -> Vec<String> {
51        self.parameters.names()
52    }
53
54    /// Modify the [Theta::matrix] to only include the rows specified by `indices`
55    pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
56        let matrix = self.matrix.to_owned();
57
58        let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
59            *matrix.get(indices[r], c)
60        });
61
62        self.matrix = new;
63    }
64
65    /// Forcibly add a support point to the matrix
66    pub(crate) fn add_point(&mut self, spp: &[f64]) {
67        self.matrix
68            .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
69    }
70
71    /// Suggest a new support point to add to the matrix
72    /// The point is only added if it is at least `min_dist` away from all existing support points
73    /// and within the limits specified by `limits`
74    pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) {
75        if self.check_point(spp, min_dist) {
76            self.add_point(spp);
77        }
78    }
79
80    /// Check if a point is at least `min_dist` away from all existing support points
81    pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
82        if self.matrix.nrows() == 0 {
83            return true;
84        }
85
86        let limits = self.parameters.ranges();
87
88        for row_idx in 0..self.matrix.nrows() {
89            let mut squared_dist = 0.0;
90            for (i, val) in spp.iter().enumerate() {
91                // Normalized squared difference for this dimension
92                let normalized_diff =
93                    (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
94                squared_dist += normalized_diff * normalized_diff;
95            }
96            let dist = squared_dist.sqrt();
97            if dist <= min_dist {
98                return false; // This point is too close to an existing point
99            }
100        }
101        true // Point is sufficiently distant from all existing points
102    }
103
104    /// Write the matrix to a CSV file
105    pub fn write(&self, path: &str) {
106        let mut writer = csv::Writer::from_path(path).unwrap();
107        for row in self.matrix.row_iter() {
108            writer
109                .write_record(row.iter().map(|x| x.to_string()))
110                .unwrap();
111        }
112    }
113
114    /// Write the theta matrix to a CSV writer
115    /// Each row represents a support point, each column represents a parameter
116    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
117        let mut csv_writer = csv::Writer::from_writer(writer);
118
119        // Write each row
120        for i in 0..self.matrix.nrows() {
121            let row: Vec<f64> = (0..self.matrix.ncols())
122                .map(|j| *self.matrix.get(i, j))
123                .collect();
124            csv_writer.serialize(row)?;
125        }
126
127        csv_writer.flush()?;
128        Ok(())
129    }
130
131    /// Read theta matrix from a CSV reader
132    /// Each row represents a support point, each column represents a parameter
133    /// Note: This only reads the matrix values, not the parameter metadata
134    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
135        let mut csv_reader = csv::Reader::from_reader(reader);
136        let mut rows: Vec<Vec<f64>> = Vec::new();
137
138        for result in csv_reader.deserialize() {
139            let row: Vec<f64> = result?;
140            rows.push(row);
141        }
142
143        if rows.is_empty() {
144            bail!("CSV file is empty");
145        }
146
147        let nrows = rows.len();
148        let ncols = rows[0].len();
149
150        // Verify all rows have the same length
151        for (i, row) in rows.iter().enumerate() {
152            if row.len() != ncols {
153                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
154            }
155        }
156
157        // Create matrix from rows
158        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
159
160        // Create empty parameters - user will need to set these separately
161        let parameters = Parameters::new();
162
163        Ok(Theta::from_parts(mat, parameters))
164    }
165}
166
167impl Debug for Theta {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        // Write nspp and nsub
170        writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
171
172        // Write the parameter names
173        for name in self.parameters.names().iter() {
174            write!(f, "\t{}", name)?;
175        }
176        writeln!(f)?;
177        // Write the matrix
178        self.matrix.row_iter().enumerate().for_each(|(index, row)| {
179            write!(f, "{}", index).unwrap();
180            for val in row.iter() {
181                write!(f, "\t{:.2}", val).unwrap();
182            }
183            writeln!(f).unwrap();
184        });
185        Ok(())
186    }
187}
188
189impl Serialize for Theta {
190    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
191    where
192        S: serde::Serializer,
193    {
194        use serde::ser::SerializeSeq;
195
196        let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
197
198        // Serialize each row as a vector
199        for i in 0..self.matrix.nrows() {
200            let row: Vec<f64> = (0..self.matrix.ncols())
201                .map(|j| *self.matrix.get(i, j))
202                .collect();
203            seq.serialize_element(&row)?;
204        }
205
206        seq.end()
207    }
208}
209
210impl<'de> Deserialize<'de> for Theta {
211    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
212    where
213        D: serde::Deserializer<'de>,
214    {
215        use serde::de::{SeqAccess, Visitor};
216        use std::fmt;
217
218        struct ThetaVisitor;
219
220        impl<'de> Visitor<'de> for ThetaVisitor {
221            type Value = Theta;
222
223            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224                formatter.write_str("a sequence of rows (vectors of f64)")
225            }
226
227            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
228            where
229                A: SeqAccess<'de>,
230            {
231                let mut rows: Vec<Vec<f64>> = Vec::new();
232
233                while let Some(row) = seq.next_element::<Vec<f64>>()? {
234                    rows.push(row);
235                }
236
237                if rows.is_empty() {
238                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
239                }
240
241                let nrows = rows.len();
242                let ncols = rows[0].len();
243
244                // Verify all rows have the same length
245                for (i, row) in rows.iter().enumerate() {
246                    if row.len() != ncols {
247                        return Err(serde::de::Error::custom(format!(
248                            "Row {} has {} columns, expected {}",
249                            i,
250                            row.len(),
251                            ncols
252                        )));
253                    }
254                }
255
256                // Create matrix from rows
257                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
258
259                // Create empty parameters - user will need to set these separately
260                let parameters = Parameters::new();
261
262                Ok(Theta::from_parts(mat, parameters))
263            }
264        }
265
266        deserializer.deserialize_seq(ThetaVisitor)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use faer::mat;
274
275    #[test]
276    fn test_filter_indices() {
277        // Create a 4x2 matrix with recognizable values
278        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
279
280        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
281
282        let mut theta = Theta::from_parts(matrix, parameters);
283
284        theta.filter_indices(&[0, 3]);
285
286        // Expected result is a 2x2 matrix with filtered rows
287        let expected = mat![[1.0, 2.0], [7.0, 8.0]];
288
289        assert_eq!(theta.matrix, expected);
290    }
291
292    #[test]
293    fn test_add_point() {
294        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
295
296        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
297
298        let mut theta = Theta::from_parts(matrix, parameters);
299
300        theta.add_point(&[7.0, 8.0]);
301
302        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
303
304        assert_eq!(theta.matrix, expected);
305    }
306
307    #[test]
308    fn test_suggest_point() {
309        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
310        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
311        let mut theta = Theta::from_parts(matrix, parameters);
312        theta.suggest_point(&[7.0, 8.0], 0.2);
313        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
314        assert_eq!(theta.matrix, expected);
315
316        // Suggest a point that is too close
317        theta.suggest_point(&[7.1, 8.1], 0.2);
318        // The point should not be added
319        assert_eq!(theta.matrix.nrows(), 4);
320    }
321
322    #[test]
323    fn test_param_names() {
324        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
325        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
326
327        let theta = Theta::from_parts(matrix, parameters);
328        let names = theta.param_names();
329        assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
330    }
331}