1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::{Equation, Prediction};
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22 equation: E,
23 data: Data,
24 theta: Theta,
25 psi: Psi,
26 w: Col<f64>,
27 objf: f64,
28 cycles: usize,
29 status: Status,
30 par_names: Vec<String>,
31 settings: Settings,
32 cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37 pub fn new(
39 equation: E,
40 data: Data,
41 theta: Theta,
42 psi: Psi,
43 w: Col<f64>,
44 objf: f64,
45 cycles: usize,
46 status: Status,
47 settings: Settings,
48 cyclelog: CycleLog,
49 ) -> Self {
50 let par_names = settings.parameters().names();
53
54 Self {
55 equation,
56 data,
57 theta,
58 psi,
59 w,
60 objf,
61 cycles,
62 status,
63 par_names,
64 settings,
65 cyclelog,
66 }
67 }
68
69 pub fn cycles(&self) -> usize {
70 self.cycles
71 }
72
73 pub fn objf(&self) -> f64 {
74 self.objf
75 }
76
77 pub fn converged(&self) -> bool {
78 self.status == Status::Converged
79 }
80
81 pub fn get_theta(&self) -> &Theta {
82 &self.theta
83 }
84
85 pub fn psi(&self) -> &Psi {
87 &self.psi
88 }
89
90 pub fn w(&self) -> &Col<f64> {
92 &self.w
93 }
94
95 pub fn write_outputs(&self) -> Result<()> {
96 if self.settings.output().write {
97 tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
98 self.settings.write()?;
99 let idelta: f64 = self.settings.predictions().idelta;
100 let tad = self.settings.predictions().tad;
101 self.cyclelog.write(&self.settings)?;
102 self.write_obs().context("Failed to write observations")?;
103 self.write_theta().context("Failed to write theta")?;
104 self.write_obspred()
105 .context("Failed to write observed-predicted file")?;
106 self.write_pred(idelta, tad)
107 .context("Failed to write predictions")?;
108 self.write_covs().context("Failed to write covariates")?;
109 self.write_posterior()
110 .context("Failed to write posterior")?;
111 }
112 Ok(())
113 }
114
115 pub fn write_obspred(&self) -> Result<()> {
117 tracing::debug!("Writing observations and predictions...");
118
119 #[derive(Debug, Clone, Serialize)]
120 struct Row {
121 id: String,
122 time: f64,
123 outeq: usize,
124 block: usize,
125 obs: Option<f64>,
126 pop_mean: f64,
127 pop_median: f64,
128 post_mean: f64,
129 post_median: f64,
130 }
131
132 let theta: Array2<f64> = self
133 .theta
134 .matrix()
135 .clone()
136 .as_mut()
137 .into_ndarray()
138 .to_owned();
139 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
140 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
141
142 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
143 .context("Failed to calculate posterior mean and median")?;
144
145 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
146 .context("Failed to calculate posterior mean and median")?;
147
148 let subjects = self.data.subjects();
149 if subjects.len() != post_mean.nrows() {
150 bail!(
151 "Number of subjects: {} and number of posterior means: {} do not match",
152 subjects.len(),
153 post_mean.nrows()
154 );
155 }
156
157 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
158 let mut writer = WriterBuilder::new()
159 .has_headers(true)
160 .from_writer(&outputfile.file);
161
162 for (i, subject) in subjects.iter().enumerate() {
163 for occasion in subject.occasions() {
164 let id = subject.id();
165 let occ = occasion.index();
166
167 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
168
169 let pop_mean_pred = self
171 .equation
172 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
173 .0
174 .get_predictions()
175 .clone();
176
177 let pop_median_pred = self
178 .equation
179 .simulate_subject(&subject, &pop_median.to_vec(), None)?
180 .0
181 .get_predictions()
182 .clone();
183
184 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
186 let post_mean_pred = self
187 .equation
188 .simulate_subject(&subject, &post_mean_spp, None)?
189 .0
190 .get_predictions()
191 .clone();
192 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
193 let post_median_pred = self
194 .equation
195 .simulate_subject(&subject, &post_median_spp, None)?
196 .0
197 .get_predictions()
198 .clone();
199 assert_eq!(
200 pop_mean_pred.len(),
201 pop_median_pred.len(),
202 "The number of predictions do not match (pop_mean vs pop_median)"
203 );
204
205 assert_eq!(
206 post_mean_pred.len(),
207 post_median_pred.len(),
208 "The number of predictions do not match (post_mean vs post_median)"
209 );
210
211 assert_eq!(
212 pop_mean_pred.len(),
213 post_mean_pred.len(),
214 "The number of predictions do not match (pop_mean vs post_mean)"
215 );
216
217 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
218 pop_mean_pred
219 .iter()
220 .zip(pop_median_pred.iter())
221 .zip(post_mean_pred.iter())
222 .zip(post_median_pred.iter())
223 {
224 let row = Row {
225 id: id.clone(),
226 time: pop_mean_pred.time(),
227 outeq: pop_mean_pred.outeq(),
228 block: occ,
229 obs: pop_mean_pred.observation(),
230 pop_mean: pop_mean_pred.prediction(),
231 pop_median: pop_median_pred.prediction(),
232 post_mean: post_mean_pred.prediction(),
233 post_median: post_median_pred.prediction(),
234 };
235 writer.serialize(row)?;
236 }
237 }
238 }
239 writer.flush()?;
240 tracing::debug!(
241 "Observations with predictions written to {:?}",
242 &outputfile.get_relative_path()
243 );
244 Ok(())
245 }
246
247 pub fn write_theta(&self) -> Result<()> {
250 tracing::debug!("Writing population parameter distribution...");
251
252 let theta = &self.theta;
253 let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
254
255 if w.len() != theta.matrix().nrows() {
256 bail!(
257 "Number of weights ({}) and number of support points ({}) do not match.",
258 w.len(),
259 theta.matrix().nrows()
260 );
261 }
262
263 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
264 .context("Failed to create output file for theta")?;
265
266 let mut writer = WriterBuilder::new()
267 .has_headers(true)
268 .from_writer(&outputfile.file);
269
270 let mut theta_header = self.par_names.clone();
272 theta_header.push("prob".to_string());
273 writer.write_record(&theta_header)?;
274
275 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
277 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
278 row.push(w_val.to_string());
279 writer.write_record(&row)?;
280 }
281 writer.flush()?;
282 tracing::debug!(
283 "Population parameter distribution written to {:?}",
284 &outputfile.get_relative_path()
285 );
286 Ok(())
287 }
288
289 pub fn write_posterior(&self) -> Result<()> {
291 tracing::debug!("Writing posterior parameter probabilities...");
292 let theta = &self.theta;
293 let w = &self.w;
294 let psi = &self.psi;
295
296 let posterior = posterior(psi, w)?;
298
299 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
301 Ok(of) => of,
302 Err(e) => {
303 tracing::error!("Failed to create output file: {}", e);
304 return Err(e.context("Failed to create output file"));
305 }
306 };
307
308 let mut writer = WriterBuilder::new()
310 .has_headers(true)
311 .from_writer(&outputfile.file);
312
313 writer.write_field("id")?;
315 writer.write_field("point")?;
316 theta.param_names().iter().for_each(|name| {
317 writer.write_field(name).unwrap();
318 });
319 writer.write_field("prob")?;
320 writer.write_record(None::<&[u8]>)?;
321
322 let subjects = self.data.subjects();
324 posterior.row_iter().enumerate().for_each(|(i, row)| {
325 let subject = subjects.get(i).unwrap();
326 let id = subject.id();
327
328 row.iter().enumerate().for_each(|(spp, prob)| {
329 writer.write_field(id.clone()).unwrap();
330 writer.write_field(spp.to_string()).unwrap();
331
332 theta.matrix().row(spp).iter().for_each(|val| {
333 writer.write_field(val.to_string()).unwrap();
334 });
335
336 writer.write_field(prob.to_string()).unwrap();
337 writer.write_record(None::<&[u8]>).unwrap();
338 });
339 });
340
341 writer.flush()?;
342 tracing::debug!(
343 "Posterior parameters written to {:?}",
344 &outputfile.get_relative_path()
345 );
346
347 Ok(())
348 }
349
350 pub fn write_obs(&self) -> Result<()> {
352 tracing::debug!("Writing observations...");
353 let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
354
355 let mut writer = WriterBuilder::new()
356 .has_headers(true)
357 .from_writer(&outputfile.file);
358
359 #[derive(Serialize)]
360 struct Row {
361 id: String,
362 block: usize,
363 time: f64,
364 out: Option<f64>,
365 outeq: usize,
366 }
367
368 for subject in self.data.subjects() {
369 for occasion in subject.occasions() {
370 for event in occasion.iter() {
371 if let Event::Observation(event) = event {
372 let row = Row {
373 id: subject.id().clone(),
374 block: occasion.index(),
375 time: event.time(),
376 out: event.value(),
377 outeq: event.outeq(),
378 };
379 writer.serialize(row)?;
380 }
381 }
382 }
383 }
384 writer.flush()?;
385
386 tracing::debug!(
387 "Observations written to {:?}",
388 &outputfile.get_relative_path()
389 );
390 Ok(())
391 }
392
393 pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
395 tracing::debug!("Writing predictions...");
396
397 let theta = self.theta.matrix();
399 let w: Vec<f64> = self.w.iter().cloned().collect();
400 let posterior = posterior(&self.psi, &self.w)?;
401
402 let data = self.data.clone().expand(idelta, tad);
403
404 let subjects = data.subjects();
405
406 if subjects.len() != posterior.nrows() {
407 bail!("Number of subjects and number of posterior means do not match");
408 };
409
410 let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
412 let mut writer = WriterBuilder::new()
413 .has_headers(true)
414 .from_writer(&outputfile.file);
415
416 for subject in subjects.iter().enumerate() {
418 let (subject_index, subject) = subject;
419
420 let occasions = subject
422 .occasions()
423 .iter()
424 .flat_map(|o| {
425 o.events()
426 .iter()
427 .filter_map(|e| {
428 if let Event::Observation(_obs) = e {
429 Some(o.index())
430 } else {
431 None
432 }
433 })
434 .collect::<Vec<_>>()
435 })
436 .collect::<Vec<usize>>();
437
438 let mut predictions: Vec<Vec<Prediction>> = Vec::new();
443
444 for spp in theta.row_iter() {
446 let spp_values = spp.iter().cloned().collect::<Vec<f64>>();
448 let pred = self
449 .equation
450 .simulate_subject(subject, &spp_values, None)?
451 .0
452 .get_predictions();
453 predictions.push(pred);
454 }
455
456 if predictions.is_empty() {
457 continue; }
459
460 let mut pop_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
462 for outer_pred in predictions.iter().enumerate() {
463 let (i, outer_pred) = outer_pred;
464 for inner_pred in outer_pred.iter().enumerate() {
465 let (j, pred) = inner_pred;
466 pop_mean[j] += pred.prediction() * w[i];
467 }
468 }
469
470 let mut pop_median: Vec<f64> = Vec::new();
472 for j in 0..predictions.first().unwrap().len() {
473 let mut values: Vec<f64> = Vec::new();
474 let mut weights: Vec<f64> = Vec::new();
475
476 for (i, outer_pred) in predictions.iter().enumerate() {
477 values.push(outer_pred[j].prediction());
478 weights.push(w[i]);
479 }
480
481 let median_val = weighted_median(&values, &weights);
482 pop_median.push(median_val);
483 }
484
485 let mut posterior_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
487 for outer_pred in predictions.iter().enumerate() {
488 let (i, outer_pred) = outer_pred;
489 for inner_pred in outer_pred.iter().enumerate() {
490 let (j, pred) = inner_pred;
491 posterior_mean[j] += pred.prediction() * posterior[(subject_index, i)];
492 }
493 }
494
495 let mut posterior_median: Vec<f64> = Vec::new();
497 for j in 0..predictions.first().unwrap().len() {
498 let mut values: Vec<f64> = Vec::new();
499 let mut weights: Vec<f64> = Vec::new();
500
501 for (i, outer_pred) in predictions.iter().enumerate() {
502 values.push(outer_pred[j].prediction());
503 weights.push(posterior[(subject_index, i)]);
504 }
505
506 let median_val = weighted_median(&values, &weights);
507 posterior_median.push(median_val);
508 }
509
510 #[derive(Debug, Clone, Serialize)]
512 struct Row {
513 id: String,
514 time: f64,
515 outeq: usize,
516 block: usize,
517 pop_mean: f64,
518 pop_median: f64,
519 post_mean: f64,
520 post_median: f64,
521 }
522
523 for pred in predictions.iter().enumerate() {
524 let (_, preds) = pred;
525 for (j, p) in preds.iter().enumerate() {
526 let row = Row {
527 id: subject.id().clone(),
528 time: p.time(),
529 outeq: p.outeq(),
530 block: occasions[j],
531 pop_mean: pop_mean[j],
532 pop_median: pop_median[j],
533 post_mean: posterior_mean[j],
534 post_median: posterior_median[j],
535 };
536 writer.serialize(row)?;
537 }
538 }
539 }
540
541 writer.flush()?;
542 tracing::debug!(
543 "Predictions written to {:?}",
544 &outputfile.get_relative_path()
545 );
546
547 Ok(())
548 }
549
550 pub fn write_covs(&self) -> Result<()> {
552 tracing::debug!("Writing covariates...");
553 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
554 let mut writer = WriterBuilder::new()
555 .has_headers(true)
556 .from_writer(&outputfile.file);
557
558 let mut covariate_names = std::collections::HashSet::new();
560 for subject in self.data.subjects() {
561 for occasion in subject.occasions() {
562 let cov = occasion.covariates();
563 let covmap = cov.covariates();
564 for cov_name in covmap.keys() {
565 covariate_names.insert(cov_name.clone());
566 }
567 }
568 }
569 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
570 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
574 headers.extend(covariate_names.iter().map(|s| s.as_str()));
575 writer.write_record(&headers)?;
576
577 for subject in self.data.subjects() {
579 for occasion in subject.occasions() {
580 let cov = occasion.covariates();
581 let covmap = cov.covariates();
582
583 for event in occasion.iter() {
584 let time = match event {
585 Event::Bolus(bolus) => bolus.time(),
586 Event::Infusion(infusion) => infusion.time(),
587 Event::Observation(observation) => observation.time(),
588 };
589
590 let mut row: Vec<String> = Vec::new();
591 row.push(subject.id().clone());
592 row.push(time.to_string());
593 row.push(occasion.index().to_string());
594
595 for cov_name in &covariate_names {
597 if let Some(cov) = covmap.get(cov_name) {
598 if let Ok(value) = cov.interpolate(time) {
599 row.push(value.to_string());
600 } else {
601 row.push(String::new());
602 }
603 } else {
604 row.push(String::new());
605 }
606 }
607
608 writer.write_record(&row)?;
609 }
610 }
611 }
612
613 writer.flush()?;
614 tracing::debug!(
615 "Covariates written to {:?}",
616 &outputfile.get_relative_path()
617 );
618 Ok(())
619 }
620}
621
622#[derive(Debug, Clone)]
632pub struct NPCycle {
633 pub cycle: usize,
634 pub objf: f64,
635 pub error_models: ErrorModels,
636 pub theta: Theta,
637 pub nspp: usize,
638 pub delta_objf: f64,
639 pub status: Status,
640}
641
642impl NPCycle {
643 pub fn new(
644 cycle: usize,
645 objf: f64,
646 error_models: ErrorModels,
647 theta: Theta,
648 nspp: usize,
649 delta_objf: f64,
650 status: Status,
651 ) -> Self {
652 Self {
653 cycle,
654 objf,
655 error_models,
656 theta,
657 nspp,
658 delta_objf,
659 status,
660 }
661 }
662
663 pub fn placeholder() -> Self {
664 Self {
665 cycle: 0,
666 objf: 0.0,
667 error_models: ErrorModels::default(),
668 theta: Theta::new(),
669 nspp: 0,
670 delta_objf: 0.0,
671 status: Status::Starting,
672 }
673 }
674}
675
676#[derive(Debug, Clone)]
678pub struct CycleLog {
679 pub cycles: Vec<NPCycle>,
680}
681
682impl CycleLog {
683 pub fn new() -> Self {
684 Self { cycles: Vec::new() }
685 }
686
687 pub fn push(&mut self, cycle: NPCycle) {
688 self.cycles.push(cycle);
689 }
690
691 pub fn write(&self, settings: &Settings) -> Result<()> {
692 tracing::debug!("Writing cycles...");
693 let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
694 let mut writer = WriterBuilder::new()
695 .has_headers(false)
696 .from_writer(&outputfile.file);
697
698 writer.write_field("cycle")?;
700 writer.write_field("converged")?;
701 writer.write_field("status")?;
702 writer.write_field("neg2ll")?;
703 writer.write_field("nspp")?;
704 if let Some(first_cycle) = self.cycles.first() {
705 first_cycle.error_models.iter().try_for_each(
706 |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
707 match errmod {
708 ErrorModel::Additive { .. } => {
709 writer.write_field(format!("gamlam.{}", outeq))?;
710 }
711 ErrorModel::Proportional { .. } => {
712 writer.write_field(format!("gamlam.{}", outeq))?;
713 }
714 ErrorModel::None => {}
715 }
716 Ok(())
717 },
718 )?;
719 }
720
721 let parameter_names = settings.parameters().names();
722 for param_name in ¶meter_names {
723 writer.write_field(format!("{}.mean", param_name))?;
724 writer.write_field(format!("{}.median", param_name))?;
725 writer.write_field(format!("{}.sd", param_name))?;
726 }
727
728 writer.write_record(None::<&[u8]>)?;
729
730 for cycle in &self.cycles {
731 writer.write_field(format!("{}", cycle.cycle))?;
732 writer.write_field(format!("{}", cycle.status == Status::Converged))?;
733 writer.write_field(format!("{}", cycle.status))?;
734 writer.write_field(format!("{}", cycle.objf))?;
735 writer
736 .write_field(format!("{}", cycle.theta.nspp()))
737 .unwrap();
738
739 cycle.error_models.iter().try_for_each(
741 |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
742 match errmod {
743 ErrorModel::Additive {
744 lambda: _,
745 poly: _,
746 lloq: _,
747 } => {
748 writer.write_field(format!("{:.5}", errmod.factor()?))?;
749 }
750 ErrorModel::Proportional {
751 gamma: _,
752 poly: _,
753 lloq: _,
754 } => {
755 writer.write_field(format!("{:.5}", errmod.factor()?))?;
756 }
757 ErrorModel::None => {}
758 }
759 Ok(())
760 },
761 )?;
762
763 for param in cycle.theta.matrix().col_iter() {
764 let param_values: Vec<f64> = param.iter().cloned().collect();
765
766 let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
767 let median = median(param_values.clone());
768 let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
769 / (param_values.len() as f64 - 1.0);
770
771 writer.write_field(format!("{}", mean))?;
772 writer.write_field(format!("{}", median))?;
773 writer.write_field(format!("{}", std))?;
774 }
775 writer.write_record(None::<&[u8]>)?;
776 }
777 writer.flush()?;
778 tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path());
779 Ok(())
780 }
781}
782
783impl Default for CycleLog {
784 fn default() -> Self {
785 Self::new()
786 }
787}
788
789pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
793 if psi.matrix().ncols() != w.nrows() {
794 bail!(
795 "Number of rows in psi ({}) and number of weights ({}) do not match.",
796 psi.matrix().nrows(),
797 w.nrows()
798 );
799 }
800
801 let psi_matrix = psi.matrix();
802 let py = psi_matrix * w;
803
804 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
805 psi_matrix.get(i, j) * w.get(j) / py.get(i)
806 });
807
808 Ok(posterior)
809}
810
811pub fn median(data: Vec<f64>) -> f64 {
812 let mut data = data.clone();
813 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
814
815 let size = data.len();
816 match size {
817 even if even % 2 == 0 => {
818 let fst = data.get(even / 2 - 1).unwrap();
819 let snd = data.get(even / 2).unwrap();
820 (fst + snd) / 2.0
821 }
822 odd => *data.get(odd / 2_usize).unwrap(),
823 }
824}
825
826fn weighted_median(data: &Vec<f64>, weights: &Vec<f64>) -> f64 {
827 assert_eq!(
829 data.len(),
830 weights.len(),
831 "The length of data and weights must be the same"
832 );
833 assert!(
834 weights.iter().all(|&x| x >= 0.0),
835 "Weights must be non-negative, weights: {:?}",
836 weights
837 );
838
839 let mut weighted_data: Vec<(f64, f64)> = data
841 .iter()
842 .zip(weights.iter())
843 .map(|(&d, &w)| (d, w))
844 .collect();
845
846 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
848
849 let total_weight: f64 = weights.iter().sum();
851 let mut cumulative_sum = 0.0;
852
853 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
854 cumulative_sum += weight;
855
856 if cumulative_sum == total_weight / 2.0 {
857 if i + 1 < weighted_data.len() {
859 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
860 } else {
861 return weighted_data[i].0;
862 }
863 } else if cumulative_sum > total_weight / 2.0 {
864 return weighted_data[i].0;
865 }
866 }
867
868 unreachable!("The function should have returned a value before reaching this point.");
869}
870
871pub fn population_mean_median(
872 theta: &Array2<f64>,
873 w: &Array1<f64>,
874) -> Result<(Array1<f64>, Array1<f64>)> {
875 let w = if w.is_empty() {
876 tracing::warn!("w.len() == 0, setting all weights to 1/n");
877 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
878 } else {
879 w.clone()
880 };
881 if theta.nrows() != w.len() {
883 bail!(
884 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
885 theta.nrows(),
886 w.len()
887 );
888 }
889
890 let mut mean = Array1::zeros(theta.ncols());
891 let mut median = Array1::zeros(theta.ncols());
892
893 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
894 let col = theta.column(i).to_owned() * w.to_owned();
896 *mn = col.sum();
897
898 let ct = theta.column(i);
900 let mut params = vec![];
901 let mut weights = vec![];
902 for (ti, wi) in ct.iter().zip(w.clone()) {
903 params.push(*ti);
904 weights.push(wi);
905 }
906
907 *mdn = weighted_median(¶ms, &weights);
908 }
909
910 Ok((mean, median))
911}
912
913pub fn posterior_mean_median(
914 theta: &Array2<f64>,
915 psi: &Array2<f64>,
916 w: &Array1<f64>,
917) -> Result<(Array2<f64>, Array2<f64>)> {
918 let mut mean = Array2::zeros((0, theta.ncols()));
919 let mut median = Array2::zeros((0, theta.ncols()));
920
921 let w = if w.is_empty() {
922 tracing::warn!("w is empty, setting all weights to 1/n");
923 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
924 } else {
925 w.clone()
926 };
927
928 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
930 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
931 }
932
933 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
935 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
936 let row_w = row.to_owned() * w.to_owned();
937 let row_sum = row_w.sum();
938 let row_norm = if row_sum == 0.0 {
939 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
940 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
941 } else {
942 &row_w / row_sum
943 };
944 psi_norm.push_row(row_norm.view())?;
945 }
946 if psi_norm.iter().any(|&x| x.is_nan()) {
947 dbg!(&psi);
948 bail!("NaN values found in psi_norm");
949 };
950
951 for probs in psi_norm.axis_iter(Axis(0)) {
956 let mut post_mean: Vec<f64> = Vec::new();
957 let mut post_median: Vec<f64> = Vec::new();
958
959 for pars in theta.axis_iter(Axis(1)) {
961 let weighted_par = &probs * &pars;
963 let the_mean = weighted_par.sum();
964 post_mean.push(the_mean);
965
966 let median = weighted_median(&pars.to_vec(), &probs.to_vec());
968 post_median.push(median);
969 }
970
971 mean.push_row(Array::from(post_mean.clone()).view())?;
972 median.push_row(Array::from(post_median.clone()).view())?;
973 }
974
975 Ok((mean, median))
976}
977
978#[derive(Debug)]
980pub struct OutputFile {
981 pub file: File,
982 pub relative_path: PathBuf,
983}
984
985impl OutputFile {
986 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
987 let relative_path = Path::new(&folder).join(file_name);
988
989 if let Some(parent) = relative_path.parent() {
990 create_dir_all(parent)
991 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
992 }
993
994 let file = OpenOptions::new()
995 .write(true)
996 .create(true)
997 .truncate(true)
998 .open(&relative_path)
999 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
1000
1001 Ok(OutputFile {
1002 file,
1003 relative_path,
1004 })
1005 }
1006
1007 pub fn get_relative_path(&self) -> &Path {
1008 &self.relative_path
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::median;
1015
1016 #[test]
1017 fn test_median_odd() {
1018 let data = vec![1.0, 3.0, 2.0];
1019 assert_eq!(median(data), 2.0);
1020 }
1021
1022 #[test]
1023 fn test_median_even() {
1024 let data = vec![1.0, 2.0, 3.0, 4.0];
1025 assert_eq!(median(data), 2.5);
1026 }
1027
1028 #[test]
1029 fn test_median_single() {
1030 let data = vec![42.0];
1031 assert_eq!(median(data), 42.0);
1032 }
1033
1034 #[test]
1035 fn test_median_sorted() {
1036 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1037 assert_eq!(median(data), 15.0);
1038 }
1039
1040 #[test]
1041 fn test_median_unsorted() {
1042 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
1043 assert_eq!(median(data), 30.0);
1044 }
1045
1046 #[test]
1047 fn test_median_with_duplicates() {
1048 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
1049 assert_eq!(median(data), 2.0);
1050 }
1051
1052 use super::weighted_median;
1053
1054 #[test]
1055 fn test_weighted_median_simple() {
1056 let data = vec![1.0, 2.0, 3.0];
1057 let weights = vec![0.2, 0.5, 0.3];
1058 assert_eq!(weighted_median(&data, &weights), 2.0);
1059 }
1060
1061 #[test]
1062 fn test_weighted_median_even_weights() {
1063 let data = vec![1.0, 2.0, 3.0, 4.0];
1064 let weights = vec![0.25, 0.25, 0.25, 0.25];
1065 assert_eq!(weighted_median(&data, &weights), 2.5);
1066 }
1067
1068 #[test]
1069 fn test_weighted_median_single_element() {
1070 let data = vec![42.0];
1071 let weights = vec![1.0];
1072 assert_eq!(weighted_median(&data, &weights), 42.0);
1073 }
1074
1075 #[test]
1076 #[should_panic(expected = "The length of data and weights must be the same")]
1077 fn test_weighted_median_mismatched_lengths() {
1078 let data = vec![1.0, 2.0, 3.0];
1079 let weights = vec![0.1, 0.2];
1080 weighted_median(&data, &weights);
1081 }
1082
1083 #[test]
1084 fn test_weighted_median_all_same_elements() {
1085 let data = vec![5.0, 5.0, 5.0, 5.0];
1086 let weights = vec![0.1, 0.2, 0.3, 0.4];
1087 assert_eq!(weighted_median(&data, &weights), 5.0);
1088 }
1089
1090 #[test]
1091 #[should_panic(expected = "Weights must be non-negative")]
1092 fn test_weighted_median_negative_weights() {
1093 let data = vec![1.0, 2.0, 3.0, 4.0];
1094 let weights = vec![0.2, -0.5, 0.5, 0.8];
1095 assert_eq!(weighted_median(&data, &weights), 4.0);
1096 }
1097
1098 #[test]
1099 fn test_weighted_median_unsorted_data() {
1100 let data = vec![3.0, 1.0, 4.0, 2.0];
1101 let weights = vec![0.1, 0.3, 0.4, 0.2];
1102 assert_eq!(weighted_median(&data, &weights), 2.5);
1103 }
1104
1105 #[test]
1106 fn test_weighted_median_with_zero_weights() {
1107 let data = vec![1.0, 2.0, 3.0, 4.0];
1108 let weights = vec![0.0, 0.0, 1.0, 0.0];
1109 assert_eq!(weighted_median(&data, &weights), 3.0);
1110 }
1111}