pmcore/routines/output/
posterior.rs

1pub use anyhow::{bail, Result};
2use faer::{Col, Mat};
3use serde::{Deserialize, Serialize};
4
5use crate::structs::{psi::Psi, weights::Weights};
6
7/// Posterior probabilities for each support points
8#[derive(Debug, Clone)]
9pub struct Posterior {
10    mat: Mat<f64>,
11}
12
13impl Posterior {
14    /// Create a new Posterior from a matrix
15    fn new(mat: Mat<f64>) -> Self {
16        Posterior { mat }
17    }
18
19    /// Calculate the posterior probabilities for each support point given the weights
20    ///
21    /// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns.
22    /// /// # Errors
23    /// Returns an error if the number of rows in `psi` does not match the number of weights in `w`.
24    /// # Arguments
25    /// * `psi` - The Psi object containing the matrix of support points.
26    /// * `w` - The weights for each support point.
27    /// # Returns
28    /// A Result containing the Posterior probabilities if successful, or an error if the
29    /// dimensions do not match.
30    pub fn calculate(psi: &Psi, w: &Col<f64>) -> Result<Self> {
31        if psi.matrix().ncols() != w.nrows() {
32            bail!(
33                "Number of rows in psi ({}) and number of weights ({}) do not match.",
34                psi.matrix().nrows(),
35                w.nrows()
36            );
37        }
38
39        let psi_matrix = psi.matrix();
40        let py = psi_matrix * w;
41
42        let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
43            psi_matrix.get(i, j) * w.get(j) / py.get(i)
44        });
45
46        Ok(posterior.into())
47    }
48
49    /// Get a reference to the underlying matrix
50    pub fn matrix(&self) -> &Mat<f64> {
51        &self.mat
52    }
53
54    /// Write the posterior probabilities to a CSV file
55    /// Each row represents a subject, each column represents a support point
56    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
57        let mut csv_writer = csv::Writer::from_writer(writer);
58
59        // Write each row
60        for i in 0..self.mat.nrows() {
61            let row: Vec<f64> = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect();
62            csv_writer.serialize(row)?;
63        }
64
65        csv_writer.flush()?;
66        Ok(())
67    }
68
69    /// Read posterior probabilities from a CSV file
70    /// Each row represents a subject, each column represents a support point
71    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
72        let mut csv_reader = csv::Reader::from_reader(reader);
73        let mut rows: Vec<Vec<f64>> = Vec::new();
74
75        for result in csv_reader.deserialize() {
76            let row: Vec<f64> = result?;
77            rows.push(row);
78        }
79
80        if rows.is_empty() {
81            bail!("CSV file is empty");
82        }
83
84        let nrows = rows.len();
85        let ncols = rows[0].len();
86
87        // Verify all rows have the same length
88        for (i, row) in rows.iter().enumerate() {
89            if row.len() != ncols {
90                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
91            }
92        }
93
94        // Create matrix from rows
95        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
96
97        Ok(Posterior::new(mat))
98    }
99}
100
101/// Convert a matrix to a [Posterior]
102impl From<Mat<f64>> for Posterior {
103    fn from(mat: Mat<f64>) -> Self {
104        Posterior::new(mat)
105    }
106}
107
108impl Serialize for Posterior {
109    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
110    where
111        S: serde::Serializer,
112    {
113        use serde::ser::SerializeSeq;
114
115        let mut seq = serializer.serialize_seq(Some(self.mat.nrows()))?;
116
117        // Serialize each row as a vector
118        for i in 0..self.mat.nrows() {
119            let row: Vec<f64> = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect();
120            seq.serialize_element(&row)?;
121        }
122
123        seq.end()
124    }
125}
126
127impl<'de> Deserialize<'de> for Posterior {
128    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
129    where
130        D: serde::Deserializer<'de>,
131    {
132        use serde::de::{SeqAccess, Visitor};
133        use std::fmt;
134
135        struct PosteriorVisitor;
136
137        impl<'de> Visitor<'de> for PosteriorVisitor {
138            type Value = Posterior;
139
140            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
141                formatter.write_str("a sequence of rows (vectors of f64)")
142            }
143
144            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
145            where
146                A: SeqAccess<'de>,
147            {
148                let mut rows: Vec<Vec<f64>> = Vec::new();
149
150                while let Some(row) = seq.next_element::<Vec<f64>>()? {
151                    rows.push(row);
152                }
153
154                if rows.is_empty() {
155                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
156                }
157
158                let nrows = rows.len();
159                let ncols = rows[0].len();
160
161                // Verify all rows have the same length
162                for (i, row) in rows.iter().enumerate() {
163                    if row.len() != ncols {
164                        return Err(serde::de::Error::custom(format!(
165                            "Row {} has {} columns, expected {}",
166                            i,
167                            row.len(),
168                            ncols
169                        )));
170                    }
171                }
172
173                // Create matrix from rows
174                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
175
176                Ok(Posterior::new(mat))
177            }
178        }
179
180        deserializer.deserialize_seq(PosteriorVisitor)
181    }
182}
183
184/// Calculates the posterior probabilities for each support point given the weights
185///
186/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns.
187pub fn posterior(psi: &Psi, w: &Weights) -> Result<Posterior> {
188    if psi.matrix().ncols() != w.len() {
189        bail!(
190            "Number of rows in psi ({}) and number of weights ({}) do not match.",
191            psi.matrix().nrows(),
192            w.len()
193        );
194    }
195
196    let psi_matrix = psi.matrix();
197    let py = psi_matrix * w.weights();
198
199    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
200        psi_matrix.get(i, j) * w.weights().get(j) / py.get(i)
201    });
202
203    Ok(posterior.into())
204}