use std::fmt::Debug;
use faer::Mat;
#[derive(Clone, PartialEq)]
pub struct Theta {
matrix: Mat<f64>,
random: Vec<(String, f64, f64)>,
fixed: Vec<(String, f64)>,
}
impl Default for Theta {
fn default() -> Self {
Theta {
matrix: Mat::new(),
random: Vec::new(),
fixed: Vec::new(),
}
}
}
impl Theta {
pub fn new() -> Self {
Theta::default()
}
pub(crate) fn from_parts(
matrix: Mat<f64>,
random: Vec<(String, f64, f64)>,
fixed: Vec<(String, f64)>,
) -> Self {
Theta {
matrix,
random,
fixed,
}
}
pub fn matrix(&self) -> &Mat<f64> {
&self.matrix
}
pub fn set_matrix(&mut self, matrix: Mat<f64>) {
self.matrix = matrix;
}
pub fn nspp(&self) -> usize {
self.matrix.nrows()
}
pub fn param_names(&self) -> Vec<String> {
self.random
.iter()
.map(|(name, _, _)| name.clone())
.collect()
}
pub(crate) fn filter_indices(&mut self, indices: &[usize]) {
let matrix = self.matrix.to_owned();
let new = Mat::from_fn(indices.len(), matrix.ncols(), |r, c| {
*matrix.get(indices[r], c)
});
self.matrix = new;
}
pub(crate) fn add_point(&mut self, spp: &[f64]) {
self.matrix
.resize_with(self.matrix.nrows() + 1, self.matrix.ncols(), |_, i| spp[i]);
}
pub(crate) fn suggest_point(&mut self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) {
if self.check_point(spp, min_dist, limits) {
self.add_point(spp);
}
}
pub(crate) fn check_point(&self, spp: &[f64], min_dist: f64, limits: &[(f64, f64)]) -> bool {
if self.matrix.nrows() == 0 {
return true;
}
for row_idx in 0..self.matrix.nrows() {
let mut squared_dist = 0.0;
for (i, val) in spp.iter().enumerate() {
let normalized_diff =
(val - self.matrix.get(row_idx, i)) / (limits[i].1 - limits[i].0);
squared_dist += normalized_diff * normalized_diff;
}
let dist = squared_dist.sqrt();
if dist <= min_dist {
return false; }
}
true }
pub fn write(&self, path: &str) {
let mut writer = csv::Writer::from_path(path).unwrap();
for row in self.matrix.row_iter() {
writer
.write_record(row.iter().map(|x| x.to_string()))
.unwrap();
}
}
}
impl Debug for Theta {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "\nTheta contains {} support points\n", self.nspp())?;
self.matrix.row_iter().enumerate().for_each(|(index, row)| {
writeln!(f, "{index}\t{:?}", row).unwrap();
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use faer::mat;
#[test]
fn test_filter_indices() {
let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let mut theta = Theta::from_parts(matrix, vec![], vec![]);
theta.filter_indices(&[0, 3]);
let expected = mat![[1.0, 2.0], [7.0, 8.0]];
assert_eq!(theta.matrix, expected);
}
#[test]
fn test_add_point() {
let matrix = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let mut theta = Theta::from_parts(matrix, vec![], vec![]);
theta.add_point(&[7.0, 8.0]);
let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
assert_eq!(theta.matrix, expected);
}
}