1use std::fmt::Debug;
2
3use anyhow::{bail, Result};
4use faer::Mat;
5use serde::{Deserialize, Serialize};
6
7use crate::{prelude::Parameters, structs::weights::Weights};
8
9#[derive(Clone, PartialEq)]
14pub struct Theta {
15 matrix: Mat<f64>,
16 parameters: Parameters,
17}
18
19impl Default for Theta {
20 fn default() -> Self {
21 Theta {
22 matrix: Mat::new(),
23 parameters: Parameters::new(),
24 }
25 }
26}
27
28impl Theta {
29 pub fn new() -> Self {
30 Theta::default()
31 }
32
33 pub fn from_parts(matrix: Mat<f64>, parameters: Parameters) -> Result<Self> {
40 if matrix.ncols() != parameters.len() {
41 bail!(
42 "Number of columns in matrix ({}) does not match number of parameters ({})",
43 matrix.ncols(),
44 parameters.len()
45 );
46 }
47
48 Ok(Theta { matrix, parameters })
49 }
50
51 pub fn matrix(&self) -> &Mat<f64> {
55 &self.matrix
56 }
57
58 pub fn matrix_mut(&mut self) -> &mut Mat<f64> {
60 &mut self.matrix
61 }
62
63 pub fn parameters(&self) -> &Parameters {
65 &self.parameters
66 }
67
68 pub fn parameters_mut(&mut self) -> &mut Parameters {
70 &mut self.parameters
71 }
72
73 pub fn nspp(&self) -> usize {
75 self.matrix.nrows()
76 }
77
78 pub fn param_names(&self) -> Vec<String> {
80 self.parameters.names()
81 }
82
83 pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
85 let matrix = self.matrix.to_owned();
86
87 let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
88 *matrix.get(indices[r], c)
89 });
90
91 self.matrix = new;
92 }
93
94 pub fn add_point(&mut self, spp: &[f64]) -> Result<()> {
96 if spp.len() != self.matrix.ncols() {
97 bail!(
98 "Support point length ({}) does not match number of parameters ({})",
99 spp.len(),
100 self.matrix.ncols()
101 );
102 }
103
104 self.matrix
105 .resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
106 Ok(())
107 }
108
109 pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64) -> Result<()> {
113 if self.check_point(spp, min_dist) {
114 self.add_point(spp)?;
115 }
116 Ok(())
117 }
118
119 pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64) -> bool {
121 if self.matrix.nrows() == 0 {
122 return true;
123 }
124
125 let limits = self.parameters.ranges();
126
127 for row_idx in 0..self.matrix.nrows() {
128 let mut squared_dist = 0.0;
129 for (i, val) in spp.iter().enumerate() {
130 let normalized_diff =
132 (val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
133 squared_dist += normalized_diff * normalized_diff;
134 }
135 let dist = squared_dist.sqrt();
136 if dist <= min_dist {
137 return false; }
139 }
140 true }
142
143 pub fn write(&self, path: &str) {
145 let mut writer = csv::Writer::from_path(path).unwrap();
146 for row in self.matrix.row_iter() {
147 writer
148 .write_record(row.iter().map(|x| x.to_string()))
149 .unwrap();
150 }
151 }
152
153 pub fn write_with_weights(&self, path: &str, weights: &Weights) -> Result<()> {
155 if self.nspp() != weights.len() {
156 bail!(
157 "Number of support points ({}) does not match number of weights ({})",
158 self.nspp(),
159 weights.len()
160 );
161 }
162
163 let mut writer = csv::Writer::from_path(path)?;
164
165 let header: Vec<String> = self
166 .parameters
167 .names()
168 .iter()
169 .cloned()
170 .chain(std::iter::once("prob".to_string()))
171 .collect();
172
173 writer.write_record(header)?;
174
175 for (row_idx, row) in self.matrix.row_iter().enumerate() {
176 let mut record: Vec<String> = row.iter().map(|x| x.to_string()).collect();
177 record.push(weights[row_idx].to_string());
178 writer.write_record(record)?;
179 }
180 Ok(())
181 }
182
183 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
186 let mut csv_writer = csv::Writer::from_writer(writer);
187
188 for i in 0..self.matrix.nrows() {
190 let row: Vec<f64> = (0..self.matrix.ncols())
191 .map(|j| *self.matrix.get(i, j))
192 .collect();
193 csv_writer.serialize(row)?;
194 }
195
196 csv_writer.flush()?;
197 Ok(())
198 }
199
200 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
204 let mut csv_reader = csv::Reader::from_reader(reader);
205 let mut rows: Vec<Vec<f64>> = Vec::new();
206
207 for result in csv_reader.deserialize() {
208 let row: Vec<f64> = result?;
209 rows.push(row);
210 }
211
212 if rows.is_empty() {
213 bail!("CSV file is empty");
214 }
215
216 let nrows = rows.len();
217 let ncols = rows[0].len();
218
219 for (i, row) in rows.iter().enumerate() {
221 if row.len() != ncols {
222 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
223 }
224 }
225
226 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
228
229 let parameters = Parameters::new();
231
232 Theta::from_parts(mat, parameters)
233 }
234}
235
236impl Debug for Theta {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
240
241 for name in self.parameters.names().iter() {
243 write!(f, "\t{}", name)?;
244 }
245 writeln!(f)?;
246 self.matrix.row_iter().enumerate().for_each(|(index, row)| {
248 write!(f, "{}", index).unwrap();
249 for val in row.iter() {
250 write!(f, "\t{:.2}", val).unwrap();
251 }
252 writeln!(f).unwrap();
253 });
254 Ok(())
255 }
256}
257
258impl Serialize for Theta {
259 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
260 where
261 S: serde::Serializer,
262 {
263 use serde::ser::SerializeSeq;
264
265 let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
266
267 for i in 0..self.matrix.nrows() {
269 let row: Vec<f64> = (0..self.matrix.ncols())
270 .map(|j| *self.matrix.get(i, j))
271 .collect();
272 seq.serialize_element(&row)?;
273 }
274
275 seq.end()
276 }
277}
278
279impl<'de> Deserialize<'de> for Theta {
280 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
281 where
282 D: serde::Deserializer<'de>,
283 {
284 use serde::de::{SeqAccess, Visitor};
285 use std::fmt;
286
287 struct ThetaVisitor;
288
289 impl<'de> Visitor<'de> for ThetaVisitor {
290 type Value = Theta;
291
292 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
293 formatter.write_str("a sequence of rows (vectors of f64)")
294 }
295
296 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
297 where
298 A: SeqAccess<'de>,
299 {
300 let mut rows: Vec<Vec<f64>> = Vec::new();
301
302 while let Some(row) = seq.next_element::<Vec<f64>>()? {
303 rows.push(row);
304 }
305
306 if rows.is_empty() {
307 return Err(serde::de::Error::custom("Empty matrix not allowed"));
308 }
309
310 let nrows = rows.len();
311 let ncols = rows[0].len();
312
313 for (i, row) in rows.iter().enumerate() {
315 if row.len() != ncols {
316 return Err(serde::de::Error::custom(format!(
317 "Row {} has {} columns, expected {}",
318 i,
319 row.len(),
320 ncols
321 )));
322 }
323 }
324
325 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
327
328 let parameters = Parameters::new();
330
331 Theta::from_parts(mat, parameters).map_err(serde::de::Error::custom)
332 }
333 }
334
335 deserializer.deserialize_seq(ThetaVisitor)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use faer::mat;
343
344 #[test]
345 fn test_filter_indices() {
346 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
348
349 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
350
351 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
352
353 theta.filter_indices(&[0, 3]);
354
355 let expected = mat![[1.0, 2.0], [7.0, 8.0]];
357
358 assert_eq!(theta.matrix, expected);
359 }
360
361 #[test]
362 fn test_add_point() {
363 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
364
365 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
366
367 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
368
369 theta.add_point(&[7.0, 8.0]).unwrap();
370
371 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
372
373 assert_eq!(theta.matrix, expected);
374 }
375
376 #[test]
377 fn test_suggest_point() {
378 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
379 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
380 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
381 theta.suggest_point(&[7.0, 8.0], 0.2).unwrap();
382 let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
383 assert_eq!(theta.matrix, expected);
384
385 theta.suggest_point(&[7.1, 8.1], 0.2).unwrap();
387 assert_eq!(theta.matrix.nrows(), 4);
389 }
390
391 #[test]
392 fn test_param_names() {
393 let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
394 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
395
396 let theta = Theta::from_parts(matrix, parameters).unwrap();
397 let names = theta.param_names();
398 assert_eq!(names, vec!["A".to_string(), "B".to_string()]);
399 }
400
401 #[test]
402 fn test_set_matrix() {
403 let matrix = mat![[1.0, 2.0], [3.0, 4.0]];
404 let parameters = Parameters::new().add("A", 0.0, 10.0).add("B", 0.0, 10.0);
405 let mut theta = Theta::from_parts(matrix, parameters).unwrap();
406
407 let new_matrix = mat![[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
408 theta.matrix_mut().clone_from(&new_matrix);
409
410 assert_eq!(theta.matrix(), &new_matrix);
411 }
412}