1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::prelude::Parameters;
8
9#[derive(Clone, PartialEq)]
14pub struct Theta {
15 matrix: Mat<f64>,
16 parameters: Parameters,
17}
18
19impl Default for Theta {
20 fn default() -> Self {
21 Theta {
22 matrix: Mat::new(),
23 parameters: Parameters::new(),
24 }
25 }
26}
27
28impl Theta {
29 pub fn new() -> Self {
30 Theta::default()
31 }
32
33 pub(crate) fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Self {
34 Theta { matrix, parameters }
35 }
36
37 pub fn matrix(&self) -> &Mat<f64> {
41 &self.matrix
42 }
43
44 pub fn nspp(&self) -> usize {
46 self.matrix.nrows()
47 }
48
49 pub fn param_names(&self) -> Vec<String> {
51 self.parameters.names()
52 }
53
54 pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
56 let matrix = self.matrix.to_owned();
57
58 let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
59 *matrix.get(indices[r], c)
60 });
61
62 self.matrix = new;
63 }
64
65 pub(crate) fn add_point(&mut self, spp: &[f64]) {
67 self.matrix
68 .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
69 }
70
71 pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) {
75 if self.check_point(spp, min_dist) {
76 self.add_point(spp);
77 }
78 }
79
80 pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
82 if self.matrix.nrows() == 0 {
83 return true;
84 }
85
86 let limits = self.parameters.ranges();
87
88 for row_idx in 0..self.matrix.nrows() {
89 let mut squared_dist = 0.0;
90 for (i, val) in spp.iter().enumerate() {
91 let normalized_diff =
93 (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
94 squared_dist += normalized_diff * normalized_diff;
95 }
96 let dist = squared_dist.sqrt();
97 if dist <= min_dist {
98 return false; }
100 }
101 true }
103
104 pub fn write(&self, path: &str) {
106 let mut writer = csv::Writer::from_path(path).unwrap();
107 for row in self.matrix.row_iter() {
108 writer
109 .write_record(row.iter().map(|x| x.to_string()))
110 .unwrap();
111 }
112 }
113
114 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
117 let mut csv_writer = csv::Writer::from_writer(writer);
118
119 for i in 0..self.matrix.nrows() {
121 let row: Vec<f64> = (0..self.matrix.ncols())
122 .map(|j| *self.matrix.get(i, j))
123 .collect();
124 csv_writer.serialize(row)?;
125 }
126
127 csv_writer.flush()?;
128 Ok(())
129 }
130
131 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
135 let mut csv_reader = csv::Reader::from_reader(reader);
136 let mut rows: Vec<Vec<f64>> = Vec::new();
137
138 for result in csv_reader.deserialize() {
139 let row: Vec<f64> = result?;
140 rows.push(row);
141 }
142
143 if rows.is_empty() {
144 bail!("CSV file is empty");
145 }
146
147 let nrows = rows.len();
148 let ncols = rows[0].len();
149
150 for (i, row) in rows.iter().enumerate() {
152 if row.len() != ncols {
153 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
154 }
155 }
156
157 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
159
160 let parameters = Parameters::new();
162
163 Ok(Theta::from_parts(mat, parameters))
164 }
165}
166
167impl Debug for Theta {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
171
172 for name in self.parameters.names().iter() {
174 write!(f, "\t{}", name)?;
175 }
176 writeln!(f)?;
177 self.matrix.row_iter().enumerate().for_each(|(index, row)| {
179 write!(f, "{}", index).unwrap();
180 for val in row.iter() {
181 write!(f, "\t{:.2}", val).unwrap();
182 }
183 writeln!(f).unwrap();
184 });
185 Ok(())
186 }
187}
188
189impl Serialize for Theta {
190 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
191 where
192 S: serde::Serializer,
193 {
194 use serde::ser::SerializeSeq;
195
196 let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
197
198 for i in 0..self.matrix.nrows() {
200 let row: Vec<f64> = (0..self.matrix.ncols())
201 .map(|j| *self.matrix.get(i, j))
202 .collect();
203 seq.serialize_element(&row)?;
204 }
205
206 seq.end()
207 }
208}
209
210impl<'de> Deserialize<'de> for Theta {
211 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
212 where
213 D: serde::Deserializer<'de>,
214 {
215 use serde::de::{SeqAccess, Visitor};
216 use std::fmt;
217
218 struct ThetaVisitor;
219
220 impl<'de> Visitor<'de> for ThetaVisitor {
221 type Value = Theta;
222
223 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224 formatter.write_str("a sequence of rows (vectors of f64)")
225 }
226
227 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
228 where
229 A: SeqAccess<'de>,
230 {
231 let mut rows: Vec<Vec<f64>> = Vec::new();
232
233 while let Some(row) = seq.next_element::<Vec<f64>>()? {
234 rows.push(row);
235 }
236
237 if rows.is_empty() {
238 return Err(serde::de::Error::custom("Empty matrix not allowed"));
239 }
240
241 let nrows = rows.len();
242 let ncols = rows[0].len();
243
244 for (i, row) in rows.iter().enumerate() {
246 if row.len() != ncols {
247 return Err(serde::de::Error::custom(format!(
248 "Row {} has {} columns, expected {}",
249 i,
250 row.len(),
251 ncols
252 )));
253 }
254 }
255
256 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
258
259 let parameters = Parameters::new();
261
262 Ok(Theta::from_parts(mat, parameters))
263 }
264 }
265
266 deserializer.deserialize_seq(ThetaVisitor)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use faer::mat;
274
275 #[test]
276 fn test_filter_indices() {
277 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
279
280 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
281
282 let mut theta = Theta::from_parts(matrix, parameters);
283
284 theta.filter_indices(&[0, 3]);
285
286 let expected = mat![[1.0, 2.0], [7.0, 8.0]];
288
289 assert_eq!(theta.matrix, expected);
290 }
291
292 #[test]
293 fn test_add_point() {
294 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
295
296 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
297
298 let mut theta = Theta::from_parts(matrix, parameters);
299
300 theta.add_point(&[7.0, 8.0]);
301
302 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
303
304 assert_eq!(theta.matrix, expected);
305 }
306
307 #[test]
308 fn test_suggest_point() {
309 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
310 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
311 let mut theta = Theta::from_parts(matrix, parameters);
312 theta.suggest_point(&[7.0, 8.0], 0.2);
313 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
314 assert_eq!(theta.matrix, expected);
315
316 theta.suggest_point(&[7.1, 8.1], 0.2);
318 assert_eq!(theta.matrix.nrows(), 4);
320 }
321
322 #[test]
323 fn test_param_names() {
324 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
325 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
326
327 let theta = Theta::from_parts(matrix, parameters);
328 let names = theta.param_names();
329 assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
330 }
331}