1use anyhow::bail;
2use anyhow::Result;
3use faer::Mat;
4use faer_ext::IntoFaer;
5use faer_ext::IntoNdarray;
6use ndarray::{Array2, ArrayView2};
7use pharmsol::prelude::simulator::psi;
8use pharmsol::Data;
9use pharmsol::Equation;
10use pharmsol::ErrorModels;
11use serde::{Deserialize, Serialize};
12
13use super::theta::Theta;
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct Psi {
18 matrix: Mat<f64>,
19}
20
21impl Psi {
22 pub fn new() -> Self {
23 Psi { matrix: Mat::new() }
24 }
25
26 pub fn matrix(&self) -> &Mat<f64> {
27 &self.matrix
28 }
29
30 pub fn nspp(&self) -> usize {
31 self.matrix.nrows()
32 }
33
34 pub fn nsub(&self) -> usize {
35 self.matrix.ncols()
36 }
37
38 pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
40 let matrix = self.matrix.to_owned();
41
42 let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
43 *matrix.get(r, indices[c])
44 });
45
46 self.matrix = new;
47 }
48
49 pub fn write(&self, path: &str) {
51 let mut writer = csv::Writer::from_path(path).unwrap();
52 for row in self.matrix.row_iter() {
53 writer
54 .write_record(row.iter().map(|x| x.to_string()))
55 .unwrap();
56 }
57 }
58
59 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
62 let mut csv_writer = csv::Writer::from_writer(writer);
63
64 for i in 0..self.matrix.nrows() {
66 let row: Vec<f64> = (0..self.matrix.ncols())
67 .map(|j| *self.matrix.get(i, j))
68 .collect();
69 csv_writer.serialize(row)?;
70 }
71
72 csv_writer.flush()?;
73 Ok(())
74 }
75
76 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
79 let mut csv_reader = csv::Reader::from_reader(reader);
80 let mut rows: Vec<Vec<f64>> = Vec::new();
81
82 for result in csv_reader.deserialize() {
83 let row: Vec<f64> = result?;
84 rows.push(row);
85 }
86
87 if rows.is_empty() {
88 bail!("CSV file is empty");
89 }
90
91 let nrows = rows.len();
92 let ncols = rows[0].len();
93
94 for (i, row) in rows.iter().enumerate() {
96 if row.len() != ncols {
97 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
98 }
99 }
100
101 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
103
104 Ok(Psi { matrix: mat })
105 }
106}
107
108impl Default for Psi {
109 fn default() -> Self {
110 Psi::new()
111 }
112}
113
114impl From<Array2<f64>> for Psi {
115 fn from(array: Array2<f64>) -> Self {
116 let matrix = array.view().into_faer().to_owned();
117 Psi { matrix }
118 }
119}
120
121impl From<Mat<f64>> for Psi {
122 fn from(matrix: Mat<f64>) -> Self {
123 Psi { matrix }
124 }
125}
126
127impl From<ArrayView2<'_, f64>> for Psi {
128 fn from(array_view: ArrayView2<'_, f64>) -> Self {
129 let matrix = array_view.into_faer().to_owned();
130 Psi { matrix }
131 }
132}
133
134impl From<&Array2<f64>> for Psi {
135 fn from(array: &Array2<f64>) -> Self {
136 let matrix = array.view().into_faer().to_owned();
137 Psi { matrix }
138 }
139}
140
141impl Serialize for Psi {
142 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
143 where
144 S: serde::Serializer,
145 {
146 use serde::ser::SerializeSeq;
147
148 let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
149
150 for i in 0..self.matrix.nrows() {
152 let row: Vec<f64> = (0..self.matrix.ncols())
153 .map(|j| *self.matrix.get(i, j))
154 .collect();
155 seq.serialize_element(&row)?;
156 }
157
158 seq.end()
159 }
160}
161
162impl<'de> Deserialize<'de> for Psi {
163 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
164 where
165 D: serde::Deserializer<'de>,
166 {
167 use serde::de::{SeqAccess, Visitor};
168 use std::fmt;
169
170 struct PsiVisitor;
171
172 impl<'de> Visitor<'de> for PsiVisitor {
173 type Value = Psi;
174
175 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176 formatter.write_str("a sequence of rows (vectors of f64)")
177 }
178
179 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
180 where
181 A: SeqAccess<'de>,
182 {
183 let mut rows: Vec<Vec<f64>> = Vec::new();
184
185 while let Some(row) = seq.next_element::<Vec<f64>>()? {
186 rows.push(row);
187 }
188
189 if rows.is_empty() {
190 return Err(serde::de::Error::custom("Empty matrix not allowed"));
191 }
192
193 let nrows = rows.len();
194 let ncols = rows[0].len();
195
196 for (i, row) in rows.iter().enumerate() {
198 if row.len() != ncols {
199 return Err(serde::de::Error::custom(format!(
200 "Row {} has {} columns, expected {}",
201 i,
202 row.len(),
203 ncols
204 )));
205 }
206 }
207
208 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
210
211 Ok(Psi { matrix: mat })
212 }
213 }
214
215 deserializer.deserialize_seq(PsiVisitor)
216 }
217}
218
219pub(crate) fn calculate_psi(
220 equation: &impl Equation,
221 subjects: &Data,
222 theta: &Theta,
223 error_models: &ErrorModels,
224 progress: bool,
225 cache: bool,
226) -> Result<Psi> {
227 let psi_ndarray = psi(
228 equation,
229 subjects,
230 &theta.matrix().clone().as_ref().into_ndarray().to_owned(),
231 error_models,
232 progress,
233 cache,
234 )?;
235
236 Ok(psi_ndarray.view().into())
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use ndarray::Array2;
243
244 #[test]
245 fn test_from_array2() {
246 let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
248
249 let psi = Psi::from(array.clone());
250
251 assert_eq!(psi.nspp(), 2);
253 assert_eq!(psi.nsub(), 3);
254
255 let result_array = psi.matrix().as_ref().into_ndarray();
257 for i in 0..2 {
258 for j in 0..3 {
259 assert_eq!(result_array[[i, j]], array[[i, j]]);
260 }
261 }
262 }
263
264 #[test]
265 fn test_from_array2_ref() {
266 let array =
268 Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
269
270 let psi = Psi::from(&array);
271
272 assert_eq!(psi.nspp(), 3);
274 assert_eq!(psi.nsub(), 2);
275
276 let result_array = psi.matrix().as_ref().into_ndarray();
278 for i in 0..3 {
279 for j in 0..2 {
280 assert_eq!(result_array[[i, j]], array[[i, j]]);
281 }
282 }
283 }
284
285 #[test]
286 fn test_nspp() {
287 let array =
289 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
290 let psi = Psi::from(array);
291
292 assert_eq!(psi.nspp(), 4);
293 }
294
295 #[test]
296 fn test_nspp_empty() {
297 let psi = Psi::new();
299 assert_eq!(psi.nspp(), 0);
300 }
301
302 #[test]
303 fn test_nspp_single_row() {
304 let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
306 let psi = Psi::from(array);
307
308 assert_eq!(psi.nspp(), 1);
309 }
310
311 #[test]
312 fn test_nsub() {
313 let array = Array2::from_shape_vec(
315 (2, 5),
316 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
317 )
318 .unwrap();
319 let psi = Psi::from(array);
320
321 assert_eq!(psi.nsub(), 5);
322 }
323
324 #[test]
325 fn test_nsub_empty() {
326 let psi = Psi::new();
328 assert_eq!(psi.nsub(), 0);
329 }
330
331 #[test]
332 fn test_nsub_single_column() {
333 let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
335 let psi = Psi::from(array);
336
337 assert_eq!(psi.nsub(), 1);
338 }
339
340 #[test]
341 fn test_from_implementations_consistency() {
342 let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
344
345 let psi_from_owned = Psi::from(array.clone());
346 let psi_from_ref = Psi::from(&array);
347
348 assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
350 assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
351
352 let owned_array = psi_from_owned.matrix().as_ref().into_ndarray();
354 let ref_array = psi_from_ref.matrix().as_ref().into_ndarray();
355
356 for i in 0..2 {
357 for j in 0..3 {
358 assert_eq!(owned_array[[i, j]], ref_array[[i, j]]);
359 }
360 }
361 }
362}