1use std::fmt::Debug;
2
3use faer::Mat;
4
5use crate::prelude::Parameters;
6
7#[derive(Clone, PartialEq)]
12pub struct Theta {
13 matrix: Mat<f64>,
14 parameters: Parameters,
15}
16
17impl Default for Theta {
18 fn default() -> Self {
19 Theta {
20 matrix: Mat::new(),
21 parameters: Parameters::new(),
22 }
23 }
24}
25
26impl Theta {
27 pub fn new() -> Self {
28 Theta::default()
29 }
30
31 pub(crate) fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Self {
32 Theta { matrix, parameters }
33 }
34
35 pub fn matrix(&self) -> &Mat<f64> {
39 &self.matrix
40 }
41
42 pub fn nspp(&self) -> usize {
44 self.matrix.nrows()
45 }
46
47 pub fn param_names(&self) -> Vec<String> {
49 self.parameters.names()
50 }
51
52 pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
54 let matrix = self.matrix.to_owned();
55
56 let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
57 *matrix.get(indices[r], c)
58 });
59
60 self.matrix = new;
61 }
62
63 pub(crate) fn add_point(&mut self, spp: &[f64]) {
65 self.matrix
66 .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
67 }
68
69 pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) {
73 if self.check_point(spp, min_dist) {
74 self.add_point(spp);
75 }
76 }
77
78 pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
80 if self.matrix.nrows() == 0 {
81 return true;
82 }
83
84 let limits = self.parameters.ranges();
85
86 for row_idx in 0..self.matrix.nrows() {
87 let mut squared_dist = 0.0;
88 for (i, val) in spp.iter().enumerate() {
89 let normalized_diff =
91 (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
92 squared_dist += normalized_diff * normalized_diff;
93 }
94 let dist = squared_dist.sqrt();
95 if dist <= min_dist {
96 return false; }
98 }
99 true }
101
102 pub fn write(&self, path: &str) {
104 let mut writer = csv::Writer::from_path(path).unwrap();
105 for row in self.matrix.row_iter() {
106 writer
107 .write_record(row.iter().map(|x| x.to_string()))
108 .unwrap();
109 }
110 }
111}
112
113impl Debug for Theta {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
117
118 for name in self.parameters.names().iter() {
120 write!(f, "\t{}", name)?;
121 }
122 writeln!(f)?;
123 self.matrix.row_iter().enumerate().for_each(|(index, row)| {
125 write!(f, "{}", index).unwrap();
126 for val in row.iter() {
127 write!(f, "\t{:.2}", val).unwrap();
128 }
129 writeln!(f).unwrap();
130 });
131 Ok(())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use faer::mat;
139
140 #[test]
141 fn test_filter_indices() {
142 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
144
145 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
146
147 let mut theta = Theta::from_parts(matrix, parameters);
148
149 theta.filter_indices(&[0, 3]);
150
151 let expected = mat![[1.0, 2.0], [7.0, 8.0]];
153
154 assert_eq!(theta.matrix, expected);
155 }
156
157 #[test]
158 fn test_add_point() {
159 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
160
161 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
162
163 let mut theta = Theta::from_parts(matrix, parameters);
164
165 theta.add_point(&[7.0, 8.0]);
166
167 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
168
169 assert_eq!(theta.matrix, expected);
170 }
171}