pmcore/structs/
psi.rs

1use anyhow::bail;
2use anyhow::Result;
3use faer::Mat;
4use ndarray::Array2;
5use pharmsol::prelude::simulator::log_likelihood_matrix;
6use pharmsol::AssayErrorModels;
7use pharmsol::Data;
8use pharmsol::Equation;
9use serde::{Deserialize, Serialize};
10
11use super::theta::Theta;
12
13/// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column)
14#[derive(Debug, Clone, PartialEq)]
15pub struct Psi {
16    matrix: Mat<f64>,
17}
18
19impl Psi {
20    pub fn new() -> Self {
21        Psi { matrix: Mat::new() }
22    }
23
24    pub fn matrix(&self) -> &Mat<f64> {
25        &self.matrix
26    }
27
28    pub fn nspp(&self) -> usize {
29        self.matrix.nrows()
30    }
31
32    pub fn nsub(&self) -> usize {
33        self.matrix.ncols()
34    }
35
36    /// Modify the [Psi::matrix] to only include the columns specified by `indices`
37    pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
38        let matrix = self.matrix.to_owned();
39
40        let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
41            *matrix.get(r, indices[c])
42        });
43
44        self.matrix = new;
45    }
46
47    /// Write the matrix to a CSV file
48    pub fn write(&self, path: &str) {
49        let mut writer = csv::Writer::from_path(path).unwrap();
50        for row in self.matrix.row_iter() {
51            writer
52                .write_record(row.iter().map(|x| x.to_string()))
53                .unwrap();
54        }
55    }
56
57    /// Write the psi matrix to a CSV writer
58    /// Each row represents a subject, each column represents a support point
59    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
60        let mut csv_writer = csv::Writer::from_writer(writer);
61
62        // Write each row
63        for i in 0..self.matrix.nrows() {
64            let row: Vec<f64> = (0..self.matrix.ncols())
65                .map(|j| *self.matrix.get(i, j))
66                .collect();
67            csv_writer.serialize(row)?;
68        }
69
70        csv_writer.flush()?;
71        Ok(())
72    }
73
74    /// Read psi matrix from a CSV reader
75    /// Each row represents a subject, each column represents a support point
76    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
77        let mut csv_reader = csv::Reader::from_reader(reader);
78        let mut rows: Vec<Vec<f64>> = Vec::new();
79
80        for result in csv_reader.deserialize() {
81            let row: Vec<f64> = result?;
82            rows.push(row);
83        }
84
85        if rows.is_empty() {
86            bail!("CSV file is empty");
87        }
88
89        let nrows = rows.len();
90        let ncols = rows[0].len();
91
92        // Verify all rows have the same length
93        for (i, row) in rows.iter().enumerate() {
94            if row.len() != ncols {
95                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
96            }
97        }
98
99        // Create matrix from rows
100        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
101
102        Ok(Psi { matrix: mat })
103    }
104}
105
106impl Default for Psi {
107    fn default() -> Self {
108        Psi::new()
109    }
110}
111
112impl From<Array2<f64>> for Psi {
113    fn from(array: Array2<f64>) -> Self {
114        let matrix = Mat::from_fn(array.nrows(), array.ncols(), |i, j| array[(i, j)]);
115        Psi { matrix }
116    }
117}
118
119impl From<Mat<f64>> for Psi {
120    fn from(matrix: Mat<f64>) -> Self {
121        Psi { matrix }
122    }
123}
124
125impl From<&Array2<f64>> for Psi {
126    fn from(array: &Array2<f64>) -> Self {
127        let matrix = Mat::from_fn(array.nrows(), array.ncols(), |i, j| array[(i, j)]);
128        Psi { matrix }
129    }
130}
131
132impl Serialize for Psi {
133    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
134    where
135        S: serde::Serializer,
136    {
137        use serde::ser::SerializeSeq;
138
139        let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
140
141        // Serialize each row as a vector
142        for i in 0..self.matrix.nrows() {
143            let row: Vec<f64> = (0..self.matrix.ncols())
144                .map(|j| *self.matrix.get(i, j))
145                .collect();
146            seq.serialize_element(&row)?;
147        }
148
149        seq.end()
150    }
151}
152
153impl<'de> Deserialize<'de> for Psi {
154    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
155    where
156        D: serde::Deserializer<'de>,
157    {
158        use serde::de::{SeqAccess, Visitor};
159        use std::fmt;
160
161        struct PsiVisitor;
162
163        impl<'de> Visitor<'de> for PsiVisitor {
164            type Value = Psi;
165
166            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
167                formatter.write_str("a sequence of rows (vectors of f64)")
168            }
169
170            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
171            where
172                A: SeqAccess<'de>,
173            {
174                let mut rows: Vec<Vec<f64>> = Vec::new();
175
176                while let Some(row) = seq.next_element::<Vec<f64>>()? {
177                    rows.push(row);
178                }
179
180                if rows.is_empty() {
181                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
182                }
183
184                let nrows = rows.len();
185                let ncols = rows[0].len();
186
187                // Verify all rows have the same length
188                for (i, row) in rows.iter().enumerate() {
189                    if row.len() != ncols {
190                        return Err(serde::de::Error::custom(format!(
191                            "Row {} has {} columns, expected {}",
192                            i,
193                            row.len(),
194                            ncols
195                        )));
196                    }
197                }
198
199                // Create matrix from rows
200                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
201
202                Ok(Psi { matrix: mat })
203            }
204        }
205
206        deserializer.deserialize_seq(PsiVisitor)
207    }
208}
209
210pub(crate) fn calculate_psi(
211    equation: &impl Equation,
212    subjects: &Data,
213    theta: &Theta,
214    error_models: &AssayErrorModels,
215    progress: bool,
216) -> Result<Psi> {
217    let tm = theta.matrix();
218    let theta_ndarray = Array2::from_shape_fn((tm.nrows(), tm.ncols()), |(i, j)| tm[(i, j)]);
219    let log_psi =
220        log_likelihood_matrix(equation, subjects, &theta_ndarray, error_models, progress)?;
221    let psi_ndarray = log_psi.mapv(f64::exp);
222
223    Ok(Psi::from(psi_ndarray))
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use ndarray::Array2;
230
231    #[test]
232    fn test_from_array2() {
233        // Create a test 2x3 array
234        let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
235
236        let psi = Psi::from(array.clone());
237
238        // Check dimensions
239        assert_eq!(psi.nspp(), 2);
240        assert_eq!(psi.nsub(), 3);
241
242        // Check values using faer matrix directly
243        let m = psi.matrix();
244        for i in 0..2 {
245            for j in 0..3 {
246                assert_eq!(m[(i, j)], array[[i, j]]);
247            }
248        }
249    }
250
251    #[test]
252    fn test_from_array2_ref() {
253        // Create a test 3x2 array
254        let array =
255            Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
256
257        let psi = Psi::from(&array);
258
259        // Check dimensions
260        assert_eq!(psi.nspp(), 3);
261        assert_eq!(psi.nsub(), 2);
262
263        // Check values using faer matrix directly
264        let m = psi.matrix();
265        for i in 0..3 {
266            for j in 0..2 {
267                assert_eq!(m[(i, j)], array[[i, j]]);
268            }
269        }
270    }
271
272    #[test]
273    fn test_nspp() {
274        // Test with a 4x2 matrix
275        let array =
276            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
277        let psi = Psi::from(array);
278
279        assert_eq!(psi.nspp(), 4);
280    }
281
282    #[test]
283    fn test_nspp_empty() {
284        // Test with empty matrix
285        let psi = Psi::new();
286        assert_eq!(psi.nspp(), 0);
287    }
288
289    #[test]
290    fn test_nspp_single_row() {
291        // Test with 1x3 matrix
292        let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
293        let psi = Psi::from(array);
294
295        assert_eq!(psi.nspp(), 1);
296    }
297
298    #[test]
299    fn test_nsub() {
300        // Test with a 2x5 matrix
301        let array = Array2::from_shape_vec(
302            (2, 5),
303            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
304        )
305        .unwrap();
306        let psi = Psi::from(array);
307
308        assert_eq!(psi.nsub(), 5);
309    }
310
311    #[test]
312    fn test_nsub_empty() {
313        // Test with empty matrix
314        let psi = Psi::new();
315        assert_eq!(psi.nsub(), 0);
316    }
317
318    #[test]
319    fn test_nsub_single_column() {
320        // Test with 3x1 matrix
321        let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
322        let psi = Psi::from(array);
323
324        assert_eq!(psi.nsub(), 1);
325    }
326
327    #[test]
328    fn test_from_implementations_consistency() {
329        // Test that both From implementations produce the same result
330        let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
331
332        let psi_from_owned = Psi::from(array.clone());
333        let psi_from_ref = Psi::from(&array);
334
335        // Both should have the same dimensions
336        assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
337        assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
338
339        // And the same values
340        let owned_m = psi_from_owned.matrix();
341        let ref_m = psi_from_ref.matrix();
342
343        for i in 0..2 {
344            for j in 0..3 {
345                assert_eq!(owned_m[(i, j)], ref_m[(i, j)]);
346            }
347        }
348    }
349}