pmcore/structs/
weights.rs

1use faer::Col;
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5/// The weight (probabilities) for each support point in the model.
6///
7/// This struct is used to hold the weights for each support point in the model.
8/// It is a thin wrapper around [faer::Col<f64>] to provide additional functionality and context
9#[derive(Debug, Clone)]
10pub struct Weights {
11    weights: Col<f64>,
12}
13
14impl Default for Weights {
15    fn default() -> Self {
16        Self {
17            weights: Col::from_fn(0, |_| 0.0),
18        }
19    }
20}
21
22impl Weights {
23    pub fn new(weights: Col<f64>) -> Self {
24        Self { weights }
25    }
26
27    /// Create a new [Weights] instance from a vector of weights.
28    pub fn from_vec(weights: Vec<f64>) -> Self {
29        Self {
30            weights: Col::from_fn(weights.len(), |i| weights[i]),
31        }
32    }
33
34    /// Get a reference to the weights.
35    pub fn weights(&self) -> &Col<f64> {
36        &self.weights
37    }
38
39    /// Get a mutable reference to the weights.
40    pub fn weights_mut(&mut self) -> &mut Col<f64> {
41        &mut self.weights
42    }
43
44    /// Get the number of weights.
45    pub fn len(&self) -> usize {
46        self.weights.nrows()
47    }
48
49    /// Get a vector representation of the weights.
50    pub fn to_vec(&self) -> Vec<f64> {
51        self.weights.iter().cloned().collect()
52    }
53
54    pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
55        self.weights.iter().cloned()
56    }
57}
58
59impl Serialize for Weights {
60    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61    where
62        S: serde::Serializer,
63    {
64        self.to_vec().serialize(serializer)
65    }
66}
67
68impl<'de> Deserialize<'de> for Weights {
69    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
70    where
71        D: serde::Deserializer<'de>,
72    {
73        let weights_vec = Vec::<f64>::deserialize(deserializer)?;
74        Ok(Self::from_vec(weights_vec))
75    }
76}
77
78impl From<Vec<f64>> for Weights {
79    fn from(weights: Vec<f64>) -> Self {
80        Self::from_vec(weights)
81    }
82}
83
84impl From<Col<f64>> for Weights {
85    fn from(weights: Col<f64>) -> Self {
86        Self { weights }
87    }
88}
89
90impl Index<usize> for Weights {
91    type Output = f64;
92
93    fn index(&self, index: usize) -> &Self::Output {
94        &self.weights[index]
95    }
96}
97
98impl IndexMut<usize> for Weights {
99    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
100        &mut self.weights[index]
101    }
102}