1use crate::prelude::*;
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use csv::WriterBuilder;
6use faer::linalg::zip::IntoView;
7use faer_ext::IntoNdarray;
8use ndarray::{Array, Array1, Array2, Axis};
9use pharmsol::prelude::data::*;
10use pharmsol::prelude::simulator::Equation;
11use serde::Serialize;
12use crate::routines::settings::Settings;
14use faer::{Col, Mat};
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22 equation: E,
23 data: Data,
24 theta: Theta,
25 psi: Psi,
26 w: Col<f64>,
27 objf: f64,
28 cycles: usize,
29 converged: bool,
30 par_names: Vec<String>,
31 settings: Settings,
32 cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37 pub fn new(
39 equation: E,
40 data: Data,
41 theta: Theta,
42 psi: Psi,
43 w: Col<f64>,
44 objf: f64,
45 cycles: usize,
46 converged: bool,
47 settings: Settings,
48 cyclelog: CycleLog,
49 ) -> Self {
50 let par_names = settings.parameters().names();
53
54 Self {
55 equation,
56 data,
57 theta,
58 psi,
59 w,
60 objf,
61 cycles,
62 converged,
63 par_names,
64 settings,
65 cyclelog,
66 }
67 }
68
69 pub fn cycles(&self) -> usize {
70 self.cycles
71 }
72
73 pub fn objf(&self) -> f64 {
74 self.objf
75 }
76
77 pub fn converged(&self) -> bool {
78 self.converged
79 }
80
81 pub fn get_theta(&self) -> &Theta {
82 &self.theta
83 }
84
85 pub fn get_psi(&self) -> &Psi {
86 &self.psi
87 }
88
89 pub fn get_w(&self) -> &Col<f64> {
90 &self.w
91 }
92
93 pub fn write_outputs(&self) -> Result<()> {
94 if self.settings.output().write {
95 self.settings.write()?;
96 let idelta: f64 = self.settings.predictions().idelta;
97 let tad = self.settings.predictions().tad;
98 self.cyclelog.write(&self.settings)?;
99 self.write_obs().context("Failed to write observations")?;
100 self.write_theta().context("Failed to write theta")?;
101 self.write_obspred()
102 .context("Failed to write observed-predicted file")?;
103 self.write_pred(idelta, tad)
104 .context("Failed to write predictions")?;
105 self.write_covs().context("Failed to write covariates")?;
106 self.write_posterior()
107 .context("Failed to write posterior")?;
108 }
109 Ok(())
110 }
111
112 pub fn write_obspred(&self) -> Result<()> {
114 tracing::debug!("Writing observations and predictions...");
115
116 #[derive(Debug, Clone, Serialize)]
117 struct Row {
118 id: String,
119 time: f64,
120 outeq: usize,
121 block: usize,
122 obs: f64,
123 pop_mean: f64,
124 pop_median: f64,
125 post_mean: f64,
126 post_median: f64,
127 }
128
129 let theta: Array2<f64> = self
130 .theta
131 .matrix()
132 .clone()
133 .as_mut()
134 .into_ndarray()
135 .to_owned();
136 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
137 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
138
139 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
140 .context("Failed to calculate posterior mean and median")?;
141
142 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
143 .context("Failed to calculate posterior mean and median")?;
144
145 let subjects = self.data.get_subjects();
146 if subjects.len() != post_mean.nrows() {
147 bail!(
148 "Number of subjects: {} and number of posterior means: {} do not match",
149 subjects.len(),
150 post_mean.nrows()
151 );
152 }
153
154 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
155 let mut writer = WriterBuilder::new()
156 .has_headers(true)
157 .from_writer(&outputfile.file);
158
159 for (i, subject) in subjects.iter().enumerate() {
160 for occasion in subject.occasions() {
161 let id = subject.id();
162 let occ = occasion.index();
163
164 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
165
166 let pop_mean_pred = self
168 .equation
169 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
170 .0
171 .get_predictions()
172 .clone();
173
174 let pop_median_pred = self
175 .equation
176 .simulate_subject(&subject, &pop_median.to_vec(), None)?
177 .0
178 .get_predictions()
179 .clone();
180
181 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
183 let post_mean_pred = self
184 .equation
185 .simulate_subject(&subject, &post_mean_spp, None)?
186 .0
187 .get_predictions()
188 .clone();
189 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
190 let post_median_pred = self
191 .equation
192 .simulate_subject(&subject, &post_median_spp, None)?
193 .0
194 .get_predictions()
195 .clone();
196 assert_eq!(
197 pop_mean_pred.len(),
198 pop_median_pred.len(),
199 "The number of predictions do not match (pop_mean vs pop_median)"
200 );
201
202 assert_eq!(
203 post_mean_pred.len(),
204 post_median_pred.len(),
205 "The number of predictions do not match (post_mean vs post_median)"
206 );
207
208 assert_eq!(
209 pop_mean_pred.len(),
210 post_mean_pred.len(),
211 "The number of predictions do not match (pop_mean vs post_mean)"
212 );
213
214 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
215 pop_mean_pred
216 .iter()
217 .zip(pop_median_pred.iter())
218 .zip(post_mean_pred.iter())
219 .zip(post_median_pred.iter())
220 {
221 let row = Row {
222 id: id.clone(),
223 time: pop_mean_pred.time(),
224 outeq: pop_mean_pred.outeq(),
225 block: occ,
226 obs: pop_mean_pred.observation(),
227 pop_mean: pop_mean_pred.prediction(),
228 pop_median: pop_median_pred.prediction(),
229 post_mean: post_mean_pred.prediction(),
230 post_median: post_median_pred.prediction(),
231 };
232 writer.serialize(row)?;
233 }
234 }
235 }
236 writer.flush()?;
237 tracing::info!(
238 "Observations with predictions written to {:?}",
239 &outputfile.get_relative_path()
240 );
241 Ok(())
242 }
243
244 pub fn write_theta(&self) -> Result<()> {
247 tracing::debug!("Writing population parameter distribution...");
248
249 let theta = &self.theta;
250 let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
251 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
259 .context("Failed to create output file for theta")?;
260
261 let mut writer = WriterBuilder::new()
262 .has_headers(true)
263 .from_writer(&outputfile.file);
264
265 let mut theta_header = self.par_names.clone();
267 theta_header.push("prob".to_string());
268 writer.write_record(&theta_header)?;
269
270 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
272 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
273 row.push(w_val.to_string());
274 writer.write_record(&row)?;
275 }
276 writer.flush()?;
277 tracing::info!(
278 "Population parameter distribution written to {:?}",
279 &outputfile.get_relative_path()
280 );
281 Ok(())
282 }
283
284 pub fn write_posterior(&self) -> Result<()> {
286 tracing::debug!("Writing posterior parameter probabilities...");
287 let theta = &self.theta;
288 let w = &self.w;
289 let psi = &self.psi;
290
291 let posterior = posterior(psi, w)?;
293
294 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
296 Ok(of) => of,
297 Err(e) => {
298 tracing::error!("Failed to create output file: {}", e);
299 return Err(e.context("Failed to create output file"));
300 }
301 };
302
303 let mut writer = WriterBuilder::new()
305 .has_headers(true)
306 .from_writer(&outputfile.file);
307
308 writer.write_field("id")?;
310 writer.write_field("point")?;
311 theta.param_names().iter().for_each(|name| {
312 writer.write_field(name).unwrap();
313 });
314 writer.write_field("prob")?;
315 writer.write_record(None::<&[u8]>)?;
316
317 let subjects = self.data.get_subjects();
319 posterior.row_iter().enumerate().for_each(|(i, row)| {
320 let subject = subjects.get(i).unwrap();
321 let id = subject.id();
322
323 row.iter().enumerate().for_each(|(spp, prob)| {
324 writer.write_field(id.clone()).unwrap();
325 writer.write_field(i.to_string()).unwrap();
326
327 theta.matrix().row(spp).iter().for_each(|val| {
328 writer.write_field(val.to_string()).unwrap();
329 });
330
331 writer.write_field(prob.to_string()).unwrap();
332 writer.write_record(None::<&[u8]>).unwrap();
333 });
334 });
335
336 writer.flush()?;
337 tracing::info!(
338 "Posterior parameters written to {:?}",
339 &outputfile.get_relative_path()
340 );
341
342 Ok(())
343 }
344
345 pub fn write_obs(&self) -> Result<()> {
347 tracing::debug!("Writing observations...");
348 let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
349 write_pmetrics_observations(&self.data, &outputfile.file)?;
350 tracing::info!(
351 "Observations written to {:?}",
352 &outputfile.get_relative_path()
353 );
354 Ok(())
355 }
356
357 pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
359 tracing::debug!("Writing predictions...");
360 let data = self.data.expand(idelta, tad);
361
362 let theta: Array2<f64> = self
363 .theta
364 .matrix()
365 .clone()
366 .as_mut()
367 .into_ndarray()
368 .to_owned();
369 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
370 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
371
372 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
373 .context("Failed to calculate posterior mean and median")?;
374
375 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
376 .context("Failed to calculate population mean and median")?;
377
378 let subjects = data.get_subjects();
379 if subjects.len() != post_mean.nrows() {
380 bail!("Number of subjects and number of posterior means do not match");
381 }
382
383 let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
384 let mut writer = WriterBuilder::new()
385 .has_headers(true)
386 .from_writer(&outputfile.file);
387
388 #[derive(Debug, Clone, Serialize)]
389 struct Row {
390 id: String,
391 time: f64,
392 outeq: usize,
393 block: usize,
394 pop_mean: f64,
395 pop_median: f64,
396 post_mean: f64,
397 post_median: f64,
398 }
399
400 for (i, subject) in subjects.iter().enumerate() {
401 for occasion in subject.occasions() {
402 let id = subject.id();
403 let block = occasion.index();
404
405 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
407
408 let pop_mean_pred = self
410 .equation
411 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
412 .0
413 .get_predictions()
414 .clone();
415 let pop_median_pred = self
416 .equation
417 .simulate_subject(&subject, &pop_median.to_vec(), None)?
418 .0
419 .get_predictions()
420 .clone();
421
422 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
424 let post_mean_pred = self
425 .equation
426 .simulate_subject(&subject, &post_mean_spp, None)?
427 .0
428 .get_predictions()
429 .clone();
430 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
431 let post_median_pred = self
432 .equation
433 .simulate_subject(&subject, &post_median_spp, None)?
434 .0
435 .get_predictions()
436 .clone();
437
438 for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
440 .iter()
441 .zip(pop_median_pred.iter())
442 .zip(post_mean_pred.iter())
443 .zip(post_median_pred.iter())
444 {
445 let row = Row {
446 id: id.clone(),
447 time: pop_mean.time(),
448 outeq: pop_mean.outeq(),
449 block,
450 pop_mean: pop_mean.prediction(),
451 pop_median: pop_median.prediction(),
452 post_mean: post_mean.prediction(),
453 post_median: post_median.prediction(),
454 };
455 writer.serialize(row)?;
456 }
457 }
458 }
459 writer.flush()?;
460 tracing::info!(
461 "Predictions written to {:?}",
462 &outputfile.get_relative_path()
463 );
464 Ok(())
465 }
466
467 pub fn write_covs(&self) -> Result<()> {
469 tracing::debug!("Writing covariates...");
470 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
471 let mut writer = WriterBuilder::new()
472 .has_headers(true)
473 .from_writer(&outputfile.file);
474
475 let mut covariate_names = std::collections::HashSet::new();
477 for subject in self.data.get_subjects() {
478 for occasion in subject.occasions() {
479 if let Some(cov) = occasion.get_covariates() {
480 let covmap = cov.covariates();
481 for cov_name in covmap.keys() {
482 covariate_names.insert(cov_name.clone());
483 }
484 }
485 }
486 }
487 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
488 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
492 headers.extend(covariate_names.iter().map(|s| s.as_str()));
493 writer.write_record(&headers)?;
494
495 for subject in self.data.get_subjects() {
497 for occasion in subject.occasions() {
498 if let Some(cov) = occasion.get_covariates() {
499 let covmap = cov.covariates();
500
501 for event in occasion.get_events(&None, &None, false) {
502 let time = match event {
503 Event::Bolus(bolus) => bolus.time(),
504 Event::Infusion(infusion) => infusion.time(),
505 Event::Observation(observation) => observation.time(),
506 };
507
508 let mut row: Vec<String> = Vec::new();
509 row.push(subject.id().clone());
510 row.push(time.to_string());
511 row.push(occasion.index().to_string());
512
513 for cov_name in &covariate_names {
515 if let Some(cov) = covmap.get(cov_name) {
516 if let Some(value) = cov.interpolate(time) {
517 row.push(value.to_string());
518 } else {
519 row.push(String::new());
520 }
521 } else {
522 row.push(String::new());
523 }
524 }
525
526 writer.write_record(&row)?;
527 }
528 }
529 }
530 }
531
532 writer.flush()?;
533 tracing::info!(
534 "Covariates written to {:?}",
535 &outputfile.get_relative_path()
536 );
537 Ok(())
538 }
539}
540
541#[derive(Debug, Clone)]
551pub struct NPCycle {
552 pub cycle: usize,
553 pub objf: f64,
554 pub gamlam: f64,
555 pub theta: Theta,
556 pub nspp: usize,
557 pub delta_objf: f64,
558 pub converged: bool,
559}
560
561impl NPCycle {
562 pub fn new(
563 cycle: usize,
564 objf: f64,
565 gamlam: f64,
566 theta: Theta,
567 nspp: usize,
568 delta_objf: f64,
569 converged: bool,
570 ) -> Self {
571 Self {
572 cycle,
573 objf,
574 gamlam,
575 theta,
576 nspp,
577 delta_objf,
578 converged,
579 }
580 }
581
582 pub fn placeholder() -> Self {
583 Self {
584 cycle: 0,
585 objf: 0.0,
586 gamlam: 0.0,
587 theta: Theta::new(),
588 nspp: 0,
589 delta_objf: 0.0,
590 converged: false,
591 }
592 }
593}
594
595#[derive(Debug, Clone)]
597pub struct CycleLog {
598 pub cycles: Vec<NPCycle>,
599}
600
601impl CycleLog {
602 pub fn new() -> Self {
603 Self { cycles: Vec::new() }
604 }
605
606 pub fn push(&mut self, cycle: NPCycle) {
607 self.cycles.push(cycle);
608 }
609
610 pub fn write(&self, settings: &Settings) -> Result<()> {
611 tracing::debug!("Writing cycles...");
612 let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
613 let mut writer = WriterBuilder::new()
614 .has_headers(false)
615 .from_writer(&outputfile.file);
616
617 writer.write_field("cycle")?;
619 writer.write_field("converged")?;
620 writer.write_field("neg2ll")?;
621 writer.write_field("gamlam")?;
622 writer.write_field("nspp")?;
623
624 let parameter_names = settings.parameters().names();
625 for param_name in ¶meter_names {
626 writer.write_field(format!("{}.mean", param_name))?;
627 writer.write_field(format!("{}.median", param_name))?;
628 writer.write_field(format!("{}.sd", param_name))?;
629 }
630
631 writer.write_record(None::<&[u8]>)?;
632
633 for cycle in &self.cycles {
634 writer.write_field(format!("{}", cycle.cycle))?;
635 writer.write_field(format!("{}", cycle.converged))?;
636 writer.write_field(format!("{}", cycle.objf))?;
637 writer.write_field(format!("{}", cycle.gamlam))?;
638 writer
639 .write_field(format!("{}", cycle.theta.matrix().nrows()))
640 .unwrap();
641
642 for param in cycle.theta.matrix().col_iter() {
643 let param_values: Vec<f64> = param.iter().cloned().collect();
644
645 let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
646 let median = median(param_values.clone());
647 let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
648 / (param_values.len() as f64 - 1.0);
649
650 writer.write_field(format!("{}", mean))?;
651 writer.write_field(format!("{}", median))?;
652 writer.write_field(format!("{}", std))?;
653 }
654 writer.write_record(None::<&[u8]>)?;
655 }
656 writer.flush()?;
657 tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
658 Ok(())
659 }
660}
661
662impl Default for CycleLog {
663 fn default() -> Self {
664 Self::new()
665 }
666}
667
668pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
669 if psi.matrix().ncols() != w.nrows() {
670 bail!(
671 "Number of rows in psi ({}) and number of weights ({}) do not match.",
672 psi.matrix().nrows(),
673 w.nrows()
674 );
675 }
676
677 let psi_matrix = psi.matrix();
678 let py = psi_matrix * w;
679
680 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
681 psi_matrix.get(i, j) * w.get(j) / py.get(i)
682 });
683
684 Ok(posterior)
685}
686
687pub fn median(data: Vec<f64>) -> f64 {
688 let mut data = data.clone();
689 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
690
691 let size = data.len();
692 match size {
693 even if even % 2 == 0 => {
694 let fst = data.get(even / 2 - 1).unwrap();
695 let snd = data.get(even / 2).unwrap();
696 (fst + snd) / 2.0
697 }
698 odd => *data.get(odd / 2_usize).unwrap(),
699 }
700}
701
702fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
703 assert_eq!(
705 data.len(),
706 weights.len(),
707 "The length of data and weights must be the same"
708 );
709 assert!(
710 weights.iter().all(|&x| x >= 0.0),
711 "Weights must be non-negative, weights: {:?}",
712 weights
713 );
714
715 let mut weighted_data: Vec<(f64, f64)> = data
717 .iter()
718 .zip(weights.iter())
719 .map(|(&d, &w)| (d, w))
720 .collect();
721
722 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
724
725 let total_weight: f64 = weights.sum();
727 let mut cumulative_sum = 0.0;
728
729 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
730 cumulative_sum += weight;
731
732 if cumulative_sum == total_weight / 2.0 {
733 if i + 1 < weighted_data.len() {
735 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
736 } else {
737 return weighted_data[i].0;
738 }
739 } else if cumulative_sum > total_weight / 2.0 {
740 return weighted_data[i].0;
741 }
742 }
743
744 unreachable!("The function should have returned a value before reaching this point.");
745}
746
747pub fn population_mean_median(
748 theta: &Array2<f64>,
749 w: &Array1<f64>,
750) -> Result<(Array1<f64>, Array1<f64>)> {
751 let w = if w.is_empty() {
752 tracing::warn!("w.len() == 0, setting all weights to 1/n");
753 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
754 } else {
755 w.clone()
756 };
757 if theta.nrows() != w.len() {
759 bail!(
760 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
761 theta.nrows(),
762 w.len()
763 );
764 }
765
766 let mut mean = Array1::zeros(theta.ncols());
767 let mut median = Array1::zeros(theta.ncols());
768
769 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
770 let col = theta.column(i).to_owned() * w.to_owned();
772 *mn = col.sum();
773
774 let ct = theta.column(i);
776 let mut params = vec![];
777 let mut weights = vec![];
778 for (ti, wi) in ct.iter().zip(w.clone()) {
779 params.push(*ti);
780 weights.push(wi);
781 }
782
783 *mdn = weighted_median(&Array::from(params), &Array::from(weights));
784 }
785
786 Ok((mean, median))
787}
788
789pub fn posterior_mean_median(
790 theta: &Array2<f64>,
791 psi: &Array2<f64>,
792 w: &Array1<f64>,
793) -> Result<(Array2<f64>, Array2<f64>)> {
794 let mut mean = Array2::zeros((0, theta.ncols()));
795 let mut median = Array2::zeros((0, theta.ncols()));
796
797 let w = if w.is_empty() {
798 tracing::warn!("w is empty, setting all weights to 1/n");
799 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
800 } else {
801 w.clone()
802 };
803
804 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
806 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
807 }
808
809 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
811 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
812 let row_w = row.to_owned() * w.to_owned();
813 let row_sum = row_w.sum();
814 let row_norm = if row_sum == 0.0 {
815 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
816 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
817 } else {
818 &row_w / row_sum
819 };
820 psi_norm.push_row(row_norm.view())?;
821 }
822 if psi_norm.iter().any(|&x| x.is_nan()) {
823 dbg!(&psi);
824 bail!("NaN values found in psi_norm");
825 };
826
827 for probs in psi_norm.axis_iter(Axis(0)) {
832 let mut post_mean: Vec<f64> = Vec::new();
833 let mut post_median: Vec<f64> = Vec::new();
834
835 for pars in theta.axis_iter(Axis(1)) {
837 let weighted_par = &probs * &pars;
839 let the_mean = weighted_par.sum();
840 post_mean.push(the_mean);
841
842 let median = weighted_median(&pars.to_owned(), &probs.to_owned());
844 post_median.push(median);
845 }
846
847 mean.push_row(Array::from(post_mean.clone()).view())?;
848 median.push_row(Array::from(post_median.clone()).view())?;
849 }
850
851 Ok((mean, median))
852}
853
854#[derive(Debug)]
856pub struct OutputFile {
857 pub file: File,
858 pub relative_path: PathBuf,
859}
860
861impl OutputFile {
862 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
863 let relative_path = Path::new(&folder).join(file_name);
864
865 if let Some(parent) = relative_path.parent() {
866 create_dir_all(parent)
867 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
868 }
869
870 let file = OpenOptions::new()
871 .write(true)
872 .create(true)
873 .truncate(true)
874 .open(&relative_path)
875 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
876
877 Ok(OutputFile {
878 file,
879 relative_path,
880 })
881 }
882
883 pub fn get_relative_path(&self) -> &Path {
884 &self.relative_path
885 }
886}
887
888pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
889 let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
890
891 writer.write_record(["id", "block", "time", "out", "outeq"])?;
892 for subject in data.get_subjects() {
893 for occasion in subject.occasions() {
894 for event in occasion.get_events(&None, &None, false) {
895 if let Event::Observation(event) = event {
896 writer.write_record([
897 subject.id(),
898 &occasion.index().to_string(),
899 &event.time().to_string(),
900 &event.value().to_string(),
901 &event.outeq().to_string(),
902 ])?;
903 }
904 }
905 }
906 }
907 Ok(())
908}
909
910#[cfg(test)]
911mod tests {
912 use super::median;
913
914 #[test]
915 fn test_median_odd() {
916 let data = vec![1.0, 3.0, 2.0];
917 assert_eq!(median(data), 2.0);
918 }
919
920 #[test]
921 fn test_median_even() {
922 let data = vec![1.0, 2.0, 3.0, 4.0];
923 assert_eq!(median(data), 2.5);
924 }
925
926 #[test]
927 fn test_median_single() {
928 let data = vec![42.0];
929 assert_eq!(median(data), 42.0);
930 }
931
932 #[test]
933 fn test_median_sorted() {
934 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
935 assert_eq!(median(data), 15.0);
936 }
937
938 #[test]
939 fn test_median_unsorted() {
940 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
941 assert_eq!(median(data), 30.0);
942 }
943
944 #[test]
945 fn test_median_with_duplicates() {
946 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
947 assert_eq!(median(data), 2.0);
948 }
949
950 use super::weighted_median;
951 use ndarray::Array1;
952
953 #[test]
954 fn test_weighted_median_simple() {
955 let data = Array1::from(vec![1.0, 2.0, 3.0]);
956 let weights = Array1::from(vec![0.2, 0.5, 0.3]);
957 assert_eq!(weighted_median(&data, &weights), 2.0);
958 }
959
960 #[test]
961 fn test_weighted_median_even_weights() {
962 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
963 let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
964 assert_eq!(weighted_median(&data, &weights), 2.5);
965 }
966
967 #[test]
968 fn test_weighted_median_single_element() {
969 let data = Array1::from(vec![42.0]);
970 let weights = Array1::from(vec![1.0]);
971 assert_eq!(weighted_median(&data, &weights), 42.0);
972 }
973
974 #[test]
975 #[should_panic(expected = "The length of data and weights must be the same")]
976 fn test_weighted_median_mismatched_lengths() {
977 let data = Array1::from(vec![1.0, 2.0, 3.0]);
978 let weights = Array1::from(vec![0.1, 0.2]);
979 weighted_median(&data, &weights);
980 }
981
982 #[test]
983 fn test_weighted_median_all_same_elements() {
984 let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
985 let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
986 assert_eq!(weighted_median(&data, &weights), 5.0);
987 }
988
989 #[test]
990 #[should_panic(expected = "Weights must be non-negative")]
991 fn test_weighted_median_negative_weights() {
992 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
993 let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
994 assert_eq!(weighted_median(&data, &weights), 4.0);
995 }
996
997 #[test]
998 fn test_weighted_median_unsorted_data() {
999 let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1000 let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1001 assert_eq!(weighted_median(&data, &weights), 2.5);
1002 }
1003
1004 #[test]
1005 fn test_weighted_median_with_zero_weights() {
1006 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1007 let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1008 assert_eq!(weighted_median(&data, &weights), 3.0);
1009 }
1010}