1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::prelude::Parameters;
8
9#[derive(Clone, PartialEq)]
14pub struct Theta {
15 matrix: Mat<f64>,
16 parameters: Parameters,
17}
18
19impl Default for Theta {
20 fn default() -> Self {
21 Theta {
22 matrix: Mat::new(),
23 parameters: Parameters::new(),
24 }
25 }
26}
27
28impl Theta {
29 pub fn new() -> Self {
30 Theta::default()
31 }
32
33 pub fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Result<Self> {
40 if matrix.ncols() != parameters.len() {
41 bail!(
42 "Number of columns in matrix ({}) does not match number of parameters ({})",
43 matrix.ncols(),
44 parameters.len()
45 );
46 }
47
48 Ok(Theta { matrix, parameters })
49 }
50
51 pub fn matrix(&self) -> &Mat<f64> {
55 &self.matrix
56 }
57
58 pub fn matrix_mut(&mut self) -> &mut Mat<f64> {
60 &mut self.matrix
61 }
62
63 pub fn parameters(&self) -> &Parameters {
65 &self.parameters
66 }
67
68 pub fn parameters_mut(&mut self) -> &mut Parameters {
70 &mut self.parameters
71 }
72
73 pub fn nspp(&self) -> usize {
75 self.matrix.nrows()
76 }
77
78 pub fn param_names(&self) -> Vec<String> {
80 self.parameters.names()
81 }
82
83 pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
85 let matrix = self.matrix.to_owned();
86
87 let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
88 *matrix.get(indices[r], c)
89 });
90
91 self.matrix = new;
92 }
93
94 pub fn add_point(&mut self, spp: &[f64]) -> Result<()> {
96 if spp.len() != self.matrix.ncols() {
97 bail!(
98 "Support point length ({}) does not match number of parameters ({})",
99 spp.len(),
100 self.matrix.ncols()
101 );
102 }
103
104 self.matrix
105 .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
106 Ok(())
107 }
108
109 pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) -> Result<()> {
113 if self.check_point(spp, min_dist) {
114 self.add_point(spp)?;
115 }
116 Ok(())
117 }
118
119 pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
121 if self.matrix.nrows() == 0 {
122 return true;
123 }
124
125 let limits = self.parameters.ranges();
126
127 for row_idx in 0..self.matrix.nrows() {
128 let mut squared_dist = 0.0;
129 for (i, val) in spp.iter().enumerate() {
130 let normalized_diff =
132 (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
133 squared_dist += normalized_diff * normalized_diff;
134 }
135 let dist = squared_dist.sqrt();
136 if dist <= min_dist {
137 return false; }
139 }
140 true }
142
143 pub fn write(&self, path: &str) {
145 let mut writer = csv::Writer::from_path(path).unwrap();
146 for row in self.matrix.row_iter() {
147 writer
148 .write_record(row.iter().map(|x| x.to_string()))
149 .unwrap();
150 }
151 }
152
153 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
156 let mut csv_writer = csv::Writer::from_writer(writer);
157
158 for i in 0..self.matrix.nrows() {
160 let row: Vec<f64> = (0..self.matrix.ncols())
161 .map(|j| *self.matrix.get(i, j))
162 .collect();
163 csv_writer.serialize(row)?;
164 }
165
166 csv_writer.flush()?;
167 Ok(())
168 }
169
170 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
174 let mut csv_reader = csv::Reader::from_reader(reader);
175 let mut rows: Vec<Vec<f64>> = Vec::new();
176
177 for result in csv_reader.deserialize() {
178 let row: Vec<f64> = result?;
179 rows.push(row);
180 }
181
182 if rows.is_empty() {
183 bail!("CSV file is empty");
184 }
185
186 let nrows = rows.len();
187 let ncols = rows[0].len();
188
189 for (i, row) in rows.iter().enumerate() {
191 if row.len() != ncols {
192 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
193 }
194 }
195
196 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
198
199 let parameters = Parameters::new();
201
202 Ok(Theta::from_parts(mat, parameters)?)
203 }
204}
205
206impl Debug for Theta {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
210
211 for name in self.parameters.names().iter() {
213 write!(f, "\t{}", name)?;
214 }
215 writeln!(f)?;
216 self.matrix.row_iter().enumerate().for_each(|(index, row)| {
218 write!(f, "{}", index).unwrap();
219 for val in row.iter() {
220 write!(f, "\t{:.2}", val).unwrap();
221 }
222 writeln!(f).unwrap();
223 });
224 Ok(())
225 }
226}
227
228impl Serialize for Theta {
229 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
230 where
231 S: serde::Serializer,
232 {
233 use serde::ser::SerializeSeq;
234
235 let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
236
237 for i in 0..self.matrix.nrows() {
239 let row: Vec<f64> = (0..self.matrix.ncols())
240 .map(|j| *self.matrix.get(i, j))
241 .collect();
242 seq.serialize_element(&row)?;
243 }
244
245 seq.end()
246 }
247}
248
249impl<'de> Deserialize<'de> for Theta {
250 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 use serde::de::{SeqAccess, Visitor};
255 use std::fmt;
256
257 struct ThetaVisitor;
258
259 impl<'de> Visitor<'de> for ThetaVisitor {
260 type Value = Theta;
261
262 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
263 formatter.write_str("a sequence of rows (vectors of f64)")
264 }
265
266 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
267 where
268 A: SeqAccess<'de>,
269 {
270 let mut rows: Vec<Vec<f64>> = Vec::new();
271
272 while let Some(row) = seq.next_element::<Vec<f64>>()? {
273 rows.push(row);
274 }
275
276 if rows.is_empty() {
277 return Err(serde::de::Error::custom("Empty matrix not allowed"));
278 }
279
280 let nrows = rows.len();
281 let ncols = rows[0].len();
282
283 for (i, row) in rows.iter().enumerate() {
285 if row.len() != ncols {
286 return Err(serde::de::Error::custom(format!(
287 "Row {} has {} columns, expected {}",
288 i,
289 row.len(),
290 ncols
291 )));
292 }
293 }
294
295 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
297
298 let parameters = Parameters::new();
300
301 Theta::from_parts(mat, parameters).map_err(serde::de::Error::custom)
302 }
303 }
304
305 deserializer.deserialize_seq(ThetaVisitor)
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use faer::mat;
313
314 #[test]
315 fn test_filter_indices() {
316 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
318
319 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
320
321 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
322
323 theta.filter_indices(&[0, 3]);
324
325 let expected = mat![[1.0, 2.0], [7.0, 8.0]];
327
328 assert_eq!(theta.matrix, expected);
329 }
330
331 #[test]
332 fn test_add_point() {
333 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
334
335 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
336
337 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
338
339 theta.add_point(&[7.0, 8.0]).unwrap();
340
341 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
342
343 assert_eq!(theta.matrix, expected);
344 }
345
346 #[test]
347 fn test_suggest_point() {
348 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
349 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
350 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
351 theta.suggest_point(&[7.0, 8.0], 0.2).unwrap();
352 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
353 assert_eq!(theta.matrix, expected);
354
355 theta.suggest_point(&[7.1, 8.1], 0.2).unwrap();
357 assert_eq!(theta.matrix.nrows(), 4);
359 }
360
361 #[test]
362 fn test_param_names() {
363 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
364 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
365
366 let theta = Theta::from_parts(matrix, parameters).unwrap();
367 let names = theta.param_names();
368 assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
369 }
370
371 #[test]
372 fn test_set_matrix() {
373 let matrix = mat![[1.0, 2.0], [3.0, 4.0]];
374 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
375 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
376
377 let new_matrix = mat![[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
378 theta.matrix_mut().clone_from(&new_matrix);
379
380 assert_eq!(theta.matrix(), &new_matrix);
381 }
382}