pmcore/structs/
weights.rs1use faer::Col;
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5#[derive(Debug, Clone)]
10pub struct Weights {
11 weights: Col<f64>,
12}
13
14impl Default for Weights {
15 fn default() -> Self {
16 Self {
17 weights: Col::from_fn(0, |_| 0.0),
18 }
19 }
20}
21
22impl Weights {
23 pub fn new(weights: Col<f64>) -> Self {
24 Self { weights }
25 }
26
27 pub fn from_vec(weights: Vec<f64>) -> Self {
29 Self {
30 weights: Col::from_fn(weights.len(), |i| weights[i]),
31 }
32 }
33
34 pub fn weights(&self) -> &Col<f64> {
36 &self.weights
37 }
38
39 pub fn weights_mut(&mut self) -> &mut Col<f64> {
41 &mut self.weights
42 }
43
44 pub fn len(&self) -> usize {
46 self.weights.nrows()
47 }
48
49 pub fn to_vec(&self) -> Vec<f64> {
51 self.weights.iter().cloned().collect()
52 }
53
54 pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
55 self.weights.iter().cloned()
56 }
57}
58
59impl Serialize for Weights {
60 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61 where
62 S: serde::Serializer,
63 {
64 self.to_vec().serialize(serializer)
65 }
66}
67
68impl<'de> Deserialize<'de> for Weights {
69 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
70 where
71 D: serde::Deserializer<'de>,
72 {
73 let weights_vec = Vec::<f64>::deserialize(deserializer)?;
74 Ok(Self::from_vec(weights_vec))
75 }
76}
77
78impl From<Vec<f64>> for Weights {
79 fn from(weights: Vec<f64>) -> Self {
80 Self::from_vec(weights)
81 }
82}
83
84impl From<Col<f64>> for Weights {
85 fn from(weights: Col<f64>) -> Self {
86 Self { weights }
87 }
88}
89
90impl Index<usize> for Weights {
91 type Output = f64;
92
93 fn index(&self, index: usize) -> &Self::Output {
94 &self.weights[index]
95 }
96}
97
98impl IndexMut<usize> for Weights {
99 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
100 &mut self.weights[index]
101 }
102}