pmcore/structs/
psi.rs

1use anyhow::Result;
2use faer::Mat;
3use faer_ext::IntoFaer;
4use faer_ext::IntoNdarray;
5use ndarray::{Array2, ArrayView2};
6use pharmsol::prelude::simulator::psi;
7use pharmsol::Data;
8use pharmsol::Equation;
9use pharmsol::ErrorModels;
10
11use super::theta::Theta;
12
13/// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column)
14#[derive(Debug, Clone, PartialEq)]
15pub struct Psi {
16    matrix: Mat<f64>,
17}
18
19impl Psi {
20    pub fn new() -> Self {
21        Psi { matrix: Mat::new() }
22    }
23
24    pub fn matrix(&self) -> &Mat<f64> {
25        &self.matrix
26    }
27
28    pub fn nspp(&self) -> usize {
29        self.matrix.nrows()
30    }
31
32    pub fn nsub(&self) -> usize {
33        self.matrix.ncols()
34    }
35
36    /// Modify the [Psi::matrix] to only include the columns specified by `indices`
37    pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
38        let matrix = self.matrix.to_owned();
39
40        let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
41            *matrix.get(r, indices[c])
42        });
43
44        self.matrix = new;
45    }
46
47    /// Write the matrix to a CSV file
48    pub fn write(&self, path: &str) {
49        let mut writer = csv::Writer::from_path(path).unwrap();
50        for row in self.matrix.row_iter() {
51            writer
52                .write_record(row.iter().map(|x| x.to_string()))
53                .unwrap();
54        }
55    }
56}
57
58impl Default for Psi {
59    fn default() -> Self {
60        Psi::new()
61    }
62}
63
64impl From<Array2<f64>> for Psi {
65    fn from(array: Array2<f64>) -> Self {
66        let matrix = array.view().into_faer().to_owned();
67        Psi { matrix }
68    }
69}
70
71impl From<Mat<f64>> for Psi {
72    fn from(matrix: Mat<f64>) -> Self {
73        Psi { matrix }
74    }
75}
76
77impl From<ArrayView2<'_, f64>> for Psi {
78    fn from(array_view: ArrayView2<'_, f64>) -> Self {
79        let matrix = array_view.into_faer().to_owned();
80        Psi { matrix }
81    }
82}
83
84impl From<&Array2<f64>> for Psi {
85    fn from(array: &Array2<f64>) -> Self {
86        let matrix = array.view().into_faer().to_owned();
87        Psi { matrix }
88    }
89}
90
91pub(crate) fn calculate_psi(
92    equation: &impl Equation,
93    subjects: &Data,
94    theta: &Theta,
95    error_models: &ErrorModels,
96    progress: bool,
97    cache: bool,
98) -> Result<Psi> {
99    let psi_ndarray = psi(
100        equation,
101        subjects,
102        &theta.matrix().clone().as_ref().into_ndarray().to_owned(),
103        error_models,
104        progress,
105        cache,
106    )?;
107
108    Ok(psi_ndarray.view().into())
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use ndarray::Array2;
115
116    #[test]
117    fn test_from_array2() {
118        // Create a test 2x3 array
119        let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
120
121        let psi = Psi::from(array.clone());
122
123        // Check dimensions
124        assert_eq!(psi.nspp(), 2);
125        assert_eq!(psi.nsub(), 3);
126
127        // Check values by converting back to ndarray and comparing
128        let result_array = psi.matrix().as_ref().into_ndarray();
129        for i in 0..2 {
130            for j in 0..3 {
131                assert_eq!(result_array[[i, j]], array[[i, j]]);
132            }
133        }
134    }
135
136    #[test]
137    fn test_from_array2_ref() {
138        // Create a test 3x2 array
139        let array =
140            Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
141
142        let psi = Psi::from(&array);
143
144        // Check dimensions
145        assert_eq!(psi.nspp(), 3);
146        assert_eq!(psi.nsub(), 2);
147
148        // Check values by converting back to ndarray and comparing
149        let result_array = psi.matrix().as_ref().into_ndarray();
150        for i in 0..3 {
151            for j in 0..2 {
152                assert_eq!(result_array[[i, j]], array[[i, j]]);
153            }
154        }
155    }
156
157    #[test]
158    fn test_nspp() {
159        // Test with a 4x2 matrix
160        let array =
161            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
162        let psi = Psi::from(array);
163
164        assert_eq!(psi.nspp(), 4);
165    }
166
167    #[test]
168    fn test_nspp_empty() {
169        // Test with empty matrix
170        let psi = Psi::new();
171        assert_eq!(psi.nspp(), 0);
172    }
173
174    #[test]
175    fn test_nspp_single_row() {
176        // Test with 1x3 matrix
177        let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
178        let psi = Psi::from(array);
179
180        assert_eq!(psi.nspp(), 1);
181    }
182
183    #[test]
184    fn test_nsub() {
185        // Test with a 2x5 matrix
186        let array = Array2::from_shape_vec(
187            (2, 5),
188            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
189        )
190        .unwrap();
191        let psi = Psi::from(array);
192
193        assert_eq!(psi.nsub(), 5);
194    }
195
196    #[test]
197    fn test_nsub_empty() {
198        // Test with empty matrix
199        let psi = Psi::new();
200        assert_eq!(psi.nsub(), 0);
201    }
202
203    #[test]
204    fn test_nsub_single_column() {
205        // Test with 3x1 matrix
206        let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
207        let psi = Psi::from(array);
208
209        assert_eq!(psi.nsub(), 1);
210    }
211
212    #[test]
213    fn test_from_implementations_consistency() {
214        // Test that both From implementations produce the same result
215        let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
216
217        let psi_from_owned = Psi::from(array.clone());
218        let psi_from_ref = Psi::from(&array);
219
220        // Both should have the same dimensions
221        assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
222        assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
223
224        // And the same values
225        let owned_array = psi_from_owned.matrix().as_ref().into_ndarray();
226        let ref_array = psi_from_ref.matrix().as_ref().into_ndarray();
227
228        for i in 0..2 {
229            for j in 0..3 {
230                assert_eq!(owned_array[[i, j]], ref_array[[i, j]]);
231            }
232        }
233    }
234}