pmcore/structs/
weights.rs

1use faer::Col;
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5/// The weight (probabilities) for each support point in the model.
6///
7/// This struct is used to hold the weights for each support point in the model.
8/// It is a thin wrapper around [faer::Col<f64>] to provide additional functionality and context
9#[derive(Debug, Clone)]
10pub struct Weights {
11    weights: Col<f64>,
12}
13
14impl Default for Weights {
15    fn default() -> Self {
16        Self {
17            weights: Col::from_fn(0, |_| 0.0),
18        }
19    }
20}
21
22impl Weights {
23    pub fn new(weights: Col<f64>) -> Self {
24        Self { weights }
25    }
26
27    /// Create a new [Weights] instance from a vector of weights.
28    pub fn from_vec(weights: Vec<f64>) -> Self {
29        Self {
30            weights: Col::from_fn(weights.len(), |i| weights[i]),
31        }
32    }
33
34    /// Create a new [Weights] instance with uniform weights.
35    /// If `n` is 0, returns an empty [Weights] instance.
36    pub fn uniform(n: usize) -> Self {
37        if n == 0 {
38            return Self::default();
39        }
40        let uniform_weight = 1.0 / n as f64;
41        Self {
42            weights: Col::from_fn(n, |_| uniform_weight),
43        }
44    }
45
46    /// Get a reference to the weights.
47    pub fn weights(&self) -> &Col<f64> {
48        &self.weights
49    }
50
51    /// Get a mutable reference to the weights.
52    pub fn weights_mut(&mut self) -> &mut Col<f64> {
53        &mut self.weights
54    }
55
56    /// Get the number of weights.
57    pub fn len(&self) -> usize {
58        self.weights.nrows()
59    }
60
61    /// Get a vector representation of the weights.
62    pub fn to_vec(&self) -> Vec<f64> {
63        self.weights.iter().cloned().collect()
64    }
65
66    pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
67        self.weights.iter().cloned()
68    }
69}
70
71impl Serialize for Weights {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        self.to_vec().serialize(serializer)
77    }
78}
79
80impl<'de> Deserialize<'de> for Weights {
81    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82    where
83        D: serde::Deserializer<'de>,
84    {
85        let weights_vec = Vec::<f64>::deserialize(deserializer)?;
86        Ok(Self::from_vec(weights_vec))
87    }
88}
89
90impl From<Vec<f64>> for Weights {
91    fn from(weights: Vec<f64>) -> Self {
92        Self::from_vec(weights)
93    }
94}
95
96impl From<Col<f64>> for Weights {
97    fn from(weights: Col<f64>) -> Self {
98        Self { weights }
99    }
100}
101
102impl Index<usize> for Weights {
103    type Output = f64;
104
105    fn index(&self, index: usize) -> &Self::Output {
106        &self.weights[index]
107    }
108}
109
110impl IndexMut<usize> for Weights {
111    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
112        &mut self.weights[index]
113    }
114}