pmcore/routines/
output.rs

1use crate::prelude::*;
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use csv::WriterBuilder;
6use faer::linalg::zip::IntoView;
7use faer_ext::IntoNdarray;
8use ndarray::{Array, Array1, Array2, Axis};
9use pharmsol::prelude::data::*;
10use pharmsol::prelude::simulator::Equation;
11use serde::Serialize;
12// use pharmsol::Cache;
13use crate::routines::settings::Settings;
14use faer::{Col, Mat};
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18/// Defines the result objects from an NPAG run
19/// An [NPResult] contains the necessary information to generate predictions and summary statistics
20#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22    equation: E,
23    data: Data,
24    theta: Theta,
25    psi: Psi,
26    w: Col<f64>,
27    objf: f64,
28    cycles: usize,
29    converged: bool,
30    par_names: Vec<String>,
31    settings: Settings,
32    cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37    /// Create a new NPResult object
38    pub fn new(
39        equation: E,
40        data: Data,
41        theta: Theta,
42        psi: Psi,
43        w: Col<f64>,
44        objf: f64,
45        cycles: usize,
46        converged: bool,
47        settings: Settings,
48        cyclelog: CycleLog,
49    ) -> Self {
50        // TODO: Add support for fixed and constant parameters
51
52        let par_names = settings.parameters().names();
53
54        Self {
55            equation,
56            data,
57            theta,
58            psi,
59            w,
60            objf,
61            cycles,
62            converged,
63            par_names,
64            settings,
65            cyclelog,
66        }
67    }
68
69    pub fn cycles(&self) -> usize {
70        self.cycles
71    }
72
73    pub fn objf(&self) -> f64 {
74        self.objf
75    }
76
77    pub fn converged(&self) -> bool {
78        self.converged
79    }
80
81    pub fn get_theta(&self) -> &Theta {
82        &self.theta
83    }
84
85    pub fn get_psi(&self) -> &Psi {
86        &self.psi
87    }
88
89    pub fn get_w(&self) -> &Col<f64> {
90        &self.w
91    }
92
93    pub fn write_outputs(&self) -> Result<()> {
94        if self.settings.output().write {
95            self.settings.write()?;
96            let idelta: f64 = self.settings.predictions().idelta;
97            let tad = self.settings.predictions().tad;
98            self.cyclelog.write(&self.settings)?;
99            self.write_obs().context("Failed to write observations")?;
100            self.write_theta().context("Failed to write theta")?;
101            self.write_obspred()
102                .context("Failed to write observed-predicted file")?;
103            self.write_pred(idelta, tad)
104                .context("Failed to write predictions")?;
105            self.write_covs().context("Failed to write covariates")?;
106            if !(self.w.nrows() == 0) {
107                self.write_posterior()
108                    .context("Failed to write posterior")?;
109            }
110        }
111        Ok(())
112    }
113
114    /// Writes the observations and predictions to a single file
115    pub fn write_obspred(&self) -> Result<()> {
116        tracing::debug!("Writing observations and predictions...");
117
118        #[derive(Debug, Clone, Serialize)]
119        struct Row {
120            id: String,
121            time: f64,
122            outeq: usize,
123            block: usize,
124            obs: f64,
125            pop_mean: f64,
126            pop_median: f64,
127            post_mean: f64,
128            post_median: f64,
129        }
130
131        let theta: Array2<f64> = self
132            .theta
133            .matrix()
134            .clone()
135            .as_mut()
136            .into_ndarray()
137            .to_owned();
138        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
139        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
140
141        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
142            .context("Failed to calculate posterior mean and median")?;
143
144        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
145            .context("Failed to calculate posterior mean and median")?;
146
147        let subjects = self.data.get_subjects();
148        if subjects.len() != post_mean.nrows() {
149            bail!(
150                "Number of subjects: {} and number of posterior means: {} do not match",
151                subjects.len(),
152                post_mean.nrows()
153            );
154        }
155
156        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
157        let mut writer = WriterBuilder::new()
158            .has_headers(true)
159            .from_writer(&outputfile.file);
160
161        for (i, subject) in subjects.iter().enumerate() {
162            for occasion in subject.occasions() {
163                let id = subject.id();
164                let occ = occasion.index();
165
166                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
167
168                // Population predictions
169                let pop_mean_pred = self
170                    .equation
171                    .simulate_subject(&subject, &pop_mean.to_vec(), None)
172                    .0
173                    .get_predictions()
174                    .clone();
175
176                let pop_median_pred = self
177                    .equation
178                    .simulate_subject(&subject, &pop_median.to_vec(), None)
179                    .0
180                    .get_predictions()
181                    .clone();
182
183                // Posterior predictions
184                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
185                let post_mean_pred = self
186                    .equation
187                    .simulate_subject(&subject, &post_mean_spp, None)
188                    .0
189                    .get_predictions()
190                    .clone();
191                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
192                let post_median_pred = self
193                    .equation
194                    .simulate_subject(&subject, &post_median_spp, None)
195                    .0
196                    .get_predictions()
197                    .clone();
198                assert_eq!(
199                    pop_mean_pred.len(),
200                    pop_median_pred.len(),
201                    "The number of predictions do not match (pop_mean vs pop_median)"
202                );
203
204                assert_eq!(
205                    post_mean_pred.len(),
206                    post_median_pred.len(),
207                    "The number of predictions do not match (post_mean vs post_median)"
208                );
209
210                assert_eq!(
211                    pop_mean_pred.len(),
212                    post_mean_pred.len(),
213                    "The number of predictions do not match (pop_mean vs post_mean)"
214                );
215
216                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
217                    pop_mean_pred
218                        .iter()
219                        .zip(pop_median_pred.iter())
220                        .zip(post_mean_pred.iter())
221                        .zip(post_median_pred.iter())
222                {
223                    let row = Row {
224                        id: id.clone(),
225                        time: pop_mean_pred.time(),
226                        outeq: pop_mean_pred.outeq(),
227                        block: occ,
228                        obs: pop_mean_pred.observation(),
229                        pop_mean: pop_mean_pred.prediction(),
230                        pop_median: pop_median_pred.prediction(),
231                        post_mean: post_mean_pred.prediction(),
232                        post_median: post_median_pred.prediction(),
233                    };
234                    writer.serialize(row)?;
235                }
236            }
237        }
238        writer.flush()?;
239        tracing::info!(
240            "Observations with predictions written to {:?}",
241            &outputfile.get_relative_path()
242        );
243        Ok(())
244    }
245
246    /// Writes theta, which contains the population support points and their associated probabilities
247    /// Each row is one support point, the last column being probability
248    pub fn write_theta(&self) -> Result<()> {
249        tracing::debug!("Writing population parameter distribution...");
250
251        let theta = &self.theta;
252        let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
253        /* let w = if self.w.len() != theta.matrix().nrows() {
254                   tracing::warn!("Number of weights and number of support points do not match. Setting all weights to 0.");
255                   Array1::zeros(theta.matrix().nrows())
256               } else {
257                   self.w.clone()
258               };
259        */
260        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
261            .context("Failed to create output file for theta")?;
262
263        let mut writer = WriterBuilder::new()
264            .has_headers(true)
265            .from_writer(&outputfile.file);
266
267        // Create the headers
268        let mut theta_header = self.par_names.clone();
269        theta_header.push("prob".to_string());
270        writer.write_record(&theta_header)?;
271
272        // Write contents
273        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
274            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
275            row.push(w_val.to_string());
276            writer.write_record(&row)?;
277        }
278        writer.flush()?;
279        tracing::info!(
280            "Population parameter distribution written to {:?}",
281            &outputfile.get_relative_path()
282        );
283        Ok(())
284    }
285
286    /// Writes the posterior support points for each individual
287    pub fn write_posterior(&self) -> Result<()> {
288        tracing::debug!("Writing posterior parameter probabilities...");
289        let theta = &self.theta;
290        let w = &self.w;
291        let psi = &self.psi;
292
293        // Calculate the posterior probabilities
294        let posterior = posterior(psi, w)?;
295
296        // Create the output folder if it doesn't exist
297        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
298            Ok(of) => of,
299            Err(e) => {
300                tracing::error!("Failed to create output file: {}", e);
301                return Err(e.context("Failed to create output file"));
302            }
303        };
304
305        // Create a new writer
306        let mut writer = WriterBuilder::new()
307            .has_headers(true)
308            .from_writer(&outputfile.file);
309
310        // Create the headers
311        writer.write_field("id")?;
312        writer.write_field("point")?;
313        theta.param_names().iter().for_each(|name| {
314            writer.write_field(name).unwrap();
315        });
316        writer.write_field("prob")?;
317        writer.write_record(None::<&[u8]>)?;
318
319        // Write contents
320        let subjects = self.data.get_subjects();
321        posterior.row_iter().enumerate().for_each(|(i, row)| {
322            let subject = subjects.get(i).unwrap();
323            let id = subject.id();
324
325            row.iter().enumerate().for_each(|(spp, prob)| {
326                writer.write_field(id.clone()).unwrap();
327                writer.write_field(i.to_string()).unwrap();
328
329                theta
330                    .matrix()
331                    .row(spp)
332                    .iter()
333                    .enumerate()
334                    .for_each(|(_, val)| {
335                        writer.write_field(val.to_string()).unwrap();
336                    });
337
338                writer.write_field(prob.to_string()).unwrap();
339                writer.write_record(None::<&[u8]>).unwrap();
340            });
341        });
342
343        writer.flush()?;
344        tracing::info!(
345            "Posterior parameters written to {:?}",
346            &outputfile.get_relative_path()
347        );
348
349        Ok(())
350    }
351
352    /// Write the observations, which is the reformatted input data
353    pub fn write_obs(&self) -> Result<()> {
354        tracing::debug!("Writing observations...");
355        let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
356        write_pmetrics_observations(&self.data, &outputfile.file)?;
357        tracing::info!(
358            "Observations written to {:?}",
359            &outputfile.get_relative_path()
360        );
361        Ok(())
362    }
363
364    /// Writes the predictions
365    pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
366        tracing::debug!("Writing predictions...");
367        let data = self.data.expand(idelta, tad);
368
369        let theta: Array2<f64> = self
370            .theta
371            .matrix()
372            .clone()
373            .as_mut()
374            .into_ndarray()
375            .to_owned();
376        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
377        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
378
379        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
380            .context("Failed to calculate posterior mean and median")?;
381
382        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
383            .context("Failed to calculate population mean and median")?;
384
385        let subjects = data.get_subjects();
386        if subjects.len() != post_mean.nrows() {
387            bail!("Number of subjects and number of posterior means do not match");
388        }
389
390        let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
391        let mut writer = WriterBuilder::new()
392            .has_headers(true)
393            .from_writer(&outputfile.file);
394
395        #[derive(Debug, Clone, Serialize)]
396        struct Row {
397            id: String,
398            time: f64,
399            outeq: usize,
400            block: usize,
401            pop_mean: f64,
402            pop_median: f64,
403            post_mean: f64,
404            post_median: f64,
405        }
406
407        for (i, subject) in subjects.iter().enumerate() {
408            for occasion in subject.occasions() {
409                let id = subject.id();
410                let block = occasion.index();
411
412                // Create a new subject with only the current occasion
413                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
414
415                // Population predictions
416                let pop_mean_pred = self
417                    .equation
418                    .simulate_subject(&subject, &pop_mean.to_vec(), None)
419                    .0
420                    .get_predictions()
421                    .clone();
422                let pop_median_pred = self
423                    .equation
424                    .simulate_subject(&subject, &pop_median.to_vec(), None)
425                    .0
426                    .get_predictions()
427                    .clone();
428
429                // Posterior predictions
430                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
431                let post_mean_pred = self
432                    .equation
433                    .simulate_subject(&subject, &post_mean_spp, None)
434                    .0
435                    .get_predictions()
436                    .clone();
437                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
438                let post_median_pred = self
439                    .equation
440                    .simulate_subject(&subject, &post_median_spp, None)
441                    .0
442                    .get_predictions()
443                    .clone();
444
445                // Write predictions for each time point
446                for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
447                    .iter()
448                    .zip(pop_median_pred.iter())
449                    .zip(post_mean_pred.iter())
450                    .zip(post_median_pred.iter())
451                {
452                    let row = Row {
453                        id: id.clone(),
454                        time: pop_mean.time(),
455                        outeq: pop_mean.outeq(),
456                        block,
457                        pop_mean: pop_mean.prediction(),
458                        pop_median: pop_median.prediction(),
459                        post_mean: post_mean.prediction(),
460                        post_median: post_median.prediction(),
461                    };
462                    writer.serialize(row)?;
463                }
464            }
465        }
466        writer.flush()?;
467        tracing::info!(
468            "Predictions written to {:?}",
469            &outputfile.get_relative_path()
470        );
471        Ok(())
472    }
473
474    /// Writes the covariates
475    pub fn write_covs(&self) -> Result<()> {
476        tracing::debug!("Writing covariates...");
477        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
478        let mut writer = WriterBuilder::new()
479            .has_headers(true)
480            .from_writer(&outputfile.file);
481
482        // Collect all unique covariate names
483        let mut covariate_names = std::collections::HashSet::new();
484        for subject in self.data.get_subjects() {
485            for occasion in subject.occasions() {
486                if let Some(cov) = occasion.get_covariates() {
487                    let covmap = cov.covariates();
488                    for cov_name in covmap.keys() {
489                        covariate_names.insert(cov_name.clone());
490                    }
491                }
492            }
493        }
494        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
495        covariate_names.sort(); // Ensure consistent order
496
497        // Write the header row: id, time, block, covariate names
498        let mut headers = vec!["id", "time", "block"];
499        headers.extend(covariate_names.iter().map(|s| s.as_str()));
500        writer.write_record(&headers)?;
501
502        // Write the data rows
503        for subject in self.data.get_subjects() {
504            for occasion in subject.occasions() {
505                if let Some(cov) = occasion.get_covariates() {
506                    let covmap = cov.covariates();
507
508                    for event in occasion.get_events(&None, &None, false) {
509                        let time = match event {
510                            Event::Bolus(bolus) => bolus.time(),
511                            Event::Infusion(infusion) => infusion.time(),
512                            Event::Observation(observation) => observation.time(),
513                        };
514
515                        let mut row: Vec<String> = Vec::new();
516                        row.push(subject.id().clone());
517                        row.push(time.to_string());
518                        row.push(occasion.index().to_string());
519
520                        // Add covariate values to the row
521                        for cov_name in &covariate_names {
522                            if let Some(cov) = covmap.get(cov_name) {
523                                if let Some(value) = cov.interpolate(time) {
524                                    row.push(value.to_string());
525                                } else {
526                                    row.push(String::new());
527                                }
528                            } else {
529                                row.push(String::new());
530                            }
531                        }
532
533                        writer.write_record(&row)?;
534                    }
535                }
536            }
537        }
538
539        writer.flush()?;
540        tracing::info!(
541            "Covariates written to {:?}",
542            &outputfile.get_relative_path()
543        );
544        Ok(())
545    }
546}
547
548/// An [NPCycle] object contains the summary of a cycle
549/// It holds the following information:
550/// - `cycle`: The cycle number
551/// - `objf`: The objective function value
552/// - `gamlam`: The assay noise parameter, either gamma or lambda
553/// - `theta`: The support points and their associated probabilities
554/// - `nspp`: The number of support points
555/// - `delta_objf`: The change in objective function value from last cycle
556/// - `converged`: Whether the algorithm has reached convergence
557#[derive(Debug, Clone)]
558pub struct NPCycle {
559    pub cycle: usize,
560    pub objf: f64,
561    pub gamlam: f64,
562    pub theta: Theta,
563    pub nspp: usize,
564    pub delta_objf: f64,
565    pub converged: bool,
566}
567
568impl NPCycle {
569    pub fn new(
570        cycle: usize,
571        objf: f64,
572        gamlam: f64,
573        theta: Theta,
574        nspp: usize,
575        delta_objf: f64,
576        converged: bool,
577    ) -> Self {
578        Self {
579            cycle,
580            objf,
581            gamlam,
582            theta,
583            nspp,
584            delta_objf,
585            converged,
586        }
587    }
588
589    pub fn placeholder() -> Self {
590        Self {
591            cycle: 0,
592            objf: 0.0,
593            gamlam: 0.0,
594            theta: Theta::new(),
595            nspp: 0,
596            delta_objf: 0.0,
597            converged: false,
598        }
599    }
600}
601
602/// This holdes a vector of [NPCycle] objects to provide a more detailed log
603#[derive(Debug, Clone)]
604pub struct CycleLog {
605    pub cycles: Vec<NPCycle>,
606}
607
608impl CycleLog {
609    pub fn new() -> Self {
610        Self { cycles: Vec::new() }
611    }
612
613    pub fn push(&mut self, cycle: NPCycle) {
614        self.cycles.push(cycle);
615    }
616
617    pub fn write(&self, settings: &Settings) -> Result<()> {
618        tracing::debug!("Writing cycles...");
619        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
620        let mut writer = WriterBuilder::new()
621            .has_headers(false)
622            .from_writer(&outputfile.file);
623
624        // Write headers
625        writer.write_field("cycle")?;
626        writer.write_field("converged")?;
627        writer.write_field("neg2ll")?;
628        writer.write_field("gamlam")?;
629        writer.write_field("nspp")?;
630
631        let parameter_names = settings.parameters().names();
632        for param_name in &parameter_names {
633            writer.write_field(format!("{}.mean", param_name))?;
634            writer.write_field(format!("{}.median", param_name))?;
635            writer.write_field(format!("{}.sd", param_name))?;
636        }
637
638        writer.write_record(None::<&[u8]>)?;
639
640        for cycle in &self.cycles {
641            writer.write_field(format!("{}", cycle.cycle))?;
642            writer.write_field(format!("{}", cycle.converged))?;
643            writer.write_field(format!("{}", cycle.objf))?;
644            writer.write_field(format!("{}", cycle.gamlam))?;
645            writer
646                .write_field(format!("{}", cycle.theta.matrix().nrows()))
647                .unwrap();
648
649            for param in cycle.theta.matrix().col_iter() {
650                let param_values: Vec<f64> = param.iter().cloned().collect();
651
652                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
653                let median = median(param_values.clone());
654                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
655                    / (param_values.len() as f64 - 1.0);
656
657                writer.write_field(format!("{}", mean))?;
658                writer.write_field(format!("{}", median))?;
659                writer.write_field(format!("{}", std))?;
660            }
661            writer.write_record(None::<&[u8]>)?;
662        }
663        writer.flush()?;
664        tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
665        Ok(())
666    }
667}
668
669impl Default for CycleLog {
670    fn default() -> Self {
671        Self::new()
672    }
673}
674
675pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
676    if psi.matrix().ncols() != w.nrows() {
677        bail!(
678            "Number of rows in psi ({}) and number of weights ({}) do not match.",
679            psi.matrix().nrows(),
680            w.nrows()
681        );
682    }
683
684    let psi_matrix = psi.matrix();
685    let py = psi_matrix * w;
686
687    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
688        psi_matrix.get(i, j) * w.get(j) / py.get(i)
689    });
690
691    Ok(posterior)
692}
693
694pub fn median(data: Vec<f64>) -> f64 {
695    let mut data = data.clone();
696    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
697
698    let size = data.len();
699    match size {
700        even if even % 2 == 0 => {
701            let fst = data.get(even / 2 - 1).unwrap();
702            let snd = data.get(even / 2).unwrap();
703            (fst + snd) / 2.0
704        }
705        odd => *data.get(odd / 2_usize).unwrap(),
706    }
707}
708
709fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
710    // Ensure the data and weights arrays have the same length
711    assert_eq!(
712        data.len(),
713        weights.len(),
714        "The length of data and weights must be the same"
715    );
716    assert!(
717        weights.iter().all(|&x| x >= 0.0),
718        "Weights must be non-negative, weights: {:?}",
719        weights
720    );
721
722    // Create a vector of tuples (data, weight)
723    let mut weighted_data: Vec<(f64, f64)> = data
724        .iter()
725        .zip(weights.iter())
726        .map(|(&d, &w)| (d, w))
727        .collect();
728
729    // Sort the vector by the data values
730    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
731
732    // Calculate the cumulative sum of weights
733    let total_weight: f64 = weights.sum();
734    let mut cumulative_sum = 0.0;
735
736    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
737        cumulative_sum += weight;
738
739        if cumulative_sum == total_weight / 2.0 {
740            // If the cumulative sum equals half the total weight, average this value with the next
741            if i + 1 < weighted_data.len() {
742                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
743            } else {
744                return weighted_data[i].0;
745            }
746        } else if cumulative_sum > total_weight / 2.0 {
747            return weighted_data[i].0;
748        }
749    }
750
751    unreachable!("The function should have returned a value before reaching this point.");
752}
753
754pub fn population_mean_median(
755    theta: &Array2<f64>,
756    w: &Array1<f64>,
757) -> Result<(Array1<f64>, Array1<f64>)> {
758    let w = if w.is_empty() {
759        tracing::warn!("w.len() == 0, setting all weights to 1/n");
760        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
761    } else {
762        w.clone()
763    };
764    // Check for compatible sizes
765    if theta.nrows() != w.len() {
766        bail!(
767            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
768            theta.nrows(),
769            w.len()
770        );
771    }
772
773    let mut mean = Array1::zeros(theta.ncols());
774    let mut median = Array1::zeros(theta.ncols());
775
776    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
777        // Calculate the weighted mean
778        let col = theta.column(i).to_owned() * w.to_owned();
779        *mn = col.sum();
780
781        // Calculate the median
782        let ct = theta.column(i);
783        let mut params = vec![];
784        let mut weights = vec![];
785        for (ti, wi) in ct.iter().zip(w.clone()) {
786            params.push(*ti);
787            weights.push(wi);
788        }
789
790        *mdn = weighted_median(&Array::from(params), &Array::from(weights));
791    }
792
793    Ok((mean, median))
794}
795
796pub fn posterior_mean_median(
797    theta: &Array2<f64>,
798    psi: &Array2<f64>,
799    w: &Array1<f64>,
800) -> Result<(Array2<f64>, Array2<f64>)> {
801    let mut mean = Array2::zeros((0, theta.ncols()));
802    let mut median = Array2::zeros((0, theta.ncols()));
803
804    let w = if w.is_empty() {
805        tracing::warn!("w is empty, setting all weights to 1/n");
806        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
807    } else {
808        w.clone()
809    };
810
811    // Check for compatible sizes
812    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
813        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
814    }
815
816    // Normalize psi to get probabilities of each spp for each id
817    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
818    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
819        let row_w = row.to_owned() * w.to_owned();
820        let row_sum = row_w.sum();
821        let row_norm = if row_sum == 0.0 {
822            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
823            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
824        } else {
825            &row_w / row_sum
826        };
827        psi_norm.push_row(row_norm.view())?;
828    }
829    if psi_norm.iter().any(|&x| x.is_nan()) {
830        dbg!(&psi);
831        bail!("NaN values found in psi_norm");
832    };
833
834    // Transpose normalized psi to get ID (col) by prob (row)
835    // let psi_norm_transposed = psi_norm.t();
836
837    // For each subject..
838    for probs in psi_norm.axis_iter(Axis(0)) {
839        let mut post_mean: Vec<f64> = Vec::new();
840        let mut post_median: Vec<f64> = Vec::new();
841
842        // For each parameter
843        for pars in theta.axis_iter(Axis(1)) {
844            // Calculate the mean
845            let weighted_par = &probs * &pars;
846            let the_mean = weighted_par.sum();
847            post_mean.push(the_mean);
848
849            // Calculate the median
850            let median = weighted_median(&pars.to_owned(), &probs.to_owned());
851            post_median.push(median);
852        }
853
854        mean.push_row(Array::from(post_mean.clone()).view())?;
855        median.push_row(Array::from(post_median.clone()).view())?;
856    }
857
858    Ok((mean, median))
859}
860
861/// Contains all the necessary information of an output file
862#[derive(Debug)]
863pub struct OutputFile {
864    pub file: File,
865    pub relative_path: PathBuf,
866}
867
868impl OutputFile {
869    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
870        let relative_path = Path::new(&folder).join(file_name);
871
872        if let Some(parent) = relative_path.parent() {
873            create_dir_all(parent)
874                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
875        }
876
877        let file = OpenOptions::new()
878            .write(true)
879            .create(true)
880            .truncate(true)
881            .open(&relative_path)
882            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
883
884        Ok(OutputFile {
885            file,
886            relative_path,
887        })
888    }
889
890    pub fn get_relative_path(&self) -> &Path {
891        &self.relative_path
892    }
893}
894
895pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
896    let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
897
898    writer.write_record(["id", "block", "time", "out", "outeq"])?;
899    for subject in data.get_subjects() {
900        for occasion in subject.occasions() {
901            for event in occasion.get_events(&None, &None, false) {
902                if let Event::Observation(event) = event {
903                    writer.write_record([
904                        subject.id(),
905                        &occasion.index().to_string(),
906                        &event.time().to_string(),
907                        &event.value().to_string(),
908                        &event.outeq().to_string(),
909                    ])?;
910                }
911            }
912        }
913    }
914    Ok(())
915}
916
917#[cfg(test)]
918mod tests {
919    use super::median;
920
921    #[test]
922    fn test_median_odd() {
923        let data = vec![1.0, 3.0, 2.0];
924        assert_eq!(median(data), 2.0);
925    }
926
927    #[test]
928    fn test_median_even() {
929        let data = vec![1.0, 2.0, 3.0, 4.0];
930        assert_eq!(median(data), 2.5);
931    }
932
933    #[test]
934    fn test_median_single() {
935        let data = vec![42.0];
936        assert_eq!(median(data), 42.0);
937    }
938
939    #[test]
940    fn test_median_sorted() {
941        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
942        assert_eq!(median(data), 15.0);
943    }
944
945    #[test]
946    fn test_median_unsorted() {
947        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
948        assert_eq!(median(data), 30.0);
949    }
950
951    #[test]
952    fn test_median_with_duplicates() {
953        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
954        assert_eq!(median(data), 2.0);
955    }
956
957    use super::weighted_median;
958    use ndarray::Array1;
959
960    #[test]
961    fn test_weighted_median_simple() {
962        let data = Array1::from(vec![1.0, 2.0, 3.0]);
963        let weights = Array1::from(vec![0.2, 0.5, 0.3]);
964        assert_eq!(weighted_median(&data, &weights), 2.0);
965    }
966
967    #[test]
968    fn test_weighted_median_even_weights() {
969        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
970        let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
971        assert_eq!(weighted_median(&data, &weights), 2.5);
972    }
973
974    #[test]
975    fn test_weighted_median_single_element() {
976        let data = Array1::from(vec![42.0]);
977        let weights = Array1::from(vec![1.0]);
978        assert_eq!(weighted_median(&data, &weights), 42.0);
979    }
980
981    #[test]
982    #[should_panic(expected = "The length of data and weights must be the same")]
983    fn test_weighted_median_mismatched_lengths() {
984        let data = Array1::from(vec![1.0, 2.0, 3.0]);
985        let weights = Array1::from(vec![0.1, 0.2]);
986        weighted_median(&data, &weights);
987    }
988
989    #[test]
990    fn test_weighted_median_all_same_elements() {
991        let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
992        let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
993        assert_eq!(weighted_median(&data, &weights), 5.0);
994    }
995
996    #[test]
997    #[should_panic(expected = "Weights must be non-negative")]
998    fn test_weighted_median_negative_weights() {
999        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1000        let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1001        assert_eq!(weighted_median(&data, &weights), 4.0);
1002    }
1003
1004    #[test]
1005    fn test_weighted_median_unsorted_data() {
1006        let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1007        let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1008        assert_eq!(weighted_median(&data, &weights), 2.5);
1009    }
1010
1011    #[test]
1012    fn test_weighted_median_with_zero_weights() {
1013        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1014        let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1015        assert_eq!(weighted_median(&data, &weights), 3.0);
1016    }
1017}