1use anyhow::bail;
2use anyhow::Result;
3use faer::Mat;
4use ndarray::Array2;
5use pharmsol::prelude::simulator::log_likelihood_matrix;
6use pharmsol::AssayErrorModels;
7use pharmsol::Data;
8use pharmsol::Equation;
9use serde::{Deserialize, Serialize};
10
11use super::theta::Theta;
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct Psi {
16 matrix: Mat<f64>,
17}
18
19impl Psi {
20 pub fn new() -> Self {
21 Psi { matrix: Mat::new() }
22 }
23
24 pub fn matrix(&self) -> &Mat<f64> {
25 &self.matrix
26 }
27
28 pub fn nspp(&self) -> usize {
29 self.matrix.nrows()
30 }
31
32 pub fn nsub(&self) -> usize {
33 self.matrix.ncols()
34 }
35
36 pub(crate) fn filter_column_indices(&mut self, indices: &[usize]) {
38 let matrix = self.matrix.to_owned();
39
40 let new = Mat::from_fn(matrix.nrows(), indices.len(), |r, c| {
41 *matrix.get(r, indices[c])
42 });
43
44 self.matrix = new;
45 }
46
47 pub fn write(&self, path: &str) {
49 let mut writer = csv::Writer::from_path(path).unwrap();
50 for row in self.matrix.row_iter() {
51 writer
52 .write_record(row.iter().map(|x| x.to_string()))
53 .unwrap();
54 }
55 }
56
57 pub fn to_csv<W: std::io::Write>(&self, writer: W) -> Result<()> {
60 let mut csv_writer = csv::Writer::from_writer(writer);
61
62 for i in 0..self.matrix.nrows() {
64 let row: Vec<f64> = (0..self.matrix.ncols())
65 .map(|j| *self.matrix.get(i, j))
66 .collect();
67 csv_writer.serialize(row)?;
68 }
69
70 csv_writer.flush()?;
71 Ok(())
72 }
73
74 pub fn from_csv<R: std::io::Read>(reader: R) -> Result<Self> {
77 let mut csv_reader = csv::Reader::from_reader(reader);
78 let mut rows: Vec<Vec<f64>> = Vec::new();
79
80 for result in csv_reader.deserialize() {
81 let row: Vec<f64> = result?;
82 rows.push(row);
83 }
84
85 if rows.is_empty() {
86 bail!("CSV file is empty");
87 }
88
89 let nrows = rows.len();
90 let ncols = rows[0].len();
91
92 for (i, row) in rows.iter().enumerate() {
94 if row.len() != ncols {
95 bail!("Row {} has {} columns, expected {}", i, row.len(), ncols);
96 }
97 }
98
99 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
101
102 Ok(Psi { matrix: mat })
103 }
104}
105
106impl Default for Psi {
107 fn default() -> Self {
108 Psi::new()
109 }
110}
111
112impl From<Array2<f64>> for Psi {
113 fn from(array: Array2<f64>) -> Self {
114 let matrix = Mat::from_fn(array.nrows(), array.ncols(), |i, j| array[(i, j)]);
115 Psi { matrix }
116 }
117}
118
119impl From<Mat<f64>> for Psi {
120 fn from(matrix: Mat<f64>) -> Self {
121 Psi { matrix }
122 }
123}
124
125impl From<&Array2<f64>> for Psi {
126 fn from(array: &Array2<f64>) -> Self {
127 let matrix = Mat::from_fn(array.nrows(), array.ncols(), |i, j| array[(i, j)]);
128 Psi { matrix }
129 }
130}
131
132impl Serialize for Psi {
133 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
134 where
135 S: serde::Serializer,
136 {
137 use serde::ser::SerializeSeq;
138
139 let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?;
140
141 for i in 0..self.matrix.nrows() {
143 let row: Vec<f64> = (0..self.matrix.ncols())
144 .map(|j| *self.matrix.get(i, j))
145 .collect();
146 seq.serialize_element(&row)?;
147 }
148
149 seq.end()
150 }
151}
152
153impl<'de> Deserialize<'de> for Psi {
154 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
155 where
156 D: serde::Deserializer<'de>,
157 {
158 use serde::de::{SeqAccess, Visitor};
159 use std::fmt;
160
161 struct PsiVisitor;
162
163 impl<'de> Visitor<'de> for PsiVisitor {
164 type Value = Psi;
165
166 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
167 formatter.write_str("a sequence of rows (vectors of f64)")
168 }
169
170 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
171 where
172 A: SeqAccess<'de>,
173 {
174 let mut rows: Vec<Vec<f64>> = Vec::new();
175
176 while let Some(row) = seq.next_element::<Vec<f64>>()? {
177 rows.push(row);
178 }
179
180 if rows.is_empty() {
181 return Err(serde::de::Error::custom("Empty matrix not allowed"));
182 }
183
184 let nrows = rows.len();
185 let ncols = rows[0].len();
186
187 for (i, row) in rows.iter().enumerate() {
189 if row.len() != ncols {
190 return Err(serde::de::Error::custom(format!(
191 "Row {} has {} columns, expected {}",
192 i,
193 row.len(),
194 ncols
195 )));
196 }
197 }
198
199 let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]);
201
202 Ok(Psi { matrix: mat })
203 }
204 }
205
206 deserializer.deserialize_seq(PsiVisitor)
207 }
208}
209
210pub(crate) fn calculate_psi(
211 equation: &impl Equation,
212 subjects: &Data,
213 theta: &Theta,
214 error_models: &AssayErrorModels,
215 progress: bool,
216) -> Result<Psi> {
217 let tm = theta.matrix();
218 let theta_ndarray = Array2::from_shape_fn((tm.nrows(), tm.ncols()), |(i, j)| tm[(i, j)]);
219 let log_psi =
220 log_likelihood_matrix(equation, subjects, &theta_ndarray, error_models, progress)?;
221 let psi_ndarray = log_psi.mapv(f64::exp);
222
223 Ok(Psi::from(psi_ndarray))
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use ndarray::Array2;
230
231 #[test]
232 fn test_from_array2() {
233 let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
235
236 let psi = Psi::from(array.clone());
237
238 assert_eq!(psi.nspp(), 2);
240 assert_eq!(psi.nsub(), 3);
241
242 let m = psi.matrix();
244 for i in 0..2 {
245 for j in 0..3 {
246 assert_eq!(m[(i, j)], array[[i, j]]);
247 }
248 }
249 }
250
251 #[test]
252 fn test_from_array2_ref() {
253 let array =
255 Array2::from_shape_vec((3, 2), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0]).unwrap();
256
257 let psi = Psi::from(&array);
258
259 assert_eq!(psi.nspp(), 3);
261 assert_eq!(psi.nsub(), 2);
262
263 let m = psi.matrix();
265 for i in 0..3 {
266 for j in 0..2 {
267 assert_eq!(m[(i, j)], array[[i, j]]);
268 }
269 }
270 }
271
272 #[test]
273 fn test_nspp() {
274 let array =
276 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
277 let psi = Psi::from(array);
278
279 assert_eq!(psi.nspp(), 4);
280 }
281
282 #[test]
283 fn test_nspp_empty() {
284 let psi = Psi::new();
286 assert_eq!(psi.nspp(), 0);
287 }
288
289 #[test]
290 fn test_nspp_single_row() {
291 let array = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
293 let psi = Psi::from(array);
294
295 assert_eq!(psi.nspp(), 1);
296 }
297
298 #[test]
299 fn test_nsub() {
300 let array = Array2::from_shape_vec(
302 (2, 5),
303 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
304 )
305 .unwrap();
306 let psi = Psi::from(array);
307
308 assert_eq!(psi.nsub(), 5);
309 }
310
311 #[test]
312 fn test_nsub_empty() {
313 let psi = Psi::new();
315 assert_eq!(psi.nsub(), 0);
316 }
317
318 #[test]
319 fn test_nsub_single_column() {
320 let array = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
322 let psi = Psi::from(array);
323
324 assert_eq!(psi.nsub(), 1);
325 }
326
327 #[test]
328 fn test_from_implementations_consistency() {
329 let array = Array2::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).unwrap();
331
332 let psi_from_owned = Psi::from(array.clone());
333 let psi_from_ref = Psi::from(&array);
334
335 assert_eq!(psi_from_owned.nspp(), psi_from_ref.nspp());
337 assert_eq!(psi_from_owned.nsub(), psi_from_ref.nsub());
338
339 let owned_m = psi_from_owned.matrix();
341 let ref_m = psi_from_ref.matrix();
342
343 for i in 0..2 {
344 for j in 0..3 {
345 assert_eq!(owned_m[(i, j)], ref_m[(i, j)]);
346 }
347 }
348 }
349}