1use anyhow::Result;
2use faer::Mat;
3use faer_ext::IntoFaer;
4use faer_ext::IntoNdarray;
5use ndarray::{Array2, ArrayView2};
6use pharmsol::prelude::simulator::psi;
7use pharmsol::Data;
8use pharmsol::Equation;
9use pharmsol::ErrorModels;
10
11use super::theta::Theta;
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct Psi {
16 matrix: Mat<f64>,
17}
18
19impl Psi {
20 pub fn new() -> Self {
21 Psi { matrix: Mat::new() }
22 }
23
24 pub fn matrix(&self) -> &Mat<f64> {
25 &self.matrix
26 }
27
28 pub fn nspp(&self) -> usize {
29 self.matrix.nrows()
30 }
31
32 pub fn nsub(&self) -> usize {
33 self.matrix.ncols()
34 }
35
36 pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
38 let matrix = self.matrix.to_owned();
39
40 let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
41 *matrix.get(r, indices[c])
42 });
43
44 self.matrix = new;
45 }
46
47 pub fn write(&self, path: &str) {
49 let mut writer = csv::Writer::from_path(path).unwrap();
50 for row in self.matrix.row_iter() {
51 writer
52 .write_record(row.iter().map(|x| x.to_string()))
53 .unwrap();
54 }
55 }
56}
57
58impl Default for Psi {
59 fn default() -> Self {
60 Psi::new()
61 }
62}
63
64impl From<Array2<f64>> for Psi {
65 fn from(array: Array2<f64>) -> Self {
66 let matrix = array.view().into_faer().to_owned();
67 Psi { matrix }
68 }
69}
70
71impl From<Mat<f64>> for Psi {
72 fn from(matrix: Mat<f64>) -> Self {
73 Psi { matrix }
74 }
75}
76
77impl From<ArrayView2<'_, f64>> for Psi {
78 fn from(array_view: ArrayView2<'_, f64>) -> Self {
79 let matrix = array_view.into_faer().to_owned();
80 Psi { matrix }
81 }
82}
83
84impl From<&Array2<f64>> for Psi {
85 fn from(array: &Array2<f64>) -> Self {
86 let matrix = array.view().into_faer().to_owned();
87 Psi { matrix }
88 }
89}
90
91pub(crate) fn calculate_psi(
92 equation: &impl Equation,
93 subjects: &Data,
94 theta: &Theta,
95 error_models: &ErrorModels,
96 progress: bool,
97 cache: bool,
98) -> Result<Psi> {
99 let psi_ndarray = psi(
100 equation,
101 subjects,
102 &theta.matrix().clone().as_ref().into_ndarray().to_owned(),
103 error_models,
104 progress,
105 cache,
106 )?;
107
108 Ok(psi_ndarray.view().into())
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use ndarray::Array2;
115
116 #[test]
117 fn test_from_array2() {
118 let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
120
121 let psi = Psi::from(array.clone());
122
123 assert_eq!(psi.nspp(), 2);
125 assert_eq!(psi.nsub(), 3);
126
127 let result_array = psi.matrix().as_ref().into_ndarray();
129 for i in 0..2 {
130 for j in 0..3 {
131 assert_eq!(result_array[[i, j]], array[[i, j]]);
132 }
133 }
134 }
135
136 #[test]
137 fn test_from_array2_ref() {
138 let array =
140 Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
141
142 let psi = Psi::from(&array);
143
144 assert_eq!(psi.nspp(), 3);
146 assert_eq!(psi.nsub(), 2);
147
148 let result_array = psi.matrix().as_ref().into_ndarray();
150 for i in 0..3 {
151 for j in 0..2 {
152 assert_eq!(result_array[[i, j]], array[[i, j]]);
153 }
154 }
155 }
156
157 #[test]
158 fn test_nspp() {
159 let array =
161 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
162 let psi = Psi::from(array);
163
164 assert_eq!(psi.nspp(), 4);
165 }
166
167 #[test]
168 fn test_nspp_empty() {
169 let psi = Psi::new();
171 assert_eq!(psi.nspp(), 0);
172 }
173
174 #[test]
175 fn test_nspp_single_row() {
176 let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
178 let psi = Psi::from(array);
179
180 assert_eq!(psi.nspp(), 1);
181 }
182
183 #[test]
184 fn test_nsub() {
185 let array = Array2::from_shape_vec(
187 (2, 5),
188 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
189 )
190 .unwrap();
191 let psi = Psi::from(array);
192
193 assert_eq!(psi.nsub(), 5);
194 }
195
196 #[test]
197 fn test_nsub_empty() {
198 let psi = Psi::new();
200 assert_eq!(psi.nsub(), 0);
201 }
202
203 #[test]
204 fn test_nsub_single_column() {
205 let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
207 let psi = Psi::from(array);
208
209 assert_eq!(psi.nsub(), 1);
210 }
211
212 #[test]
213 fn test_from_implementations_consistency() {
214 let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
216
217 let psi_from_owned = Psi::from(array.clone());
218 let psi_from_ref = Psi::from(&array);
219
220 assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
222 assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
223
224 let owned_array = psi_from_owned.matrix().as_ref().into_ndarray();
226 let ref_array = psi_from_ref.matrix().as_ref().into_ndarray();
227
228 for i in 0..2 {
229 for j in 0..3 {
230 assert_eq!(owned_array[[i, j]], ref_array[[i, j]]);
231 }
232 }
233 }
234}