pmcore/structs/
theta.rs

1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::prelude::Parameters;
8
9/// [Theta] is a structure that holds the support points
10/// These represent the joint population parameter distribution
11///
12/// Each row represents a support points, and each column a parameter
13#[derive(Clone, PartialEq)]
14pub struct Theta {
15    matrix: Mat<f64>,
16    parameters: Parameters,
17}
18
19impl Default for Theta {
20    fn default() -> Self {
21        Theta {
22            matrix: Mat::new(),
23            parameters: Parameters::new(),
24        }
25    }
26}
27
28impl Theta {
29    pub fn new() -> Self {
30        Theta::default()
31    }
32
33    /// Create a new [Theta] from a matrix and [Parameters]
34    ///
35    /// It is important that the number of columns in the matrix matches the number of parameters
36    /// in the [Parameters] object
37    ///
38    /// The order of parameters in the [Parameters] object should match the order of columns in the matrix
39    pub fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Result<Self> {
40        if matrix.ncols() != parameters.len() {
41            bail!(
42                "Number of columns in matrix ({}) does not match number of parameters ({})",
43                matrix.ncols(),
44                parameters.len()
45            );
46        }
47
48        Ok(Theta { matrix, parameters })
49    }
50
51    /// Get the matrix containing parameter values
52    ///
53    /// The matrix is a 2D array where each row represents a support point, and each column a parameter
54    pub fn matrix(&self) -> &Mat<f64> {
55        &self.matrix
56    }
57
58    /// Get a mutable reference to the matrix
59    pub fn matrix_mut(&mut self) -> &mut Mat<f64> {
60        &mut self.matrix
61    }
62
63    /// Get the [Parameters] object associated with this [Theta]
64    pub fn parameters(&self) -> &Parameters {
65        &self.parameters
66    }
67
68    /// Get a mutable reference to the [Parameters] object
69    pub fn parameters_mut(&mut self) -> &mut Parameters {
70        &mut self.parameters
71    }
72
73    /// Get the number of support points, equal to the number of rows in the matrix
74    pub fn nspp(&self) -> usize {
75        self.matrix.nrows()
76    }
77
78    /// Get the parameter names
79    pub fn param_names(&self) -> Vec<String> {
80        self.parameters.names()
81    }
82
83    /// Modify the [Theta::matrix] to only include the rows specified by `indices`
84    pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
85        let matrix = self.matrix.to_owned();
86
87        let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
88            *matrix.get(indices[r], c)
89        });
90
91        self.matrix = new;
92    }
93
94    /// Forcibly add a support point to the matrix
95    pub fn add_point(&mut self, spp: &[f64]) -> Result<()> {
96        if spp.len() != self.matrix.ncols() {
97            bail!(
98                "Support point length ({}) does not match number of parameters ({})",
99                spp.len(),
100                self.matrix.ncols()
101            );
102        }
103
104        self.matrix
105            .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
106        Ok(())
107    }
108
109    /// Suggest a new support point to add to the matrix
110    /// The point is only added if it is at least `min_dist` away from all existing support points
111    /// and within the limits specified by `limits`
112    pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) -> Result<()> {
113        if self.check_point(spp, min_dist) {
114            self.add_point(spp)?;
115        }
116        Ok(())
117    }
118
119    /// Check if a point is at least `min_dist` away from all existing support points
120    pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
121        if self.matrix.nrows() == 0 {
122            return true;
123        }
124
125        let limits = self.parameters.ranges();
126
127        for row_idx in 0..self.matrix.nrows() {
128            let mut squared_dist = 0.0;
129            for (i, val) in spp.iter().enumerate() {
130                // Normalized squared difference for this dimension
131                let normalized_diff =
132                    (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
133                squared_dist += normalized_diff * normalized_diff;
134            }
135            let dist = squared_dist.sqrt();
136            if dist <= min_dist {
137                return false; // This point is too close to an existing point
138            }
139        }
140        true // Point is sufficiently distant from all existing points
141    }
142
143    /// Write the matrix to a CSV file
144    pub fn write(&self, path: &str) {
145        let mut writer = csv::Writer::from_path(path).unwrap();
146        for row in self.matrix.row_iter() {
147            writer
148                .write_record(row.iter().map(|x| x.to_string()))
149                .unwrap();
150        }
151    }
152
153    /// Write the theta matrix to a CSV writer
154    /// Each row represents a support point, each column represents a parameter
155    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
156        let mut csv_writer = csv::Writer::from_writer(writer);
157
158        // Write each row
159        for i in 0..self.matrix.nrows() {
160            let row: Vec<f64> = (0..self.matrix.ncols())
161                .map(|j| *self.matrix.get(i, j))
162                .collect();
163            csv_writer.serialize(row)?;
164        }
165
166        csv_writer.flush()?;
167        Ok(())
168    }
169
170    /// Read theta matrix from a CSV reader
171    /// Each row represents a support point, each column represents a parameter
172    /// Note: This only reads the matrix values, not the parameter metadata
173    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
174        let mut csv_reader = csv::Reader::from_reader(reader);
175        let mut rows: Vec<Vec<f64>> = Vec::new();
176
177        for result in csv_reader.deserialize() {
178            let row: Vec<f64> = result?;
179            rows.push(row);
180        }
181
182        if rows.is_empty() {
183            bail!("CSV file is empty");
184        }
185
186        let nrows = rows.len();
187        let ncols = rows[0].len();
188
189        // Verify all rows have the same length
190        for (i, row) in rows.iter().enumerate() {
191            if row.len() != ncols {
192                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
193            }
194        }
195
196        // Create matrix from rows
197        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
198
199        // Create empty parameters - user will need to set these separately
200        let parameters = Parameters::new();
201
202        Ok(Theta::from_parts(mat, parameters)?)
203    }
204}
205
206impl Debug for Theta {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        // Write nspp and nsub
209        writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
210
211        // Write the parameter names
212        for name in self.parameters.names().iter() {
213            write!(f, "\t{}", name)?;
214        }
215        writeln!(f)?;
216        // Write the matrix
217        self.matrix.row_iter().enumerate().for_each(|(index, row)| {
218            write!(f, "{}", index).unwrap();
219            for val in row.iter() {
220                write!(f, "\t{:.2}", val).unwrap();
221            }
222            writeln!(f).unwrap();
223        });
224        Ok(())
225    }
226}
227
228impl Serialize for Theta {
229    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
230    where
231        S: serde::Serializer,
232    {
233        use serde::ser::SerializeSeq;
234
235        let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
236
237        // Serialize each row as a vector
238        for i in 0..self.matrix.nrows() {
239            let row: Vec<f64> = (0..self.matrix.ncols())
240                .map(|j| *self.matrix.get(i, j))
241                .collect();
242            seq.serialize_element(&row)?;
243        }
244
245        seq.end()
246    }
247}
248
249impl<'de> Deserialize<'de> for Theta {
250    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
251    where
252        D: serde::Deserializer<'de>,
253    {
254        use serde::de::{SeqAccess, Visitor};
255        use std::fmt;
256
257        struct ThetaVisitor;
258
259        impl<'de> Visitor<'de> for ThetaVisitor {
260            type Value = Theta;
261
262            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
263                formatter.write_str("a sequence of rows (vectors of f64)")
264            }
265
266            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
267            where
268                A: SeqAccess<'de>,
269            {
270                let mut rows: Vec<Vec<f64>> = Vec::new();
271
272                while let Some(row) = seq.next_element::<Vec<f64>>()? {
273                    rows.push(row);
274                }
275
276                if rows.is_empty() {
277                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
278                }
279
280                let nrows = rows.len();
281                let ncols = rows[0].len();
282
283                // Verify all rows have the same length
284                for (i, row) in rows.iter().enumerate() {
285                    if row.len() != ncols {
286                        return Err(serde::de::Error::custom(format!(
287                            "Row {} has {} columns, expected {}",
288                            i,
289                            row.len(),
290                            ncols
291                        )));
292                    }
293                }
294
295                // Create matrix from rows
296                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
297
298                // Create empty parameters - user will need to set these separately
299                let parameters = Parameters::new();
300
301                Theta::from_parts(mat, parameters).map_err(serde::de::Error::custom)
302            }
303        }
304
305        deserializer.deserialize_seq(ThetaVisitor)
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use faer::mat;
313
314    #[test]
315    fn test_filter_indices() {
316        // Create a 4x2 matrix with recognizable values
317        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
318
319        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
320
321        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
322
323        theta.filter_indices(&[0, 3]);
324
325        // Expected result is a 2x2 matrix with filtered rows
326        let expected = mat![[1.0, 2.0], [7.0, 8.0]];
327
328        assert_eq!(theta.matrix, expected);
329    }
330
331    #[test]
332    fn test_add_point() {
333        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
334
335        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
336
337        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
338
339        theta.add_point(&[7.0, 8.0]).unwrap();
340
341        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
342
343        assert_eq!(theta.matrix, expected);
344    }
345
346    #[test]
347    fn test_suggest_point() {
348        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
349        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
350        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
351        theta.suggest_point(&[7.0, 8.0], 0.2).unwrap();
352        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
353        assert_eq!(theta.matrix, expected);
354
355        // Suggest a point that is too close
356        theta.suggest_point(&[7.1, 8.1], 0.2).unwrap();
357        // The point should not be added
358        assert_eq!(theta.matrix.nrows(), 4);
359    }
360
361    #[test]
362    fn test_param_names() {
363        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
364        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
365
366        let theta = Theta::from_parts(matrix, parameters).unwrap();
367        let names = theta.param_names();
368        assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
369    }
370
371    #[test]
372    fn test_set_matrix() {
373        let matrix = mat![[1.0, 2.0], [3.0, 4.0]];
374        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
375        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
376
377        let new_matrix = mat![[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
378        theta.matrix_mut().clone_from(&new_matrix);
379
380        assert_eq!(theta.matrix(), &new_matrix);
381    }
382}