pmcore/routines/
output.rs

1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::Equation;
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18/// Defines the result objects from an NPAG run
19/// An [NPResult] contains the necessary information to generate predictions and summary statistics
20#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22    equation: E,
23    data: Data,
24    theta: Theta,
25    psi: Psi,
26    w: Col<f64>,
27    objf: f64,
28    cycles: usize,
29    status: Status,
30    par_names: Vec<String>,
31    settings: Settings,
32    cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37    /// Create a new NPResult object
38    pub fn new(
39        equation: E,
40        data: Data,
41        theta: Theta,
42        psi: Psi,
43        w: Col<f64>,
44        objf: f64,
45        cycles: usize,
46        status: Status,
47        settings: Settings,
48        cyclelog: CycleLog,
49    ) -> Self {
50        // TODO: Add support for fixed and constant parameters
51
52        let par_names = settings.parameters().names();
53
54        Self {
55            equation,
56            data,
57            theta,
58            psi,
59            w,
60            objf,
61            cycles,
62            status,
63            par_names,
64            settings,
65            cyclelog,
66        }
67    }
68
69    pub fn cycles(&self) -> usize {
70        self.cycles
71    }
72
73    pub fn objf(&self) -> f64 {
74        self.objf
75    }
76
77    pub fn converged(&self) -> bool {
78        self.status == Status::Converged
79    }
80
81    pub fn get_theta(&self) -> &Theta {
82        &self.theta
83    }
84
85    /// Get the [Psi] structure
86    pub fn psi(&self) -> &Psi {
87        &self.psi
88    }
89
90    /// Get the weights (probabilities) of the support points
91    pub fn w(&self) -> &Col<f64> {
92        &self.w
93    }
94
95    pub fn write_outputs(&self) -> Result<()> {
96        if self.settings.output().write {
97            tracing::info!("Writing outputs to {:?}", self.settings.output().path);
98            self.settings.write()?;
99            let idelta: f64 = self.settings.predictions().idelta;
100            let tad = self.settings.predictions().tad;
101            self.cyclelog.write(&self.settings)?;
102            self.write_obs().context("Failed to write observations")?;
103            self.write_theta().context("Failed to write theta")?;
104            self.write_obspred()
105                .context("Failed to write observed-predicted file")?;
106            self.write_pred(idelta, tad)
107                .context("Failed to write predictions")?;
108            self.write_covs().context("Failed to write covariates")?;
109            self.write_posterior()
110                .context("Failed to write posterior")?;
111        }
112        Ok(())
113    }
114
115    /// Writes the observations and predictions to a single file
116    pub fn write_obspred(&self) -> Result<()> {
117        tracing::debug!("Writing observations and predictions...");
118
119        #[derive(Debug, Clone, Serialize)]
120        struct Row {
121            id: String,
122            time: f64,
123            outeq: usize,
124            block: usize,
125            obs: f64,
126            pop_mean: f64,
127            pop_median: f64,
128            post_mean: f64,
129            post_median: f64,
130        }
131
132        let theta: Array2<f64> = self
133            .theta
134            .matrix()
135            .clone()
136            .as_mut()
137            .into_ndarray()
138            .to_owned();
139        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
140        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
141
142        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
143            .context("Failed to calculate posterior mean and median")?;
144
145        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
146            .context("Failed to calculate posterior mean and median")?;
147
148        let subjects = self.data.get_subjects();
149        if subjects.len() != post_mean.nrows() {
150            bail!(
151                "Number of subjects: {} and number of posterior means: {} do not match",
152                subjects.len(),
153                post_mean.nrows()
154            );
155        }
156
157        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
158        let mut writer = WriterBuilder::new()
159            .has_headers(true)
160            .from_writer(&outputfile.file);
161
162        for (i, subject) in subjects.iter().enumerate() {
163            for occasion in subject.occasions() {
164                let id = subject.id();
165                let occ = occasion.index();
166
167                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
168
169                // Population predictions
170                let pop_mean_pred = self
171                    .equation
172                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
173                    .0
174                    .get_predictions()
175                    .clone();
176
177                let pop_median_pred = self
178                    .equation
179                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
180                    .0
181                    .get_predictions()
182                    .clone();
183
184                // Posterior predictions
185                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
186                let post_mean_pred = self
187                    .equation
188                    .simulate_subject(&subject, &post_mean_spp, None)?
189                    .0
190                    .get_predictions()
191                    .clone();
192                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
193                let post_median_pred = self
194                    .equation
195                    .simulate_subject(&subject, &post_median_spp, None)?
196                    .0
197                    .get_predictions()
198                    .clone();
199                assert_eq!(
200                    pop_mean_pred.len(),
201                    pop_median_pred.len(),
202                    "The number of predictions do not match (pop_mean vs pop_median)"
203                );
204
205                assert_eq!(
206                    post_mean_pred.len(),
207                    post_median_pred.len(),
208                    "The number of predictions do not match (post_mean vs post_median)"
209                );
210
211                assert_eq!(
212                    pop_mean_pred.len(),
213                    post_mean_pred.len(),
214                    "The number of predictions do not match (pop_mean vs post_mean)"
215                );
216
217                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
218                    pop_mean_pred
219                        .iter()
220                        .zip(pop_median_pred.iter())
221                        .zip(post_mean_pred.iter())
222                        .zip(post_median_pred.iter())
223                {
224                    let row = Row {
225                        id: id.clone(),
226                        time: pop_mean_pred.time(),
227                        outeq: pop_mean_pred.outeq(),
228                        block: occ,
229                        obs: pop_mean_pred.observation(),
230                        pop_mean: pop_mean_pred.prediction(),
231                        pop_median: pop_median_pred.prediction(),
232                        post_mean: post_mean_pred.prediction(),
233                        post_median: post_median_pred.prediction(),
234                    };
235                    writer.serialize(row)?;
236                }
237            }
238        }
239        writer.flush()?;
240        tracing::debug!(
241            "Observations with predictions written to {:?}",
242            &outputfile.get_relative_path()
243        );
244        Ok(())
245    }
246
247    /// Writes theta, which contains the population support points and their associated probabilities
248    /// Each row is one support point, the last column being probability
249    pub fn write_theta(&self) -> Result<()> {
250        tracing::debug!("Writing population parameter distribution...");
251
252        let theta = &self.theta;
253        let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
254        /* let w = if self.w.len() != theta.matrix().nrows() {
255                   tracing::warn!("Number of weights and number of support points do not match. Setting all weights to 0.");
256                   Array1::zeros(theta.matrix().nrows())
257               } else {
258                   self.w.clone()
259               };
260        */
261        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
262            .context("Failed to create output file for theta")?;
263
264        let mut writer = WriterBuilder::new()
265            .has_headers(true)
266            .from_writer(&outputfile.file);
267
268        // Create the headers
269        let mut theta_header = self.par_names.clone();
270        theta_header.push("prob".to_string());
271        writer.write_record(&theta_header)?;
272
273        // Write contents
274        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
275            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
276            row.push(w_val.to_string());
277            writer.write_record(&row)?;
278        }
279        writer.flush()?;
280        tracing::debug!(
281            "Population parameter distribution written to {:?}",
282            &outputfile.get_relative_path()
283        );
284        Ok(())
285    }
286
287    /// Writes the posterior support points for each individual
288    pub fn write_posterior(&self) -> Result<()> {
289        tracing::debug!("Writing posterior parameter probabilities...");
290        let theta = &self.theta;
291        let w = &self.w;
292        let psi = &self.psi;
293
294        // Calculate the posterior probabilities
295        let posterior = posterior(psi, w)?;
296
297        // Create the output folder if it doesn't exist
298        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
299            Ok(of) => of,
300            Err(e) => {
301                tracing::error!("Failed to create output file: {}", e);
302                return Err(e.context("Failed to create output file"));
303            }
304        };
305
306        // Create a new writer
307        let mut writer = WriterBuilder::new()
308            .has_headers(true)
309            .from_writer(&outputfile.file);
310
311        // Create the headers
312        writer.write_field("id")?;
313        writer.write_field("point")?;
314        theta.param_names().iter().for_each(|name| {
315            writer.write_field(name).unwrap();
316        });
317        writer.write_field("prob")?;
318        writer.write_record(None::<&[u8]>)?;
319
320        // Write contents
321        let subjects = self.data.get_subjects();
322        posterior.row_iter().enumerate().for_each(|(i, row)| {
323            let subject = subjects.get(i).unwrap();
324            let id = subject.id();
325
326            row.iter().enumerate().for_each(|(spp, prob)| {
327                writer.write_field(id.clone()).unwrap();
328                writer.write_field(spp.to_string()).unwrap();
329
330                theta.matrix().row(spp).iter().for_each(|val| {
331                    writer.write_field(val.to_string()).unwrap();
332                });
333
334                writer.write_field(prob.to_string()).unwrap();
335                writer.write_record(None::<&[u8]>).unwrap();
336            });
337        });
338
339        writer.flush()?;
340        tracing::debug!(
341            "Posterior parameters written to {:?}",
342            &outputfile.get_relative_path()
343        );
344
345        Ok(())
346    }
347
348    /// Write the observations, which is the reformatted input data
349    pub fn write_obs(&self) -> Result<()> {
350        tracing::debug!("Writing observations...");
351        let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
352
353        let mut writer = WriterBuilder::new()
354            .has_headers(true)
355            .from_writer(&outputfile.file);
356
357        #[derive(Serialize)]
358        struct Row {
359            id: String,
360            block: usize,
361            time: f64,
362            out: f64,
363            outeq: usize,
364        }
365
366        for subject in self.data.get_subjects() {
367            for occasion in subject.occasions() {
368                for event in occasion.get_events(None, false) {
369                    if let Event::Observation(event) = event {
370                        let row = Row {
371                            id: subject.id().clone(),
372                            block: occasion.index(),
373                            time: event.time(),
374                            out: event.value(),
375                            outeq: event.outeq(),
376                        };
377                        writer.serialize(row)?;
378                    }
379                }
380            }
381        }
382        writer.flush()?;
383
384        tracing::debug!(
385            "Observations written to {:?}",
386            &outputfile.get_relative_path()
387        );
388        Ok(())
389    }
390
391    /// Writes the predictions
392    pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
393        tracing::debug!("Writing predictions...");
394        let data = self.data.expand(idelta, tad);
395
396        let theta: Array2<f64> = self
397            .theta
398            .matrix()
399            .clone()
400            .as_mut()
401            .into_ndarray()
402            .to_owned();
403        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
404        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
405
406        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
407            .context("Failed to calculate posterior mean and median")?;
408
409        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
410            .context("Failed to calculate population mean and median")?;
411
412        let subjects = data.get_subjects();
413        if subjects.len() != post_mean.nrows() {
414            bail!("Number of subjects and number of posterior means do not match");
415        }
416
417        let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
418        let mut writer = WriterBuilder::new()
419            .has_headers(true)
420            .from_writer(&outputfile.file);
421
422        #[derive(Debug, Clone, Serialize)]
423        struct Row {
424            id: String,
425            time: f64,
426            outeq: usize,
427            block: usize,
428            pop_mean: f64,
429            pop_median: f64,
430            post_mean: f64,
431            post_median: f64,
432        }
433
434        for (i, subject) in subjects.iter().enumerate() {
435            for occasion in subject.occasions() {
436                let id = subject.id();
437                let block = occasion.index();
438
439                // Create a new subject with only the current occasion
440                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
441
442                // Population predictions
443                let pop_mean_pred = self
444                    .equation
445                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
446                    .0
447                    .get_predictions()
448                    .clone();
449                let pop_median_pred = self
450                    .equation
451                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
452                    .0
453                    .get_predictions()
454                    .clone();
455
456                // Posterior predictions
457                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
458                let post_mean_pred = self
459                    .equation
460                    .simulate_subject(&subject, &post_mean_spp, None)?
461                    .0
462                    .get_predictions()
463                    .clone();
464                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
465                let post_median_pred = self
466                    .equation
467                    .simulate_subject(&subject, &post_median_spp, None)?
468                    .0
469                    .get_predictions()
470                    .clone();
471
472                // Write predictions for each time point
473                for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
474                    .iter()
475                    .zip(pop_median_pred.iter())
476                    .zip(post_mean_pred.iter())
477                    .zip(post_median_pred.iter())
478                {
479                    let row = Row {
480                        id: id.clone(),
481                        time: pop_mean.time(),
482                        outeq: pop_mean.outeq(),
483                        block,
484                        pop_mean: pop_mean.prediction(),
485                        pop_median: pop_median.prediction(),
486                        post_mean: post_mean.prediction(),
487                        post_median: post_median.prediction(),
488                    };
489                    writer.serialize(row)?;
490                }
491            }
492        }
493        writer.flush()?;
494        tracing::debug!(
495            "Predictions written to {:?}",
496            &outputfile.get_relative_path()
497        );
498        Ok(())
499    }
500
501    /// Writes the covariates
502    pub fn write_covs(&self) -> Result<()> {
503        tracing::debug!("Writing covariates...");
504        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
505        let mut writer = WriterBuilder::new()
506            .has_headers(true)
507            .from_writer(&outputfile.file);
508
509        // Collect all unique covariate names
510        let mut covariate_names = std::collections::HashSet::new();
511        for subject in self.data.get_subjects() {
512            for occasion in subject.occasions() {
513                let cov = occasion.covariates();
514                let covmap = cov.covariates();
515                for cov_name in covmap.keys() {
516                    covariate_names.insert(cov_name.clone());
517                }
518            }
519        }
520        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
521        covariate_names.sort(); // Ensure consistent order
522
523        // Write the header row: id, time, block, covariate names
524        let mut headers = vec!["id", "time", "block"];
525        headers.extend(covariate_names.iter().map(|s| s.as_str()));
526        writer.write_record(&headers)?;
527
528        // Write the data rows
529        for subject in self.data.get_subjects() {
530            for occasion in subject.occasions() {
531                let cov = occasion.covariates();
532                let covmap = cov.covariates();
533
534                for event in occasion.get_events(None, false) {
535                    let time = match event {
536                        Event::Bolus(bolus) => bolus.time(),
537                        Event::Infusion(infusion) => infusion.time(),
538                        Event::Observation(observation) => observation.time(),
539                    };
540
541                    let mut row: Vec<String> = Vec::new();
542                    row.push(subject.id().clone());
543                    row.push(time.to_string());
544                    row.push(occasion.index().to_string());
545
546                    // Add covariate values to the row
547                    for cov_name in &covariate_names {
548                        if let Some(cov) = covmap.get(cov_name) {
549                            if let Some(value) = cov.interpolate(time) {
550                                row.push(value.to_string());
551                            } else {
552                                row.push(String::new());
553                            }
554                        } else {
555                            row.push(String::new());
556                        }
557                    }
558
559                    writer.write_record(&row)?;
560                }
561            }
562        }
563
564        writer.flush()?;
565        tracing::debug!(
566            "Covariates written to {:?}",
567            &outputfile.get_relative_path()
568        );
569        Ok(())
570    }
571}
572
573/// An [NPCycle] object contains the summary of a cycle
574/// It holds the following information:
575/// - `cycle`: The cycle number
576/// - `objf`: The objective function value
577/// - `gamlam`: The assay noise parameter, either gamma or lambda
578/// - `theta`: The support points and their associated probabilities
579/// - `nspp`: The number of support points
580/// - `delta_objf`: The change in objective function value from last cycle
581/// - `converged`: Whether the algorithm has reached convergence
582#[derive(Debug, Clone)]
583pub struct NPCycle {
584    pub cycle: usize,
585    pub objf: f64,
586    pub error_models: ErrorModels,
587    pub theta: Theta,
588    pub nspp: usize,
589    pub delta_objf: f64,
590    pub status: Status,
591}
592
593impl NPCycle {
594    pub fn new(
595        cycle: usize,
596        objf: f64,
597        error_models: ErrorModels,
598        theta: Theta,
599        nspp: usize,
600        delta_objf: f64,
601        status: Status,
602    ) -> Self {
603        Self {
604            cycle,
605            objf,
606            error_models,
607            theta,
608            nspp,
609            delta_objf,
610            status,
611        }
612    }
613
614    pub fn placeholder() -> Self {
615        Self {
616            cycle: 0,
617            objf: 0.0,
618            error_models: ErrorModels::default(),
619            theta: Theta::new(),
620            nspp: 0,
621            delta_objf: 0.0,
622            status: Status::Starting,
623        }
624    }
625}
626
627/// This holdes a vector of [NPCycle] objects to provide a more detailed log
628#[derive(Debug, Clone)]
629pub struct CycleLog {
630    pub cycles: Vec<NPCycle>,
631}
632
633impl CycleLog {
634    pub fn new() -> Self {
635        Self { cycles: Vec::new() }
636    }
637
638    pub fn push(&mut self, cycle: NPCycle) {
639        self.cycles.push(cycle);
640    }
641
642    pub fn write(&self, settings: &Settings) -> Result<()> {
643        tracing::debug!("Writing cycles...");
644        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
645        let mut writer = WriterBuilder::new()
646            .has_headers(false)
647            .from_writer(&outputfile.file);
648
649        // Write headers
650        writer.write_field("cycle")?;
651        writer.write_field("converged")?;
652        writer.write_field("status")?;
653        writer.write_field("neg2ll")?;
654        writer.write_field("nspp")?;
655        if let Some(first_cycle) = self.cycles.first() {
656            first_cycle.error_models.iter().try_for_each(
657                |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
658                    match errmod {
659                        ErrorModel::Additive { .. } => {
660                            writer.write_field(format!("gamlam.{}", outeq))?;
661                        }
662                        ErrorModel::Proportional { .. } => {
663                            writer.write_field(format!("gamlam.{}", outeq))?;
664                        }
665                        ErrorModel::None { .. } => {}
666                    }
667                    Ok(())
668                },
669            )?;
670        }
671
672        let parameter_names = settings.parameters().names();
673        for param_name in &parameter_names {
674            writer.write_field(format!("{}.mean", param_name))?;
675            writer.write_field(format!("{}.median", param_name))?;
676            writer.write_field(format!("{}.sd", param_name))?;
677        }
678
679        writer.write_record(None::<&[u8]>)?;
680
681        for cycle in &self.cycles {
682            writer.write_field(format!("{}", cycle.cycle))?;
683            writer.write_field(format!("{}", cycle.status == Status::Converged))?;
684            writer.write_field(format!("{}", cycle.status))?;
685            writer.write_field(format!("{}", cycle.objf))?;
686            writer
687                .write_field(format!("{}", cycle.theta.nspp()))
688                .unwrap();
689
690            // Write the error models
691            cycle.error_models.iter().try_for_each(
692                |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
693                    match errmod {
694                        ErrorModel::Additive { .. } => {
695                            writer.write_field(format!("{:.5}", errmod.scalar()?))?;
696                        }
697                        ErrorModel::Proportional { .. } => {
698                            writer.write_field(format!("{:.5}", errmod.scalar()?))?;
699                        }
700                        ErrorModel::None { .. } => {}
701                    }
702                    Ok(())
703                },
704            )?;
705
706            for param in cycle.theta.matrix().col_iter() {
707                let param_values: Vec<f64> = param.iter().cloned().collect();
708
709                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
710                let median = median(param_values.clone());
711                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
712                    / (param_values.len() as f64 - 1.0);
713
714                writer.write_field(format!("{}", mean))?;
715                writer.write_field(format!("{}", median))?;
716                writer.write_field(format!("{}", std))?;
717            }
718            writer.write_record(None::<&[u8]>)?;
719        }
720        writer.flush()?;
721        tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path());
722        Ok(())
723    }
724}
725
726impl Default for CycleLog {
727    fn default() -> Self {
728        Self::new()
729    }
730}
731
732/// Calculates the posterior probabilities for each support point given the weights
733pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
734    if psi.matrix().ncols() != w.nrows() {
735        bail!(
736            "Number of rows in psi ({}) and number of weights ({}) do not match.",
737            psi.matrix().nrows(),
738            w.nrows()
739        );
740    }
741
742    let psi_matrix = psi.matrix();
743    let py = psi_matrix * w;
744
745    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
746        psi_matrix.get(i, j) * w.get(j) / py.get(i)
747    });
748
749    Ok(posterior)
750}
751
752pub fn median(data: Vec<f64>) -> f64 {
753    let mut data = data.clone();
754    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
755
756    let size = data.len();
757    match size {
758        even if even % 2 == 0 => {
759            let fst = data.get(even / 2 - 1).unwrap();
760            let snd = data.get(even / 2).unwrap();
761            (fst + snd) / 2.0
762        }
763        odd => *data.get(odd / 2_usize).unwrap(),
764    }
765}
766
767fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
768    // Ensure the data and weights arrays have the same length
769    assert_eq!(
770        data.len(),
771        weights.len(),
772        "The length of data and weights must be the same"
773    );
774    assert!(
775        weights.iter().all(|&x| x >= 0.0),
776        "Weights must be non-negative, weights: {:?}",
777        weights
778    );
779
780    // Create a vector of tuples (data, weight)
781    let mut weighted_data: Vec<(f64, f64)> = data
782        .iter()
783        .zip(weights.iter())
784        .map(|(&d, &w)| (d, w))
785        .collect();
786
787    // Sort the vector by the data values
788    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
789
790    // Calculate the cumulative sum of weights
791    let total_weight: f64 = weights.sum();
792    let mut cumulative_sum = 0.0;
793
794    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
795        cumulative_sum += weight;
796
797        if cumulative_sum == total_weight / 2.0 {
798            // If the cumulative sum equals half the total weight, average this value with the next
799            if i + 1 < weighted_data.len() {
800                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
801            } else {
802                return weighted_data[i].0;
803            }
804        } else if cumulative_sum > total_weight / 2.0 {
805            return weighted_data[i].0;
806        }
807    }
808
809    unreachable!("The function should have returned a value before reaching this point.");
810}
811
812pub fn population_mean_median(
813    theta: &Array2<f64>,
814    w: &Array1<f64>,
815) -> Result<(Array1<f64>, Array1<f64>)> {
816    let w = if w.is_empty() {
817        tracing::warn!("w.len() == 0, setting all weights to 1/n");
818        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
819    } else {
820        w.clone()
821    };
822    // Check for compatible sizes
823    if theta.nrows() != w.len() {
824        bail!(
825            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
826            theta.nrows(),
827            w.len()
828        );
829    }
830
831    let mut mean = Array1::zeros(theta.ncols());
832    let mut median = Array1::zeros(theta.ncols());
833
834    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
835        // Calculate the weighted mean
836        let col = theta.column(i).to_owned() * w.to_owned();
837        *mn = col.sum();
838
839        // Calculate the median
840        let ct = theta.column(i);
841        let mut params = vec![];
842        let mut weights = vec![];
843        for (ti, wi) in ct.iter().zip(w.clone()) {
844            params.push(*ti);
845            weights.push(wi);
846        }
847
848        *mdn = weighted_median(&Array::from(params), &Array::from(weights));
849    }
850
851    Ok((mean, median))
852}
853
854pub fn posterior_mean_median(
855    theta: &Array2<f64>,
856    psi: &Array2<f64>,
857    w: &Array1<f64>,
858) -> Result<(Array2<f64>, Array2<f64>)> {
859    let mut mean = Array2::zeros((0, theta.ncols()));
860    let mut median = Array2::zeros((0, theta.ncols()));
861
862    let w = if w.is_empty() {
863        tracing::warn!("w is empty, setting all weights to 1/n");
864        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
865    } else {
866        w.clone()
867    };
868
869    // Check for compatible sizes
870    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
871        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
872    }
873
874    // Normalize psi to get probabilities of each spp for each id
875    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
876    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
877        let row_w = row.to_owned() * w.to_owned();
878        let row_sum = row_w.sum();
879        let row_norm = if row_sum == 0.0 {
880            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
881            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
882        } else {
883            &row_w / row_sum
884        };
885        psi_norm.push_row(row_norm.view())?;
886    }
887    if psi_norm.iter().any(|&x| x.is_nan()) {
888        dbg!(&psi);
889        bail!("NaN values found in psi_norm");
890    };
891
892    // Transpose normalized psi to get ID (col) by prob (row)
893    // let psi_norm_transposed = psi_norm.t();
894
895    // For each subject..
896    for probs in psi_norm.axis_iter(Axis(0)) {
897        let mut post_mean: Vec<f64> = Vec::new();
898        let mut post_median: Vec<f64> = Vec::new();
899
900        // For each parameter
901        for pars in theta.axis_iter(Axis(1)) {
902            // Calculate the mean
903            let weighted_par = &probs * &pars;
904            let the_mean = weighted_par.sum();
905            post_mean.push(the_mean);
906
907            // Calculate the median
908            let median = weighted_median(&pars.to_owned(), &probs.to_owned());
909            post_median.push(median);
910        }
911
912        mean.push_row(Array::from(post_mean.clone()).view())?;
913        median.push_row(Array::from(post_median.clone()).view())?;
914    }
915
916    Ok((mean, median))
917}
918
919/// Contains all the necessary information of an output file
920#[derive(Debug)]
921pub struct OutputFile {
922    pub file: File,
923    pub relative_path: PathBuf,
924}
925
926impl OutputFile {
927    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
928        let relative_path = Path::new(&folder).join(file_name);
929
930        if let Some(parent) = relative_path.parent() {
931            create_dir_all(parent)
932                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
933        }
934
935        let file = OpenOptions::new()
936            .write(true)
937            .create(true)
938            .truncate(true)
939            .open(&relative_path)
940            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
941
942        Ok(OutputFile {
943            file,
944            relative_path,
945        })
946    }
947
948    pub fn get_relative_path(&self) -> &Path {
949        &self.relative_path
950    }
951}
952
953#[cfg(test)]
954mod tests {
955    use super::median;
956
957    #[test]
958    fn test_median_odd() {
959        let data = vec![1.0, 3.0, 2.0];
960        assert_eq!(median(data), 2.0);
961    }
962
963    #[test]
964    fn test_median_even() {
965        let data = vec![1.0, 2.0, 3.0, 4.0];
966        assert_eq!(median(data), 2.5);
967    }
968
969    #[test]
970    fn test_median_single() {
971        let data = vec![42.0];
972        assert_eq!(median(data), 42.0);
973    }
974
975    #[test]
976    fn test_median_sorted() {
977        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
978        assert_eq!(median(data), 15.0);
979    }
980
981    #[test]
982    fn test_median_unsorted() {
983        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
984        assert_eq!(median(data), 30.0);
985    }
986
987    #[test]
988    fn test_median_with_duplicates() {
989        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
990        assert_eq!(median(data), 2.0);
991    }
992
993    use super::weighted_median;
994    use ndarray::Array1;
995
996    #[test]
997    fn test_weighted_median_simple() {
998        let data = Array1::from(vec![1.0, 2.0, 3.0]);
999        let weights = Array1::from(vec![0.2, 0.5, 0.3]);
1000        assert_eq!(weighted_median(&data, &weights), 2.0);
1001    }
1002
1003    #[test]
1004    fn test_weighted_median_even_weights() {
1005        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1006        let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
1007        assert_eq!(weighted_median(&data, &weights), 2.5);
1008    }
1009
1010    #[test]
1011    fn test_weighted_median_single_element() {
1012        let data = Array1::from(vec![42.0]);
1013        let weights = Array1::from(vec![1.0]);
1014        assert_eq!(weighted_median(&data, &weights), 42.0);
1015    }
1016
1017    #[test]
1018    #[should_panic(expected = "The length of data and weights must be the same")]
1019    fn test_weighted_median_mismatched_lengths() {
1020        let data = Array1::from(vec![1.0, 2.0, 3.0]);
1021        let weights = Array1::from(vec![0.1, 0.2]);
1022        weighted_median(&data, &weights);
1023    }
1024
1025    #[test]
1026    fn test_weighted_median_all_same_elements() {
1027        let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
1028        let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
1029        assert_eq!(weighted_median(&data, &weights), 5.0);
1030    }
1031
1032    #[test]
1033    #[should_panic(expected = "Weights must be non-negative")]
1034    fn test_weighted_median_negative_weights() {
1035        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1036        let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1037        assert_eq!(weighted_median(&data, &weights), 4.0);
1038    }
1039
1040    #[test]
1041    fn test_weighted_median_unsorted_data() {
1042        let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1043        let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1044        assert_eq!(weighted_median(&data, &weights), 2.5);
1045    }
1046
1047    #[test]
1048    fn test_weighted_median_with_zero_weights() {
1049        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1050        let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1051        assert_eq!(weighted_median(&data, &weights), 3.0);
1052    }
1053}