1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::Equation;
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22 equation: E,
23 data: Data,
24 theta: Theta,
25 psi: Psi,
26 w: Col<f64>,
27 objf: f64,
28 cycles: usize,
29 status: Status,
30 par_names: Vec<String>,
31 settings: Settings,
32 cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37 pub fn new(
39 equation: E,
40 data: Data,
41 theta: Theta,
42 psi: Psi,
43 w: Col<f64>,
44 objf: f64,
45 cycles: usize,
46 status: Status,
47 settings: Settings,
48 cyclelog: CycleLog,
49 ) -> Self {
50 let par_names = settings.parameters().names();
53
54 Self {
55 equation,
56 data,
57 theta,
58 psi,
59 w,
60 objf,
61 cycles,
62 status,
63 par_names,
64 settings,
65 cyclelog,
66 }
67 }
68
69 pub fn cycles(&self) -> usize {
70 self.cycles
71 }
72
73 pub fn objf(&self) -> f64 {
74 self.objf
75 }
76
77 pub fn converged(&self) -> bool {
78 self.status == Status::Converged
79 }
80
81 pub fn get_theta(&self) -> &Theta {
82 &self.theta
83 }
84
85 pub fn psi(&self) -> &Psi {
87 &self.psi
88 }
89
90 pub fn w(&self) -> &Col<f64> {
92 &self.w
93 }
94
95 pub fn write_outputs(&self) -> Result<()> {
96 if self.settings.output().write {
97 self.settings.write()?;
98 let idelta: f64 = self.settings.predictions().idelta;
99 let tad = self.settings.predictions().tad;
100 self.cyclelog.write(&self.settings)?;
101 self.write_obs().context("Failed to write observations")?;
102 self.write_theta().context("Failed to write theta")?;
103 self.write_obspred()
104 .context("Failed to write observed-predicted file")?;
105 self.write_pred(idelta, tad)
106 .context("Failed to write predictions")?;
107 self.write_covs().context("Failed to write covariates")?;
108 self.write_posterior()
109 .context("Failed to write posterior")?;
110 }
111 Ok(())
112 }
113
114 pub fn write_obspred(&self) -> Result<()> {
116 tracing::debug!("Writing observations and predictions...");
117
118 #[derive(Debug, Clone, Serialize)]
119 struct Row {
120 id: String,
121 time: f64,
122 outeq: usize,
123 block: usize,
124 obs: f64,
125 pop_mean: f64,
126 pop_median: f64,
127 post_mean: f64,
128 post_median: f64,
129 }
130
131 let theta: Array2<f64> = self
132 .theta
133 .matrix()
134 .clone()
135 .as_mut()
136 .into_ndarray()
137 .to_owned();
138 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
139 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
140
141 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
142 .context("Failed to calculate posterior mean and median")?;
143
144 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
145 .context("Failed to calculate posterior mean and median")?;
146
147 let subjects = self.data.get_subjects();
148 if subjects.len() != post_mean.nrows() {
149 bail!(
150 "Number of subjects: {} and number of posterior means: {} do not match",
151 subjects.len(),
152 post_mean.nrows()
153 );
154 }
155
156 let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
157 let mut writer = WriterBuilder::new()
158 .has_headers(true)
159 .from_writer(&outputfile.file);
160
161 for (i, subject) in subjects.iter().enumerate() {
162 for occasion in subject.occasions() {
163 let id = subject.id();
164 let occ = occasion.index();
165
166 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
167
168 let pop_mean_pred = self
170 .equation
171 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
172 .0
173 .get_predictions()
174 .clone();
175
176 let pop_median_pred = self
177 .equation
178 .simulate_subject(&subject, &pop_median.to_vec(), None)?
179 .0
180 .get_predictions()
181 .clone();
182
183 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
185 let post_mean_pred = self
186 .equation
187 .simulate_subject(&subject, &post_mean_spp, None)?
188 .0
189 .get_predictions()
190 .clone();
191 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
192 let post_median_pred = self
193 .equation
194 .simulate_subject(&subject, &post_median_spp, None)?
195 .0
196 .get_predictions()
197 .clone();
198 assert_eq!(
199 pop_mean_pred.len(),
200 pop_median_pred.len(),
201 "The number of predictions do not match (pop_mean vs pop_median)"
202 );
203
204 assert_eq!(
205 post_mean_pred.len(),
206 post_median_pred.len(),
207 "The number of predictions do not match (post_mean vs post_median)"
208 );
209
210 assert_eq!(
211 pop_mean_pred.len(),
212 post_mean_pred.len(),
213 "The number of predictions do not match (pop_mean vs post_mean)"
214 );
215
216 for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
217 pop_mean_pred
218 .iter()
219 .zip(pop_median_pred.iter())
220 .zip(post_mean_pred.iter())
221 .zip(post_median_pred.iter())
222 {
223 let row = Row {
224 id: id.clone(),
225 time: pop_mean_pred.time(),
226 outeq: pop_mean_pred.outeq(),
227 block: occ,
228 obs: pop_mean_pred.observation(),
229 pop_mean: pop_mean_pred.prediction(),
230 pop_median: pop_median_pred.prediction(),
231 post_mean: post_mean_pred.prediction(),
232 post_median: post_median_pred.prediction(),
233 };
234 writer.serialize(row)?;
235 }
236 }
237 }
238 writer.flush()?;
239 tracing::info!(
240 "Observations with predictions written to {:?}",
241 &outputfile.get_relative_path()
242 );
243 Ok(())
244 }
245
246 pub fn write_theta(&self) -> Result<()> {
249 tracing::debug!("Writing population parameter distribution...");
250
251 let theta = &self.theta;
252 let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
253 let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
261 .context("Failed to create output file for theta")?;
262
263 let mut writer = WriterBuilder::new()
264 .has_headers(true)
265 .from_writer(&outputfile.file);
266
267 let mut theta_header = self.par_names.clone();
269 theta_header.push("prob".to_string());
270 writer.write_record(&theta_header)?;
271
272 for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
274 let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
275 row.push(w_val.to_string());
276 writer.write_record(&row)?;
277 }
278 writer.flush()?;
279 tracing::info!(
280 "Population parameter distribution written to {:?}",
281 &outputfile.get_relative_path()
282 );
283 Ok(())
284 }
285
286 pub fn write_posterior(&self) -> Result<()> {
288 tracing::debug!("Writing posterior parameter probabilities...");
289 let theta = &self.theta;
290 let w = &self.w;
291 let psi = &self.psi;
292
293 let posterior = posterior(psi, w)?;
295
296 let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
298 Ok(of) => of,
299 Err(e) => {
300 tracing::error!("Failed to create output file: {}", e);
301 return Err(e.context("Failed to create output file"));
302 }
303 };
304
305 let mut writer = WriterBuilder::new()
307 .has_headers(true)
308 .from_writer(&outputfile.file);
309
310 writer.write_field("id")?;
312 writer.write_field("point")?;
313 theta.param_names().iter().for_each(|name| {
314 writer.write_field(name).unwrap();
315 });
316 writer.write_field("prob")?;
317 writer.write_record(None::<&[u8]>)?;
318
319 let subjects = self.data.get_subjects();
321 posterior.row_iter().enumerate().for_each(|(i, row)| {
322 let subject = subjects.get(i).unwrap();
323 let id = subject.id();
324
325 row.iter().enumerate().for_each(|(spp, prob)| {
326 writer.write_field(id.clone()).unwrap();
327 writer.write_field(i.to_string()).unwrap();
328
329 theta.matrix().row(spp).iter().for_each(|val| {
330 writer.write_field(val.to_string()).unwrap();
331 });
332
333 writer.write_field(prob.to_string()).unwrap();
334 writer.write_record(None::<&[u8]>).unwrap();
335 });
336 });
337
338 writer.flush()?;
339 tracing::info!(
340 "Posterior parameters written to {:?}",
341 &outputfile.get_relative_path()
342 );
343
344 Ok(())
345 }
346
347 pub fn write_obs(&self) -> Result<()> {
349 tracing::debug!("Writing observations...");
350 let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
351 write_pmetrics_observations(&self.data, &outputfile.file)?;
352 tracing::info!(
353 "Observations written to {:?}",
354 &outputfile.get_relative_path()
355 );
356 Ok(())
357 }
358
359 pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
361 tracing::debug!("Writing predictions...");
362 let data = self.data.expand(idelta, tad);
363
364 let theta: Array2<f64> = self
365 .theta
366 .matrix()
367 .clone()
368 .as_mut()
369 .into_ndarray()
370 .to_owned();
371 let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
372 let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
373
374 let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
375 .context("Failed to calculate posterior mean and median")?;
376
377 let (pop_mean, pop_median) = population_mean_median(&theta, &w)
378 .context("Failed to calculate population mean and median")?;
379
380 let subjects = data.get_subjects();
381 if subjects.len() != post_mean.nrows() {
382 bail!("Number of subjects and number of posterior means do not match");
383 }
384
385 let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
386 let mut writer = WriterBuilder::new()
387 .has_headers(true)
388 .from_writer(&outputfile.file);
389
390 #[derive(Debug, Clone, Serialize)]
391 struct Row {
392 id: String,
393 time: f64,
394 outeq: usize,
395 block: usize,
396 pop_mean: f64,
397 pop_median: f64,
398 post_mean: f64,
399 post_median: f64,
400 }
401
402 for (i, subject) in subjects.iter().enumerate() {
403 for occasion in subject.occasions() {
404 let id = subject.id();
405 let block = occasion.index();
406
407 let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
409
410 let pop_mean_pred = self
412 .equation
413 .simulate_subject(&subject, &pop_mean.to_vec(), None)?
414 .0
415 .get_predictions()
416 .clone();
417 let pop_median_pred = self
418 .equation
419 .simulate_subject(&subject, &pop_median.to_vec(), None)?
420 .0
421 .get_predictions()
422 .clone();
423
424 let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
426 let post_mean_pred = self
427 .equation
428 .simulate_subject(&subject, &post_mean_spp, None)?
429 .0
430 .get_predictions()
431 .clone();
432 let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
433 let post_median_pred = self
434 .equation
435 .simulate_subject(&subject, &post_median_spp, None)?
436 .0
437 .get_predictions()
438 .clone();
439
440 for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
442 .iter()
443 .zip(pop_median_pred.iter())
444 .zip(post_mean_pred.iter())
445 .zip(post_median_pred.iter())
446 {
447 let row = Row {
448 id: id.clone(),
449 time: pop_mean.time(),
450 outeq: pop_mean.outeq(),
451 block,
452 pop_mean: pop_mean.prediction(),
453 pop_median: pop_median.prediction(),
454 post_mean: post_mean.prediction(),
455 post_median: post_median.prediction(),
456 };
457 writer.serialize(row)?;
458 }
459 }
460 }
461 writer.flush()?;
462 tracing::info!(
463 "Predictions written to {:?}",
464 &outputfile.get_relative_path()
465 );
466 Ok(())
467 }
468
469 pub fn write_covs(&self) -> Result<()> {
471 tracing::debug!("Writing covariates...");
472 let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
473 let mut writer = WriterBuilder::new()
474 .has_headers(true)
475 .from_writer(&outputfile.file);
476
477 let mut covariate_names = std::collections::HashSet::new();
479 for subject in self.data.get_subjects() {
480 for occasion in subject.occasions() {
481 let cov = occasion.covariates();
482 let covmap = cov.covariates();
483 for cov_name in covmap.keys() {
484 covariate_names.insert(cov_name.clone());
485 }
486 }
487 }
488 let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
489 covariate_names.sort(); let mut headers = vec!["id", "time", "block"];
493 headers.extend(covariate_names.iter().map(|s| s.as_str()));
494 writer.write_record(&headers)?;
495
496 for subject in self.data.get_subjects() {
498 for occasion in subject.occasions() {
499 let cov = occasion.covariates();
500 let covmap = cov.covariates();
501
502 for event in occasion.get_events(&None, &None, false) {
503 let time = match event {
504 Event::Bolus(bolus) => bolus.time(),
505 Event::Infusion(infusion) => infusion.time(),
506 Event::Observation(observation) => observation.time(),
507 };
508
509 let mut row: Vec<String> = Vec::new();
510 row.push(subject.id().clone());
511 row.push(time.to_string());
512 row.push(occasion.index().to_string());
513
514 for cov_name in &covariate_names {
516 if let Some(cov) = covmap.get(cov_name) {
517 if let Some(value) = cov.interpolate(time) {
518 row.push(value.to_string());
519 } else {
520 row.push(String::new());
521 }
522 } else {
523 row.push(String::new());
524 }
525 }
526
527 writer.write_record(&row)?;
528 }
529 }
530 }
531
532 writer.flush()?;
533 tracing::info!(
534 "Covariates written to {:?}",
535 &outputfile.get_relative_path()
536 );
537 Ok(())
538 }
539}
540
541#[derive(Debug, Clone)]
551pub struct NPCycle {
552 pub cycle: usize,
553 pub objf: f64,
554 pub error_models: ErrorModels,
555 pub theta: Theta,
556 pub nspp: usize,
557 pub delta_objf: f64,
558 pub status: Status,
559}
560
561impl NPCycle {
562 pub fn new(
563 cycle: usize,
564 objf: f64,
565 error_models: ErrorModels,
566 theta: Theta,
567 nspp: usize,
568 delta_objf: f64,
569 status: Status,
570 ) -> Self {
571 Self {
572 cycle,
573 objf,
574 error_models,
575 theta,
576 nspp,
577 delta_objf,
578 status,
579 }
580 }
581
582 pub fn placeholder() -> Self {
583 Self {
584 cycle: 0,
585 objf: 0.0,
586 error_models: ErrorModels::default(),
587 theta: Theta::new(),
588 nspp: 0,
589 delta_objf: 0.0,
590 status: Status::Starting,
591 }
592 }
593}
594
595#[derive(Debug, Clone)]
597pub struct CycleLog {
598 pub cycles: Vec<NPCycle>,
599}
600
601impl CycleLog {
602 pub fn new() -> Self {
603 Self { cycles: Vec::new() }
604 }
605
606 pub fn push(&mut self, cycle: NPCycle) {
607 self.cycles.push(cycle);
608 }
609
610 pub fn write(&self, settings: &Settings) -> Result<()> {
611 tracing::debug!("Writing cycles...");
612 let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
613 let mut writer = WriterBuilder::new()
614 .has_headers(false)
615 .from_writer(&outputfile.file);
616
617 writer.write_field("cycle")?;
619 writer.write_field("converged")?;
620 writer.write_field("status")?;
621 writer.write_field("neg2ll")?;
622 writer.write_field("nspp")?;
623 if let Some(first_cycle) = self.cycles.first() {
624 first_cycle.error_models.iter().try_for_each(
625 |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
626 match errmod {
627 ErrorModel::Additive { .. } => {
628 writer.write_field(format!("gamlam.{}", outeq))?;
629 }
630 ErrorModel::Proportional { .. } => {
631 writer.write_field(format!("gamlam.{}", outeq))?;
632 }
633 ErrorModel::None { .. } => {}
634 }
635 Ok(())
636 },
637 )?;
638 }
639
640 let parameter_names = settings.parameters().names();
641 for param_name in ¶meter_names {
642 writer.write_field(format!("{}.mean", param_name))?;
643 writer.write_field(format!("{}.median", param_name))?;
644 writer.write_field(format!("{}.sd", param_name))?;
645 }
646
647 writer.write_record(None::<&[u8]>)?;
648
649 for cycle in &self.cycles {
650 writer.write_field(format!("{}", cycle.cycle))?;
651 writer.write_field(format!("{}", cycle.status == Status::Converged))?;
652 writer.write_field(format!("{}", cycle.status))?;
653 writer.write_field(format!("{}", cycle.objf))?;
654 writer
655 .write_field(format!("{}", cycle.theta.nspp()))
656 .unwrap();
657
658 cycle.error_models.iter().try_for_each(
660 |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
661 match errmod {
662 ErrorModel::Additive { .. } => {
663 writer.write_field(format!("{:.5}", errmod.scalar()?))?;
664 }
665 ErrorModel::Proportional { .. } => {
666 writer.write_field(format!("{:.5}", errmod.scalar()?))?;
667 }
668 ErrorModel::None { .. } => {}
669 }
670 Ok(())
671 },
672 )?;
673
674 for param in cycle.theta.matrix().col_iter() {
675 let param_values: Vec<f64> = param.iter().cloned().collect();
676
677 let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
678 let median = median(param_values.clone());
679 let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
680 / (param_values.len() as f64 - 1.0);
681
682 writer.write_field(format!("{}", mean))?;
683 writer.write_field(format!("{}", median))?;
684 writer.write_field(format!("{}", std))?;
685 }
686 writer.write_record(None::<&[u8]>)?;
687 }
688 writer.flush()?;
689 tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
690 Ok(())
691 }
692}
693
694impl Default for CycleLog {
695 fn default() -> Self {
696 Self::new()
697 }
698}
699
700pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
701 if psi.matrix().ncols() != w.nrows() {
702 bail!(
703 "Number of rows in psi ({}) and number of weights ({}) do not match.",
704 psi.matrix().nrows(),
705 w.nrows()
706 );
707 }
708
709 let psi_matrix = psi.matrix();
710 let py = psi_matrix * w;
711
712 let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
713 psi_matrix.get(i, j) * w.get(j) / py.get(i)
714 });
715
716 Ok(posterior)
717}
718
719pub fn median(data: Vec<f64>) -> f64 {
720 let mut data = data.clone();
721 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
722
723 let size = data.len();
724 match size {
725 even if even % 2 == 0 => {
726 let fst = data.get(even / 2 - 1).unwrap();
727 let snd = data.get(even / 2).unwrap();
728 (fst + snd) / 2.0
729 }
730 odd => *data.get(odd / 2_usize).unwrap(),
731 }
732}
733
734fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
735 assert_eq!(
737 data.len(),
738 weights.len(),
739 "The length of data and weights must be the same"
740 );
741 assert!(
742 weights.iter().all(|&x| x >= 0.0),
743 "Weights must be non-negative, weights: {:?}",
744 weights
745 );
746
747 let mut weighted_data: Vec<(f64, f64)> = data
749 .iter()
750 .zip(weights.iter())
751 .map(|(&d, &w)| (d, w))
752 .collect();
753
754 weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
756
757 let total_weight: f64 = weights.sum();
759 let mut cumulative_sum = 0.0;
760
761 for (i, &(_, weight)) in weighted_data.iter().enumerate() {
762 cumulative_sum += weight;
763
764 if cumulative_sum == total_weight / 2.0 {
765 if i + 1 < weighted_data.len() {
767 return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
768 } else {
769 return weighted_data[i].0;
770 }
771 } else if cumulative_sum > total_weight / 2.0 {
772 return weighted_data[i].0;
773 }
774 }
775
776 unreachable!("The function should have returned a value before reaching this point.");
777}
778
779pub fn population_mean_median(
780 theta: &Array2<f64>,
781 w: &Array1<f64>,
782) -> Result<(Array1<f64>, Array1<f64>)> {
783 let w = if w.is_empty() {
784 tracing::warn!("w.len() == 0, setting all weights to 1/n");
785 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
786 } else {
787 w.clone()
788 };
789 if theta.nrows() != w.len() {
791 bail!(
792 "Number of parameters and number of weights do not match. Theta: {}, w: {}",
793 theta.nrows(),
794 w.len()
795 );
796 }
797
798 let mut mean = Array1::zeros(theta.ncols());
799 let mut median = Array1::zeros(theta.ncols());
800
801 for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
802 let col = theta.column(i).to_owned() * w.to_owned();
804 *mn = col.sum();
805
806 let ct = theta.column(i);
808 let mut params = vec![];
809 let mut weights = vec![];
810 for (ti, wi) in ct.iter().zip(w.clone()) {
811 params.push(*ti);
812 weights.push(wi);
813 }
814
815 *mdn = weighted_median(&Array::from(params), &Array::from(weights));
816 }
817
818 Ok((mean, median))
819}
820
821pub fn posterior_mean_median(
822 theta: &Array2<f64>,
823 psi: &Array2<f64>,
824 w: &Array1<f64>,
825) -> Result<(Array2<f64>, Array2<f64>)> {
826 let mut mean = Array2::zeros((0, theta.ncols()));
827 let mut median = Array2::zeros((0, theta.ncols()));
828
829 let w = if w.is_empty() {
830 tracing::warn!("w is empty, setting all weights to 1/n");
831 Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
832 } else {
833 w.clone()
834 };
835
836 if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
838 bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
839 }
840
841 let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
843 for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
844 let row_w = row.to_owned() * w.to_owned();
845 let row_sum = row_w.sum();
846 let row_norm = if row_sum == 0.0 {
847 tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
848 Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
849 } else {
850 &row_w / row_sum
851 };
852 psi_norm.push_row(row_norm.view())?;
853 }
854 if psi_norm.iter().any(|&x| x.is_nan()) {
855 dbg!(&psi);
856 bail!("NaN values found in psi_norm");
857 };
858
859 for probs in psi_norm.axis_iter(Axis(0)) {
864 let mut post_mean: Vec<f64> = Vec::new();
865 let mut post_median: Vec<f64> = Vec::new();
866
867 for pars in theta.axis_iter(Axis(1)) {
869 let weighted_par = &probs * &pars;
871 let the_mean = weighted_par.sum();
872 post_mean.push(the_mean);
873
874 let median = weighted_median(&pars.to_owned(), &probs.to_owned());
876 post_median.push(median);
877 }
878
879 mean.push_row(Array::from(post_mean.clone()).view())?;
880 median.push_row(Array::from(post_median.clone()).view())?;
881 }
882
883 Ok((mean, median))
884}
885
886#[derive(Debug)]
888pub struct OutputFile {
889 pub file: File,
890 pub relative_path: PathBuf,
891}
892
893impl OutputFile {
894 pub fn new(folder: &str, file_name: &str) -> Result<Self> {
895 let relative_path = Path::new(&folder).join(file_name);
896
897 if let Some(parent) = relative_path.parent() {
898 create_dir_all(parent)
899 .with_context(|| format!("Failed to create directories for {:?}", parent))?;
900 }
901
902 let file = OpenOptions::new()
903 .write(true)
904 .create(true)
905 .truncate(true)
906 .open(&relative_path)
907 .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
908
909 Ok(OutputFile {
910 file,
911 relative_path,
912 })
913 }
914
915 pub fn get_relative_path(&self) -> &Path {
916 &self.relative_path
917 }
918}
919
920pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
921 let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
922
923 writer.write_record(["id", "block", "time", "out", "outeq"])?;
924 for subject in data.get_subjects() {
925 for occasion in subject.occasions() {
926 for event in occasion.get_events(&None, &None, false) {
927 if let Event::Observation(event) = event {
928 writer.write_record([
929 subject.id(),
930 &occasion.index().to_string(),
931 &event.time().to_string(),
932 &event.value().to_string(),
933 &event.outeq().to_string(),
934 ])?;
935 }
936 }
937 }
938 }
939 Ok(())
940}
941
942#[cfg(test)]
943mod tests {
944 use super::median;
945
946 #[test]
947 fn test_median_odd() {
948 let data = vec![1.0, 3.0, 2.0];
949 assert_eq!(median(data), 2.0);
950 }
951
952 #[test]
953 fn test_median_even() {
954 let data = vec![1.0, 2.0, 3.0, 4.0];
955 assert_eq!(median(data), 2.5);
956 }
957
958 #[test]
959 fn test_median_single() {
960 let data = vec![42.0];
961 assert_eq!(median(data), 42.0);
962 }
963
964 #[test]
965 fn test_median_sorted() {
966 let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
967 assert_eq!(median(data), 15.0);
968 }
969
970 #[test]
971 fn test_median_unsorted() {
972 let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
973 assert_eq!(median(data), 30.0);
974 }
975
976 #[test]
977 fn test_median_with_duplicates() {
978 let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
979 assert_eq!(median(data), 2.0);
980 }
981
982 use super::weighted_median;
983 use ndarray::Array1;
984
985 #[test]
986 fn test_weighted_median_simple() {
987 let data = Array1::from(vec![1.0, 2.0, 3.0]);
988 let weights = Array1::from(vec![0.2, 0.5, 0.3]);
989 assert_eq!(weighted_median(&data, &weights), 2.0);
990 }
991
992 #[test]
993 fn test_weighted_median_even_weights() {
994 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
995 let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
996 assert_eq!(weighted_median(&data, &weights), 2.5);
997 }
998
999 #[test]
1000 fn test_weighted_median_single_element() {
1001 let data = Array1::from(vec![42.0]);
1002 let weights = Array1::from(vec![1.0]);
1003 assert_eq!(weighted_median(&data, &weights), 42.0);
1004 }
1005
1006 #[test]
1007 #[should_panic(expected = "The length of data and weights must be the same")]
1008 fn test_weighted_median_mismatched_lengths() {
1009 let data = Array1::from(vec![1.0, 2.0, 3.0]);
1010 let weights = Array1::from(vec![0.1, 0.2]);
1011 weighted_median(&data, &weights);
1012 }
1013
1014 #[test]
1015 fn test_weighted_median_all_same_elements() {
1016 let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
1017 let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
1018 assert_eq!(weighted_median(&data, &weights), 5.0);
1019 }
1020
1021 #[test]
1022 #[should_panic(expected = "Weights must be non-negative")]
1023 fn test_weighted_median_negative_weights() {
1024 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1025 let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1026 assert_eq!(weighted_median(&data, &weights), 4.0);
1027 }
1028
1029 #[test]
1030 fn test_weighted_median_unsorted_data() {
1031 let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1032 let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1033 assert_eq!(weighted_median(&data, &weights), 2.5);
1034 }
1035
1036 #[test]
1037 fn test_weighted_median_with_zero_weights() {
1038 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1039 let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1040 assert_eq!(weighted_median(&data, &weights), 3.0);
1041 }
1042}