1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::predictions::NPPredictions;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use crate::structs::weights::Weights;
9use anyhow::{bail, Context, Result};
10use csv::WriterBuilder;
11use faer::linalg::zip::IntoView;
12use faer_ext::IntoNdarray;
13use ndarray::{Array, Array1, Array2, Axis};
14use pharmsol::prelude::data::*;
15use pharmsol::prelude::simulator::Equation;
16use serde::Serialize;
17use std::fs::{create_dir_all, File, OpenOptions};
18use std::path::{Path, PathBuf};
19
20pub mod cycles;
21pub mod posterior;
22pub mod predictions;
23
24use posterior::posterior;
25
26#[derive(Debug)]
29pub struct NPResult<E: Equation> {
30 equation: E,
31 data: Data,
32 theta: Theta,
33 psi: Psi,
34 w: Weights,
35 objf: f64,
36 cycles: usize,
37 status: Status,
38 par_names: Vec<String>,
39 settings: Settings,
40 cyclelog: CycleLog,
41}
42
43#[allow(clippy::too_many_arguments)]
44impl<E: Equation> NPResult<E> {
45 pub fn new(
47 equation: E,
48 data: Data,
49 theta: Theta,
50 psi: Psi,
51 w: Weights,
52 objf: f64,
53 cycles: usize,
54 status: Status,
55 settings: Settings,
56 cyclelog: CycleLog,
57 ) -> Self {
58 let par_names = settings.parameters().names();
61
62 Self {
63 equation,
64 data,
65 theta,
66 psi,
67 w,
68 objf,
69 cycles,
70 status,
71 par_names,
72 settings,
73 cyclelog,
74 }
75 }
76
77 pub fn cycles(&self) -> usize {
78 self.cycles
79 }
80
81 pub fn objf(&self) -> f64 {
82 self.objf
83 }
84
85 pub fn converged(&self) -> bool {
86 self.status == Status::Converged
87 }
88
89 pub fn get_theta(&self) -> &Theta {
90 &self.theta
91 }
92
93 pub fn psi(&self) -> &Psi {
95 &self.psi
96 }
97
98 pub fn weights(&self) -> &Weights {
100 &self.w
101 }
102
103 pub fn write_outputs(&self) -> Result<()> {
104 if self.settings.output().write {
105 tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
106 self.settings.write()?;
107 let idelta: f64 = self.settings.predictions().idelta;
108 let tad = self.settings.predictions().tad;
109 self.cyclelog.write(&self.settings)?;
110 self.write_theta().context("Failed to write theta")?;
111 self.write_covs().context("Failed to write covariates")?;
112 self.write_predictions(idelta, tad)
113 .context("Failed to write predictions")?;
114 self.write_posterior()
115 .context("Failed to write posterior")?;
116 }
117 Ok(())
118 }
119
120 pub fn write_obspred(&self) -> Result<()> {
122 tracing::debug!("Writing observations and predictions...");
123
124 #[derive(Debug, Clone, Serialize)]
125 struct Row {
126 id: String,
127 time: f64,
128 outeq: usize,
129 block: usize,
130 obs: Option<f64>,
131 pop_mean: f64,
132 pop_median: f64,
133 post_mean: f64,
134 post_median: f64,
135 }
136
137 let theta: Array2<f64> = self
138 .theta
139 .matrix()
140 .clone()
141 .as_mut()
142 .into_ndarray()
143 .to_owned();
144 let w: Array1<f64> = self
145 .w
146 .weights()
147 .clone()
148 .into_view()
149 .iter()
150 .cloned()
151 .collect();
152 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
153
154 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
155 .context("Failed to calculate posterior mean and median")?;
156
157 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
158 .context("Failed to calculate posterior mean and median")?;
159
160 let subjects = self.data.subjects();
161 if subjects.len() != post_mean.nrows() {
162 bail!(
163 "Number of subjects: {} and number of posterior means: {} do not match",
164 subjects.len(),
165 post_mean.nrows()
166 );
167 }
168
169 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
170 let mut writer = WriterBuilder::new()
171 .has_headers(true)
172 .from_writer(&outputfile.file);
173
174 for (i, subject) in subjects.iter().enumerate() {
175 for occasion in subject.occasions() {
176 let id = subject.id();
177 let occ = occasion.index();
178
179 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
180
181 let pop_mean_pred = self
183 .equation
184 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
185 .0
186 .get_predictions()
187 .clone();
188
189 let pop_median_pred = self
190 .equation
191 .simulate_subject(&subject, &pop_median.to_vec(), None)?
192 .0
193 .get_predictions()
194 .clone();
195
196 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
198 let post_mean_pred = self
199 .equation
200 .simulate_subject(&subject, &post_mean_spp, None)?
201 .0
202 .get_predictions()
203 .clone();
204 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
205 let post_median_pred = self
206 .equation
207 .simulate_subject(&subject, &post_median_spp, None)?
208 .0
209 .get_predictions()
210 .clone();
211 assert_eq!(
212 pop_mean_pred.len(),
213 pop_median_pred.len(),
214 "The number of predictions do not match (pop_mean vs pop_median)"
215 );
216
217 assert_eq!(
218 post_mean_pred.len(),
219 post_median_pred.len(),
220 "The number of predictions do not match (post_mean vs post_median)"
221 );
222
223 assert_eq!(
224 pop_mean_pred.len(),
225 post_mean_pred.len(),
226 "The number of predictions do not match (pop_mean vs post_mean)"
227 );
228
229 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
230 pop_mean_pred
231 .iter()
232 .zip(pop_median_pred.iter())
233 .zip(post_mean_pred.iter())
234 .zip(post_median_pred.iter())
235 {
236 let row = Row {
237 id: id.clone(),
238 time: pop_mean_pred.time(),
239 outeq: pop_mean_pred.outeq(),
240 block: occ,
241 obs: pop_mean_pred.observation(),
242 pop_mean: pop_mean_pred.prediction(),
243 pop_median: pop_median_pred.prediction(),
244 post_mean: post_mean_pred.prediction(),
245 post_median: post_median_pred.prediction(),
246 };
247 writer.serialize(row)?;
248 }
249 }
250 }
251 writer.flush()?;
252 tracing::debug!(
253 "Observations with predictions written to {:?}",
254 &outputfile.relative_path()
255 );
256 Ok(())
257 }
258
259 pub fn write_theta(&self) -> Result<()> {
262 tracing::debug!("Writing population parameter distribution...");
263
264 let theta = &self.theta;
265 let w: Vec<f64> = self
266 .w
267 .weights()
268 .clone()
269 .into_view()
270 .iter()
271 .cloned()
272 .collect();
273
274 if w.len() != theta.matrix().nrows() {
275 bail!(
276 "Number of weights ({}) and number of support points ({}) do not match.",
277 w.len(),
278 theta.matrix().nrows()
279 );
280 }
281
282 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
283 .context("Failed to create output file for theta")?;
284
285 let mut writer = WriterBuilder::new()
286 .has_headers(true)
287 .from_writer(&outputfile.file);
288
289 let mut theta_header = self.par_names.clone();
291 theta_header.push("prob".to_string());
292 writer.write_record(&theta_header)?;
293
294 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
296 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
297 row.push(w_val.to_string());
298 writer.write_record(&row)?;
299 }
300 writer.flush()?;
301 tracing::debug!(
302 "Population parameter distribution written to {:?}",
303 &outputfile.relative_path()
304 );
305 Ok(())
306 }
307
308 pub fn write_posterior(&self) -> Result<()> {
310 tracing::debug!("Writing posterior parameter probabilities...");
311 let theta = &self.theta;
312 let w = &self.w;
313 let psi = &self.psi;
314
315 let posterior = posterior(psi, w)?;
317
318 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
320 Ok(of) => of,
321 Err(e) => {
322 tracing::error!("Failed to create output file: {}", e);
323 return Err(e.context("Failed to create output file"));
324 }
325 };
326
327 let mut writer = WriterBuilder::new()
329 .has_headers(true)
330 .from_writer(&outputfile.file);
331
332 writer.write_field("id")?;
334 writer.write_field("point")?;
335 theta.param_names().iter().for_each(|name| {
336 writer.write_field(name).unwrap();
337 });
338 writer.write_field("prob")?;
339 writer.write_record(None::<&[u8]>)?;
340
341 let subjects = self.data.subjects();
343 posterior
344 .matrix()
345 .row_iter()
346 .enumerate()
347 .for_each(|(i, row)| {
348 let subject = subjects.get(i).unwrap();
349 let id = subject.id();
350
351 row.iter().enumerate().for_each(|(spp, prob)| {
352 writer.write_field(id.clone()).unwrap();
353 writer.write_field(spp.to_string()).unwrap();
354
355 theta.matrix().row(spp).iter().for_each(|val| {
356 writer.write_field(val.to_string()).unwrap();
357 });
358
359 writer.write_field(prob.to_string()).unwrap();
360 writer.write_record(None::<&[u8]>).unwrap();
361 });
362 });
363
364 writer.flush()?;
365 tracing::debug!(
366 "Posterior parameters written to {:?}",
367 &outputfile.relative_path()
368 );
369
370 Ok(())
371 }
372
373 pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> {
375 tracing::debug!("Writing predictions...");
376
377 let posterior = posterior(&self.psi, &self.w)?;
378
379 let predictions = NPPredictions::calculate(
381 &self.equation,
382 &self.data,
383 self.theta.clone(),
384 &self.w,
385 &posterior,
386 idelta,
387 tad,
388 )?;
389
390 let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
392 let mut writer = WriterBuilder::new()
393 .has_headers(true)
394 .from_writer(&outputfile_pred.file);
395
396 for row in predictions.predictions() {
398 writer.serialize(row)?;
399 }
400
401 writer.flush()?;
402 tracing::debug!(
403 "Predictions written to {:?}",
404 &outputfile_pred.relative_path()
405 );
406
407 let outputfile_op = OutputFile::new(&self.settings.output().path, "op.csv")?;
409 let mut writer = WriterBuilder::new()
410 .has_headers(true)
411 .from_writer(&outputfile_op.file);
412
413 for row in predictions
415 .predictions()
416 .iter()
417 .filter(|r| r.obs().is_some())
418 {
419 writer.serialize(row)?;
420 }
421
422 writer.flush()?;
423 tracing::debug!(
424 "Observed-predicted values written to {:?}",
425 &outputfile_op.relative_path()
426 );
427
428 Ok(())
429 }
430
431 pub fn write_covs(&self) -> Result<()> {
433 tracing::debug!("Writing covariates...");
434 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
435 let mut writer = WriterBuilder::new()
436 .has_headers(true)
437 .from_writer(&outputfile.file);
438
439 let mut covariate_names = std::collections::HashSet::new();
441 for subject in self.data.subjects() {
442 for occasion in subject.occasions() {
443 let cov = occasion.covariates();
444 let covmap = cov.covariates();
445 for cov_name in covmap.keys() {
446 covariate_names.insert(cov_name.clone());
447 }
448 }
449 }
450 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
451 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
455 headers.extend(covariate_names.iter().map(|s| s.as_str()));
456 writer.write_record(&headers)?;
457
458 for subject in self.data.subjects() {
460 for occasion in subject.occasions() {
461 let cov = occasion.covariates();
462 let covmap = cov.covariates();
463
464 for event in occasion.iter() {
465 let time = match event {
466 Event::Bolus(bolus) => bolus.time(),
467 Event::Infusion(infusion) => infusion.time(),
468 Event::Observation(observation) => observation.time(),
469 };
470
471 let mut row: Vec<String> = Vec::new();
472 row.push(subject.id().clone());
473 row.push(time.to_string());
474 row.push(occasion.index().to_string());
475
476 for cov_name in &covariate_names {
478 if let Some(cov) = covmap.get(cov_name) {
479 if let Ok(value) = cov.interpolate(time) {
480 row.push(value.to_string());
481 } else {
482 row.push(String::new());
483 }
484 } else {
485 row.push(String::new());
486 }
487 }
488
489 writer.write_record(&row)?;
490 }
491 }
492 }
493
494 writer.flush()?;
495 tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
496 Ok(())
497 }
498}
499
500pub(crate) fn median(data: &[f64]) -> f64 {
501 let mut data: Vec<f64> = data.to_vec();
502 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
503
504 let size = data.len();
505 match size {
506 even if even % 2 == 0 => {
507 let fst = data.get(even / 2 - 1).unwrap();
508 let snd = data.get(even / 2).unwrap();
509 (fst + snd) / 2.0
510 }
511 odd => *data.get(odd / 2_usize).unwrap(),
512 }
513}
514
515fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
516 assert_eq!(
518 data.len(),
519 weights.len(),
520 "The length of data and weights must be the same"
521 );
522 assert!(
523 weights.iter().all(|&x| x >= 0.0),
524 "Weights must be non-negative, weights: {:?}",
525 weights
526 );
527
528 let mut weighted_data: Vec<(f64, f64)> = data
530 .iter()
531 .zip(weights.iter())
532 .map(|(&d, &w)| (d, w))
533 .collect();
534
535 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
537
538 let total_weight: f64 = weights.iter().sum();
540 let mut cumulative_sum = 0.0;
541
542 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
543 cumulative_sum += weight;
544
545 if cumulative_sum == total_weight / 2.0 {
546 if i + 1 < weighted_data.len() {
548 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
549 } else {
550 return weighted_data[i].0;
551 }
552 } else if cumulative_sum > total_weight / 2.0 {
553 return weighted_data[i].0;
554 }
555 }
556
557 unreachable!("The function should have returned a value before reaching this point.");
558}
559
560pub fn population_mean_median(
561 theta: &Array2<f64>,
562 w: &Array1<f64>,
563) -> Result<(Array1<f64>, Array1<f64>)> {
564 let w = if w.is_empty() {
565 tracing::warn!("w.len() == 0, setting all weights to 1/n");
566 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
567 } else {
568 w.clone()
569 };
570 if theta.nrows() != w.len() {
572 bail!(
573 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
574 theta.nrows(),
575 w.len()
576 );
577 }
578
579 let mut mean = Array1::zeros(theta.ncols());
580 let mut median = Array1::zeros(theta.ncols());
581
582 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
583 let col = theta.column(i).to_owned() * w.to_owned();
585 *mn = col.sum();
586
587 let ct = theta.column(i);
589 let mut params = vec![];
590 let mut weights = vec![];
591 for (ti, wi) in ct.iter().zip(w.clone()) {
592 params.push(*ti);
593 weights.push(wi);
594 }
595
596 *mdn = weighted_median(¶ms, &weights);
597 }
598
599 Ok((mean, median))
600}
601
602pub fn posterior_mean_median(
603 theta: &Array2<f64>,
604 psi: &Array2<f64>,
605 w: &Array1<f64>,
606) -> Result<(Array2<f64>, Array2<f64>)> {
607 let mut mean = Array2::zeros((0, theta.ncols()));
608 let mut median = Array2::zeros((0, theta.ncols()));
609
610 let w = if w.is_empty() {
611 tracing::warn!("w is empty, setting all weights to 1/n");
612 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
613 } else {
614 w.clone()
615 };
616
617 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
619 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
620 }
621
622 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
624 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
625 let row_w = row.to_owned() * w.to_owned();
626 let row_sum = row_w.sum();
627 let row_norm = if row_sum == 0.0 {
628 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
629 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
630 } else {
631 &row_w / row_sum
632 };
633 psi_norm.push_row(row_norm.view())?;
634 }
635 if psi_norm.iter().any(|&x| x.is_nan()) {
636 dbg!(&psi);
637 bail!("NaN values found in psi_norm");
638 };
639
640 for probs in psi_norm.axis_iter(Axis(0)) {
645 let mut post_mean: Vec<f64> = Vec::new();
646 let mut post_median: Vec<f64> = Vec::new();
647
648 for pars in theta.axis_iter(Axis(1)) {
650 let weighted_par = &probs * &pars;
652 let the_mean = weighted_par.sum();
653 post_mean.push(the_mean);
654
655 let median = weighted_median(&pars.to_vec(), &probs.to_vec());
657 post_median.push(median);
658 }
659
660 mean.push_row(Array::from(post_mean.clone()).view())?;
661 median.push_row(Array::from(post_median.clone()).view())?;
662 }
663
664 Ok((mean, median))
665}
666
667#[derive(Debug)]
669pub struct OutputFile {
670 file: File,
671 relative_path: PathBuf,
672}
673
674impl OutputFile {
675 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
676 let relative_path = Path::new(&folder).join(file_name);
677
678 if let Some(parent) = relative_path.parent() {
679 create_dir_all(parent)
680 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
681 }
682
683 let file = OpenOptions::new()
684 .write(true)
685 .create(true)
686 .truncate(true)
687 .open(&relative_path)
688 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
689
690 Ok(OutputFile {
691 file,
692 relative_path,
693 })
694 }
695
696 pub fn file(&self) -> &File {
697 &self.file
698 }
699
700 pub fn file_owned(self) -> File {
701 self.file
702 }
703
704 pub fn relative_path(&self) -> &Path {
705 &self.relative_path
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::median;
712
713 #[test]
714 fn test_median_odd() {
715 let data = vec![1.0, 3.0, 2.0];
716 assert_eq!(median(&data), 2.0);
717 }
718
719 #[test]
720 fn test_median_even() {
721 let data = vec![1.0, 2.0, 3.0, 4.0];
722 assert_eq!(median(&data), 2.5);
723 }
724
725 #[test]
726 fn test_median_single() {
727 let data = vec![42.0];
728 assert_eq!(median(&data), 42.0);
729 }
730
731 #[test]
732 fn test_median_sorted() {
733 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
734 assert_eq!(median(&data), 15.0);
735 }
736
737 #[test]
738 fn test_median_unsorted() {
739 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
740 assert_eq!(median(&data), 30.0);
741 }
742
743 #[test]
744 fn test_median_with_duplicates() {
745 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
746 assert_eq!(median(&data), 2.0);
747 }
748
749 use super::weighted_median;
750
751 #[test]
752 fn test_weighted_median_simple() {
753 let data = vec![1.0, 2.0, 3.0];
754 let weights = vec![0.2, 0.5, 0.3];
755 assert_eq!(weighted_median(&data, &weights), 2.0);
756 }
757
758 #[test]
759 fn test_weighted_median_even_weights() {
760 let data = vec![1.0, 2.0, 3.0, 4.0];
761 let weights = vec![0.25, 0.25, 0.25, 0.25];
762 assert_eq!(weighted_median(&data, &weights), 2.5);
763 }
764
765 #[test]
766 fn test_weighted_median_single_element() {
767 let data = vec![42.0];
768 let weights = vec![1.0];
769 assert_eq!(weighted_median(&data, &weights), 42.0);
770 }
771
772 #[test]
773 #[should_panic(expected = "The length of data and weights must be the same")]
774 fn test_weighted_median_mismatched_lengths() {
775 let data = vec![1.0, 2.0, 3.0];
776 let weights = vec![0.1, 0.2];
777 weighted_median(&data, &weights);
778 }
779
780 #[test]
781 fn test_weighted_median_all_same_elements() {
782 let data = vec![5.0, 5.0, 5.0, 5.0];
783 let weights = vec![0.1, 0.2, 0.3, 0.4];
784 assert_eq!(weighted_median(&data, &weights), 5.0);
785 }
786
787 #[test]
788 #[should_panic(expected = "Weights must be non-negative")]
789 fn test_weighted_median_negative_weights() {
790 let data = vec![1.0, 2.0, 3.0, 4.0];
791 let weights = vec![0.2, -0.5, 0.5, 0.8];
792 assert_eq!(weighted_median(&data, &weights), 4.0);
793 }
794
795 #[test]
796 fn test_weighted_median_unsorted_data() {
797 let data = vec![3.0, 1.0, 4.0, 2.0];
798 let weights = vec![0.1, 0.3, 0.4, 0.2];
799 assert_eq!(weighted_median(&data, &weights), 2.5);
800 }
801
802 #[test]
803 fn test_weighted_median_with_zero_weights() {
804 let data = vec![1.0, 2.0, 3.0, 4.0];
805 let weights = vec![0.0, 0.0, 1.0, 0.0];
806 assert_eq!(weighted_median(&data, &weights), 3.0);
807 }
808}