1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::posterior::Posterior;
5use crate::routines::output::predictions::NPPredictions;
6use crate::routines::settings::Settings;
7use crate::structs::psi::Psi;
8use crate::structs::theta::Theta;
9use crate::structs::weights::Weights;
10use anyhow::{bail, Context, Result};
11use csv::WriterBuilder;
12use faer::linalg::zip::IntoView;
13use faer_ext::IntoNdarray;
14use ndarray::{Array, Array1, Array2, Axis};
15use pharmsol::prelude::data::*;
16use pharmsol::prelude::simulator::Equation;
17use serde::Serialize;
18use std::fs::{create_dir_all, File, OpenOptions};
19use std::path::{Path, PathBuf};
20
21pub mod cycles;
22pub mod posterior;
23pub mod predictions;
24
25use posterior::posterior;
26
27#[derive(Debug, Serialize)]
30pub struct NPResult<E: Equation> {
31 #[serde(skip)]
32 equation: E,
33 data: Data,
34 theta: Theta,
35 psi: Psi,
36 w: Weights,
37 objf: f64,
38 cycles: usize,
39 status: Status,
40 settings: Settings,
41 cyclelog: CycleLog,
42 predictions: Option<NPPredictions>,
43 posterior: Posterior,
44}
45
46#[allow(clippy::too_many_arguments)]
47impl<E: Equation> NPResult<E> {
48 pub(crate) fn new(
52 equation: E,
53 data: Data,
54 theta: Theta,
55 psi: Psi,
56 w: Weights,
57 objf: f64,
58 cycles: usize,
59 status: Status,
60 settings: Settings,
61 cyclelog: CycleLog,
62 ) -> Result<Self> {
63 let posterior = posterior(&psi, &w)
65 .context("Failed to calculate posterior during initialization of NPResult")?;
66
67 let result = Self {
68 equation,
69 data,
70 theta,
71 psi,
72 w,
73 objf,
74 cycles,
75 status,
76 settings,
77 cyclelog,
78 predictions: None,
79 posterior,
80 };
81
82 Ok(result)
83 }
84
85 pub fn cycles(&self) -> usize {
86 self.cycles
87 }
88
89 pub fn objf(&self) -> f64 {
90 self.objf
91 }
92
93 pub fn converged(&self) -> bool {
94 self.status == Status::Stop(StopReason::Converged)
95 }
96
97 pub fn get_theta(&self) -> &Theta {
98 &self.theta
99 }
100
101 pub fn data(&self) -> &Data {
102 &self.data
103 }
104
105 pub fn cycle_log(&self) -> &CycleLog {
106 &self.cyclelog
107 }
108
109 pub fn settings(&self) -> &Settings {
110 &self.settings
111 }
112
113 pub fn psi(&self) -> &Psi {
115 &self.psi
116 }
117
118 pub fn weights(&self) -> &Weights {
120 &self.w
121 }
122
123 pub fn calculate_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
127 let predictions = NPPredictions::calculate(
128 &self.equation,
129 &self.data,
130 &self.theta,
131 &self.w,
132 &self.posterior,
133 idelta,
134 tad,
135 )?;
136 self.predictions = Some(predictions);
137 Ok(())
138 }
139
140 pub fn write_outputs(&mut self) -> Result<()> {
141 if self.settings.output().write {
142 tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
143 self.settings.write()?;
144 let idelta: f64 = self.settings.predictions().idelta;
145 let tad = self.settings.predictions().tad;
146 self.cyclelog.write(&self.settings)?;
147 self.write_theta().context("Failed to write theta")?;
148 self.write_covs().context("Failed to write covariates")?;
149 self.write_predictions(idelta, tad)
150 .context("Failed to write predictions")?;
151 self.write_posterior()
152 .context("Failed to write posterior")?;
153 }
154 Ok(())
155 }
156
157 pub fn write_obspred(&self) -> Result<()> {
159 tracing::debug!("Writing observations and predictions...");
160
161 #[derive(Debug, Clone, Serialize)]
162 struct Row {
163 id: String,
164 time: f64,
165 outeq: usize,
166 block: usize,
167 obs: Option<f64>,
168 pop_mean: f64,
169 pop_median: f64,
170 post_mean: f64,
171 post_median: f64,
172 }
173
174 let theta: Array2<f64> = self
175 .theta
176 .matrix()
177 .clone()
178 .as_mut()
179 .into_ndarray()
180 .to_owned();
181 let w: Array1<f64> = self
182 .w
183 .weights()
184 .clone()
185 .into_view()
186 .iter()
187 .cloned()
188 .collect();
189 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
190
191 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
192 .context("Failed to calculate posterior mean and median")?;
193
194 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
195 .context("Failed to calculate posterior mean and median")?;
196
197 let subjects = self.data.subjects();
198 if subjects.len() != post_mean.nrows() {
199 bail!(
200 "Number of subjects: {} and number of posterior means: {} do not match",
201 subjects.len(),
202 post_mean.nrows()
203 );
204 }
205
206 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
207 let mut writer = WriterBuilder::new()
208 .has_headers(true)
209 .from_writer(&outputfile.file);
210
211 for (i, subject) in subjects.iter().enumerate() {
212 for occasion in subject.occasions() {
213 let id = subject.id();
214 let occ = occasion.index();
215
216 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
217
218 let pop_mean_pred = self
220 .equation
221 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
222 .0
223 .get_predictions()
224 .clone();
225
226 let pop_median_pred = self
227 .equation
228 .simulate_subject(&subject, &pop_median.to_vec(), None)?
229 .0
230 .get_predictions()
231 .clone();
232
233 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
235 let post_mean_pred = self
236 .equation
237 .simulate_subject(&subject, &post_mean_spp, None)?
238 .0
239 .get_predictions()
240 .clone();
241 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
242 let post_median_pred = self
243 .equation
244 .simulate_subject(&subject, &post_median_spp, None)?
245 .0
246 .get_predictions()
247 .clone();
248 assert_eq!(
249 pop_mean_pred.len(),
250 pop_median_pred.len(),
251 "The number of predictions do not match (pop_mean vs pop_median)"
252 );
253
254 assert_eq!(
255 post_mean_pred.len(),
256 post_median_pred.len(),
257 "The number of predictions do not match (post_mean vs post_median)"
258 );
259
260 assert_eq!(
261 pop_mean_pred.len(),
262 post_mean_pred.len(),
263 "The number of predictions do not match (pop_mean vs post_mean)"
264 );
265
266 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
267 pop_mean_pred
268 .iter()
269 .zip(pop_median_pred.iter())
270 .zip(post_mean_pred.iter())
271 .zip(post_median_pred.iter())
272 {
273 let row = Row {
274 id: id.clone(),
275 time: pop_mean_pred.time(),
276 outeq: pop_mean_pred.outeq(),
277 block: occ,
278 obs: pop_mean_pred.observation(),
279 pop_mean: pop_mean_pred.prediction(),
280 pop_median: pop_median_pred.prediction(),
281 post_mean: post_mean_pred.prediction(),
282 post_median: post_median_pred.prediction(),
283 };
284 writer.serialize(row)?;
285 }
286 }
287 }
288 writer.flush()?;
289 tracing::debug!(
290 "Observations with predictions written to {:?}",
291 &outputfile.relative_path()
292 );
293 Ok(())
294 }
295
296 pub fn write_theta(&self) -> Result<()> {
299 tracing::debug!("Writing population parameter distribution...");
300
301 let theta = &self.theta;
302 let w: Vec<f64> = self
303 .w
304 .weights()
305 .clone()
306 .into_view()
307 .iter()
308 .cloned()
309 .collect();
310
311 if w.len() != theta.matrix().nrows() {
312 bail!(
313 "Number of weights ({}) and number of support points ({}) do not match.",
314 w.len(),
315 theta.matrix().nrows()
316 );
317 }
318
319 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
320 .context("Failed to create output file for theta")?;
321
322 let mut writer = WriterBuilder::new()
323 .has_headers(true)
324 .from_writer(&outputfile.file);
325
326 let mut theta_header = self.settings.parameters().names();
328 theta_header.push("prob".to_string());
329 writer.write_record(&theta_header)?;
330
331 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
333 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
334 row.push(w_val.to_string());
335 writer.write_record(&row)?;
336 }
337 writer.flush()?;
338 tracing::debug!(
339 "Population parameter distribution written to {:?}",
340 &outputfile.relative_path()
341 );
342 Ok(())
343 }
344
345 pub fn write_posterior(&self) -> Result<()> {
347 tracing::debug!("Writing posterior parameter probabilities...");
348 let theta = &self.theta;
349
350 let posterior = self.posterior.clone();
352
353 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
355 Ok(of) => of,
356 Err(e) => {
357 tracing::error!("Failed to create output file: {}", e);
358 return Err(e.context("Failed to create output file"));
359 }
360 };
361
362 let mut writer = WriterBuilder::new()
364 .has_headers(true)
365 .from_writer(&outputfile.file);
366
367 writer.write_field("id")?;
369 writer.write_field("point")?;
370 theta.param_names().iter().for_each(|name| {
371 writer.write_field(name).unwrap();
372 });
373 writer.write_field("prob")?;
374 writer.write_record(None::<&[u8]>)?;
375
376 let subjects = self.data.subjects();
378 posterior
379 .matrix()
380 .row_iter()
381 .enumerate()
382 .for_each(|(i, row)| {
383 let subject = subjects.get(i).unwrap();
384 let id = subject.id();
385
386 row.iter().enumerate().for_each(|(spp, prob)| {
387 writer.write_field(id.clone()).unwrap();
388 writer.write_field(spp.to_string()).unwrap();
389
390 theta.matrix().row(spp).iter().for_each(|val| {
391 writer.write_field(val.to_string()).unwrap();
392 });
393
394 writer.write_field(prob.to_string()).unwrap();
395 writer.write_record(None::<&[u8]>).unwrap();
396 });
397 });
398
399 writer.flush()?;
400 tracing::debug!(
401 "Posterior parameters written to {:?}",
402 &outputfile.relative_path()
403 );
404
405 Ok(())
406 }
407
408 pub fn write_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
410 tracing::debug!("Writing predictions...");
411
412 self.calculate_predictions(idelta, tad)?;
413
414 let predictions = self
415 .predictions
416 .as_ref()
417 .expect("Predictions should have been calculated, but are of type None.");
418
419 let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
421 let mut writer = WriterBuilder::new()
422 .has_headers(true)
423 .from_writer(&outputfile_pred.file);
424
425 for row in predictions.predictions() {
427 writer.serialize(row)?;
428 }
429
430 writer.flush()?;
431 tracing::debug!(
432 "Predictions written to {:?}",
433 &outputfile_pred.relative_path()
434 );
435
436 Ok(())
437 }
438
439 pub fn write_covs(&self) -> Result<()> {
441 tracing::debug!("Writing covariates...");
442 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
443 let mut writer = WriterBuilder::new()
444 .has_headers(true)
445 .from_writer(&outputfile.file);
446
447 let mut covariate_names = std::collections::HashSet::new();
449 for subject in self.data.subjects() {
450 for occasion in subject.occasions() {
451 let cov = occasion.covariates();
452 let covmap = cov.covariates();
453 for cov_name in covmap.keys() {
454 covariate_names.insert(cov_name.clone());
455 }
456 }
457 }
458 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
459 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
463 headers.extend(covariate_names.iter().map(|s| s.as_str()));
464 writer.write_record(&headers)?;
465
466 for subject in self.data.subjects() {
468 for occasion in subject.occasions() {
469 let cov = occasion.covariates();
470 let covmap = cov.covariates();
471
472 for event in occasion.iter() {
473 let time = match event {
474 Event::Bolus(bolus) => bolus.time(),
475 Event::Infusion(infusion) => infusion.time(),
476 Event::Observation(observation) => observation.time(),
477 };
478
479 let mut row: Vec<String> = Vec::new();
480 row.push(subject.id().clone());
481 row.push(time.to_string());
482 row.push(occasion.index().to_string());
483
484 for cov_name in &covariate_names {
486 if let Some(cov) = covmap.get(cov_name) {
487 if let Ok(value) = cov.interpolate(time) {
488 row.push(value.to_string());
489 } else {
490 row.push(String::new());
491 }
492 } else {
493 row.push(String::new());
494 }
495 }
496
497 writer.write_record(&row)?;
498 }
499 }
500 }
501
502 writer.flush()?;
503 tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
504 Ok(())
505 }
506}
507
508pub(crate) fn median(data: &[f64]) -> f64 {
509 let mut data: Vec<f64> = data.to_vec();
510 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
511
512 let size = data.len();
513 match size {
514 even if even % 2 == 0 => {
515 let fst = data.get(even / 2 - 1).unwrap();
516 let snd = data.get(even / 2).unwrap();
517 (fst + snd) / 2.0
518 }
519 odd => *data.get(odd / 2_usize).unwrap(),
520 }
521}
522
523fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
524 assert_eq!(
526 data.len(),
527 weights.len(),
528 "The length of data and weights must be the same"
529 );
530 assert!(
531 weights.iter().all(|&x| x >= 0.0),
532 "Weights must be non-negative, weights: {:?}",
533 weights
534 );
535
536 let mut weighted_data: Vec<(f64, f64)> = data
538 .iter()
539 .zip(weights.iter())
540 .map(|(&d, &w)| (d, w))
541 .collect();
542
543 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
545
546 let total_weight: f64 = weights.iter().sum();
548 let mut cumulative_sum = 0.0;
549
550 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
551 cumulative_sum += weight;
552
553 if cumulative_sum == total_weight / 2.0 {
554 if i + 1 < weighted_data.len() {
556 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
557 } else {
558 return weighted_data[i].0;
559 }
560 } else if cumulative_sum > total_weight / 2.0 {
561 return weighted_data[i].0;
562 }
563 }
564
565 unreachable!("The function should have returned a value before reaching this point.");
566}
567
568pub fn population_mean_median(
569 theta: &Array2<f64>,
570 w: &Array1<f64>,
571) -> Result<(Array1<f64>, Array1<f64>)> {
572 let w = if w.is_empty() {
573 tracing::warn!("w.len() == 0, setting all weights to 1/n");
574 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
575 } else {
576 w.clone()
577 };
578 if theta.nrows() != w.len() {
580 bail!(
581 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
582 theta.nrows(),
583 w.len()
584 );
585 }
586
587 let mut mean = Array1::zeros(theta.ncols());
588 let mut median = Array1::zeros(theta.ncols());
589
590 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
591 let col = theta.column(i).to_owned() * w.to_owned();
593 *mn = col.sum();
594
595 let ct = theta.column(i);
597 let mut params = vec![];
598 let mut weights = vec![];
599 for (ti, wi) in ct.iter().zip(w.clone()) {
600 params.push(*ti);
601 weights.push(wi);
602 }
603
604 *mdn = weighted_median(¶ms, &weights);
605 }
606
607 Ok((mean, median))
608}
609
610pub fn posterior_mean_median(
611 theta: &Array2<f64>,
612 psi: &Array2<f64>,
613 w: &Array1<f64>,
614) -> Result<(Array2<f64>, Array2<f64>)> {
615 let mut mean = Array2::zeros((0, theta.ncols()));
616 let mut median = Array2::zeros((0, theta.ncols()));
617
618 let w = if w.is_empty() {
619 tracing::warn!("w is empty, setting all weights to 1/n");
620 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
621 } else {
622 w.clone()
623 };
624
625 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
627 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
628 }
629
630 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
632 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
633 let row_w = row.to_owned() * w.to_owned();
634 let row_sum = row_w.sum();
635 let row_norm = if row_sum == 0.0 {
636 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
637 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
638 } else {
639 &row_w / row_sum
640 };
641 psi_norm.push_row(row_norm.view())?;
642 }
643 if psi_norm.iter().any(|&x| x.is_nan()) {
644 dbg!(&psi);
645 bail!("NaN values found in psi_norm");
646 };
647
648 for probs in psi_norm.axis_iter(Axis(0)) {
653 let mut post_mean: Vec<f64> = Vec::new();
654 let mut post_median: Vec<f64> = Vec::new();
655
656 for pars in theta.axis_iter(Axis(1)) {
658 let weighted_par = &probs * &pars;
660 let the_mean = weighted_par.sum();
661 post_mean.push(the_mean);
662
663 let median = weighted_median(&pars.to_vec(), &probs.to_vec());
665 post_median.push(median);
666 }
667
668 mean.push_row(Array::from(post_mean.clone()).view())?;
669 median.push_row(Array::from(post_median.clone()).view())?;
670 }
671
672 Ok((mean, median))
673}
674
675#[derive(Debug)]
677pub struct OutputFile {
678 file: File,
679 relative_path: PathBuf,
680}
681
682impl OutputFile {
683 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
684 let relative_path = Path::new(&folder).join(file_name);
685
686 if let Some(parent) = relative_path.parent() {
687 create_dir_all(parent)
688 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
689 }
690
691 let file = OpenOptions::new()
692 .write(true)
693 .create(true)
694 .truncate(true)
695 .open(&relative_path)
696 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
697
698 Ok(OutputFile {
699 file,
700 relative_path,
701 })
702 }
703
704 pub fn file(&self) -> &File {
705 &self.file
706 }
707
708 pub fn file_owned(self) -> File {
709 self.file
710 }
711
712 pub fn relative_path(&self) -> &Path {
713 &self.relative_path
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::median;
720
721 #[test]
722 fn test_median_odd() {
723 let data = vec![1.0, 3.0, 2.0];
724 assert_eq!(median(&data), 2.0);
725 }
726
727 #[test]
728 fn test_median_even() {
729 let data = vec![1.0, 2.0, 3.0, 4.0];
730 assert_eq!(median(&data), 2.5);
731 }
732
733 #[test]
734 fn test_median_single() {
735 let data = vec![42.0];
736 assert_eq!(median(&data), 42.0);
737 }
738
739 #[test]
740 fn test_median_sorted() {
741 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
742 assert_eq!(median(&data), 15.0);
743 }
744
745 #[test]
746 fn test_median_unsorted() {
747 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
748 assert_eq!(median(&data), 30.0);
749 }
750
751 #[test]
752 fn test_median_with_duplicates() {
753 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
754 assert_eq!(median(&data), 2.0);
755 }
756
757 use super::weighted_median;
758
759 #[test]
760 fn test_weighted_median_simple() {
761 let data = vec![1.0, 2.0, 3.0];
762 let weights = vec![0.2, 0.5, 0.3];
763 assert_eq!(weighted_median(&data, &weights), 2.0);
764 }
765
766 #[test]
767 fn test_weighted_median_even_weights() {
768 let data = vec![1.0, 2.0, 3.0, 4.0];
769 let weights = vec![0.25, 0.25, 0.25, 0.25];
770 assert_eq!(weighted_median(&data, &weights), 2.5);
771 }
772
773 #[test]
774 fn test_weighted_median_single_element() {
775 let data = vec![42.0];
776 let weights = vec![1.0];
777 assert_eq!(weighted_median(&data, &weights), 42.0);
778 }
779
780 #[test]
781 #[should_panic(expected = "The length of data and weights must be the same")]
782 fn test_weighted_median_mismatched_lengths() {
783 let data = vec![1.0, 2.0, 3.0];
784 let weights = vec![0.1, 0.2];
785 weighted_median(&data, &weights);
786 }
787
788 #[test]
789 fn test_weighted_median_all_same_elements() {
790 let data = vec![5.0, 5.0, 5.0, 5.0];
791 let weights = vec![0.1, 0.2, 0.3, 0.4];
792 assert_eq!(weighted_median(&data, &weights), 5.0);
793 }
794
795 #[test]
796 #[should_panic(expected = "Weights must be non-negative")]
797 fn test_weighted_median_negative_weights() {
798 let data = vec![1.0, 2.0, 3.0, 4.0];
799 let weights = vec![0.2, -0.5, 0.5, 0.8];
800 assert_eq!(weighted_median(&data, &weights), 4.0);
801 }
802
803 #[test]
804 fn test_weighted_median_unsorted_data() {
805 let data = vec![3.0, 1.0, 4.0, 2.0];
806 let weights = vec![0.1, 0.3, 0.4, 0.2];
807 assert_eq!(weighted_median(&data, &weights), 2.5);
808 }
809
810 #[test]
811 fn test_weighted_median_with_zero_weights() {
812 let data = vec![1.0, 2.0, 3.0, 4.0];
813 let weights = vec![0.0, 0.0, 1.0, 0.0];
814 assert_eq!(weighted_median(&data, &weights), 3.0);
815 }
816}