1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::Equation;
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22 equation: E,
23 data: Data,
24 theta: Theta,
25 psi: Psi,
26 w: Col<f64>,
27 objf: f64,
28 cycles: usize,
29 status: Status,
30 par_names: Vec<String>,
31 settings: Settings,
32 cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37 pub fn new(
39 equation: E,
40 data: Data,
41 theta: Theta,
42 psi: Psi,
43 w: Col<f64>,
44 objf: f64,
45 cycles: usize,
46 status: Status,
47 settings: Settings,
48 cyclelog: CycleLog,
49 ) -> Self {
50 let par_names = settings.parameters().names();
53
54 Self {
55 equation,
56 data,
57 theta,
58 psi,
59 w,
60 objf,
61 cycles,
62 status,
63 par_names,
64 settings,
65 cyclelog,
66 }
67 }
68
69 pub fn cycles(&self) -> usize {
70 self.cycles
71 }
72
73 pub fn objf(&self) -> f64 {
74 self.objf
75 }
76
77 pub fn converged(&self) -> bool {
78 self.status == Status::Converged
79 }
80
81 pub fn get_theta(&self) -> &Theta {
82 &self.theta
83 }
84
85 pub fn psi(&self) -> &Psi {
87 &self.psi
88 }
89
90 pub fn w(&self) -> &Col<f64> {
92 &self.w
93 }
94
95 pub fn write_outputs(&self) -> Result<()> {
96 if self.settings.output().write {
97 tracing::info!("Writing outputs to {:?}", self.settings.output().path);
98 self.settings.write()?;
99 let idelta: f64 = self.settings.predictions().idelta;
100 let tad = self.settings.predictions().tad;
101 self.cyclelog.write(&self.settings)?;
102 self.write_obs().context("Failed to write observations")?;
103 self.write_theta().context("Failed to write theta")?;
104 self.write_obspred()
105 .context("Failed to write observed-predicted file")?;
106 self.write_pred(idelta, tad)
107 .context("Failed to write predictions")?;
108 self.write_covs().context("Failed to write covariates")?;
109 self.write_posterior()
110 .context("Failed to write posterior")?;
111 }
112 Ok(())
113 }
114
115 pub fn write_obspred(&self) -> Result<()> {
117 tracing::debug!("Writing observations and predictions...");
118
119 #[derive(Debug, Clone, Serialize)]
120 struct Row {
121 id: String,
122 time: f64,
123 outeq: usize,
124 block: usize,
125 obs: f64,
126 pop_mean: f64,
127 pop_median: f64,
128 post_mean: f64,
129 post_median: f64,
130 }
131
132 let theta: Array2<f64> = self
133 .theta
134 .matrix()
135 .clone()
136 .as_mut()
137 .into_ndarray()
138 .to_owned();
139 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
140 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
141
142 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
143 .context("Failed to calculate posterior mean and median")?;
144
145 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
146 .context("Failed to calculate posterior mean and median")?;
147
148 let subjects = self.data.get_subjects();
149 if subjects.len() != post_mean.nrows() {
150 bail!(
151 "Number of subjects: {} and number of posterior means: {} do not match",
152 subjects.len(),
153 post_mean.nrows()
154 );
155 }
156
157 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
158 let mut writer = WriterBuilder::new()
159 .has_headers(true)
160 .from_writer(&outputfile.file);
161
162 for (i, subject) in subjects.iter().enumerate() {
163 for occasion in subject.occasions() {
164 let id = subject.id();
165 let occ = occasion.index();
166
167 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
168
169 let pop_mean_pred = self
171 .equation
172 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
173 .0
174 .get_predictions()
175 .clone();
176
177 let pop_median_pred = self
178 .equation
179 .simulate_subject(&subject, &pop_median.to_vec(), None)?
180 .0
181 .get_predictions()
182 .clone();
183
184 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
186 let post_mean_pred = self
187 .equation
188 .simulate_subject(&subject, &post_mean_spp, None)?
189 .0
190 .get_predictions()
191 .clone();
192 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
193 let post_median_pred = self
194 .equation
195 .simulate_subject(&subject, &post_median_spp, None)?
196 .0
197 .get_predictions()
198 .clone();
199 assert_eq!(
200 pop_mean_pred.len(),
201 pop_median_pred.len(),
202 "The number of predictions do not match (pop_mean vs pop_median)"
203 );
204
205 assert_eq!(
206 post_mean_pred.len(),
207 post_median_pred.len(),
208 "The number of predictions do not match (post_mean vs post_median)"
209 );
210
211 assert_eq!(
212 pop_mean_pred.len(),
213 post_mean_pred.len(),
214 "The number of predictions do not match (pop_mean vs post_mean)"
215 );
216
217 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
218 pop_mean_pred
219 .iter()
220 .zip(pop_median_pred.iter())
221 .zip(post_mean_pred.iter())
222 .zip(post_median_pred.iter())
223 {
224 let row = Row {
225 id: id.clone(),
226 time: pop_mean_pred.time(),
227 outeq: pop_mean_pred.outeq(),
228 block: occ,
229 obs: pop_mean_pred.observation(),
230 pop_mean: pop_mean_pred.prediction(),
231 pop_median: pop_median_pred.prediction(),
232 post_mean: post_mean_pred.prediction(),
233 post_median: post_median_pred.prediction(),
234 };
235 writer.serialize(row)?;
236 }
237 }
238 }
239 writer.flush()?;
240 tracing::debug!(
241 "Observations with predictions written to {:?}",
242 &outputfile.get_relative_path()
243 );
244 Ok(())
245 }
246
247 pub fn write_theta(&self) -> Result<()> {
250 tracing::debug!("Writing population parameter distribution...");
251
252 let theta = &self.theta;
253 let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
254 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
262 .context("Failed to create output file for theta")?;
263
264 let mut writer = WriterBuilder::new()
265 .has_headers(true)
266 .from_writer(&outputfile.file);
267
268 let mut theta_header = self.par_names.clone();
270 theta_header.push("prob".to_string());
271 writer.write_record(&theta_header)?;
272
273 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
275 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
276 row.push(w_val.to_string());
277 writer.write_record(&row)?;
278 }
279 writer.flush()?;
280 tracing::debug!(
281 "Population parameter distribution written to {:?}",
282 &outputfile.get_relative_path()
283 );
284 Ok(())
285 }
286
287 pub fn write_posterior(&self) -> Result<()> {
289 tracing::debug!("Writing posterior parameter probabilities...");
290 let theta = &self.theta;
291 let w = &self.w;
292 let psi = &self.psi;
293
294 let posterior = posterior(psi, w)?;
296
297 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
299 Ok(of) => of,
300 Err(e) => {
301 tracing::error!("Failed to create output file: {}", e);
302 return Err(e.context("Failed to create output file"));
303 }
304 };
305
306 let mut writer = WriterBuilder::new()
308 .has_headers(true)
309 .from_writer(&outputfile.file);
310
311 writer.write_field("id")?;
313 writer.write_field("point")?;
314 theta.param_names().iter().for_each(|name| {
315 writer.write_field(name).unwrap();
316 });
317 writer.write_field("prob")?;
318 writer.write_record(None::<&[u8]>)?;
319
320 let subjects = self.data.get_subjects();
322 posterior.row_iter().enumerate().for_each(|(i, row)| {
323 let subject = subjects.get(i).unwrap();
324 let id = subject.id();
325
326 row.iter().enumerate().for_each(|(spp, prob)| {
327 writer.write_field(id.clone()).unwrap();
328 writer.write_field(spp.to_string()).unwrap();
329
330 theta.matrix().row(spp).iter().for_each(|val| {
331 writer.write_field(val.to_string()).unwrap();
332 });
333
334 writer.write_field(prob.to_string()).unwrap();
335 writer.write_record(None::<&[u8]>).unwrap();
336 });
337 });
338
339 writer.flush()?;
340 tracing::debug!(
341 "Posterior parameters written to {:?}",
342 &outputfile.get_relative_path()
343 );
344
345 Ok(())
346 }
347
348 pub fn write_obs(&self) -> Result<()> {
350 tracing::debug!("Writing observations...");
351 let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
352
353 let mut writer = WriterBuilder::new()
354 .has_headers(true)
355 .from_writer(&outputfile.file);
356
357 #[derive(Serialize)]
358 struct Row {
359 id: String,
360 block: usize,
361 time: f64,
362 out: f64,
363 outeq: usize,
364 }
365
366 for subject in self.data.get_subjects() {
367 for occasion in subject.occasions() {
368 for event in occasion.get_events(None, false) {
369 if let Event::Observation(event) = event {
370 let row = Row {
371 id: subject.id().clone(),
372 block: occasion.index(),
373 time: event.time(),
374 out: event.value(),
375 outeq: event.outeq(),
376 };
377 writer.serialize(row)?;
378 }
379 }
380 }
381 }
382 writer.flush()?;
383
384 tracing::debug!(
385 "Observations written to {:?}",
386 &outputfile.get_relative_path()
387 );
388 Ok(())
389 }
390
391 pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
393 tracing::debug!("Writing predictions...");
394 let data = self.data.expand(idelta, tad);
395
396 let theta: Array2<f64> = self
397 .theta
398 .matrix()
399 .clone()
400 .as_mut()
401 .into_ndarray()
402 .to_owned();
403 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
404 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
405
406 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
407 .context("Failed to calculate posterior mean and median")?;
408
409 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
410 .context("Failed to calculate population mean and median")?;
411
412 let subjects = data.get_subjects();
413 if subjects.len() != post_mean.nrows() {
414 bail!("Number of subjects and number of posterior means do not match");
415 }
416
417 let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
418 let mut writer = WriterBuilder::new()
419 .has_headers(true)
420 .from_writer(&outputfile.file);
421
422 #[derive(Debug, Clone, Serialize)]
423 struct Row {
424 id: String,
425 time: f64,
426 outeq: usize,
427 block: usize,
428 pop_mean: f64,
429 pop_median: f64,
430 post_mean: f64,
431 post_median: f64,
432 }
433
434 for (i, subject) in subjects.iter().enumerate() {
435 for occasion in subject.occasions() {
436 let id = subject.id();
437 let block = occasion.index();
438
439 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
441
442 let pop_mean_pred = self
444 .equation
445 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
446 .0
447 .get_predictions()
448 .clone();
449 let pop_median_pred = self
450 .equation
451 .simulate_subject(&subject, &pop_median.to_vec(), None)?
452 .0
453 .get_predictions()
454 .clone();
455
456 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
458 let post_mean_pred = self
459 .equation
460 .simulate_subject(&subject, &post_mean_spp, None)?
461 .0
462 .get_predictions()
463 .clone();
464 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
465 let post_median_pred = self
466 .equation
467 .simulate_subject(&subject, &post_median_spp, None)?
468 .0
469 .get_predictions()
470 .clone();
471
472 for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
474 .iter()
475 .zip(pop_median_pred.iter())
476 .zip(post_mean_pred.iter())
477 .zip(post_median_pred.iter())
478 {
479 let row = Row {
480 id: id.clone(),
481 time: pop_mean.time(),
482 outeq: pop_mean.outeq(),
483 block,
484 pop_mean: pop_mean.prediction(),
485 pop_median: pop_median.prediction(),
486 post_mean: post_mean.prediction(),
487 post_median: post_median.prediction(),
488 };
489 writer.serialize(row)?;
490 }
491 }
492 }
493 writer.flush()?;
494 tracing::debug!(
495 "Predictions written to {:?}",
496 &outputfile.get_relative_path()
497 );
498 Ok(())
499 }
500
501 pub fn write_covs(&self) -> Result<()> {
503 tracing::debug!("Writing covariates...");
504 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
505 let mut writer = WriterBuilder::new()
506 .has_headers(true)
507 .from_writer(&outputfile.file);
508
509 let mut covariate_names = std::collections::HashSet::new();
511 for subject in self.data.get_subjects() {
512 for occasion in subject.occasions() {
513 let cov = occasion.covariates();
514 let covmap = cov.covariates();
515 for cov_name in covmap.keys() {
516 covariate_names.insert(cov_name.clone());
517 }
518 }
519 }
520 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
521 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
525 headers.extend(covariate_names.iter().map(|s| s.as_str()));
526 writer.write_record(&headers)?;
527
528 for subject in self.data.get_subjects() {
530 for occasion in subject.occasions() {
531 let cov = occasion.covariates();
532 let covmap = cov.covariates();
533
534 for event in occasion.get_events(None, false) {
535 let time = match event {
536 Event::Bolus(bolus) => bolus.time(),
537 Event::Infusion(infusion) => infusion.time(),
538 Event::Observation(observation) => observation.time(),
539 };
540
541 let mut row: Vec<String> = Vec::new();
542 row.push(subject.id().clone());
543 row.push(time.to_string());
544 row.push(occasion.index().to_string());
545
546 for cov_name in &covariate_names {
548 if let Some(cov) = covmap.get(cov_name) {
549 if let Some(value) = cov.interpolate(time) {
550 row.push(value.to_string());
551 } else {
552 row.push(String::new());
553 }
554 } else {
555 row.push(String::new());
556 }
557 }
558
559 writer.write_record(&row)?;
560 }
561 }
562 }
563
564 writer.flush()?;
565 tracing::debug!(
566 "Covariates written to {:?}",
567 &outputfile.get_relative_path()
568 );
569 Ok(())
570 }
571}
572
573#[derive(Debug, Clone)]
583pub struct NPCycle {
584 pub cycle: usize,
585 pub objf: f64,
586 pub error_models: ErrorModels,
587 pub theta: Theta,
588 pub nspp: usize,
589 pub delta_objf: f64,
590 pub status: Status,
591}
592
593impl NPCycle {
594 pub fn new(
595 cycle: usize,
596 objf: f64,
597 error_models: ErrorModels,
598 theta: Theta,
599 nspp: usize,
600 delta_objf: f64,
601 status: Status,
602 ) -> Self {
603 Self {
604 cycle,
605 objf,
606 error_models,
607 theta,
608 nspp,
609 delta_objf,
610 status,
611 }
612 }
613
614 pub fn placeholder() -> Self {
615 Self {
616 cycle: 0,
617 objf: 0.0,
618 error_models: ErrorModels::default(),
619 theta: Theta::new(),
620 nspp: 0,
621 delta_objf: 0.0,
622 status: Status::Starting,
623 }
624 }
625}
626
627#[derive(Debug, Clone)]
629pub struct CycleLog {
630 pub cycles: Vec<NPCycle>,
631}
632
633impl CycleLog {
634 pub fn new() -> Self {
635 Self { cycles: Vec::new() }
636 }
637
638 pub fn push(&mut self, cycle: NPCycle) {
639 self.cycles.push(cycle);
640 }
641
642 pub fn write(&self, settings: &Settings) -> Result<()> {
643 tracing::debug!("Writing cycles...");
644 let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
645 let mut writer = WriterBuilder::new()
646 .has_headers(false)
647 .from_writer(&outputfile.file);
648
649 writer.write_field("cycle")?;
651 writer.write_field("converged")?;
652 writer.write_field("status")?;
653 writer.write_field("neg2ll")?;
654 writer.write_field("nspp")?;
655 if let Some(first_cycle) = self.cycles.first() {
656 first_cycle.error_models.iter().try_for_each(
657 |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
658 match errmod {
659 ErrorModel::Additive { .. } => {
660 writer.write_field(format!("gamlam.{}", outeq))?;
661 }
662 ErrorModel::Proportional { .. } => {
663 writer.write_field(format!("gamlam.{}", outeq))?;
664 }
665 ErrorModel::None { .. } => {}
666 }
667 Ok(())
668 },
669 )?;
670 }
671
672 let parameter_names = settings.parameters().names();
673 for param_name in ¶meter_names {
674 writer.write_field(format!("{}.mean", param_name))?;
675 writer.write_field(format!("{}.median", param_name))?;
676 writer.write_field(format!("{}.sd", param_name))?;
677 }
678
679 writer.write_record(None::<&[u8]>)?;
680
681 for cycle in &self.cycles {
682 writer.write_field(format!("{}", cycle.cycle))?;
683 writer.write_field(format!("{}", cycle.status == Status::Converged))?;
684 writer.write_field(format!("{}", cycle.status))?;
685 writer.write_field(format!("{}", cycle.objf))?;
686 writer
687 .write_field(format!("{}", cycle.theta.nspp()))
688 .unwrap();
689
690 cycle.error_models.iter().try_for_each(
692 |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
693 match errmod {
694 ErrorModel::Additive { .. } => {
695 writer.write_field(format!("{:.5}", errmod.scalar()?))?;
696 }
697 ErrorModel::Proportional { .. } => {
698 writer.write_field(format!("{:.5}", errmod.scalar()?))?;
699 }
700 ErrorModel::None { .. } => {}
701 }
702 Ok(())
703 },
704 )?;
705
706 for param in cycle.theta.matrix().col_iter() {
707 let param_values: Vec<f64> = param.iter().cloned().collect();
708
709 let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
710 let median = median(param_values.clone());
711 let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
712 / (param_values.len() as f64 - 1.0);
713
714 writer.write_field(format!("{}", mean))?;
715 writer.write_field(format!("{}", median))?;
716 writer.write_field(format!("{}", std))?;
717 }
718 writer.write_record(None::<&[u8]>)?;
719 }
720 writer.flush()?;
721 tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path());
722 Ok(())
723 }
724}
725
726impl Default for CycleLog {
727 fn default() -> Self {
728 Self::new()
729 }
730}
731
732pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
734 if psi.matrix().ncols() != w.nrows() {
735 bail!(
736 "Number of rows in psi ({}) and number of weights ({}) do not match.",
737 psi.matrix().nrows(),
738 w.nrows()
739 );
740 }
741
742 let psi_matrix = psi.matrix();
743 let py = psi_matrix * w;
744
745 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
746 psi_matrix.get(i, j) * w.get(j) / py.get(i)
747 });
748
749 Ok(posterior)
750}
751
752pub fn median(data: Vec<f64>) -> f64 {
753 let mut data = data.clone();
754 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
755
756 let size = data.len();
757 match size {
758 even if even % 2 == 0 => {
759 let fst = data.get(even / 2 - 1).unwrap();
760 let snd = data.get(even / 2).unwrap();
761 (fst + snd) / 2.0
762 }
763 odd => *data.get(odd / 2_usize).unwrap(),
764 }
765}
766
767fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
768 assert_eq!(
770 data.len(),
771 weights.len(),
772 "The length of data and weights must be the same"
773 );
774 assert!(
775 weights.iter().all(|&x| x >= 0.0),
776 "Weights must be non-negative, weights: {:?}",
777 weights
778 );
779
780 let mut weighted_data: Vec<(f64, f64)> = data
782 .iter()
783 .zip(weights.iter())
784 .map(|(&d, &w)| (d, w))
785 .collect();
786
787 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
789
790 let total_weight: f64 = weights.sum();
792 let mut cumulative_sum = 0.0;
793
794 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
795 cumulative_sum += weight;
796
797 if cumulative_sum == total_weight / 2.0 {
798 if i + 1 < weighted_data.len() {
800 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
801 } else {
802 return weighted_data[i].0;
803 }
804 } else if cumulative_sum > total_weight / 2.0 {
805 return weighted_data[i].0;
806 }
807 }
808
809 unreachable!("The function should have returned a value before reaching this point.");
810}
811
812pub fn population_mean_median(
813 theta: &Array2<f64>,
814 w: &Array1<f64>,
815) -> Result<(Array1<f64>, Array1<f64>)> {
816 let w = if w.is_empty() {
817 tracing::warn!("w.len() == 0, setting all weights to 1/n");
818 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
819 } else {
820 w.clone()
821 };
822 if theta.nrows() != w.len() {
824 bail!(
825 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
826 theta.nrows(),
827 w.len()
828 );
829 }
830
831 let mut mean = Array1::zeros(theta.ncols());
832 let mut median = Array1::zeros(theta.ncols());
833
834 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
835 let col = theta.column(i).to_owned() * w.to_owned();
837 *mn = col.sum();
838
839 let ct = theta.column(i);
841 let mut params = vec![];
842 let mut weights = vec![];
843 for (ti, wi) in ct.iter().zip(w.clone()) {
844 params.push(*ti);
845 weights.push(wi);
846 }
847
848 *mdn = weighted_median(&Array::from(params), &Array::from(weights));
849 }
850
851 Ok((mean, median))
852}
853
854pub fn posterior_mean_median(
855 theta: &Array2<f64>,
856 psi: &Array2<f64>,
857 w: &Array1<f64>,
858) -> Result<(Array2<f64>, Array2<f64>)> {
859 let mut mean = Array2::zeros((0, theta.ncols()));
860 let mut median = Array2::zeros((0, theta.ncols()));
861
862 let w = if w.is_empty() {
863 tracing::warn!("w is empty, setting all weights to 1/n");
864 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
865 } else {
866 w.clone()
867 };
868
869 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
871 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
872 }
873
874 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
876 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
877 let row_w = row.to_owned() * w.to_owned();
878 let row_sum = row_w.sum();
879 let row_norm = if row_sum == 0.0 {
880 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
881 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
882 } else {
883 &row_w / row_sum
884 };
885 psi_norm.push_row(row_norm.view())?;
886 }
887 if psi_norm.iter().any(|&x| x.is_nan()) {
888 dbg!(&psi);
889 bail!("NaN values found in psi_norm");
890 };
891
892 for probs in psi_norm.axis_iter(Axis(0)) {
897 let mut post_mean: Vec<f64> = Vec::new();
898 let mut post_median: Vec<f64> = Vec::new();
899
900 for pars in theta.axis_iter(Axis(1)) {
902 let weighted_par = &probs * &pars;
904 let the_mean = weighted_par.sum();
905 post_mean.push(the_mean);
906
907 let median = weighted_median(&pars.to_owned(), &probs.to_owned());
909 post_median.push(median);
910 }
911
912 mean.push_row(Array::from(post_mean.clone()).view())?;
913 median.push_row(Array::from(post_median.clone()).view())?;
914 }
915
916 Ok((mean, median))
917}
918
919#[derive(Debug)]
921pub struct OutputFile {
922 pub file: File,
923 pub relative_path: PathBuf,
924}
925
926impl OutputFile {
927 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
928 let relative_path = Path::new(&folder).join(file_name);
929
930 if let Some(parent) = relative_path.parent() {
931 create_dir_all(parent)
932 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
933 }
934
935 let file = OpenOptions::new()
936 .write(true)
937 .create(true)
938 .truncate(true)
939 .open(&relative_path)
940 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
941
942 Ok(OutputFile {
943 file,
944 relative_path,
945 })
946 }
947
948 pub fn get_relative_path(&self) -> &Path {
949 &self.relative_path
950 }
951}
952
953#[cfg(test)]
954mod tests {
955 use super::median;
956
957 #[test]
958 fn test_median_odd() {
959 let data = vec![1.0, 3.0, 2.0];
960 assert_eq!(median(data), 2.0);
961 }
962
963 #[test]
964 fn test_median_even() {
965 let data = vec![1.0, 2.0, 3.0, 4.0];
966 assert_eq!(median(data), 2.5);
967 }
968
969 #[test]
970 fn test_median_single() {
971 let data = vec![42.0];
972 assert_eq!(median(data), 42.0);
973 }
974
975 #[test]
976 fn test_median_sorted() {
977 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
978 assert_eq!(median(data), 15.0);
979 }
980
981 #[test]
982 fn test_median_unsorted() {
983 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
984 assert_eq!(median(data), 30.0);
985 }
986
987 #[test]
988 fn test_median_with_duplicates() {
989 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
990 assert_eq!(median(data), 2.0);
991 }
992
993 use super::weighted_median;
994 use ndarray::Array1;
995
996 #[test]
997 fn test_weighted_median_simple() {
998 let data = Array1::from(vec![1.0, 2.0, 3.0]);
999 let weights = Array1::from(vec![0.2, 0.5, 0.3]);
1000 assert_eq!(weighted_median(&data, &weights), 2.0);
1001 }
1002
1003 #[test]
1004 fn test_weighted_median_even_weights() {
1005 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1006 let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
1007 assert_eq!(weighted_median(&data, &weights), 2.5);
1008 }
1009
1010 #[test]
1011 fn test_weighted_median_single_element() {
1012 let data = Array1::from(vec![42.0]);
1013 let weights = Array1::from(vec![1.0]);
1014 assert_eq!(weighted_median(&data, &weights), 42.0);
1015 }
1016
1017 #[test]
1018 #[should_panic(expected = "The length of data and weights must be the same")]
1019 fn test_weighted_median_mismatched_lengths() {
1020 let data = Array1::from(vec![1.0, 2.0, 3.0]);
1021 let weights = Array1::from(vec![0.1, 0.2]);
1022 weighted_median(&data, &weights);
1023 }
1024
1025 #[test]
1026 fn test_weighted_median_all_same_elements() {
1027 let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
1028 let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
1029 assert_eq!(weighted_median(&data, &weights), 5.0);
1030 }
1031
1032 #[test]
1033 #[should_panic(expected = "Weights must be non-negative")]
1034 fn test_weighted_median_negative_weights() {
1035 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1036 let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1037 assert_eq!(weighted_median(&data, &weights), 4.0);
1038 }
1039
1040 #[test]
1041 fn test_weighted_median_unsorted_data() {
1042 let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1043 let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1044 assert_eq!(weighted_median(&data, &weights), 2.5);
1045 }
1046
1047 #[test]
1048 fn test_weighted_median_with_zero_weights() {
1049 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1050 let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1051 assert_eq!(weighted_median(&data, &weights), 3.0);
1052 }
1053}