pmcore/structs/
weights.rs1use faer::Col;
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5#[derive(Debug, Clone)]
10pub struct Weights {
11 weights: Col<f64>,
12}
13
14impl Default for Weights {
15 fn default() -> Self {
16 Self {
17 weights: Col::from_fn(0, |_| 0.0),
18 }
19 }
20}
21
22impl Weights {
23 pub fn new(weights: Col<f64>) -> Self {
24 Self { weights }
25 }
26
27 pub fn from_vec(weights: Vec<f64>) -> Self {
29 Self {
30 weights: Col::from_fn(weights.len(), |i| weights[i]),
31 }
32 }
33
34 pub fn uniform(n: usize) -> Self {
37 if n == 0 {
38 return Self::default();
39 }
40 let uniform_weight = 1.0 / n as f64;
41 Self {
42 weights: Col::from_fn(n, |_| uniform_weight),
43 }
44 }
45
46 pub fn weights(&self) -> &Col<f64> {
48 &self.weights
49 }
50
51 pub fn weights_mut(&mut self) -> &mut Col<f64> {
53 &mut self.weights
54 }
55
56 pub fn len(&self) -> usize {
58 self.weights.nrows()
59 }
60
61 pub fn to_vec(&self) -> Vec<f64> {
63 self.weights.iter().cloned().collect()
64 }
65
66 pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
67 self.weights.iter().cloned()
68 }
69}
70
71impl Serialize for Weights {
72 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73 where
74 S: serde::Serializer,
75 {
76 self.to_vec().serialize(serializer)
77 }
78}
79
80impl<'de> Deserialize<'de> for Weights {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 let weights_vec = Vec::<f64>::deserialize(deserializer)?;
86 Ok(Self::from_vec(weights_vec))
87 }
88}
89
90impl From<Vec<f64>> for Weights {
91 fn from(weights: Vec<f64>) -> Self {
92 Self::from_vec(weights)
93 }
94}
95
96impl From<Col<f64>> for Weights {
97 fn from(weights: Col<f64>) -> Self {
98 Self { weights }
99 }
100}
101
102impl Index<usize> for Weights {
103 type Output = f64;
104
105 fn index(&self, index: usize) -> &Self::Output {
106 &self.weights[index]
107 }
108}
109
110impl IndexMut<usize> for Weights {
111 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
112 &mut self.weights[index]
113 }
114}