1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::predictions::NPPredictions;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use crate::structs::weights::Weights;
9use anyhow::{bail, Context, Result};
10use csv::WriterBuilder;
11use faer::linalg::zip::IntoView;
12use faer_ext::IntoNdarray;
13use ndarray::{Array, Array1, Array2, Axis};
14use pharmsol::prelude::data::*;
15use pharmsol::prelude::simulator::Equation;
16use serde::Serialize;
17use std::fs::{create_dir_all, File, OpenOptions};
18use std::path::{Path, PathBuf};
19
20pub mod cycles;
21pub mod posterior;
22pub mod predictions;
23
24use posterior::posterior;
25
26#[derive(Debug, Serialize)]
29pub struct NPResult<E: Equation> {
30 #[serde(skip)]
31 equation: E,
32 data: Data,
33 theta: Theta,
34 psi: Psi,
35 w: Weights,
36 objf: f64,
37 cycles: usize,
38 status: Status,
39 par_names: Vec<String>,
40 settings: Settings,
41 cyclelog: CycleLog,
42}
43
44#[allow(clippy::too_many_arguments)]
45impl<E: Equation> NPResult<E> {
46 pub fn new(
48 equation: E,
49 data: Data,
50 theta: Theta,
51 psi: Psi,
52 w: Weights,
53 objf: f64,
54 cycles: usize,
55 status: Status,
56 settings: Settings,
57 cyclelog: CycleLog,
58 ) -> Self {
59 let par_names = settings.parameters().names();
62
63 Self {
64 equation,
65 data,
66 theta,
67 psi,
68 w,
69 objf,
70 cycles,
71 status,
72 par_names,
73 settings,
74 cyclelog,
75 }
76 }
77
78 pub fn cycles(&self) -> usize {
79 self.cycles
80 }
81
82 pub fn objf(&self) -> f64 {
83 self.objf
84 }
85
86 pub fn converged(&self) -> bool {
87 self.status == Status::Stop(StopReason::Converged)
88 }
89
90 pub fn get_theta(&self) -> &Theta {
91 &self.theta
92 }
93
94 pub fn psi(&self) -> &Psi {
96 &self.psi
97 }
98
99 pub fn weights(&self) -> &Weights {
101 &self.w
102 }
103
104 pub fn write_outputs(&self) -> Result<()> {
105 if self.settings.output().write {
106 tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
107 self.settings.write()?;
108 let idelta: f64 = self.settings.predictions().idelta;
109 let tad = self.settings.predictions().tad;
110 self.cyclelog.write(&self.settings)?;
111 self.write_theta().context("Failed to write theta")?;
112 self.write_covs().context("Failed to write covariates")?;
113 self.write_predictions(idelta, tad)
114 .context("Failed to write predictions")?;
115 self.write_posterior()
116 .context("Failed to write posterior")?;
117 }
118 Ok(())
119 }
120
121 pub fn write_obspred(&self) -> Result<()> {
123 tracing::debug!("Writing observations and predictions...");
124
125 #[derive(Debug, Clone, Serialize)]
126 struct Row {
127 id: String,
128 time: f64,
129 outeq: usize,
130 block: usize,
131 obs: Option<f64>,
132 pop_mean: f64,
133 pop_median: f64,
134 post_mean: f64,
135 post_median: f64,
136 }
137
138 let theta: Array2<f64> = self
139 .theta
140 .matrix()
141 .clone()
142 .as_mut()
143 .into_ndarray()
144 .to_owned();
145 let w: Array1<f64> = self
146 .w
147 .weights()
148 .clone()
149 .into_view()
150 .iter()
151 .cloned()
152 .collect();
153 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
154
155 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
156 .context("Failed to calculate posterior mean and median")?;
157
158 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
159 .context("Failed to calculate posterior mean and median")?;
160
161 let subjects = self.data.subjects();
162 if subjects.len() != post_mean.nrows() {
163 bail!(
164 "Number of subjects: {} and number of posterior means: {} do not match",
165 subjects.len(),
166 post_mean.nrows()
167 );
168 }
169
170 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
171 let mut writer = WriterBuilder::new()
172 .has_headers(true)
173 .from_writer(&outputfile.file);
174
175 for (i, subject) in subjects.iter().enumerate() {
176 for occasion in subject.occasions() {
177 let id = subject.id();
178 let occ = occasion.index();
179
180 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
181
182 let pop_mean_pred = self
184 .equation
185 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
186 .0
187 .get_predictions()
188 .clone();
189
190 let pop_median_pred = self
191 .equation
192 .simulate_subject(&subject, &pop_median.to_vec(), None)?
193 .0
194 .get_predictions()
195 .clone();
196
197 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
199 let post_mean_pred = self
200 .equation
201 .simulate_subject(&subject, &post_mean_spp, None)?
202 .0
203 .get_predictions()
204 .clone();
205 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
206 let post_median_pred = self
207 .equation
208 .simulate_subject(&subject, &post_median_spp, None)?
209 .0
210 .get_predictions()
211 .clone();
212 assert_eq!(
213 pop_mean_pred.len(),
214 pop_median_pred.len(),
215 "The number of predictions do not match (pop_mean vs pop_median)"
216 );
217
218 assert_eq!(
219 post_mean_pred.len(),
220 post_median_pred.len(),
221 "The number of predictions do not match (post_mean vs post_median)"
222 );
223
224 assert_eq!(
225 pop_mean_pred.len(),
226 post_mean_pred.len(),
227 "The number of predictions do not match (pop_mean vs post_mean)"
228 );
229
230 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
231 pop_mean_pred
232 .iter()
233 .zip(pop_median_pred.iter())
234 .zip(post_mean_pred.iter())
235 .zip(post_median_pred.iter())
236 {
237 let row = Row {
238 id: id.clone(),
239 time: pop_mean_pred.time(),
240 outeq: pop_mean_pred.outeq(),
241 block: occ,
242 obs: pop_mean_pred.observation(),
243 pop_mean: pop_mean_pred.prediction(),
244 pop_median: pop_median_pred.prediction(),
245 post_mean: post_mean_pred.prediction(),
246 post_median: post_median_pred.prediction(),
247 };
248 writer.serialize(row)?;
249 }
250 }
251 }
252 writer.flush()?;
253 tracing::debug!(
254 "Observations with predictions written to {:?}",
255 &outputfile.relative_path()
256 );
257 Ok(())
258 }
259
260 pub fn write_theta(&self) -> Result<()> {
263 tracing::debug!("Writing population parameter distribution...");
264
265 let theta = &self.theta;
266 let w: Vec<f64> = self
267 .w
268 .weights()
269 .clone()
270 .into_view()
271 .iter()
272 .cloned()
273 .collect();
274
275 if w.len() != theta.matrix().nrows() {
276 bail!(
277 "Number of weights ({}) and number of support points ({}) do not match.",
278 w.len(),
279 theta.matrix().nrows()
280 );
281 }
282
283 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
284 .context("Failed to create output file for theta")?;
285
286 let mut writer = WriterBuilder::new()
287 .has_headers(true)
288 .from_writer(&outputfile.file);
289
290 let mut theta_header = self.par_names.clone();
292 theta_header.push("prob".to_string());
293 writer.write_record(&theta_header)?;
294
295 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
297 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
298 row.push(w_val.to_string());
299 writer.write_record(&row)?;
300 }
301 writer.flush()?;
302 tracing::debug!(
303 "Population parameter distribution written to {:?}",
304 &outputfile.relative_path()
305 );
306 Ok(())
307 }
308
309 pub fn write_posterior(&self) -> Result<()> {
311 tracing::debug!("Writing posterior parameter probabilities...");
312 let theta = &self.theta;
313 let w = &self.w;
314 let psi = &self.psi;
315
316 let posterior = posterior(psi, w)?;
318
319 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
321 Ok(of) => of,
322 Err(e) => {
323 tracing::error!("Failed to create output file: {}", e);
324 return Err(e.context("Failed to create output file"));
325 }
326 };
327
328 let mut writer = WriterBuilder::new()
330 .has_headers(true)
331 .from_writer(&outputfile.file);
332
333 writer.write_field("id")?;
335 writer.write_field("point")?;
336 theta.param_names().iter().for_each(|name| {
337 writer.write_field(name).unwrap();
338 });
339 writer.write_field("prob")?;
340 writer.write_record(None::<&[u8]>)?;
341
342 let subjects = self.data.subjects();
344 posterior
345 .matrix()
346 .row_iter()
347 .enumerate()
348 .for_each(|(i, row)| {
349 let subject = subjects.get(i).unwrap();
350 let id = subject.id();
351
352 row.iter().enumerate().for_each(|(spp, prob)| {
353 writer.write_field(id.clone()).unwrap();
354 writer.write_field(spp.to_string()).unwrap();
355
356 theta.matrix().row(spp).iter().for_each(|val| {
357 writer.write_field(val.to_string()).unwrap();
358 });
359
360 writer.write_field(prob.to_string()).unwrap();
361 writer.write_record(None::<&[u8]>).unwrap();
362 });
363 });
364
365 writer.flush()?;
366 tracing::debug!(
367 "Posterior parameters written to {:?}",
368 &outputfile.relative_path()
369 );
370
371 Ok(())
372 }
373
374 pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> {
376 tracing::debug!("Writing predictions...");
377
378 let posterior = posterior(&self.psi, &self.w)?;
379
380 let predictions = NPPredictions::calculate(
382 &self.equation,
383 &self.data,
384 self.theta.clone(),
385 &self.w,
386 &posterior,
387 idelta,
388 tad,
389 )?;
390
391 let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
393 let mut writer = WriterBuilder::new()
394 .has_headers(true)
395 .from_writer(&outputfile_pred.file);
396
397 for row in predictions.predictions() {
399 writer.serialize(row)?;
400 }
401
402 writer.flush()?;
403 tracing::debug!(
404 "Predictions written to {:?}",
405 &outputfile_pred.relative_path()
406 );
407
408 Ok(())
409 }
410
411 pub fn write_covs(&self) -> Result<()> {
413 tracing::debug!("Writing covariates...");
414 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
415 let mut writer = WriterBuilder::new()
416 .has_headers(true)
417 .from_writer(&outputfile.file);
418
419 let mut covariate_names = std::collections::HashSet::new();
421 for subject in self.data.subjects() {
422 for occasion in subject.occasions() {
423 let cov = occasion.covariates();
424 let covmap = cov.covariates();
425 for cov_name in covmap.keys() {
426 covariate_names.insert(cov_name.clone());
427 }
428 }
429 }
430 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
431 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
435 headers.extend(covariate_names.iter().map(|s| s.as_str()));
436 writer.write_record(&headers)?;
437
438 for subject in self.data.subjects() {
440 for occasion in subject.occasions() {
441 let cov = occasion.covariates();
442 let covmap = cov.covariates();
443
444 for event in occasion.iter() {
445 let time = match event {
446 Event::Bolus(bolus) => bolus.time(),
447 Event::Infusion(infusion) => infusion.time(),
448 Event::Observation(observation) => observation.time(),
449 };
450
451 let mut row: Vec<String> = Vec::new();
452 row.push(subject.id().clone());
453 row.push(time.to_string());
454 row.push(occasion.index().to_string());
455
456 for cov_name in &covariate_names {
458 if let Some(cov) = covmap.get(cov_name) {
459 if let Ok(value) = cov.interpolate(time) {
460 row.push(value.to_string());
461 } else {
462 row.push(String::new());
463 }
464 } else {
465 row.push(String::new());
466 }
467 }
468
469 writer.write_record(&row)?;
470 }
471 }
472 }
473
474 writer.flush()?;
475 tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
476 Ok(())
477 }
478}
479
480pub(crate) fn median(data: &[f64]) -> f64 {
481 let mut data: Vec<f64> = data.to_vec();
482 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
483
484 let size = data.len();
485 match size {
486 even if even % 2 == 0 => {
487 let fst = data.get(even / 2 - 1).unwrap();
488 let snd = data.get(even / 2).unwrap();
489 (fst + snd) / 2.0
490 }
491 odd => *data.get(odd / 2_usize).unwrap(),
492 }
493}
494
495fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
496 assert_eq!(
498 data.len(),
499 weights.len(),
500 "The length of data and weights must be the same"
501 );
502 assert!(
503 weights.iter().all(|&x| x >= 0.0),
504 "Weights must be non-negative, weights: {:?}",
505 weights
506 );
507
508 let mut weighted_data: Vec<(f64, f64)> = data
510 .iter()
511 .zip(weights.iter())
512 .map(|(&d, &w)| (d, w))
513 .collect();
514
515 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
517
518 let total_weight: f64 = weights.iter().sum();
520 let mut cumulative_sum = 0.0;
521
522 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
523 cumulative_sum += weight;
524
525 if cumulative_sum == total_weight / 2.0 {
526 if i + 1 < weighted_data.len() {
528 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
529 } else {
530 return weighted_data[i].0;
531 }
532 } else if cumulative_sum > total_weight / 2.0 {
533 return weighted_data[i].0;
534 }
535 }
536
537 unreachable!("The function should have returned a value before reaching this point.");
538}
539
540pub fn population_mean_median(
541 theta: &Array2<f64>,
542 w: &Array1<f64>,
543) -> Result<(Array1<f64>, Array1<f64>)> {
544 let w = if w.is_empty() {
545 tracing::warn!("w.len() == 0, setting all weights to 1/n");
546 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
547 } else {
548 w.clone()
549 };
550 if theta.nrows() != w.len() {
552 bail!(
553 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
554 theta.nrows(),
555 w.len()
556 );
557 }
558
559 let mut mean = Array1::zeros(theta.ncols());
560 let mut median = Array1::zeros(theta.ncols());
561
562 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
563 let col = theta.column(i).to_owned() * w.to_owned();
565 *mn = col.sum();
566
567 let ct = theta.column(i);
569 let mut params = vec![];
570 let mut weights = vec![];
571 for (ti, wi) in ct.iter().zip(w.clone()) {
572 params.push(*ti);
573 weights.push(wi);
574 }
575
576 *mdn = weighted_median(¶ms, &weights);
577 }
578
579 Ok((mean, median))
580}
581
582pub fn posterior_mean_median(
583 theta: &Array2<f64>,
584 psi: &Array2<f64>,
585 w: &Array1<f64>,
586) -> Result<(Array2<f64>, Array2<f64>)> {
587 let mut mean = Array2::zeros((0, theta.ncols()));
588 let mut median = Array2::zeros((0, theta.ncols()));
589
590 let w = if w.is_empty() {
591 tracing::warn!("w is empty, setting all weights to 1/n");
592 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
593 } else {
594 w.clone()
595 };
596
597 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
599 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
600 }
601
602 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
604 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
605 let row_w = row.to_owned() * w.to_owned();
606 let row_sum = row_w.sum();
607 let row_norm = if row_sum == 0.0 {
608 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
609 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
610 } else {
611 &row_w / row_sum
612 };
613 psi_norm.push_row(row_norm.view())?;
614 }
615 if psi_norm.iter().any(|&x| x.is_nan()) {
616 dbg!(&psi);
617 bail!("NaN values found in psi_norm");
618 };
619
620 for probs in psi_norm.axis_iter(Axis(0)) {
625 let mut post_mean: Vec<f64> = Vec::new();
626 let mut post_median: Vec<f64> = Vec::new();
627
628 for pars in theta.axis_iter(Axis(1)) {
630 let weighted_par = &probs * &pars;
632 let the_mean = weighted_par.sum();
633 post_mean.push(the_mean);
634
635 let median = weighted_median(&pars.to_vec(), &probs.to_vec());
637 post_median.push(median);
638 }
639
640 mean.push_row(Array::from(post_mean.clone()).view())?;
641 median.push_row(Array::from(post_median.clone()).view())?;
642 }
643
644 Ok((mean, median))
645}
646
647#[derive(Debug)]
649pub struct OutputFile {
650 file: File,
651 relative_path: PathBuf,
652}
653
654impl OutputFile {
655 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
656 let relative_path = Path::new(&folder).join(file_name);
657
658 if let Some(parent) = relative_path.parent() {
659 create_dir_all(parent)
660 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
661 }
662
663 let file = OpenOptions::new()
664 .write(true)
665 .create(true)
666 .truncate(true)
667 .open(&relative_path)
668 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
669
670 Ok(OutputFile {
671 file,
672 relative_path,
673 })
674 }
675
676 pub fn file(&self) -> &File {
677 &self.file
678 }
679
680 pub fn file_owned(self) -> File {
681 self.file
682 }
683
684 pub fn relative_path(&self) -> &Path {
685 &self.relative_path
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::median;
692
693 #[test]
694 fn test_median_odd() {
695 let data = vec![1.0, 3.0, 2.0];
696 assert_eq!(median(&data), 2.0);
697 }
698
699 #[test]
700 fn test_median_even() {
701 let data = vec![1.0, 2.0, 3.0, 4.0];
702 assert_eq!(median(&data), 2.5);
703 }
704
705 #[test]
706 fn test_median_single() {
707 let data = vec![42.0];
708 assert_eq!(median(&data), 42.0);
709 }
710
711 #[test]
712 fn test_median_sorted() {
713 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
714 assert_eq!(median(&data), 15.0);
715 }
716
717 #[test]
718 fn test_median_unsorted() {
719 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
720 assert_eq!(median(&data), 30.0);
721 }
722
723 #[test]
724 fn test_median_with_duplicates() {
725 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
726 assert_eq!(median(&data), 2.0);
727 }
728
729 use super::weighted_median;
730
731 #[test]
732 fn test_weighted_median_simple() {
733 let data = vec![1.0, 2.0, 3.0];
734 let weights = vec![0.2, 0.5, 0.3];
735 assert_eq!(weighted_median(&data, &weights), 2.0);
736 }
737
738 #[test]
739 fn test_weighted_median_even_weights() {
740 let data = vec![1.0, 2.0, 3.0, 4.0];
741 let weights = vec![0.25, 0.25, 0.25, 0.25];
742 assert_eq!(weighted_median(&data, &weights), 2.5);
743 }
744
745 #[test]
746 fn test_weighted_median_single_element() {
747 let data = vec![42.0];
748 let weights = vec![1.0];
749 assert_eq!(weighted_median(&data, &weights), 42.0);
750 }
751
752 #[test]
753 #[should_panic(expected = "The length of data and weights must be the same")]
754 fn test_weighted_median_mismatched_lengths() {
755 let data = vec![1.0, 2.0, 3.0];
756 let weights = vec![0.1, 0.2];
757 weighted_median(&data, &weights);
758 }
759
760 #[test]
761 fn test_weighted_median_all_same_elements() {
762 let data = vec![5.0, 5.0, 5.0, 5.0];
763 let weights = vec![0.1, 0.2, 0.3, 0.4];
764 assert_eq!(weighted_median(&data, &weights), 5.0);
765 }
766
767 #[test]
768 #[should_panic(expected = "Weights must be non-negative")]
769 fn test_weighted_median_negative_weights() {
770 let data = vec![1.0, 2.0, 3.0, 4.0];
771 let weights = vec![0.2, -0.5, 0.5, 0.8];
772 assert_eq!(weighted_median(&data, &weights), 4.0);
773 }
774
775 #[test]
776 fn test_weighted_median_unsorted_data() {
777 let data = vec![3.0, 1.0, 4.0, 2.0];
778 let weights = vec![0.1, 0.3, 0.4, 0.2];
779 assert_eq!(weighted_median(&data, &weights), 2.5);
780 }
781
782 #[test]
783 fn test_weighted_median_with_zero_weights() {
784 let data = vec![1.0, 2.0, 3.0, 4.0];
785 let weights = vec![0.0, 0.0, 1.0, 0.0];
786 assert_eq!(weighted_median(&data, &weights), 3.0);
787 }
788}