1use std::fmt::Debug;
2
3use faer::Mat;
4
5#[derive(Clone, PartialEq)]
10pub struct Theta {
11 matrix: Mat<f64>,
12 random: Vec<(String, f64, f64)>,
13 fixed: Vec<(String, f64)>,
14}
15
16impl Default for Theta {
17 fn default() -> Self {
18 Theta {
19 matrix: Mat::new(),
20 random: Vec::new(),
21 fixed: Vec::new(),
22 }
23 }
24}
25
26impl Theta {
27 pub fn new() -> Self {
28 Theta::default()
29 }
30
31 pub(crate) fn from_parts(
32 matrix: Mat<f64>,
33 random: Vec<(String, f64, f64)>,
34 fixed: Vec<(String, f64)>,
35 ) -> Self {
36 Theta {
37 matrix,
38 random,
39 fixed,
40 }
41 }
42
43 pub fn matrix(&self) -> &Mat<f64> {
47 &self.matrix
48 }
49
50 pub fn set_matrix(&mut self, matrix: Mat<f64>) {
52 self.matrix = matrix;
53 }
54
55 pub fn nspp(&self) -> usize {
57 self.matrix.nrows()
58 }
59
60 pub fn param_names(&self) -> Vec<String> {
62 self.random
63 .iter()
64 .map(|(name, _, _)| name.clone())
65 .collect()
66 }
67
68 pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
70 let matrix = self.matrix.to_owned();
71
72 let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
73 *matrix.get(indices[r], c)
74 });
75
76 self.matrix = new;
77 }
78
79 pub(crate) fn add_point(&mut self, spp: &[f64]) {
81 self.matrix
82 .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
83 }
84
85 pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) {
89 if self.check_point(spp, min_dist, limits) {
90 self.add_point(spp);
91 }
92 }
93
94 pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) -> bool {
96 if self.matrix.nrows() == 0 {
97 return true;
98 }
99
100 for row_idx in 0..self.matrix.nrows() {
101 let mut squared_dist = 0.0;
102 for (i, val) in spp.iter().enumerate() {
103 let normalized_diff =
105 (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
106 squared_dist += normalized_diff * normalized_diff;
107 }
108 let dist = squared_dist.sqrt();
109 if dist <= min_dist {
110 return false; }
112 }
113 true }
115
116 pub fn write(&self, path: &str) {
118 let mut writer = csv::Writer::from_path(path).unwrap();
119 for row in self.matrix.row_iter() {
120 writer
121 .write_record(row.iter().map(|x| x.to_string()))
122 .unwrap();
123 }
124 }
125}
126
127impl Debug for Theta {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
131 self.matrix.row_iter().enumerate().for_each(|(index, row)| {
133 writeln!(f, "{index}\t{:?}", row).unwrap();
134 });
135 Ok(())
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use faer::mat;
143
144 #[test]
145 fn test_filter_indices() {
146 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
148
149 let mut theta = Theta::from_parts(matrix, vec![], vec![]);
150
151 theta.filter_indices(&[0, 3]);
152
153 let expected = mat![[1.0, 2.0], [7.0, 8.0]];
155
156 assert_eq!(theta.matrix, expected);
157 }
158
159 #[test]
160 fn test_add_point() {
161 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
162
163 let mut theta = Theta::from_parts(matrix, vec![], vec![]);
164
165 theta.add_point(&[7.0, 8.0]);
166
167 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
168
169 assert_eq!(theta.matrix, expected);
170 }
171}