1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::posterior::Posterior;
5use crate::routines::output::predictions::NPPredictions;
6use crate::routines::settings::Settings;
7use crate::structs::psi::Psi;
8use crate::structs::theta::Theta;
9use crate::structs::weights::Weights;
10use anyhow::{bail, Context, Result};
11use csv::WriterBuilder;
12use ndarray::{Array, Array1, Array2, Axis};
13use pharmsol::prelude::data::*;
14use pharmsol::prelude::simulator::Equation;
15use serde::Serialize;
16use std::fs::{create_dir_all, File, OpenOptions};
17use std::path::{Path, PathBuf};
18
19pub mod cycles;
20pub mod posterior;
21pub mod predictions;
22
23use posterior::posterior;
24
25#[derive(Debug, Serialize)]
28pub struct NPResult<E: Equation> {
29 #[serde(skip)]
30 equation: E,
31 data: Data,
32 theta: Theta,
33 psi: Psi,
34 w: Weights,
35 objf: f64,
36 cycles: usize,
37 status: Status,
38 settings: Settings,
39 cyclelog: CycleLog,
40 predictions: Option<NPPredictions>,
41 posterior: Posterior,
42}
43
44#[allow(clippy::too_many_arguments)]
45impl<E: Equation> NPResult<E> {
46 pub(crate) fn new(
50 equation: E,
51 data: Data,
52 theta: Theta,
53 psi: Psi,
54 w: Weights,
55 objf: f64,
56 cycles: usize,
57 status: Status,
58 settings: Settings,
59 cyclelog: CycleLog,
60 ) -> Result<Self> {
61 let posterior = posterior(&psi, &w)
63 .context("Failed to calculate posterior during initialization of NPResult")?;
64
65 let result = Self {
66 equation,
67 data,
68 theta,
69 psi,
70 w,
71 objf,
72 cycles,
73 status,
74 settings,
75 cyclelog,
76 predictions: None,
77 posterior,
78 };
79
80 Ok(result)
81 }
82
83 pub fn cycles(&self) -> usize {
84 self.cycles
85 }
86
87 pub fn objf(&self) -> f64 {
88 self.objf
89 }
90
91 pub fn converged(&self) -> bool {
92 self.status == Status::Stop(StopReason::Converged)
93 }
94
95 pub fn get_theta(&self) -> &Theta {
96 &self.theta
97 }
98
99 pub fn data(&self) -> &Data {
100 &self.data
101 }
102
103 pub fn cycle_log(&self) -> &CycleLog {
104 &self.cyclelog
105 }
106
107 pub fn settings(&self) -> &Settings {
108 &self.settings
109 }
110
111 pub fn psi(&self) -> &Psi {
113 &self.psi
114 }
115
116 pub fn weights(&self) -> &Weights {
118 &self.w
119 }
120
121 pub fn calculate_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
125 let predictions = NPPredictions::calculate(
126 &self.equation,
127 &self.data,
128 &self.theta,
129 &self.w,
130 &self.posterior,
131 idelta,
132 tad,
133 )?;
134 self.predictions = Some(predictions);
135 Ok(())
136 }
137
138 pub fn write_outputs(&mut self) -> Result<()> {
139 if self.settings.output().write {
140 tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
141 self.settings.write()?;
142 let idelta: f64 = self.settings.predictions().idelta;
143 let tad = self.settings.predictions().tad;
144 self.cyclelog.write(&self.settings)?;
145 self.write_theta().context("Failed to write theta")?;
146 self.write_covs().context("Failed to write covariates")?;
147 self.write_predictions(idelta, tad)
148 .context("Failed to write predictions")?;
149 self.write_posterior()
150 .context("Failed to write posterior")?;
151 }
152 Ok(())
153 }
154
155 pub fn write_obspred(&self) -> Result<()> {
157 tracing::debug!("Writing observations and predictions...");
158
159 #[derive(Debug, Clone, Serialize)]
160 struct Row {
161 id: String,
162 time: f64,
163 outeq: usize,
164 block: usize,
165 obs: Option<f64>,
166 pop_mean: f64,
167 pop_median: f64,
168 post_mean: f64,
169 post_median: f64,
170 }
171
172 let tm = self.theta.matrix();
173 let theta = Array2::from_shape_fn((tm.nrows(), tm.ncols()), |(i, j)| tm[(i, j)]);
174 let w: Array1<f64> = self.w.iter().collect();
175 let pm = self.psi.matrix();
176 let psi = Array2::from_shape_fn((pm.nrows(), pm.ncols()), |(i, j)| pm[(i, j)]);
177
178 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
179 .context("Failed to calculate posterior mean and median")?;
180
181 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
182 .context("Failed to calculate posterior mean and median")?;
183
184 let subjects = self.data.subjects();
185 if subjects.len() != post_mean.nrows() {
186 bail!(
187 "Number of subjects: {} and number of posterior means: {} do not match",
188 subjects.len(),
189 post_mean.nrows()
190 );
191 }
192
193 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
194 let mut writer = WriterBuilder::new()
195 .has_headers(true)
196 .from_writer(&outputfile.file);
197
198 for (i, subject) in subjects.iter().enumerate() {
199 for occasion in subject.occasions() {
200 let id = subject.id();
201 let occ = occasion.index();
202
203 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
204
205 let pop_mean_pred = self
207 .equation
208 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
209 .0
210 .get_predictions()
211 .clone();
212
213 let pop_median_pred = self
214 .equation
215 .simulate_subject(&subject, &pop_median.to_vec(), None)?
216 .0
217 .get_predictions()
218 .clone();
219
220 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
222 let post_mean_pred = self
223 .equation
224 .simulate_subject(&subject, &post_mean_spp, None)?
225 .0
226 .get_predictions()
227 .clone();
228 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
229 let post_median_pred = self
230 .equation
231 .simulate_subject(&subject, &post_median_spp, None)?
232 .0
233 .get_predictions()
234 .clone();
235 assert_eq!(
236 pop_mean_pred.len(),
237 pop_median_pred.len(),
238 "The number of predictions do not match (pop_mean vs pop_median)"
239 );
240
241 assert_eq!(
242 post_mean_pred.len(),
243 post_median_pred.len(),
244 "The number of predictions do not match (post_mean vs post_median)"
245 );
246
247 assert_eq!(
248 pop_mean_pred.len(),
249 post_mean_pred.len(),
250 "The number of predictions do not match (pop_mean vs post_mean)"
251 );
252
253 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
254 pop_mean_pred
255 .iter()
256 .zip(pop_median_pred.iter())
257 .zip(post_mean_pred.iter())
258 .zip(post_median_pred.iter())
259 {
260 let row = Row {
261 id: id.clone(),
262 time: pop_mean_pred.time(),
263 outeq: pop_mean_pred.outeq(),
264 block: occ,
265 obs: pop_mean_pred.observation(),
266 pop_mean: pop_mean_pred.prediction(),
267 pop_median: pop_median_pred.prediction(),
268 post_mean: post_mean_pred.prediction(),
269 post_median: post_median_pred.prediction(),
270 };
271 writer.serialize(row)?;
272 }
273 }
274 }
275 writer.flush()?;
276 tracing::debug!(
277 "Observations with predictions written to {:?}",
278 &outputfile.relative_path()
279 );
280 Ok(())
281 }
282
283 pub fn write_theta(&self) -> Result<()> {
286 tracing::debug!("Writing population parameter distribution...");
287
288 let theta = &self.theta;
289 let w: Vec<f64> = self.w.to_vec();
290
291 if w.len() != theta.matrix().nrows() {
292 bail!(
293 "Number of weights ({}) and number of support points ({}) do not match.",
294 w.len(),
295 theta.matrix().nrows()
296 );
297 }
298
299 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
300 .context("Failed to create output file for theta")?;
301
302 let mut writer = WriterBuilder::new()
303 .has_headers(true)
304 .from_writer(&outputfile.file);
305
306 let mut theta_header = self.settings.parameters().names();
308 theta_header.push("prob".to_string());
309 writer.write_record(&theta_header)?;
310
311 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
313 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
314 row.push(w_val.to_string());
315 writer.write_record(&row)?;
316 }
317 writer.flush()?;
318 tracing::debug!(
319 "Population parameter distribution written to {:?}",
320 &outputfile.relative_path()
321 );
322 Ok(())
323 }
324
325 pub fn write_posterior(&self) -> Result<()> {
327 tracing::debug!("Writing posterior parameter probabilities...");
328 let theta = &self.theta;
329
330 let posterior = self.posterior.clone();
332
333 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
335 Ok(of) => of,
336 Err(e) => {
337 tracing::error!("Failed to create output file: {}", e);
338 return Err(e.context("Failed to create output file"));
339 }
340 };
341
342 let mut writer = WriterBuilder::new()
344 .has_headers(true)
345 .from_writer(&outputfile.file);
346
347 writer.write_field("id")?;
349 writer.write_field("point")?;
350 theta.param_names().iter().for_each(|name| {
351 writer.write_field(name).unwrap();
352 });
353 writer.write_field("prob")?;
354 writer.write_record(None::<&[u8]>)?;
355
356 let subjects = self.data.subjects();
358 posterior
359 .matrix()
360 .row_iter()
361 .enumerate()
362 .for_each(|(i, row)| {
363 let subject = subjects.get(i).unwrap();
364 let id = subject.id();
365
366 row.iter().enumerate().for_each(|(spp, prob)| {
367 writer.write_field(id.clone()).unwrap();
368 writer.write_field(spp.to_string()).unwrap();
369
370 theta.matrix().row(spp).iter().for_each(|val| {
371 writer.write_field(val.to_string()).unwrap();
372 });
373
374 writer.write_field(prob.to_string()).unwrap();
375 writer.write_record(None::<&[u8]>).unwrap();
376 });
377 });
378
379 writer.flush()?;
380 tracing::debug!(
381 "Posterior parameters written to {:?}",
382 &outputfile.relative_path()
383 );
384
385 Ok(())
386 }
387
388 pub fn write_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
390 tracing::debug!("Writing predictions...");
391
392 self.calculate_predictions(idelta, tad)?;
393
394 let predictions = self
395 .predictions
396 .as_ref()
397 .expect("Predictions should have been calculated, but are of type None.");
398
399 let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
401 let mut writer = WriterBuilder::new()
402 .has_headers(true)
403 .from_writer(&outputfile_pred.file);
404
405 for row in predictions.predictions() {
407 writer.serialize(row)?;
408 }
409
410 writer.flush()?;
411 tracing::debug!(
412 "Predictions written to {:?}",
413 &outputfile_pred.relative_path()
414 );
415
416 Ok(())
417 }
418
419 pub fn write_covs(&self) -> Result<()> {
421 tracing::debug!("Writing covariates...");
422 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
423 let mut writer = WriterBuilder::new()
424 .has_headers(true)
425 .from_writer(&outputfile.file);
426
427 let mut covariate_names = std::collections::HashSet::new();
429 for subject in self.data.subjects() {
430 for occasion in subject.occasions() {
431 let cov = occasion.covariates();
432 let covmap = cov.covariates();
433 for cov_name in covmap.keys() {
434 covariate_names.insert(cov_name.clone());
435 }
436 }
437 }
438 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
439 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
443 headers.extend(covariate_names.iter().map(|s| s.as_str()));
444 writer.write_record(&headers)?;
445
446 for subject in self.data.subjects() {
448 for occasion in subject.occasions() {
449 let cov = occasion.covariates();
450 let covmap = cov.covariates();
451
452 for event in occasion.iter() {
453 let time = match event {
454 Event::Bolus(bolus) => bolus.time(),
455 Event::Infusion(infusion) => infusion.time(),
456 Event::Observation(observation) => observation.time(),
457 };
458
459 let mut row: Vec<String> = Vec::new();
460 row.push(subject.id().clone());
461 row.push(time.to_string());
462 row.push(occasion.index().to_string());
463
464 for cov_name in &covariate_names {
466 if let Some(cov) = covmap.get(cov_name) {
467 if let Ok(value) = cov.interpolate(time) {
468 row.push(value.to_string());
469 } else {
470 row.push(String::new());
471 }
472 } else {
473 row.push(String::new());
474 }
475 }
476
477 writer.write_record(&row)?;
478 }
479 }
480 }
481
482 writer.flush()?;
483 tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
484 Ok(())
485 }
486}
487
488pub(crate) fn median(data: &[f64]) -> f64 {
489 let mut data: Vec<f64> = data.to_vec();
490 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
491
492 let size = data.len();
493 match size {
494 even if even % 2 == 0 => {
495 let fst = data.get(even / 2 - 1).unwrap();
496 let snd = data.get(even / 2).unwrap();
497 (fst + snd) / 2.0
498 }
499 odd => *data.get(odd / 2_usize).unwrap(),
500 }
501}
502
503fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
504 assert_eq!(
506 data.len(),
507 weights.len(),
508 "The length of data and weights must be the same"
509 );
510 assert!(
511 weights.iter().all(|&x| x >= 0.0),
512 "Weights must be non-negative, weights: {:?}",
513 weights
514 );
515
516 let mut weighted_data: Vec<(f64, f64)> = data
518 .iter()
519 .zip(weights.iter())
520 .map(|(&d, &w)| (d, w))
521 .collect();
522
523 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
525
526 let total_weight: f64 = weights.iter().sum();
528 let mut cumulative_sum = 0.0;
529
530 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
531 cumulative_sum += weight;
532
533 if cumulative_sum == total_weight / 2.0 {
534 if i + 1 < weighted_data.len() {
536 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
537 } else {
538 return weighted_data[i].0;
539 }
540 } else if cumulative_sum > total_weight / 2.0 {
541 return weighted_data[i].0;
542 }
543 }
544
545 unreachable!("The function should have returned a value before reaching this point.");
546}
547
548pub fn population_mean_median(
549 theta: &Array2<f64>,
550 w: &Array1<f64>,
551) -> Result<(Array1<f64>, Array1<f64>)> {
552 let w = if w.is_empty() {
553 tracing::warn!("w.len() == 0, setting all weights to 1/n");
554 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
555 } else {
556 w.clone()
557 };
558 if theta.nrows() != w.len() {
560 bail!(
561 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
562 theta.nrows(),
563 w.len()
564 );
565 }
566
567 let mut mean = Array1::zeros(theta.ncols());
568 let mut median = Array1::zeros(theta.ncols());
569
570 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
571 let col = theta.column(i).to_owned() * w.to_owned();
573 *mn = col.sum();
574
575 let ct = theta.column(i);
577 let mut params = vec![];
578 let mut weights = vec![];
579 for (ti, wi) in ct.iter().zip(w.clone()) {
580 params.push(*ti);
581 weights.push(wi);
582 }
583
584 *mdn = weighted_median(¶ms, &weights);
585 }
586
587 Ok((mean, median))
588}
589
590pub fn posterior_mean_median(
591 theta: &Array2<f64>,
592 psi: &Array2<f64>,
593 w: &Array1<f64>,
594) -> Result<(Array2<f64>, Array2<f64>)> {
595 let mut mean = Array2::zeros((0, theta.ncols()));
596 let mut median = Array2::zeros((0, theta.ncols()));
597
598 let w = if w.is_empty() {
599 tracing::warn!("w is empty, setting all weights to 1/n");
600 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
601 } else {
602 w.clone()
603 };
604
605 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
607 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
608 }
609
610 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
612 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
613 let row_w = row.to_owned() * w.to_owned();
614 let row_sum = row_w.sum();
615 let row_norm = if row_sum == 0.0 {
616 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
617 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
618 } else {
619 &row_w / row_sum
620 };
621 psi_norm.push_row(row_norm.view())?;
622 }
623 if psi_norm.iter().any(|&x| x.is_nan()) {
624 dbg!(&psi);
625 bail!("NaN values found in psi_norm");
626 };
627
628 for probs in psi_norm.axis_iter(Axis(0)) {
633 let mut post_mean: Vec<f64> = Vec::new();
634 let mut post_median: Vec<f64> = Vec::new();
635
636 for pars in theta.axis_iter(Axis(1)) {
638 let weighted_par = &probs * &pars;
640 let the_mean = weighted_par.sum();
641 post_mean.push(the_mean);
642
643 let median = weighted_median(&pars.to_vec(), &probs.to_vec());
645 post_median.push(median);
646 }
647
648 mean.push_row(Array::from(post_mean.clone()).view())?;
649 median.push_row(Array::from(post_median.clone()).view())?;
650 }
651
652 Ok((mean, median))
653}
654
655#[derive(Debug)]
657pub struct OutputFile {
658 file: File,
659 relative_path: PathBuf,
660}
661
662impl OutputFile {
663 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
664 let relative_path = Path::new(&folder).join(file_name);
665
666 if let Some(parent) = relative_path.parent() {
667 create_dir_all(parent)
668 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
669 }
670
671 let file = OpenOptions::new()
672 .write(true)
673 .create(true)
674 .truncate(true)
675 .open(&relative_path)
676 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
677
678 Ok(OutputFile {
679 file,
680 relative_path,
681 })
682 }
683
684 pub fn file(&self) -> &File {
685 &self.file
686 }
687
688 pub fn file_owned(self) -> File {
689 self.file
690 }
691
692 pub fn relative_path(&self) -> &Path {
693 &self.relative_path
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::median;
700
701 #[test]
702 fn test_median_odd() {
703 let data = vec![1.0, 3.0, 2.0];
704 assert_eq!(median(&data), 2.0);
705 }
706
707 #[test]
708 fn test_median_even() {
709 let data = vec![1.0, 2.0, 3.0, 4.0];
710 assert_eq!(median(&data), 2.5);
711 }
712
713 #[test]
714 fn test_median_single() {
715 let data = vec![42.0];
716 assert_eq!(median(&data), 42.0);
717 }
718
719 #[test]
720 fn test_median_sorted() {
721 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
722 assert_eq!(median(&data), 15.0);
723 }
724
725 #[test]
726 fn test_median_unsorted() {
727 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
728 assert_eq!(median(&data), 30.0);
729 }
730
731 #[test]
732 fn test_median_with_duplicates() {
733 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
734 assert_eq!(median(&data), 2.0);
735 }
736
737 use super::weighted_median;
738
739 #[test]
740 fn test_weighted_median_simple() {
741 let data = vec![1.0, 2.0, 3.0];
742 let weights = vec![0.2, 0.5, 0.3];
743 assert_eq!(weighted_median(&data, &weights), 2.0);
744 }
745
746 #[test]
747 fn test_weighted_median_even_weights() {
748 let data = vec![1.0, 2.0, 3.0, 4.0];
749 let weights = vec![0.25, 0.25, 0.25, 0.25];
750 assert_eq!(weighted_median(&data, &weights), 2.5);
751 }
752
753 #[test]
754 fn test_weighted_median_single_element() {
755 let data = vec![42.0];
756 let weights = vec![1.0];
757 assert_eq!(weighted_median(&data, &weights), 42.0);
758 }
759
760 #[test]
761 #[should_panic(expected = "The length of data and weights must be the same")]
762 fn test_weighted_median_mismatched_lengths() {
763 let data = vec![1.0, 2.0, 3.0];
764 let weights = vec![0.1, 0.2];
765 weighted_median(&data, &weights);
766 }
767
768 #[test]
769 fn test_weighted_median_all_same_elements() {
770 let data = vec![5.0, 5.0, 5.0, 5.0];
771 let weights = vec![0.1, 0.2, 0.3, 0.4];
772 assert_eq!(weighted_median(&data, &weights), 5.0);
773 }
774
775 #[test]
776 #[should_panic(expected = "Weights must be non-negative")]
777 fn test_weighted_median_negative_weights() {
778 let data = vec![1.0, 2.0, 3.0, 4.0];
779 let weights = vec![0.2, -0.5, 0.5, 0.8];
780 assert_eq!(weighted_median(&data, &weights), 4.0);
781 }
782
783 #[test]
784 fn test_weighted_median_unsorted_data() {
785 let data = vec![3.0, 1.0, 4.0, 2.0];
786 let weights = vec![0.1, 0.3, 0.4, 0.2];
787 assert_eq!(weighted_median(&data, &weights), 2.5);
788 }
789
790 #[test]
791 fn test_weighted_median_with_zero_weights() {
792 let data = vec![1.0, 2.0, 3.0, 4.0];
793 let weights = vec![0.0, 0.0, 1.0, 0.0];
794 assert_eq!(weighted_median(&data, &weights), 3.0);
795 }
796}