pmcore/structs/
theta.rs

1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::{prelude::Parameters, structs::weights::Weights};
8
9/// [Theta] is a structure that holds the support points
10/// These represent the joint population parameter distribution
11///
12/// Each row represents a support points, and each column a parameter
13#[derive(Clone, PartialEq)]
14pub struct Theta {
15    matrix: Mat<f64>,
16    parameters: Parameters,
17}
18
19impl Default for Theta {
20    fn default() -> Self {
21        Theta {
22            matrix: Mat::new(),
23            parameters: Parameters::new(),
24        }
25    }
26}
27
28impl Theta {
29    pub fn new() -> Self {
30        Theta::default()
31    }
32
33    /// Create a new [Theta] from a matrix and [Parameters]
34    ///
35    /// It is important that the number of columns in the matrix matches the number of parameters
36    /// in the [Parameters] object
37    ///
38    /// The order of parameters in the [Parameters] object should match the order of columns in the matrix
39    pub fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Result<Self> {
40        if matrix.ncols() != parameters.len() {
41            bail!(
42                "Number of columns in matrix ({}) does not match number of parameters ({})",
43                matrix.ncols(),
44                parameters.len()
45            );
46        }
47
48        Ok(Theta { matrix, parameters })
49    }
50
51    /// Get the matrix containing parameter values
52    ///
53    /// The matrix is a 2D array where each row represents a support point, and each column a parameter
54    pub fn matrix(&self) -> &Mat<f64> {
55        &self.matrix
56    }
57
58    /// Get a mutable reference to the matrix
59    pub fn matrix_mut(&mut self) -> &mut Mat<f64> {
60        &mut self.matrix
61    }
62
63    /// Get the [Parameters] object associated with this [Theta]
64    pub fn parameters(&self) -> &Parameters {
65        &self.parameters
66    }
67
68    /// Get a mutable reference to the [Parameters] object
69    pub fn parameters_mut(&mut self) -> &mut Parameters {
70        &mut self.parameters
71    }
72
73    /// Get the number of support points, equal to the number of rows in the matrix
74    pub fn nspp(&self) -> usize {
75        self.matrix.nrows()
76    }
77
78    /// Get the parameter names
79    pub fn param_names(&self) -> Vec<String> {
80        self.parameters.names()
81    }
82
83    /// Modify the [Theta::matrix] to only include the rows specified by `indices`
84    pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
85        let matrix = self.matrix.to_owned();
86
87        let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
88            *matrix.get(indices[r], c)
89        });
90
91        self.matrix = new;
92    }
93
94    /// Forcibly add a support point to the matrix
95    pub fn add_point(&mut self, spp: &[f64]) -> Result<()> {
96        if spp.len() != self.matrix.ncols() {
97            bail!(
98                "Support point length ({}) does not match number of parameters ({})",
99                spp.len(),
100                self.matrix.ncols()
101            );
102        }
103
104        self.matrix
105            .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
106        Ok(())
107    }
108
109    /// Suggest a new support point to add to the matrix
110    /// The point is only added if it is at least `min_dist` away from all existing support points
111    /// and within the limits specified by `limits`
112    pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) -> Result<()> {
113        if self.check_point(spp, min_dist) {
114            self.add_point(spp)?;
115        }
116        Ok(())
117    }
118
119    /// Check if a point is at least `min_dist` away from all existing support points
120    pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
121        if self.matrix.nrows() == 0 {
122            return true;
123        }
124
125        let limits = self.parameters.ranges();
126
127        for row_idx in 0..self.matrix.nrows() {
128            let mut squared_dist = 0.0;
129            for (i, val) in spp.iter().enumerate() {
130                // Normalized squared difference for this dimension
131                let normalized_diff =
132                    (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
133                squared_dist += normalized_diff * normalized_diff;
134            }
135            let dist = squared_dist.sqrt();
136            if dist <= min_dist {
137                return false; // This point is too close to an existing point
138            }
139        }
140        true // Point is sufficiently distant from all existing points
141    }
142
143    /// Write the matrix to a CSV file
144    pub fn write(&self, path: &str) {
145        let mut writer = csv::Writer::from_path(path).unwrap();
146        for row in self.matrix.row_iter() {
147            writer
148                .write_record(row.iter().map(|x| x.to_string()))
149                .unwrap();
150        }
151    }
152
153    /// Write the matrix to a CSV file with weights
154    pub fn write_with_weights(&self, path: &str, weights: &Weights) -> Result<()> {
155        if self.nspp() != weights.len() {
156            bail!(
157                "Number of support points ({}) does not match number of weights ({})",
158                self.nspp(),
159                weights.len()
160            );
161        }
162
163        let mut writer = csv::Writer::from_path(path)?;
164
165        let header: Vec<String> = self
166            .parameters
167            .names()
168            .iter()
169            .cloned()
170            .chain(std::iter::once("prob".to_string()))
171            .collect();
172
173        writer.write_record(header)?;
174
175        for (row_idx, row) in self.matrix.row_iter().enumerate() {
176            let mut record: Vec<String> = row.iter().map(|x| x.to_string()).collect();
177            record.push(weights[row_idx].to_string());
178            writer.write_record(record)?;
179        }
180        Ok(())
181    }
182
183    /// Write the theta matrix to a CSV writer
184    /// Each row represents a support point, each column represents a parameter
185    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
186        let mut csv_writer = csv::Writer::from_writer(writer);
187
188        // Write each row
189        for i in 0..self.matrix.nrows() {
190            let row: Vec<f64> = (0..self.matrix.ncols())
191                .map(|j| *self.matrix.get(i, j))
192                .collect();
193            csv_writer.serialize(row)?;
194        }
195
196        csv_writer.flush()?;
197        Ok(())
198    }
199
200    /// Read theta matrix from a CSV reader
201    /// Each row represents a support point, each column represents a parameter
202    /// Note: This only reads the matrix values, not the parameter metadata
203    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
204        let mut csv_reader = csv::Reader::from_reader(reader);
205        let mut rows: Vec<Vec<f64>> = Vec::new();
206
207        for result in csv_reader.deserialize() {
208            let row: Vec<f64> = result?;
209            rows.push(row);
210        }
211
212        if rows.is_empty() {
213            bail!("CSV file is empty");
214        }
215
216        let nrows = rows.len();
217        let ncols = rows[0].len();
218
219        // Verify all rows have the same length
220        for (i, row) in rows.iter().enumerate() {
221            if row.len() != ncols {
222                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
223            }
224        }
225
226        // Create matrix from rows
227        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
228
229        // Create empty parameters - user will need to set these separately
230        let parameters = Parameters::new();
231
232        Theta::from_parts(mat, parameters)
233    }
234}
235
236impl Debug for Theta {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        // Write nspp and nsub
239        writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
240
241        // Write the parameter names
242        for name in self.parameters.names().iter() {
243            write!(f, "\t{}", name)?;
244        }
245        writeln!(f)?;
246        // Write the matrix
247        self.matrix.row_iter().enumerate().for_each(|(index, row)| {
248            write!(f, "{}", index).unwrap();
249            for val in row.iter() {
250                write!(f, "\t{:.2}", val).unwrap();
251            }
252            writeln!(f).unwrap();
253        });
254        Ok(())
255    }
256}
257
258impl Serialize for Theta {
259    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
260    where
261        S: serde::Serializer,
262    {
263        use serde::ser::SerializeSeq;
264
265        let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
266
267        // Serialize each row as a vector
268        for i in 0..self.matrix.nrows() {
269            let row: Vec<f64> = (0..self.matrix.ncols())
270                .map(|j| *self.matrix.get(i, j))
271                .collect();
272            seq.serialize_element(&row)?;
273        }
274
275        seq.end()
276    }
277}
278
279impl<'de> Deserialize<'de> for Theta {
280    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
281    where
282        D: serde::Deserializer<'de>,
283    {
284        use serde::de::{SeqAccess, Visitor};
285        use std::fmt;
286
287        struct ThetaVisitor;
288
289        impl<'de> Visitor<'de> for ThetaVisitor {
290            type Value = Theta;
291
292            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
293                formatter.write_str("a sequence of rows (vectors of f64)")
294            }
295
296            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
297            where
298                A: SeqAccess<'de>,
299            {
300                let mut rows: Vec<Vec<f64>> = Vec::new();
301
302                while let Some(row) = seq.next_element::<Vec<f64>>()? {
303                    rows.push(row);
304                }
305
306                if rows.is_empty() {
307                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
308                }
309
310                let nrows = rows.len();
311                let ncols = rows[0].len();
312
313                // Verify all rows have the same length
314                for (i, row) in rows.iter().enumerate() {
315                    if row.len() != ncols {
316                        return Err(serde::de::Error::custom(format!(
317                            "Row {} has {} columns, expected {}",
318                            i,
319                            row.len(),
320                            ncols
321                        )));
322                    }
323                }
324
325                // Create matrix from rows
326                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
327
328                // Create empty parameters - user will need to set these separately
329                let parameters = Parameters::new();
330
331                Theta::from_parts(mat, parameters).map_err(serde::de::Error::custom)
332            }
333        }
334
335        deserializer.deserialize_seq(ThetaVisitor)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use faer::mat;
343
344    #[test]
345    fn test_filter_indices() {
346        // Create a 4x2 matrix with recognizable values
347        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
348
349        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
350
351        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
352
353        theta.filter_indices(&[0, 3]);
354
355        // Expected result is a 2x2 matrix with filtered rows
356        let expected = mat![[1.0, 2.0], [7.0, 8.0]];
357
358        assert_eq!(theta.matrix, expected);
359    }
360
361    #[test]
362    fn test_add_point() {
363        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
364
365        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
366
367        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
368
369        theta.add_point(&[7.0, 8.0]).unwrap();
370
371        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
372
373        assert_eq!(theta.matrix, expected);
374    }
375
376    #[test]
377    fn test_suggest_point() {
378        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
379        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
380        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
381        theta.suggest_point(&[7.0, 8.0], 0.2).unwrap();
382        let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
383        assert_eq!(theta.matrix, expected);
384
385        // Suggest a point that is too close
386        theta.suggest_point(&[7.1, 8.1], 0.2).unwrap();
387        // The point should not be added
388        assert_eq!(theta.matrix.nrows(), 4);
389    }
390
391    #[test]
392    fn test_param_names() {
393        let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
394        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
395
396        let theta = Theta::from_parts(matrix, parameters).unwrap();
397        let names = theta.param_names();
398        assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
399    }
400
401    #[test]
402    fn test_set_matrix() {
403        let matrix = mat![[1.0, 2.0], [3.0, 4.0]];
404        let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
405        let mut theta = Theta::from_parts(matrix, parameters).unwrap();
406
407        let new_matrix = mat![[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
408        theta.matrix_mut().clone_from(&new_matrix);
409
410        assert_eq!(theta.matrix(), &new_matrix);
411    }
412}