pmcore/routines/
output.rs

1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::{Equation, Prediction};
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18/// Defines the result objects from an NPAG run
19/// An [NPResult] contains the necessary information to generate predictions and summary statistics
20#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22    equation: E,
23    data: Data,
24    theta: Theta,
25    psi: Psi,
26    w: Col<f64>,
27    objf: f64,
28    cycles: usize,
29    status: Status,
30    par_names: Vec<String>,
31    settings: Settings,
32    cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37    /// Create a new NPResult object
38    pub fn new(
39        equation: E,
40        data: Data,
41        theta: Theta,
42        psi: Psi,
43        w: Col<f64>,
44        objf: f64,
45        cycles: usize,
46        status: Status,
47        settings: Settings,
48        cyclelog: CycleLog,
49    ) -> Self {
50        // TODO: Add support for fixed and constant parameters
51
52        let par_names = settings.parameters().names();
53
54        Self {
55            equation,
56            data,
57            theta,
58            psi,
59            w,
60            objf,
61            cycles,
62            status,
63            par_names,
64            settings,
65            cyclelog,
66        }
67    }
68
69    pub fn cycles(&self) -> usize {
70        self.cycles
71    }
72
73    pub fn objf(&self) -> f64 {
74        self.objf
75    }
76
77    pub fn converged(&self) -> bool {
78        self.status == Status::Converged
79    }
80
81    pub fn get_theta(&self) -> &Theta {
82        &self.theta
83    }
84
85    /// Get the [Psi] structure
86    pub fn psi(&self) -> &Psi {
87        &self.psi
88    }
89
90    /// Get the weights (probabilities) of the support points
91    pub fn w(&self) -> &Col<f64> {
92        &self.w
93    }
94
95    pub fn write_outputs(&self) -> Result<()> {
96        if self.settings.output().write {
97            tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
98            self.settings.write()?;
99            let idelta: f64 = self.settings.predictions().idelta;
100            let tad = self.settings.predictions().tad;
101            self.cyclelog.write(&self.settings)?;
102            self.write_obs().context("Failed to write observations")?;
103            self.write_theta().context("Failed to write theta")?;
104            self.write_obspred()
105                .context("Failed to write observed-predicted file")?;
106            self.write_pred(idelta, tad)
107                .context("Failed to write predictions")?;
108            self.write_covs().context("Failed to write covariates")?;
109            self.write_posterior()
110                .context("Failed to write posterior")?;
111        }
112        Ok(())
113    }
114
115    /// Writes the observations and predictions to a single file
116    pub fn write_obspred(&self) -> Result<()> {
117        tracing::debug!("Writing observations and predictions...");
118
119        #[derive(Debug, Clone, Serialize)]
120        struct Row {
121            id: String,
122            time: f64,
123            outeq: usize,
124            block: usize,
125            obs: Option<f64>,
126            pop_mean: f64,
127            pop_median: f64,
128            post_mean: f64,
129            post_median: f64,
130        }
131
132        let theta: Array2<f64> = self
133            .theta
134            .matrix()
135            .clone()
136            .as_mut()
137            .into_ndarray()
138            .to_owned();
139        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
140        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
141
142        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
143            .context("Failed to calculate posterior mean and median")?;
144
145        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
146            .context("Failed to calculate posterior mean and median")?;
147
148        let subjects = self.data.subjects();
149        if subjects.len() != post_mean.nrows() {
150            bail!(
151                "Number of subjects: {} and number of posterior means: {} do not match",
152                subjects.len(),
153                post_mean.nrows()
154            );
155        }
156
157        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
158        let mut writer = WriterBuilder::new()
159            .has_headers(true)
160            .from_writer(&outputfile.file);
161
162        for (i, subject) in subjects.iter().enumerate() {
163            for occasion in subject.occasions() {
164                let id = subject.id();
165                let occ = occasion.index();
166
167                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
168
169                // Population predictions
170                let pop_mean_pred = self
171                    .equation
172                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
173                    .0
174                    .get_predictions()
175                    .clone();
176
177                let pop_median_pred = self
178                    .equation
179                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
180                    .0
181                    .get_predictions()
182                    .clone();
183
184                // Posterior predictions
185                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
186                let post_mean_pred = self
187                    .equation
188                    .simulate_subject(&subject, &post_mean_spp, None)?
189                    .0
190                    .get_predictions()
191                    .clone();
192                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
193                let post_median_pred = self
194                    .equation
195                    .simulate_subject(&subject, &post_median_spp, None)?
196                    .0
197                    .get_predictions()
198                    .clone();
199                assert_eq!(
200                    pop_mean_pred.len(),
201                    pop_median_pred.len(),
202                    "The number of predictions do not match (pop_mean vs pop_median)"
203                );
204
205                assert_eq!(
206                    post_mean_pred.len(),
207                    post_median_pred.len(),
208                    "The number of predictions do not match (post_mean vs post_median)"
209                );
210
211                assert_eq!(
212                    pop_mean_pred.len(),
213                    post_mean_pred.len(),
214                    "The number of predictions do not match (pop_mean vs post_mean)"
215                );
216
217                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
218                    pop_mean_pred
219                        .iter()
220                        .zip(pop_median_pred.iter())
221                        .zip(post_mean_pred.iter())
222                        .zip(post_median_pred.iter())
223                {
224                    let row = Row {
225                        id: id.clone(),
226                        time: pop_mean_pred.time(),
227                        outeq: pop_mean_pred.outeq(),
228                        block: occ,
229                        obs: pop_mean_pred.observation(),
230                        pop_mean: pop_mean_pred.prediction(),
231                        pop_median: pop_median_pred.prediction(),
232                        post_mean: post_mean_pred.prediction(),
233                        post_median: post_median_pred.prediction(),
234                    };
235                    writer.serialize(row)?;
236                }
237            }
238        }
239        writer.flush()?;
240        tracing::debug!(
241            "Observations with predictions written to {:?}",
242            &outputfile.get_relative_path()
243        );
244        Ok(())
245    }
246
247    /// Writes theta, which contains the population support points and their associated probabilities
248    /// Each row is one support point, the last column being probability
249    pub fn write_theta(&self) -> Result<()> {
250        tracing::debug!("Writing population parameter distribution...");
251
252        let theta = &self.theta;
253        let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
254
255        if w.len() != theta.matrix().nrows() {
256            bail!(
257                "Number of weights ({}) and number of support points ({}) do not match.",
258                w.len(),
259                theta.matrix().nrows()
260            );
261        }
262
263        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
264            .context("Failed to create output file for theta")?;
265
266        let mut writer = WriterBuilder::new()
267            .has_headers(true)
268            .from_writer(&outputfile.file);
269
270        // Create the headers
271        let mut theta_header = self.par_names.clone();
272        theta_header.push("prob".to_string());
273        writer.write_record(&theta_header)?;
274
275        // Write contents
276        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
277            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
278            row.push(w_val.to_string());
279            writer.write_record(&row)?;
280        }
281        writer.flush()?;
282        tracing::debug!(
283            "Population parameter distribution written to {:?}",
284            &outputfile.get_relative_path()
285        );
286        Ok(())
287    }
288
289    /// Writes the posterior support points for each individual
290    pub fn write_posterior(&self) -> Result<()> {
291        tracing::debug!("Writing posterior parameter probabilities...");
292        let theta = &self.theta;
293        let w = &self.w;
294        let psi = &self.psi;
295
296        // Calculate the posterior probabilities
297        let posterior = posterior(psi, w)?;
298
299        // Create the output folder if it doesn't exist
300        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
301            Ok(of) => of,
302            Err(e) => {
303                tracing::error!("Failed to create output file: {}", e);
304                return Err(e.context("Failed to create output file"));
305            }
306        };
307
308        // Create a new writer
309        let mut writer = WriterBuilder::new()
310            .has_headers(true)
311            .from_writer(&outputfile.file);
312
313        // Create the headers
314        writer.write_field("id")?;
315        writer.write_field("point")?;
316        theta.param_names().iter().for_each(|name| {
317            writer.write_field(name).unwrap();
318        });
319        writer.write_field("prob")?;
320        writer.write_record(None::<&[u8]>)?;
321
322        // Write contents
323        let subjects = self.data.subjects();
324        posterior.row_iter().enumerate().for_each(|(i, row)| {
325            let subject = subjects.get(i).unwrap();
326            let id = subject.id();
327
328            row.iter().enumerate().for_each(|(spp, prob)| {
329                writer.write_field(id.clone()).unwrap();
330                writer.write_field(spp.to_string()).unwrap();
331
332                theta.matrix().row(spp).iter().for_each(|val| {
333                    writer.write_field(val.to_string()).unwrap();
334                });
335
336                writer.write_field(prob.to_string()).unwrap();
337                writer.write_record(None::<&[u8]>).unwrap();
338            });
339        });
340
341        writer.flush()?;
342        tracing::debug!(
343            "Posterior parameters written to {:?}",
344            &outputfile.get_relative_path()
345        );
346
347        Ok(())
348    }
349
350    /// Write the observations, which is the reformatted input data
351    pub fn write_obs(&self) -> Result<()> {
352        tracing::debug!("Writing observations...");
353        let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
354
355        let mut writer = WriterBuilder::new()
356            .has_headers(true)
357            .from_writer(&outputfile.file);
358
359        #[derive(Serialize)]
360        struct Row {
361            id: String,
362            block: usize,
363            time: f64,
364            out: Option<f64>,
365            outeq: usize,
366        }
367
368        for subject in self.data.subjects() {
369            for occasion in subject.occasions() {
370                for event in occasion.iter() {
371                    if let Event::Observation(event) = event {
372                        let row = Row {
373                            id: subject.id().clone(),
374                            block: occasion.index(),
375                            time: event.time(),
376                            out: event.value(),
377                            outeq: event.outeq(),
378                        };
379                        writer.serialize(row)?;
380                    }
381                }
382            }
383        }
384        writer.flush()?;
385
386        tracing::debug!(
387            "Observations written to {:?}",
388            &outputfile.get_relative_path()
389        );
390        Ok(())
391    }
392
393    /// Writes the predictions
394    pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
395        tracing::debug!("Writing predictions...");
396
397        // Get necessary data
398        let theta = self.theta.matrix();
399        let w: Vec<f64> = self.w.iter().cloned().collect();
400        let posterior = posterior(&self.psi, &self.w)?;
401
402        let data = self.data.clone().expand(idelta, tad);
403
404        let subjects = data.subjects();
405
406        if subjects.len() != posterior.nrows() {
407            bail!("Number of subjects and number of posterior means do not match");
408        };
409
410        // Create the output file and writer for pred.csv
411        let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
412        let mut writer = WriterBuilder::new()
413            .has_headers(true)
414            .from_writer(&outputfile.file);
415
416        // Iterate over each subject and then each support point
417        for subject in subjects.iter().enumerate() {
418            let (subject_index, subject) = subject;
419
420            // Get a vector of occasions for this subject, for each predictions
421            let occasions = subject
422                .occasions()
423                .iter()
424                .flat_map(|o| {
425                    o.events()
426                        .iter()
427                        .filter_map(|e| {
428                            if let Event::Observation(_obs) = e {
429                                Some(o.index())
430                            } else {
431                                None
432                            }
433                        })
434                        .collect::<Vec<_>>()
435                })
436                .collect::<Vec<usize>>();
437
438            // Container for predictions for this subject
439            // This will hold predictions for each support point
440            // The outer vector is for each support point
441            // The inner vector is for the vector of predictions for that support point
442            let mut predictions: Vec<Vec<Prediction>> = Vec::new();
443
444            // And each support points
445            for spp in theta.row_iter() {
446                // Simulate the subject with the current support point
447                let spp_values = spp.iter().cloned().collect::<Vec<f64>>();
448                let pred = self
449                    .equation
450                    .simulate_subject(subject, &spp_values, None)?
451                    .0
452                    .get_predictions();
453                predictions.push(pred);
454            }
455
456            if predictions.is_empty() {
457                continue; // Skip this subject if no predictions are available
458            }
459
460            // Calculate population mean using
461            let mut pop_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
462            for outer_pred in predictions.iter().enumerate() {
463                let (i, outer_pred) = outer_pred;
464                for inner_pred in outer_pred.iter().enumerate() {
465                    let (j, pred) = inner_pred;
466                    pop_mean[j] += pred.prediction() * w[i];
467                }
468            }
469
470            // Calculate population median
471            let mut pop_median: Vec<f64> = Vec::new();
472            for j in 0..predictions.first().unwrap().len() {
473                let mut values: Vec<f64> = Vec::new();
474                let mut weights: Vec<f64> = Vec::new();
475
476                for (i, outer_pred) in predictions.iter().enumerate() {
477                    values.push(outer_pred[j].prediction());
478                    weights.push(w[i]);
479                }
480
481                let median_val = weighted_median(&values, &weights);
482                pop_median.push(median_val);
483            }
484
485            // Calculate posterior mean
486            let mut posterior_mean: Vec<f64> = vec![0.0; predictions.first().unwrap().len()];
487            for outer_pred in predictions.iter().enumerate() {
488                let (i, outer_pred) = outer_pred;
489                for inner_pred in outer_pred.iter().enumerate() {
490                    let (j, pred) = inner_pred;
491                    posterior_mean[j] += pred.prediction() * posterior[(subject_index, i)];
492                }
493            }
494
495            // Calculate posterior median
496            let mut posterior_median: Vec<f64> = Vec::new();
497            for j in 0..predictions.first().unwrap().len() {
498                let mut values: Vec<f64> = Vec::new();
499                let mut weights: Vec<f64> = Vec::new();
500
501                for (i, outer_pred) in predictions.iter().enumerate() {
502                    values.push(outer_pred[j].prediction());
503                    weights.push(posterior[(subject_index, i)]);
504                }
505
506                let median_val = weighted_median(&values, &weights);
507                posterior_median.push(median_val);
508            }
509
510            // Structure for the output
511            #[derive(Debug, Clone, Serialize)]
512            struct Row {
513                id: String,
514                time: f64,
515                outeq: usize,
516                block: usize,
517                pop_mean: f64,
518                pop_median: f64,
519                post_mean: f64,
520                post_median: f64,
521            }
522
523            for pred in predictions.iter().enumerate() {
524                let (_, preds) = pred;
525                for (j, p) in preds.iter().enumerate() {
526                    let row = Row {
527                        id: subject.id().clone(),
528                        time: p.time(),
529                        outeq: p.outeq(),
530                        block: occasions[j],
531                        pop_mean: pop_mean[j],
532                        pop_median: pop_median[j],
533                        post_mean: posterior_mean[j],
534                        post_median: posterior_median[j],
535                    };
536                    writer.serialize(row)?;
537                }
538            }
539        }
540
541        writer.flush()?;
542        tracing::debug!(
543            "Predictions written to {:?}",
544            &outputfile.get_relative_path()
545        );
546
547        Ok(())
548    }
549
550    /// Writes the covariates
551    pub fn write_covs(&self) -> Result<()> {
552        tracing::debug!("Writing covariates...");
553        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
554        let mut writer = WriterBuilder::new()
555            .has_headers(true)
556            .from_writer(&outputfile.file);
557
558        // Collect all unique covariate names
559        let mut covariate_names = std::collections::HashSet::new();
560        for subject in self.data.subjects() {
561            for occasion in subject.occasions() {
562                let cov = occasion.covariates();
563                let covmap = cov.covariates();
564                for cov_name in covmap.keys() {
565                    covariate_names.insert(cov_name.clone());
566                }
567            }
568        }
569        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
570        covariate_names.sort(); // Ensure consistent order
571
572        // Write the header row: id, time, block, covariate names
573        let mut headers = vec!["id", "time", "block"];
574        headers.extend(covariate_names.iter().map(|s| s.as_str()));
575        writer.write_record(&headers)?;
576
577        // Write the data rows
578        for subject in self.data.subjects() {
579            for occasion in subject.occasions() {
580                let cov = occasion.covariates();
581                let covmap = cov.covariates();
582
583                for event in occasion.iter() {
584                    let time = match event {
585                        Event::Bolus(bolus) => bolus.time(),
586                        Event::Infusion(infusion) => infusion.time(),
587                        Event::Observation(observation) => observation.time(),
588                    };
589
590                    let mut row: Vec<String> = Vec::new();
591                    row.push(subject.id().clone());
592                    row.push(time.to_string());
593                    row.push(occasion.index().to_string());
594
595                    // Add covariate values to the row
596                    for cov_name in &covariate_names {
597                        if let Some(cov) = covmap.get(cov_name) {
598                            if let Ok(value) = cov.interpolate(time) {
599                                row.push(value.to_string());
600                            } else {
601                                row.push(String::new());
602                            }
603                        } else {
604                            row.push(String::new());
605                        }
606                    }
607
608                    writer.write_record(&row)?;
609                }
610            }
611        }
612
613        writer.flush()?;
614        tracing::debug!(
615            "Covariates written to {:?}",
616            &outputfile.get_relative_path()
617        );
618        Ok(())
619    }
620}
621
622/// An [NPCycle] object contains the summary of a cycle
623/// It holds the following information:
624/// - `cycle`: The cycle number
625/// - `objf`: The objective function value
626/// - `gamlam`: The assay noise parameter, either gamma or lambda
627/// - `theta`: The support points and their associated probabilities
628/// - `nspp`: The number of support points
629/// - `delta_objf`: The change in objective function value from last cycle
630/// - `converged`: Whether the algorithm has reached convergence
631#[derive(Debug, Clone)]
632pub struct NPCycle {
633    pub cycle: usize,
634    pub objf: f64,
635    pub error_models: ErrorModels,
636    pub theta: Theta,
637    pub nspp: usize,
638    pub delta_objf: f64,
639    pub status: Status,
640}
641
642impl NPCycle {
643    pub fn new(
644        cycle: usize,
645        objf: f64,
646        error_models: ErrorModels,
647        theta: Theta,
648        nspp: usize,
649        delta_objf: f64,
650        status: Status,
651    ) -> Self {
652        Self {
653            cycle,
654            objf,
655            error_models,
656            theta,
657            nspp,
658            delta_objf,
659            status,
660        }
661    }
662
663    pub fn placeholder() -> Self {
664        Self {
665            cycle: 0,
666            objf: 0.0,
667            error_models: ErrorModels::default(),
668            theta: Theta::new(),
669            nspp: 0,
670            delta_objf: 0.0,
671            status: Status::Starting,
672        }
673    }
674}
675
676/// This holdes a vector of [NPCycle] objects to provide a more detailed log
677#[derive(Debug, Clone)]
678pub struct CycleLog {
679    pub cycles: Vec<NPCycle>,
680}
681
682impl CycleLog {
683    pub fn new() -> Self {
684        Self { cycles: Vec::new() }
685    }
686
687    pub fn push(&mut self, cycle: NPCycle) {
688        self.cycles.push(cycle);
689    }
690
691    pub fn write(&self, settings: &Settings) -> Result<()> {
692        tracing::debug!("Writing cycles...");
693        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
694        let mut writer = WriterBuilder::new()
695            .has_headers(false)
696            .from_writer(&outputfile.file);
697
698        // Write headers
699        writer.write_field("cycle")?;
700        writer.write_field("converged")?;
701        writer.write_field("status")?;
702        writer.write_field("neg2ll")?;
703        writer.write_field("nspp")?;
704        if let Some(first_cycle) = self.cycles.first() {
705            first_cycle.error_models.iter().try_for_each(
706                |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
707                    match errmod {
708                        ErrorModel::Additive { .. } => {
709                            writer.write_field(format!("gamlam.{}", outeq))?;
710                        }
711                        ErrorModel::Proportional { .. } => {
712                            writer.write_field(format!("gamlam.{}", outeq))?;
713                        }
714                        ErrorModel::None => {}
715                    }
716                    Ok(())
717                },
718            )?;
719        }
720
721        let parameter_names = settings.parameters().names();
722        for param_name in &parameter_names {
723            writer.write_field(format!("{}.mean", param_name))?;
724            writer.write_field(format!("{}.median", param_name))?;
725            writer.write_field(format!("{}.sd", param_name))?;
726        }
727
728        writer.write_record(None::<&[u8]>)?;
729
730        for cycle in &self.cycles {
731            writer.write_field(format!("{}", cycle.cycle))?;
732            writer.write_field(format!("{}", cycle.status == Status::Converged))?;
733            writer.write_field(format!("{}", cycle.status))?;
734            writer.write_field(format!("{}", cycle.objf))?;
735            writer
736                .write_field(format!("{}", cycle.theta.nspp()))
737                .unwrap();
738
739            // Write the error models
740            cycle.error_models.iter().try_for_each(
741                |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
742                    match errmod {
743                        ErrorModel::Additive {
744                            lambda: _,
745                            poly: _,
746                            lloq: _,
747                        } => {
748                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
749                        }
750                        ErrorModel::Proportional {
751                            gamma: _,
752                            poly: _,
753                            lloq: _,
754                        } => {
755                            writer.write_field(format!("{:.5}", errmod.factor()?))?;
756                        }
757                        ErrorModel::None => {}
758                    }
759                    Ok(())
760                },
761            )?;
762
763            for param in cycle.theta.matrix().col_iter() {
764                let param_values: Vec<f64> = param.iter().cloned().collect();
765
766                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
767                let median = median(param_values.clone());
768                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
769                    / (param_values.len() as f64 - 1.0);
770
771                writer.write_field(format!("{}", mean))?;
772                writer.write_field(format!("{}", median))?;
773                writer.write_field(format!("{}", std))?;
774            }
775            writer.write_record(None::<&[u8]>)?;
776        }
777        writer.flush()?;
778        tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path());
779        Ok(())
780    }
781}
782
783impl Default for CycleLog {
784    fn default() -> Self {
785        Self::new()
786    }
787}
788
789/// Calculates the posterior probabilities for each support point given the weights
790///
791/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns.
792pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
793    if psi.matrix().ncols() != w.nrows() {
794        bail!(
795            "Number of rows in psi ({}) and number of weights ({}) do not match.",
796            psi.matrix().nrows(),
797            w.nrows()
798        );
799    }
800
801    let psi_matrix = psi.matrix();
802    let py = psi_matrix * w;
803
804    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
805        psi_matrix.get(i, j) * w.get(j) / py.get(i)
806    });
807
808    Ok(posterior)
809}
810
811pub fn median(data: Vec<f64>) -> f64 {
812    let mut data = data.clone();
813    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
814
815    let size = data.len();
816    match size {
817        even if even % 2 == 0 => {
818            let fst = data.get(even / 2 - 1).unwrap();
819            let snd = data.get(even / 2).unwrap();
820            (fst + snd) / 2.0
821        }
822        odd => *data.get(odd / 2_usize).unwrap(),
823    }
824}
825
826fn weighted_median(data: &Vec<f64>, weights: &Vec<f64>) -> f64 {
827    // Ensure the data and weights arrays have the same length
828    assert_eq!(
829        data.len(),
830        weights.len(),
831        "The length of data and weights must be the same"
832    );
833    assert!(
834        weights.iter().all(|&x| x >= 0.0),
835        "Weights must be non-negative, weights: {:?}",
836        weights
837    );
838
839    // Create a vector of tuples (data, weight)
840    let mut weighted_data: Vec<(f64, f64)> = data
841        .iter()
842        .zip(weights.iter())
843        .map(|(&d, &w)| (d, w))
844        .collect();
845
846    // Sort the vector by the data values
847    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
848
849    // Calculate the cumulative sum of weights
850    let total_weight: f64 = weights.iter().sum();
851    let mut cumulative_sum = 0.0;
852
853    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
854        cumulative_sum += weight;
855
856        if cumulative_sum == total_weight / 2.0 {
857            // If the cumulative sum equals half the total weight, average this value with the next
858            if i + 1 < weighted_data.len() {
859                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
860            } else {
861                return weighted_data[i].0;
862            }
863        } else if cumulative_sum > total_weight / 2.0 {
864            return weighted_data[i].0;
865        }
866    }
867
868    unreachable!("The function should have returned a value before reaching this point.");
869}
870
871pub fn population_mean_median(
872    theta: &Array2<f64>,
873    w: &Array1<f64>,
874) -> Result<(Array1<f64>, Array1<f64>)> {
875    let w = if w.is_empty() {
876        tracing::warn!("w.len() == 0, setting all weights to 1/n");
877        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
878    } else {
879        w.clone()
880    };
881    // Check for compatible sizes
882    if theta.nrows() != w.len() {
883        bail!(
884            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
885            theta.nrows(),
886            w.len()
887        );
888    }
889
890    let mut mean = Array1::zeros(theta.ncols());
891    let mut median = Array1::zeros(theta.ncols());
892
893    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
894        // Calculate the weighted mean
895        let col = theta.column(i).to_owned() * w.to_owned();
896        *mn = col.sum();
897
898        // Calculate the median
899        let ct = theta.column(i);
900        let mut params = vec![];
901        let mut weights = vec![];
902        for (ti, wi) in ct.iter().zip(w.clone()) {
903            params.push(*ti);
904            weights.push(wi);
905        }
906
907        *mdn = weighted_median(&params, &weights);
908    }
909
910    Ok((mean, median))
911}
912
913pub fn posterior_mean_median(
914    theta: &Array2<f64>,
915    psi: &Array2<f64>,
916    w: &Array1<f64>,
917) -> Result<(Array2<f64>, Array2<f64>)> {
918    let mut mean = Array2::zeros((0, theta.ncols()));
919    let mut median = Array2::zeros((0, theta.ncols()));
920
921    let w = if w.is_empty() {
922        tracing::warn!("w is empty, setting all weights to 1/n");
923        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
924    } else {
925        w.clone()
926    };
927
928    // Check for compatible sizes
929    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
930        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
931    }
932
933    // Normalize psi to get probabilities of each spp for each id
934    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
935    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
936        let row_w = row.to_owned() * w.to_owned();
937        let row_sum = row_w.sum();
938        let row_norm = if row_sum == 0.0 {
939            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
940            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
941        } else {
942            &row_w / row_sum
943        };
944        psi_norm.push_row(row_norm.view())?;
945    }
946    if psi_norm.iter().any(|&x| x.is_nan()) {
947        dbg!(&psi);
948        bail!("NaN values found in psi_norm");
949    };
950
951    // Transpose normalized psi to get ID (col) by prob (row)
952    // let psi_norm_transposed = psi_norm.t();
953
954    // For each subject..
955    for probs in psi_norm.axis_iter(Axis(0)) {
956        let mut post_mean: Vec<f64> = Vec::new();
957        let mut post_median: Vec<f64> = Vec::new();
958
959        // For each parameter
960        for pars in theta.axis_iter(Axis(1)) {
961            // Calculate the mean
962            let weighted_par = &probs * &pars;
963            let the_mean = weighted_par.sum();
964            post_mean.push(the_mean);
965
966            // Calculate the median
967            let median = weighted_median(&pars.to_vec(), &probs.to_vec());
968            post_median.push(median);
969        }
970
971        mean.push_row(Array::from(post_mean.clone()).view())?;
972        median.push_row(Array::from(post_median.clone()).view())?;
973    }
974
975    Ok((mean, median))
976}
977
978/// Contains all the necessary information of an output file
979#[derive(Debug)]
980pub struct OutputFile {
981    pub file: File,
982    pub relative_path: PathBuf,
983}
984
985impl OutputFile {
986    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
987        let relative_path = Path::new(&folder).join(file_name);
988
989        if let Some(parent) = relative_path.parent() {
990            create_dir_all(parent)
991                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
992        }
993
994        let file = OpenOptions::new()
995            .write(true)
996            .create(true)
997            .truncate(true)
998            .open(&relative_path)
999            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
1000
1001        Ok(OutputFile {
1002            file,
1003            relative_path,
1004        })
1005    }
1006
1007    pub fn get_relative_path(&self) -> &Path {
1008        &self.relative_path
1009    }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::median;
1015
1016    #[test]
1017    fn test_median_odd() {
1018        let data = vec![1.0, 3.0, 2.0];
1019        assert_eq!(median(data), 2.0);
1020    }
1021
1022    #[test]
1023    fn test_median_even() {
1024        let data = vec![1.0, 2.0, 3.0, 4.0];
1025        assert_eq!(median(data), 2.5);
1026    }
1027
1028    #[test]
1029    fn test_median_single() {
1030        let data = vec![42.0];
1031        assert_eq!(median(data), 42.0);
1032    }
1033
1034    #[test]
1035    fn test_median_sorted() {
1036        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
1037        assert_eq!(median(data), 15.0);
1038    }
1039
1040    #[test]
1041    fn test_median_unsorted() {
1042        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
1043        assert_eq!(median(data), 30.0);
1044    }
1045
1046    #[test]
1047    fn test_median_with_duplicates() {
1048        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
1049        assert_eq!(median(data), 2.0);
1050    }
1051
1052    use super::weighted_median;
1053
1054    #[test]
1055    fn test_weighted_median_simple() {
1056        let data = vec![1.0, 2.0, 3.0];
1057        let weights = vec![0.2, 0.5, 0.3];
1058        assert_eq!(weighted_median(&data, &weights), 2.0);
1059    }
1060
1061    #[test]
1062    fn test_weighted_median_even_weights() {
1063        let data = vec![1.0, 2.0, 3.0, 4.0];
1064        let weights = vec![0.25, 0.25, 0.25, 0.25];
1065        assert_eq!(weighted_median(&data, &weights), 2.5);
1066    }
1067
1068    #[test]
1069    fn test_weighted_median_single_element() {
1070        let data = vec![42.0];
1071        let weights = vec![1.0];
1072        assert_eq!(weighted_median(&data, &weights), 42.0);
1073    }
1074
1075    #[test]
1076    #[should_panic(expected = "The length of data and weights must be the same")]
1077    fn test_weighted_median_mismatched_lengths() {
1078        let data = vec![1.0, 2.0, 3.0];
1079        let weights = vec![0.1, 0.2];
1080        weighted_median(&data, &weights);
1081    }
1082
1083    #[test]
1084    fn test_weighted_median_all_same_elements() {
1085        let data = vec![5.0, 5.0, 5.0, 5.0];
1086        let weights = vec![0.1, 0.2, 0.3, 0.4];
1087        assert_eq!(weighted_median(&data, &weights), 5.0);
1088    }
1089
1090    #[test]
1091    #[should_panic(expected = "Weights must be non-negative")]
1092    fn test_weighted_median_negative_weights() {
1093        let data = vec![1.0, 2.0, 3.0, 4.0];
1094        let weights = vec![0.2, -0.5, 0.5, 0.8];
1095        assert_eq!(weighted_median(&data, &weights), 4.0);
1096    }
1097
1098    #[test]
1099    fn test_weighted_median_unsorted_data() {
1100        let data = vec![3.0, 1.0, 4.0, 2.0];
1101        let weights = vec![0.1, 0.3, 0.4, 0.2];
1102        assert_eq!(weighted_median(&data, &weights), 2.5);
1103    }
1104
1105    #[test]
1106    fn test_weighted_median_with_zero_weights() {
1107        let data = vec![1.0, 2.0, 3.0, 4.0];
1108        let weights = vec![0.0, 0.0, 1.0, 0.0];
1109        assert_eq!(weighted_median(&data, &weights), 3.0);
1110    }
1111}