pmcore/routines/output/
posterior.rs1pub use anyhow::{bail, Result};
2use faer::{Col, Mat};
3use serde::{Deserialize, Serialize};
4
5use crate::structs::{psi::Psi, weights::Weights};
6
7#[derive(Debug, Clone)]
9pub struct Posterior {
10 mat: Mat<f64>,
11}
12
13impl Posterior {
14 fn new(mat: Mat<f64>) -> Self {
16 Posterior { mat }
17 }
18
19 pub fn calculate(psi: &Psi, w: &Col<f64>) -> Result<Self> {
31 if psi.matrix().ncols() != w.nrows() {
32 bail!(
33 "Number of rows in psi ({}) and number of weights ({}) do not match.",
34 psi.matrix().nrows(),
35 w.nrows()
36 );
37 }
38
39 let psi_matrix = psi.matrix();
40 let py = psi_matrix * w;
41
42 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
43 psi_matrix.get(i, j) * w.get(j) / py.get(i)
44 });
45
46 Ok(posterior.into())
47 }
48
49 pub fn matrix(&self) -> &Mat<f64> {
51 &self.mat
52 }
53
54 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
57 let mut csv_writer = csv::Writer::from_writer(writer);
58
59 for i in 0..self.mat.nrows() {
61 let row: Vec<f64> = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect();
62 csv_writer.serialize(row)?;
63 }
64
65 csv_writer.flush()?;
66 Ok(())
67 }
68
69 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
72 let mut csv_reader = csv::Reader::from_reader(reader);
73 let mut rows: Vec<Vec<f64>> = Vec::new();
74
75 for result in csv_reader.deserialize() {
76 let row: Vec<f64> = result?;
77 rows.push(row);
78 }
79
80 if rows.is_empty() {
81 bail!("CSV file is empty");
82 }
83
84 let nrows = rows.len();
85 let ncols = rows[0].len();
86
87 for (i, row) in rows.iter().enumerate() {
89 if row.len() != ncols {
90 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
91 }
92 }
93
94 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
96
97 Ok(Posterior::new(mat))
98 }
99}
100
101impl From<Mat<f64>> for Posterior {
103 fn from(mat: Mat<f64>) -> Self {
104 Posterior::new(mat)
105 }
106}
107
108impl Serialize for Posterior {
109 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
110 where
111 S: serde::Serializer,
112 {
113 use serde::ser::SerializeSeq;
114
115 let mut seq = serializer.serialize_seq(Some(self.mat.nrows()))?;
116
117 for i in 0..self.mat.nrows() {
119 let row: Vec<f64> = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect();
120 seq.serialize_element(&row)?;
121 }
122
123 seq.end()
124 }
125}
126
127impl<'de> Deserialize<'de> for Posterior {
128 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
129 where
130 D: serde::Deserializer<'de>,
131 {
132 use serde::de::{SeqAccess, Visitor};
133 use std::fmt;
134
135 struct PosteriorVisitor;
136
137 impl<'de> Visitor<'de> for PosteriorVisitor {
138 type Value = Posterior;
139
140 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
141 formatter.write_str("a sequence of rows (vectors of f64)")
142 }
143
144 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
145 where
146 A: SeqAccess<'de>,
147 {
148 let mut rows: Vec<Vec<f64>> = Vec::new();
149
150 while let Some(row) = seq.next_element::<Vec<f64>>()? {
151 rows.push(row);
152 }
153
154 if rows.is_empty() {
155 return Err(serde::de::Error::custom("Empty matrix not allowed"));
156 }
157
158 let nrows = rows.len();
159 let ncols = rows[0].len();
160
161 for (i, row) in rows.iter().enumerate() {
163 if row.len() != ncols {
164 return Err(serde::de::Error::custom(format!(
165 "Row {} has {} columns, expected {}",
166 i,
167 row.len(),
168 ncols
169 )));
170 }
171 }
172
173 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
175
176 Ok(Posterior::new(mat))
177 }
178 }
179
180 deserializer.deserialize_seq(PosteriorVisitor)
181 }
182}
183
184pub fn posterior(psi: &Psi, w: &Weights) -> Result<Posterior> {
188 if psi.matrix().ncols() != w.len() {
189 bail!(
190 "Number of rows in psi ({}) and number of weights ({}) do not match.",
191 psi.matrix().nrows(),
192 w.len()
193 );
194 }
195
196 let psi_matrix = psi.matrix();
197 let py = psi_matrix * w.weights();
198
199 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
200 psi_matrix.get(i, j) * w.weights().get(j) / py.get(i)
201 });
202
203 Ok(posterior.into())
204}