pmcore/structs/
psi.rs

1use anyhow::bail;
2use anyhow::Result;
3use faer::Mat;
4use faer_ext::IntoFaer;
5use faer_ext::IntoNdarray;
6use ndarray::{Array2, ArrayView2};
7use pharmsol::prelude::simulator::psi;
8use pharmsol::Data;
9use pharmsol::Equation;
10use pharmsol::ErrorModels;
11use serde::{Deserialize, Serialize};
12
13use super::theta::Theta;
14
15/// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column)
16#[derive(Debug, Clone, PartialEq)]
17pub struct Psi {
18    matrix: Mat<f64>,
19}
20
21impl Psi {
22    pub fn new() -> Self {
23        Psi { matrix: Mat::new() }
24    }
25
26    pub fn matrix(&self) -> &Mat<f64> {
27        &self.matrix
28    }
29
30    pub fn nspp(&self) -> usize {
31        self.matrix.nrows()
32    }
33
34    pub fn nsub(&self) -> usize {
35        self.matrix.ncols()
36    }
37
38    /// Modify the [Psi::matrix] to only include the columns specified by `indices`
39    pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
40        let matrix = self.matrix.to_owned();
41
42        let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
43            *matrix.get(r, indices[c])
44        });
45
46        self.matrix = new;
47    }
48
49    /// Write the matrix to a CSV file
50    pub fn write(&self, path: &str) {
51        let mut writer = csv::Writer::from_path(path).unwrap();
52        for row in self.matrix.row_iter() {
53            writer
54                .write_record(row.iter().map(|x| x.to_string()))
55                .unwrap();
56        }
57    }
58
59    /// Write the psi matrix to a CSV writer
60    /// Each row represents a subject, each column represents a support point
61    pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
62        let mut csv_writer = csv::Writer::from_writer(writer);
63
64        // Write each row
65        for i in 0..self.matrix.nrows() {
66            let row: Vec<f64> = (0..self.matrix.ncols())
67                .map(|j| *self.matrix.get(i, j))
68                .collect();
69            csv_writer.serialize(row)?;
70        }
71
72        csv_writer.flush()?;
73        Ok(())
74    }
75
76    /// Read psi matrix from a CSV reader
77    /// Each row represents a subject, each column represents a support point
78    pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
79        let mut csv_reader = csv::Reader::from_reader(reader);
80        let mut rows: Vec<Vec<f64>> = Vec::new();
81
82        for result in csv_reader.deserialize() {
83            let row: Vec<f64> = result?;
84            rows.push(row);
85        }
86
87        if rows.is_empty() {
88            bail!("CSV file is empty");
89        }
90
91        let nrows = rows.len();
92        let ncols = rows[0].len();
93
94        // Verify all rows have the same length
95        for (i, row) in rows.iter().enumerate() {
96            if row.len() != ncols {
97                bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
98            }
99        }
100
101        // Create matrix from rows
102        let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
103
104        Ok(Psi { matrix: mat })
105    }
106}
107
108impl Default for Psi {
109    fn default() -> Self {
110        Psi::new()
111    }
112}
113
114impl From<Array2<f64>> for Psi {
115    fn from(array: Array2<f64>) -> Self {
116        let matrix = array.view().into_faer().to_owned();
117        Psi { matrix }
118    }
119}
120
121impl From<Mat<f64>> for Psi {
122    fn from(matrix: Mat<f64>) -> Self {
123        Psi { matrix }
124    }
125}
126
127impl From<ArrayView2<'_, f64>> for Psi {
128    fn from(array_view: ArrayView2<'_, f64>) -> Self {
129        let matrix = array_view.into_faer().to_owned();
130        Psi { matrix }
131    }
132}
133
134impl From<&Array2<f64>> for Psi {
135    fn from(array: &Array2<f64>) -> Self {
136        let matrix = array.view().into_faer().to_owned();
137        Psi { matrix }
138    }
139}
140
141impl Serialize for Psi {
142    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
143    where
144        S: serde::Serializer,
145    {
146        use serde::ser::SerializeSeq;
147
148        let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
149
150        // Serialize each row as a vector
151        for i in 0..self.matrix.nrows() {
152            let row: Vec<f64> = (0..self.matrix.ncols())
153                .map(|j| *self.matrix.get(i, j))
154                .collect();
155            seq.serialize_element(&row)?;
156        }
157
158        seq.end()
159    }
160}
161
162impl<'de> Deserialize<'de> for Psi {
163    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
164    where
165        D: serde::Deserializer<'de>,
166    {
167        use serde::de::{SeqAccess, Visitor};
168        use std::fmt;
169
170        struct PsiVisitor;
171
172        impl<'de> Visitor<'de> for PsiVisitor {
173            type Value = Psi;
174
175            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176                formatter.write_str("a sequence of rows (vectors of f64)")
177            }
178
179            fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
180            where
181                A: SeqAccess<'de>,
182            {
183                let mut rows: Vec<Vec<f64>> = Vec::new();
184
185                while let Some(row) = seq.next_element::<Vec<f64>>()? {
186                    rows.push(row);
187                }
188
189                if rows.is_empty() {
190                    return Err(serde::de::Error::custom("Empty matrix not allowed"));
191                }
192
193                let nrows = rows.len();
194                let ncols = rows[0].len();
195
196                // Verify all rows have the same length
197                for (i, row) in rows.iter().enumerate() {
198                    if row.len() != ncols {
199                        return Err(serde::de::Error::custom(format!(
200                            "Row {} has {} columns, expected {}",
201                            i,
202                            row.len(),
203                            ncols
204                        )));
205                    }
206                }
207
208                // Create matrix from rows
209                let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
210
211                Ok(Psi { matrix: mat })
212            }
213        }
214
215        deserializer.deserialize_seq(PsiVisitor)
216    }
217}
218
219pub(crate) fn calculate_psi(
220    equation: &impl Equation,
221    subjects: &Data,
222    theta: &Theta,
223    error_models: &ErrorModels,
224    progress: bool,
225    cache: bool,
226) -> Result<Psi> {
227    let psi_ndarray = psi(
228        equation,
229        subjects,
230        &theta.matrix().clone().as_ref().into_ndarray().to_owned(),
231        error_models,
232        progress,
233        cache,
234    )?;
235
236    Ok(psi_ndarray.view().into())
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use ndarray::Array2;
243
244    #[test]
245    fn test_from_array2() {
246        // Create a test 2x3 array
247        let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
248
249        let psi = Psi::from(array.clone());
250
251        // Check dimensions
252        assert_eq!(psi.nspp(), 2);
253        assert_eq!(psi.nsub(), 3);
254
255        // Check values by converting back to ndarray and comparing
256        let result_array = psi.matrix().as_ref().into_ndarray();
257        for i in 0..2 {
258            for j in 0..3 {
259                assert_eq!(result_array[[i, j]], array[[i, j]]);
260            }
261        }
262    }
263
264    #[test]
265    fn test_from_array2_ref() {
266        // Create a test 3x2 array
267        let array =
268            Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
269
270        let psi = Psi::from(&array);
271
272        // Check dimensions
273        assert_eq!(psi.nspp(), 3);
274        assert_eq!(psi.nsub(), 2);
275
276        // Check values by converting back to ndarray and comparing
277        let result_array = psi.matrix().as_ref().into_ndarray();
278        for i in 0..3 {
279            for j in 0..2 {
280                assert_eq!(result_array[[i, j]], array[[i, j]]);
281            }
282        }
283    }
284
285    #[test]
286    fn test_nspp() {
287        // Test with a 4x2 matrix
288        let array =
289            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
290        let psi = Psi::from(array);
291
292        assert_eq!(psi.nspp(), 4);
293    }
294
295    #[test]
296    fn test_nspp_empty() {
297        // Test with empty matrix
298        let psi = Psi::new();
299        assert_eq!(psi.nspp(), 0);
300    }
301
302    #[test]
303    fn test_nspp_single_row() {
304        // Test with 1x3 matrix
305        let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
306        let psi = Psi::from(array);
307
308        assert_eq!(psi.nspp(), 1);
309    }
310
311    #[test]
312    fn test_nsub() {
313        // Test with a 2x5 matrix
314        let array = Array2::from_shape_vec(
315            (2, 5),
316            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
317        )
318        .unwrap();
319        let psi = Psi::from(array);
320
321        assert_eq!(psi.nsub(), 5);
322    }
323
324    #[test]
325    fn test_nsub_empty() {
326        // Test with empty matrix
327        let psi = Psi::new();
328        assert_eq!(psi.nsub(), 0);
329    }
330
331    #[test]
332    fn test_nsub_single_column() {
333        // Test with 3x1 matrix
334        let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
335        let psi = Psi::from(array);
336
337        assert_eq!(psi.nsub(), 1);
338    }
339
340    #[test]
341    fn test_from_implementations_consistency() {
342        // Test that both From implementations produce the same result
343        let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
344
345        let psi_from_owned = Psi::from(array.clone());
346        let psi_from_ref = Psi::from(&array);
347
348        // Both should have the same dimensions
349        assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
350        assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
351
352        // And the same values
353        let owned_array = psi_from_owned.matrix().as_ref().into_ndarray();
354        let ref_array = psi_from_ref.matrix().as_ref().into_ndarray();
355
356        for i in 0..2 {
357            for j in 0..3 {
358                assert_eq!(owned_array[[i, j]], ref_array[[i, j]]);
359            }
360        }
361    }
362}