1use crate::prelude::*;
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use csv::WriterBuilder;
6use faer::linalg::zip::IntoView;
7use faer_ext::IntoNdarray;
8use ndarray::{Array, Array1, Array2, Axis};
9use pharmsol::prelude::data::*;
10use pharmsol::prelude::simulator::Equation;
11use serde::Serialize;
12use crate::routines::settings::Settings;
14use faer::{Col, Mat};
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22 equation: E,
23 data: Data,
24 theta: Theta,
25 psi: Psi,
26 w: Col<f64>,
27 objf: f64,
28 cycles: usize,
29 converged: bool,
30 par_names: Vec<String>,
31 settings: Settings,
32 cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37 pub fn new(
39 equation: E,
40 data: Data,
41 theta: Theta,
42 psi: Psi,
43 w: Col<f64>,
44 objf: f64,
45 cycles: usize,
46 converged: bool,
47 settings: Settings,
48 cyclelog: CycleLog,
49 ) -> Self {
50 let par_names = settings.parameters().names();
53
54 Self {
55 equation,
56 data,
57 theta,
58 psi,
59 w,
60 objf,
61 cycles,
62 converged,
63 par_names,
64 settings,
65 cyclelog,
66 }
67 }
68
69 pub fn cycles(&self) -> usize {
70 self.cycles
71 }
72
73 pub fn objf(&self) -> f64 {
74 self.objf
75 }
76
77 pub fn converged(&self) -> bool {
78 self.converged
79 }
80
81 pub fn get_theta(&self) -> &Theta {
82 &self.theta
83 }
84
85 pub fn get_psi(&self) -> &Psi {
86 &self.psi
87 }
88
89 pub fn get_w(&self) -> &Col<f64> {
90 &self.w
91 }
92
93 pub fn write_outputs(&self) -> Result<()> {
94 if self.settings.output().write {
95 self.settings.write()?;
96 let idelta: f64 = self.settings.predictions().idelta;
97 let tad = self.settings.predictions().tad;
98 self.cyclelog.write(&self.settings)?;
99 self.write_obs().context("Failed to write observations")?;
100 self.write_theta().context("Failed to write theta")?;
101 self.write_obspred()
102 .context("Failed to write observed-predicted file")?;
103 self.write_pred(idelta, tad)
104 .context("Failed to write predictions")?;
105 self.write_covs().context("Failed to write covariates")?;
106 if !(self.w.nrows() == 0) {
107 self.write_posterior()
108 .context("Failed to write posterior")?;
109 }
110 }
111 Ok(())
112 }
113
114 pub fn write_obspred(&self) -> Result<()> {
116 tracing::debug!("Writing observations and predictions...");
117
118 #[derive(Debug, Clone, Serialize)]
119 struct Row {
120 id: String,
121 time: f64,
122 outeq: usize,
123 block: usize,
124 obs: f64,
125 pop_mean: f64,
126 pop_median: f64,
127 post_mean: f64,
128 post_median: f64,
129 }
130
131 let theta: Array2<f64> = self
132 .theta
133 .matrix()
134 .clone()
135 .as_mut()
136 .into_ndarray()
137 .to_owned();
138 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
139 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
140
141 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
142 .context("Failed to calculate posterior mean and median")?;
143
144 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
145 .context("Failed to calculate posterior mean and median")?;
146
147 let subjects = self.data.get_subjects();
148 if subjects.len() != post_mean.nrows() {
149 bail!(
150 "Number of subjects: {} and number of posterior means: {} do not match",
151 subjects.len(),
152 post_mean.nrows()
153 );
154 }
155
156 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
157 let mut writer = WriterBuilder::new()
158 .has_headers(true)
159 .from_writer(&outputfile.file);
160
161 for (i, subject) in subjects.iter().enumerate() {
162 for occasion in subject.occasions() {
163 let id = subject.id();
164 let occ = occasion.index();
165
166 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
167
168 let pop_mean_pred = self
170 .equation
171 .simulate_subject(&subject, &pop_mean.to_vec(), None)
172 .0
173 .get_predictions()
174 .clone();
175
176 let pop_median_pred = self
177 .equation
178 .simulate_subject(&subject, &pop_median.to_vec(), None)
179 .0
180 .get_predictions()
181 .clone();
182
183 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
185 let post_mean_pred = self
186 .equation
187 .simulate_subject(&subject, &post_mean_spp, None)
188 .0
189 .get_predictions()
190 .clone();
191 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
192 let post_median_pred = self
193 .equation
194 .simulate_subject(&subject, &post_median_spp, None)
195 .0
196 .get_predictions()
197 .clone();
198 assert_eq!(
199 pop_mean_pred.len(),
200 pop_median_pred.len(),
201 "The number of predictions do not match (pop_mean vs pop_median)"
202 );
203
204 assert_eq!(
205 post_mean_pred.len(),
206 post_median_pred.len(),
207 "The number of predictions do not match (post_mean vs post_median)"
208 );
209
210 assert_eq!(
211 pop_mean_pred.len(),
212 post_mean_pred.len(),
213 "The number of predictions do not match (pop_mean vs post_mean)"
214 );
215
216 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
217 pop_mean_pred
218 .iter()
219 .zip(pop_median_pred.iter())
220 .zip(post_mean_pred.iter())
221 .zip(post_median_pred.iter())
222 {
223 let row = Row {
224 id: id.clone(),
225 time: pop_mean_pred.time(),
226 outeq: pop_mean_pred.outeq(),
227 block: occ,
228 obs: pop_mean_pred.observation(),
229 pop_mean: pop_mean_pred.prediction(),
230 pop_median: pop_median_pred.prediction(),
231 post_mean: post_mean_pred.prediction(),
232 post_median: post_median_pred.prediction(),
233 };
234 writer.serialize(row)?;
235 }
236 }
237 }
238 writer.flush()?;
239 tracing::info!(
240 "Observations with predictions written to {:?}",
241 &outputfile.get_relative_path()
242 );
243 Ok(())
244 }
245
246 pub fn write_theta(&self) -> Result<()> {
249 tracing::debug!("Writing population parameter distribution...");
250
251 let theta = &self.theta;
252 let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
253 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
261 .context("Failed to create output file for theta")?;
262
263 let mut writer = WriterBuilder::new()
264 .has_headers(true)
265 .from_writer(&outputfile.file);
266
267 let mut theta_header = self.par_names.clone();
269 theta_header.push("prob".to_string());
270 writer.write_record(&theta_header)?;
271
272 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
274 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
275 row.push(w_val.to_string());
276 writer.write_record(&row)?;
277 }
278 writer.flush()?;
279 tracing::info!(
280 "Population parameter distribution written to {:?}",
281 &outputfile.get_relative_path()
282 );
283 Ok(())
284 }
285
286 pub fn write_posterior(&self) -> Result<()> {
288 tracing::debug!("Writing posterior parameter probabilities...");
289 let theta = &self.theta;
290 let w = &self.w;
291 let psi = &self.psi;
292
293 let posterior = posterior(psi, w)?;
295
296 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
298 Ok(of) => of,
299 Err(e) => {
300 tracing::error!("Failed to create output file: {}", e);
301 return Err(e.context("Failed to create output file"));
302 }
303 };
304
305 let mut writer = WriterBuilder::new()
307 .has_headers(true)
308 .from_writer(&outputfile.file);
309
310 writer.write_field("id")?;
312 writer.write_field("point")?;
313 theta.param_names().iter().for_each(|name| {
314 writer.write_field(name).unwrap();
315 });
316 writer.write_field("prob")?;
317 writer.write_record(None::<&[u8]>)?;
318
319 let subjects = self.data.get_subjects();
321 posterior.row_iter().enumerate().for_each(|(i, row)| {
322 let subject = subjects.get(i).unwrap();
323 let id = subject.id();
324
325 row.iter().enumerate().for_each(|(spp, prob)| {
326 writer.write_field(id.clone()).unwrap();
327 writer.write_field(i.to_string()).unwrap();
328
329 theta
330 .matrix()
331 .row(spp)
332 .iter()
333 .enumerate()
334 .for_each(|(_, val)| {
335 writer.write_field(val.to_string()).unwrap();
336 });
337
338 writer.write_field(prob.to_string()).unwrap();
339 writer.write_record(None::<&[u8]>).unwrap();
340 });
341 });
342
343 writer.flush()?;
344 tracing::info!(
345 "Posterior parameters written to {:?}",
346 &outputfile.get_relative_path()
347 );
348
349 Ok(())
350 }
351
352 pub fn write_obs(&self) -> Result<()> {
354 tracing::debug!("Writing observations...");
355 let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
356 write_pmetrics_observations(&self.data, &outputfile.file)?;
357 tracing::info!(
358 "Observations written to {:?}",
359 &outputfile.get_relative_path()
360 );
361 Ok(())
362 }
363
364 pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
366 tracing::debug!("Writing predictions...");
367 let data = self.data.expand(idelta, tad);
368
369 let theta: Array2<f64> = self
370 .theta
371 .matrix()
372 .clone()
373 .as_mut()
374 .into_ndarray()
375 .to_owned();
376 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
377 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
378
379 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
380 .context("Failed to calculate posterior mean and median")?;
381
382 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
383 .context("Failed to calculate population mean and median")?;
384
385 let subjects = data.get_subjects();
386 if subjects.len() != post_mean.nrows() {
387 bail!("Number of subjects and number of posterior means do not match");
388 }
389
390 let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
391 let mut writer = WriterBuilder::new()
392 .has_headers(true)
393 .from_writer(&outputfile.file);
394
395 #[derive(Debug, Clone, Serialize)]
396 struct Row {
397 id: String,
398 time: f64,
399 outeq: usize,
400 block: usize,
401 pop_mean: f64,
402 pop_median: f64,
403 post_mean: f64,
404 post_median: f64,
405 }
406
407 for (i, subject) in subjects.iter().enumerate() {
408 for occasion in subject.occasions() {
409 let id = subject.id();
410 let block = occasion.index();
411
412 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
414
415 let pop_mean_pred = self
417 .equation
418 .simulate_subject(&subject, &pop_mean.to_vec(), None)
419 .0
420 .get_predictions()
421 .clone();
422 let pop_median_pred = self
423 .equation
424 .simulate_subject(&subject, &pop_median.to_vec(), None)
425 .0
426 .get_predictions()
427 .clone();
428
429 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
431 let post_mean_pred = self
432 .equation
433 .simulate_subject(&subject, &post_mean_spp, None)
434 .0
435 .get_predictions()
436 .clone();
437 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
438 let post_median_pred = self
439 .equation
440 .simulate_subject(&subject, &post_median_spp, None)
441 .0
442 .get_predictions()
443 .clone();
444
445 for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
447 .iter()
448 .zip(pop_median_pred.iter())
449 .zip(post_mean_pred.iter())
450 .zip(post_median_pred.iter())
451 {
452 let row = Row {
453 id: id.clone(),
454 time: pop_mean.time(),
455 outeq: pop_mean.outeq(),
456 block,
457 pop_mean: pop_mean.prediction(),
458 pop_median: pop_median.prediction(),
459 post_mean: post_mean.prediction(),
460 post_median: post_median.prediction(),
461 };
462 writer.serialize(row)?;
463 }
464 }
465 }
466 writer.flush()?;
467 tracing::info!(
468 "Predictions written to {:?}",
469 &outputfile.get_relative_path()
470 );
471 Ok(())
472 }
473
474 pub fn write_covs(&self) -> Result<()> {
476 tracing::debug!("Writing covariates...");
477 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
478 let mut writer = WriterBuilder::new()
479 .has_headers(true)
480 .from_writer(&outputfile.file);
481
482 let mut covariate_names = std::collections::HashSet::new();
484 for subject in self.data.get_subjects() {
485 for occasion in subject.occasions() {
486 if let Some(cov) = occasion.get_covariates() {
487 let covmap = cov.covariates();
488 for cov_name in covmap.keys() {
489 covariate_names.insert(cov_name.clone());
490 }
491 }
492 }
493 }
494 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
495 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
499 headers.extend(covariate_names.iter().map(|s| s.as_str()));
500 writer.write_record(&headers)?;
501
502 for subject in self.data.get_subjects() {
504 for occasion in subject.occasions() {
505 if let Some(cov) = occasion.get_covariates() {
506 let covmap = cov.covariates();
507
508 for event in occasion.get_events(&None, &None, false) {
509 let time = match event {
510 Event::Bolus(bolus) => bolus.time(),
511 Event::Infusion(infusion) => infusion.time(),
512 Event::Observation(observation) => observation.time(),
513 };
514
515 let mut row: Vec<String> = Vec::new();
516 row.push(subject.id().clone());
517 row.push(time.to_string());
518 row.push(occasion.index().to_string());
519
520 for cov_name in &covariate_names {
522 if let Some(cov) = covmap.get(cov_name) {
523 if let Some(value) = cov.interpolate(time) {
524 row.push(value.to_string());
525 } else {
526 row.push(String::new());
527 }
528 } else {
529 row.push(String::new());
530 }
531 }
532
533 writer.write_record(&row)?;
534 }
535 }
536 }
537 }
538
539 writer.flush()?;
540 tracing::info!(
541 "Covariates written to {:?}",
542 &outputfile.get_relative_path()
543 );
544 Ok(())
545 }
546}
547
548#[derive(Debug, Clone)]
558pub struct NPCycle {
559 pub cycle: usize,
560 pub objf: f64,
561 pub gamlam: f64,
562 pub theta: Theta,
563 pub nspp: usize,
564 pub delta_objf: f64,
565 pub converged: bool,
566}
567
568impl NPCycle {
569 pub fn new(
570 cycle: usize,
571 objf: f64,
572 gamlam: f64,
573 theta: Theta,
574 nspp: usize,
575 delta_objf: f64,
576 converged: bool,
577 ) -> Self {
578 Self {
579 cycle,
580 objf,
581 gamlam,
582 theta,
583 nspp,
584 delta_objf,
585 converged,
586 }
587 }
588
589 pub fn placeholder() -> Self {
590 Self {
591 cycle: 0,
592 objf: 0.0,
593 gamlam: 0.0,
594 theta: Theta::new(),
595 nspp: 0,
596 delta_objf: 0.0,
597 converged: false,
598 }
599 }
600}
601
602#[derive(Debug, Clone)]
604pub struct CycleLog {
605 pub cycles: Vec<NPCycle>,
606}
607
608impl CycleLog {
609 pub fn new() -> Self {
610 Self { cycles: Vec::new() }
611 }
612
613 pub fn push(&mut self, cycle: NPCycle) {
614 self.cycles.push(cycle);
615 }
616
617 pub fn write(&self, settings: &Settings) -> Result<()> {
618 tracing::debug!("Writing cycles...");
619 let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
620 let mut writer = WriterBuilder::new()
621 .has_headers(false)
622 .from_writer(&outputfile.file);
623
624 writer.write_field("cycle")?;
626 writer.write_field("converged")?;
627 writer.write_field("neg2ll")?;
628 writer.write_field("gamlam")?;
629 writer.write_field("nspp")?;
630
631 let parameter_names = settings.parameters().names();
632 for param_name in ¶meter_names {
633 writer.write_field(format!("{}.mean", param_name))?;
634 writer.write_field(format!("{}.median", param_name))?;
635 writer.write_field(format!("{}.sd", param_name))?;
636 }
637
638 writer.write_record(None::<&[u8]>)?;
639
640 for cycle in &self.cycles {
641 writer.write_field(format!("{}", cycle.cycle))?;
642 writer.write_field(format!("{}", cycle.converged))?;
643 writer.write_field(format!("{}", cycle.objf))?;
644 writer.write_field(format!("{}", cycle.gamlam))?;
645 writer
646 .write_field(format!("{}", cycle.theta.matrix().nrows()))
647 .unwrap();
648
649 for param in cycle.theta.matrix().col_iter() {
650 let param_values: Vec<f64> = param.iter().cloned().collect();
651
652 let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
653 let median = median(param_values.clone());
654 let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
655 / (param_values.len() as f64 - 1.0);
656
657 writer.write_field(format!("{}", mean))?;
658 writer.write_field(format!("{}", median))?;
659 writer.write_field(format!("{}", std))?;
660 }
661 writer.write_record(None::<&[u8]>)?;
662 }
663 writer.flush()?;
664 tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
665 Ok(())
666 }
667}
668
669impl Default for CycleLog {
670 fn default() -> Self {
671 Self::new()
672 }
673}
674
675pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
676 if psi.matrix().ncols() != w.nrows() {
677 bail!(
678 "Number of rows in psi ({}) and number of weights ({}) do not match.",
679 psi.matrix().nrows(),
680 w.nrows()
681 );
682 }
683
684 let psi_matrix = psi.matrix();
685 let py = psi_matrix * w;
686
687 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
688 psi_matrix.get(i, j) * w.get(j) / py.get(i)
689 });
690
691 Ok(posterior)
692}
693
694pub fn median(data: Vec<f64>) -> f64 {
695 let mut data = data.clone();
696 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
697
698 let size = data.len();
699 match size {
700 even if even % 2 == 0 => {
701 let fst = data.get(even / 2 - 1).unwrap();
702 let snd = data.get(even / 2).unwrap();
703 (fst + snd) / 2.0
704 }
705 odd => *data.get(odd / 2_usize).unwrap(),
706 }
707}
708
709fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
710 assert_eq!(
712 data.len(),
713 weights.len(),
714 "The length of data and weights must be the same"
715 );
716 assert!(
717 weights.iter().all(|&x| x >= 0.0),
718 "Weights must be non-negative, weights: {:?}",
719 weights
720 );
721
722 let mut weighted_data: Vec<(f64, f64)> = data
724 .iter()
725 .zip(weights.iter())
726 .map(|(&d, &w)| (d, w))
727 .collect();
728
729 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
731
732 let total_weight: f64 = weights.sum();
734 let mut cumulative_sum = 0.0;
735
736 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
737 cumulative_sum += weight;
738
739 if cumulative_sum == total_weight / 2.0 {
740 if i + 1 < weighted_data.len() {
742 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
743 } else {
744 return weighted_data[i].0;
745 }
746 } else if cumulative_sum > total_weight / 2.0 {
747 return weighted_data[i].0;
748 }
749 }
750
751 unreachable!("The function should have returned a value before reaching this point.");
752}
753
754pub fn population_mean_median(
755 theta: &Array2<f64>,
756 w: &Array1<f64>,
757) -> Result<(Array1<f64>, Array1<f64>)> {
758 let w = if w.is_empty() {
759 tracing::warn!("w.len() == 0, setting all weights to 1/n");
760 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
761 } else {
762 w.clone()
763 };
764 if theta.nrows() != w.len() {
766 bail!(
767 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
768 theta.nrows(),
769 w.len()
770 );
771 }
772
773 let mut mean = Array1::zeros(theta.ncols());
774 let mut median = Array1::zeros(theta.ncols());
775
776 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
777 let col = theta.column(i).to_owned() * w.to_owned();
779 *mn = col.sum();
780
781 let ct = theta.column(i);
783 let mut params = vec![];
784 let mut weights = vec![];
785 for (ti, wi) in ct.iter().zip(w.clone()) {
786 params.push(*ti);
787 weights.push(wi);
788 }
789
790 *mdn = weighted_median(&Array::from(params), &Array::from(weights));
791 }
792
793 Ok((mean, median))
794}
795
796pub fn posterior_mean_median(
797 theta: &Array2<f64>,
798 psi: &Array2<f64>,
799 w: &Array1<f64>,
800) -> Result<(Array2<f64>, Array2<f64>)> {
801 let mut mean = Array2::zeros((0, theta.ncols()));
802 let mut median = Array2::zeros((0, theta.ncols()));
803
804 let w = if w.is_empty() {
805 tracing::warn!("w is empty, setting all weights to 1/n");
806 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
807 } else {
808 w.clone()
809 };
810
811 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
813 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
814 }
815
816 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
818 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
819 let row_w = row.to_owned() * w.to_owned();
820 let row_sum = row_w.sum();
821 let row_norm = if row_sum == 0.0 {
822 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
823 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
824 } else {
825 &row_w / row_sum
826 };
827 psi_norm.push_row(row_norm.view())?;
828 }
829 if psi_norm.iter().any(|&x| x.is_nan()) {
830 dbg!(&psi);
831 bail!("NaN values found in psi_norm");
832 };
833
834 for probs in psi_norm.axis_iter(Axis(0)) {
839 let mut post_mean: Vec<f64> = Vec::new();
840 let mut post_median: Vec<f64> = Vec::new();
841
842 for pars in theta.axis_iter(Axis(1)) {
844 let weighted_par = &probs * &pars;
846 let the_mean = weighted_par.sum();
847 post_mean.push(the_mean);
848
849 let median = weighted_median(&pars.to_owned(), &probs.to_owned());
851 post_median.push(median);
852 }
853
854 mean.push_row(Array::from(post_mean.clone()).view())?;
855 median.push_row(Array::from(post_median.clone()).view())?;
856 }
857
858 Ok((mean, median))
859}
860
861#[derive(Debug)]
863pub struct OutputFile {
864 pub file: File,
865 pub relative_path: PathBuf,
866}
867
868impl OutputFile {
869 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
870 let relative_path = Path::new(&folder).join(file_name);
871
872 if let Some(parent) = relative_path.parent() {
873 create_dir_all(parent)
874 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
875 }
876
877 let file = OpenOptions::new()
878 .write(true)
879 .create(true)
880 .truncate(true)
881 .open(&relative_path)
882 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
883
884 Ok(OutputFile {
885 file,
886 relative_path,
887 })
888 }
889
890 pub fn get_relative_path(&self) -> &Path {
891 &self.relative_path
892 }
893}
894
895pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
896 let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
897
898 writer.write_record(["id", "block", "time", "out", "outeq"])?;
899 for subject in data.get_subjects() {
900 for occasion in subject.occasions() {
901 for event in occasion.get_events(&None, &None, false) {
902 if let Event::Observation(event) = event {
903 writer.write_record([
904 subject.id(),
905 &occasion.index().to_string(),
906 &event.time().to_string(),
907 &event.value().to_string(),
908 &event.outeq().to_string(),
909 ])?;
910 }
911 }
912 }
913 }
914 Ok(())
915}
916
917#[cfg(test)]
918mod tests {
919 use super::median;
920
921 #[test]
922 fn test_median_odd() {
923 let data = vec![1.0, 3.0, 2.0];
924 assert_eq!(median(data), 2.0);
925 }
926
927 #[test]
928 fn test_median_even() {
929 let data = vec![1.0, 2.0, 3.0, 4.0];
930 assert_eq!(median(data), 2.5);
931 }
932
933 #[test]
934 fn test_median_single() {
935 let data = vec![42.0];
936 assert_eq!(median(data), 42.0);
937 }
938
939 #[test]
940 fn test_median_sorted() {
941 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
942 assert_eq!(median(data), 15.0);
943 }
944
945 #[test]
946 fn test_median_unsorted() {
947 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
948 assert_eq!(median(data), 30.0);
949 }
950
951 #[test]
952 fn test_median_with_duplicates() {
953 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
954 assert_eq!(median(data), 2.0);
955 }
956
957 use super::weighted_median;
958 use ndarray::Array1;
959
960 #[test]
961 fn test_weighted_median_simple() {
962 let data = Array1::from(vec![1.0, 2.0, 3.0]);
963 let weights = Array1::from(vec![0.2, 0.5, 0.3]);
964 assert_eq!(weighted_median(&data, &weights), 2.0);
965 }
966
967 #[test]
968 fn test_weighted_median_even_weights() {
969 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
970 let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
971 assert_eq!(weighted_median(&data, &weights), 2.5);
972 }
973
974 #[test]
975 fn test_weighted_median_single_element() {
976 let data = Array1::from(vec![42.0]);
977 let weights = Array1::from(vec![1.0]);
978 assert_eq!(weighted_median(&data, &weights), 42.0);
979 }
980
981 #[test]
982 #[should_panic(expected = "The length of data and weights must be the same")]
983 fn test_weighted_median_mismatched_lengths() {
984 let data = Array1::from(vec![1.0, 2.0, 3.0]);
985 let weights = Array1::from(vec![0.1, 0.2]);
986 weighted_median(&data, &weights);
987 }
988
989 #[test]
990 fn test_weighted_median_all_same_elements() {
991 let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
992 let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
993 assert_eq!(weighted_median(&data, &weights), 5.0);
994 }
995
996 #[test]
997 #[should_panic(expected = "Weights must be non-negative")]
998 fn test_weighted_median_negative_weights() {
999 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1000 let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1001 assert_eq!(weighted_median(&data, &weights), 4.0);
1002 }
1003
1004 #[test]
1005 fn test_weighted_median_unsorted_data() {
1006 let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1007 let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1008 assert_eq!(weighted_median(&data, &weights), 2.5);
1009 }
1010
1011 #[test]
1012 fn test_weighted_median_with_zero_weights() {
1013 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1014 let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1015 assert_eq!(weighted_median(&data, &weights), 3.0);
1016 }
1017}