pmcore/structs/
weights.rs1use faer::Col;
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5#[derive(Debug, Clone)]
10pub struct Weights {
11 weights: Col<f64>,
12}
13
14impl Default for Weights {
15 fn default() -> Self {
16 Self {
17 weights: Col::from_fn(0, |_| 0.0),
18 }
19 }
20}
21
22impl Weights {
23 pub fn new(weights: Col<f64>) -> Self {
24 Self { weights }
25 }
26
27 pub fn from_vec(weights: Vec<f64>) -> Self {
29 Self {
30 weights: Col::from_fn(weights.len(), |i| weights[i]),
31 }
32 }
33
34 pub fn uniform(n: usize) -> Self {
37 if n == 0 {
38 return Self::default();
39 }
40 let uniform_weight = 1.0 / n as f64;
41 Self {
42 weights: Col::from_fn(n, |_| uniform_weight),
43 }
44 }
45
46 pub fn weights(&self) -> &Col<f64> {
48 &self.weights
49 }
50
51 pub fn weights_mut(&mut self) -> &mut Col<f64> {
53 &mut self.weights
54 }
55
56 pub fn len(&self) -> usize {
58 self.weights.nrows()
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.len() == 0
64 }
65
66 pub fn to_vec(&self) -> Vec<f64> {
68 self.weights.iter().cloned().collect()
69 }
70
71 pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
72 self.weights.iter().cloned()
73 }
74}
75
76impl Serialize for Weights {
77 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
78 where
79 S: serde::Serializer,
80 {
81 self.to_vec().serialize(serializer)
82 }
83}
84
85impl<'de> Deserialize<'de> for Weights {
86 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87 where
88 D: serde::Deserializer<'de>,
89 {
90 let weights_vec = Vec::<f64>::deserialize(deserializer)?;
91 Ok(Self::from_vec(weights_vec))
92 }
93}
94
95impl From<Vec<f64>> for Weights {
96 fn from(weights: Vec<f64>) -> Self {
97 Self::from_vec(weights)
98 }
99}
100
101impl From<Col<f64>> for Weights {
102 fn from(weights: Col<f64>) -> Self {
103 Self { weights }
104 }
105}
106
107impl Index<usize> for Weights {
108 type Output = f64;
109
110 fn index(&self, index: usize) -> &Self::Output {
111 &self.weights[index]
112 }
113}
114
115impl IndexMut<usize> for Weights {
116 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
117 &mut self.weights[index]
118 }
119}