pmcore/routines/
output.rs

1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::settings::Settings;
4use crate::structs::psi::Psi;
5use crate::structs::theta::Theta;
6use anyhow::{bail, Context, Result};
7use csv::WriterBuilder;
8use faer::linalg::zip::IntoView;
9use faer::{Col, Mat};
10use faer_ext::IntoNdarray;
11use ndarray::{Array, Array1, Array2, Axis};
12use pharmsol::prelude::data::*;
13use pharmsol::prelude::simulator::Equation;
14use serde::Serialize;
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18/// Defines the result objects from an NPAG run
19/// An [NPResult] contains the necessary information to generate predictions and summary statistics
20#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22    equation: E,
23    data: Data,
24    theta: Theta,
25    psi: Psi,
26    w: Col<f64>,
27    objf: f64,
28    cycles: usize,
29    status: Status,
30    par_names: Vec<String>,
31    settings: Settings,
32    cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37    /// Create a new NPResult object
38    pub fn new(
39        equation: E,
40        data: Data,
41        theta: Theta,
42        psi: Psi,
43        w: Col<f64>,
44        objf: f64,
45        cycles: usize,
46        status: Status,
47        settings: Settings,
48        cyclelog: CycleLog,
49    ) -> Self {
50        // TODO: Add support for fixed and constant parameters
51
52        let par_names = settings.parameters().names();
53
54        Self {
55            equation,
56            data,
57            theta,
58            psi,
59            w,
60            objf,
61            cycles,
62            status,
63            par_names,
64            settings,
65            cyclelog,
66        }
67    }
68
69    pub fn cycles(&self) -> usize {
70        self.cycles
71    }
72
73    pub fn objf(&self) -> f64 {
74        self.objf
75    }
76
77    pub fn converged(&self) -> bool {
78        self.status == Status::Converged
79    }
80
81    pub fn get_theta(&self) -> &Theta {
82        &self.theta
83    }
84
85    /// Get the [Psi] structure
86    pub fn psi(&self) -> &Psi {
87        &self.psi
88    }
89
90    /// Get the weights (probabilities) of the support points
91    pub fn w(&self) -> &Col<f64> {
92        &self.w
93    }
94
95    pub fn write_outputs(&self) -> Result<()> {
96        if self.settings.output().write {
97            self.settings.write()?;
98            let idelta: f64 = self.settings.predictions().idelta;
99            let tad = self.settings.predictions().tad;
100            self.cyclelog.write(&self.settings)?;
101            self.write_obs().context("Failed to write observations")?;
102            self.write_theta().context("Failed to write theta")?;
103            self.write_obspred()
104                .context("Failed to write observed-predicted file")?;
105            self.write_pred(idelta, tad)
106                .context("Failed to write predictions")?;
107            self.write_covs().context("Failed to write covariates")?;
108            self.write_posterior()
109                .context("Failed to write posterior")?;
110        }
111        Ok(())
112    }
113
114    /// Writes the observations and predictions to a single file
115    pub fn write_obspred(&self) -> Result<()> {
116        tracing::debug!("Writing observations and predictions...");
117
118        #[derive(Debug, Clone, Serialize)]
119        struct Row {
120            id: String,
121            time: f64,
122            outeq: usize,
123            block: usize,
124            obs: f64,
125            pop_mean: f64,
126            pop_median: f64,
127            post_mean: f64,
128            post_median: f64,
129        }
130
131        let theta: Array2<f64> = self
132            .theta
133            .matrix()
134            .clone()
135            .as_mut()
136            .into_ndarray()
137            .to_owned();
138        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
139        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
140
141        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
142            .context("Failed to calculate posterior mean and median")?;
143
144        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
145            .context("Failed to calculate posterior mean and median")?;
146
147        let subjects = self.data.get_subjects();
148        if subjects.len() != post_mean.nrows() {
149            bail!(
150                "Number of subjects: {} and number of posterior means: {} do not match",
151                subjects.len(),
152                post_mean.nrows()
153            );
154        }
155
156        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
157        let mut writer = WriterBuilder::new()
158            .has_headers(true)
159            .from_writer(&outputfile.file);
160
161        for (i, subject) in subjects.iter().enumerate() {
162            for occasion in subject.occasions() {
163                let id = subject.id();
164                let occ = occasion.index();
165
166                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
167
168                // Population predictions
169                let pop_mean_pred = self
170                    .equation
171                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
172                    .0
173                    .get_predictions()
174                    .clone();
175
176                let pop_median_pred = self
177                    .equation
178                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
179                    .0
180                    .get_predictions()
181                    .clone();
182
183                // Posterior predictions
184                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
185                let post_mean_pred = self
186                    .equation
187                    .simulate_subject(&subject, &post_mean_spp, None)?
188                    .0
189                    .get_predictions()
190                    .clone();
191                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
192                let post_median_pred = self
193                    .equation
194                    .simulate_subject(&subject, &post_median_spp, None)?
195                    .0
196                    .get_predictions()
197                    .clone();
198                assert_eq!(
199                    pop_mean_pred.len(),
200                    pop_median_pred.len(),
201                    "The number of predictions do not match (pop_mean vs pop_median)"
202                );
203
204                assert_eq!(
205                    post_mean_pred.len(),
206                    post_median_pred.len(),
207                    "The number of predictions do not match (post_mean vs post_median)"
208                );
209
210                assert_eq!(
211                    pop_mean_pred.len(),
212                    post_mean_pred.len(),
213                    "The number of predictions do not match (pop_mean vs post_mean)"
214                );
215
216                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
217                    pop_mean_pred
218                        .iter()
219                        .zip(pop_median_pred.iter())
220                        .zip(post_mean_pred.iter())
221                        .zip(post_median_pred.iter())
222                {
223                    let row = Row {
224                        id: id.clone(),
225                        time: pop_mean_pred.time(),
226                        outeq: pop_mean_pred.outeq(),
227                        block: occ,
228                        obs: pop_mean_pred.observation(),
229                        pop_mean: pop_mean_pred.prediction(),
230                        pop_median: pop_median_pred.prediction(),
231                        post_mean: post_mean_pred.prediction(),
232                        post_median: post_median_pred.prediction(),
233                    };
234                    writer.serialize(row)?;
235                }
236            }
237        }
238        writer.flush()?;
239        tracing::info!(
240            "Observations with predictions written to {:?}",
241            &outputfile.get_relative_path()
242        );
243        Ok(())
244    }
245
246    /// Writes theta, which contains the population support points and their associated probabilities
247    /// Each row is one support point, the last column being probability
248    pub fn write_theta(&self) -> Result<()> {
249        tracing::debug!("Writing population parameter distribution...");
250
251        let theta = &self.theta;
252        let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
253        /* let w = if self.w.len() != theta.matrix().nrows() {
254                   tracing::warn!("Number of weights and number of support points do not match. Setting all weights to 0.");
255                   Array1::zeros(theta.matrix().nrows())
256               } else {
257                   self.w.clone()
258               };
259        */
260        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
261            .context("Failed to create output file for theta")?;
262
263        let mut writer = WriterBuilder::new()
264            .has_headers(true)
265            .from_writer(&outputfile.file);
266
267        // Create the headers
268        let mut theta_header = self.par_names.clone();
269        theta_header.push("prob".to_string());
270        writer.write_record(&theta_header)?;
271
272        // Write contents
273        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
274            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
275            row.push(w_val.to_string());
276            writer.write_record(&row)?;
277        }
278        writer.flush()?;
279        tracing::info!(
280            "Population parameter distribution written to {:?}",
281            &outputfile.get_relative_path()
282        );
283        Ok(())
284    }
285
286    /// Writes the posterior support points for each individual
287    pub fn write_posterior(&self) -> Result<()> {
288        tracing::debug!("Writing posterior parameter probabilities...");
289        let theta = &self.theta;
290        let w = &self.w;
291        let psi = &self.psi;
292
293        // Calculate the posterior probabilities
294        let posterior = posterior(psi, w)?;
295
296        // Create the output folder if it doesn't exist
297        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
298            Ok(of) => of,
299            Err(e) => {
300                tracing::error!("Failed to create output file: {}", e);
301                return Err(e.context("Failed to create output file"));
302            }
303        };
304
305        // Create a new writer
306        let mut writer = WriterBuilder::new()
307            .has_headers(true)
308            .from_writer(&outputfile.file);
309
310        // Create the headers
311        writer.write_field("id")?;
312        writer.write_field("point")?;
313        theta.param_names().iter().for_each(|name| {
314            writer.write_field(name).unwrap();
315        });
316        writer.write_field("prob")?;
317        writer.write_record(None::<&[u8]>)?;
318
319        // Write contents
320        let subjects = self.data.get_subjects();
321        posterior.row_iter().enumerate().for_each(|(i, row)| {
322            let subject = subjects.get(i).unwrap();
323            let id = subject.id();
324
325            row.iter().enumerate().for_each(|(spp, prob)| {
326                writer.write_field(id.clone()).unwrap();
327                writer.write_field(i.to_string()).unwrap();
328
329                theta.matrix().row(spp).iter().for_each(|val| {
330                    writer.write_field(val.to_string()).unwrap();
331                });
332
333                writer.write_field(prob.to_string()).unwrap();
334                writer.write_record(None::<&[u8]>).unwrap();
335            });
336        });
337
338        writer.flush()?;
339        tracing::info!(
340            "Posterior parameters written to {:?}",
341            &outputfile.get_relative_path()
342        );
343
344        Ok(())
345    }
346
347    /// Write the observations, which is the reformatted input data
348    pub fn write_obs(&self) -> Result<()> {
349        tracing::debug!("Writing observations...");
350        let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
351        write_pmetrics_observations(&self.data, &outputfile.file)?;
352        tracing::info!(
353            "Observations written to {:?}",
354            &outputfile.get_relative_path()
355        );
356        Ok(())
357    }
358
359    /// Writes the predictions
360    pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
361        tracing::debug!("Writing predictions...");
362        let data = self.data.expand(idelta, tad);
363
364        let theta: Array2<f64> = self
365            .theta
366            .matrix()
367            .clone()
368            .as_mut()
369            .into_ndarray()
370            .to_owned();
371        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
372        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
373
374        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
375            .context("Failed to calculate posterior mean and median")?;
376
377        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
378            .context("Failed to calculate population mean and median")?;
379
380        let subjects = data.get_subjects();
381        if subjects.len() != post_mean.nrows() {
382            bail!("Number of subjects and number of posterior means do not match");
383        }
384
385        let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
386        let mut writer = WriterBuilder::new()
387            .has_headers(true)
388            .from_writer(&outputfile.file);
389
390        #[derive(Debug, Clone, Serialize)]
391        struct Row {
392            id: String,
393            time: f64,
394            outeq: usize,
395            block: usize,
396            pop_mean: f64,
397            pop_median: f64,
398            post_mean: f64,
399            post_median: f64,
400        }
401
402        for (i, subject) in subjects.iter().enumerate() {
403            for occasion in subject.occasions() {
404                let id = subject.id();
405                let block = occasion.index();
406
407                // Create a new subject with only the current occasion
408                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
409
410                // Population predictions
411                let pop_mean_pred = self
412                    .equation
413                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
414                    .0
415                    .get_predictions()
416                    .clone();
417                let pop_median_pred = self
418                    .equation
419                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
420                    .0
421                    .get_predictions()
422                    .clone();
423
424                // Posterior predictions
425                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
426                let post_mean_pred = self
427                    .equation
428                    .simulate_subject(&subject, &post_mean_spp, None)?
429                    .0
430                    .get_predictions()
431                    .clone();
432                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
433                let post_median_pred = self
434                    .equation
435                    .simulate_subject(&subject, &post_median_spp, None)?
436                    .0
437                    .get_predictions()
438                    .clone();
439
440                // Write predictions for each time point
441                for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
442                    .iter()
443                    .zip(pop_median_pred.iter())
444                    .zip(post_mean_pred.iter())
445                    .zip(post_median_pred.iter())
446                {
447                    let row = Row {
448                        id: id.clone(),
449                        time: pop_mean.time(),
450                        outeq: pop_mean.outeq(),
451                        block,
452                        pop_mean: pop_mean.prediction(),
453                        pop_median: pop_median.prediction(),
454                        post_mean: post_mean.prediction(),
455                        post_median: post_median.prediction(),
456                    };
457                    writer.serialize(row)?;
458                }
459            }
460        }
461        writer.flush()?;
462        tracing::info!(
463            "Predictions written to {:?}",
464            &outputfile.get_relative_path()
465        );
466        Ok(())
467    }
468
469    /// Writes the covariates
470    pub fn write_covs(&self) -> Result<()> {
471        tracing::debug!("Writing covariates...");
472        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
473        let mut writer = WriterBuilder::new()
474            .has_headers(true)
475            .from_writer(&outputfile.file);
476
477        // Collect all unique covariate names
478        let mut covariate_names = std::collections::HashSet::new();
479        for subject in self.data.get_subjects() {
480            for occasion in subject.occasions() {
481                let cov = occasion.covariates();
482                let covmap = cov.covariates();
483                for cov_name in covmap.keys() {
484                    covariate_names.insert(cov_name.clone());
485                }
486            }
487        }
488        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
489        covariate_names.sort(); // Ensure consistent order
490
491        // Write the header row: id, time, block, covariate names
492        let mut headers = vec!["id", "time", "block"];
493        headers.extend(covariate_names.iter().map(|s| s.as_str()));
494        writer.write_record(&headers)?;
495
496        // Write the data rows
497        for subject in self.data.get_subjects() {
498            for occasion in subject.occasions() {
499                let cov = occasion.covariates();
500                let covmap = cov.covariates();
501
502                for event in occasion.get_events(&None, &None, false) {
503                    let time = match event {
504                        Event::Bolus(bolus) => bolus.time(),
505                        Event::Infusion(infusion) => infusion.time(),
506                        Event::Observation(observation) => observation.time(),
507                    };
508
509                    let mut row: Vec<String> = Vec::new();
510                    row.push(subject.id().clone());
511                    row.push(time.to_string());
512                    row.push(occasion.index().to_string());
513
514                    // Add covariate values to the row
515                    for cov_name in &covariate_names {
516                        if let Some(cov) = covmap.get(cov_name) {
517                            if let Some(value) = cov.interpolate(time) {
518                                row.push(value.to_string());
519                            } else {
520                                row.push(String::new());
521                            }
522                        } else {
523                            row.push(String::new());
524                        }
525                    }
526
527                    writer.write_record(&row)?;
528                }
529            }
530        }
531
532        writer.flush()?;
533        tracing::info!(
534            "Covariates written to {:?}",
535            &outputfile.get_relative_path()
536        );
537        Ok(())
538    }
539}
540
541/// An [NPCycle] object contains the summary of a cycle
542/// It holds the following information:
543/// - `cycle`: The cycle number
544/// - `objf`: The objective function value
545/// - `gamlam`: The assay noise parameter, either gamma or lambda
546/// - `theta`: The support points and their associated probabilities
547/// - `nspp`: The number of support points
548/// - `delta_objf`: The change in objective function value from last cycle
549/// - `converged`: Whether the algorithm has reached convergence
550#[derive(Debug, Clone)]
551pub struct NPCycle {
552    pub cycle: usize,
553    pub objf: f64,
554    pub error_models: ErrorModels,
555    pub theta: Theta,
556    pub nspp: usize,
557    pub delta_objf: f64,
558    pub status: Status,
559}
560
561impl NPCycle {
562    pub fn new(
563        cycle: usize,
564        objf: f64,
565        error_models: ErrorModels,
566        theta: Theta,
567        nspp: usize,
568        delta_objf: f64,
569        status: Status,
570    ) -> Self {
571        Self {
572            cycle,
573            objf,
574            error_models,
575            theta,
576            nspp,
577            delta_objf,
578            status,
579        }
580    }
581
582    pub fn placeholder() -> Self {
583        Self {
584            cycle: 0,
585            objf: 0.0,
586            error_models: ErrorModels::default(),
587            theta: Theta::new(),
588            nspp: 0,
589            delta_objf: 0.0,
590            status: Status::Starting,
591        }
592    }
593}
594
595/// This holdes a vector of [NPCycle] objects to provide a more detailed log
596#[derive(Debug, Clone)]
597pub struct CycleLog {
598    pub cycles: Vec<NPCycle>,
599}
600
601impl CycleLog {
602    pub fn new() -> Self {
603        Self { cycles: Vec::new() }
604    }
605
606    pub fn push(&mut self, cycle: NPCycle) {
607        self.cycles.push(cycle);
608    }
609
610    pub fn write(&self, settings: &Settings) -> Result<()> {
611        tracing::debug!("Writing cycles...");
612        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
613        let mut writer = WriterBuilder::new()
614            .has_headers(false)
615            .from_writer(&outputfile.file);
616
617        // Write headers
618        writer.write_field("cycle")?;
619        writer.write_field("converged")?;
620        writer.write_field("status")?;
621        writer.write_field("neg2ll")?;
622        writer.write_field("nspp")?;
623        if let Some(first_cycle) = self.cycles.first() {
624            first_cycle.error_models.iter().try_for_each(
625                |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> {
626                    match errmod {
627                        ErrorModel::Additive { .. } => {
628                            writer.write_field(format!("gamlam.{}", outeq))?;
629                        }
630                        ErrorModel::Proportional { .. } => {
631                            writer.write_field(format!("gamlam.{}", outeq))?;
632                        }
633                        ErrorModel::None { .. } => {}
634                    }
635                    Ok(())
636                },
637            )?;
638        }
639
640        let parameter_names = settings.parameters().names();
641        for param_name in &parameter_names {
642            writer.write_field(format!("{}.mean", param_name))?;
643            writer.write_field(format!("{}.median", param_name))?;
644            writer.write_field(format!("{}.sd", param_name))?;
645        }
646
647        writer.write_record(None::<&[u8]>)?;
648
649        for cycle in &self.cycles {
650            writer.write_field(format!("{}", cycle.cycle))?;
651            writer.write_field(format!("{}", cycle.status == Status::Converged))?;
652            writer.write_field(format!("{}", cycle.status))?;
653            writer.write_field(format!("{}", cycle.objf))?;
654            writer
655                .write_field(format!("{}", cycle.theta.nspp()))
656                .unwrap();
657
658            // Write the error models
659            cycle.error_models.iter().try_for_each(
660                |(_, errmod): (usize, &ErrorModel)| -> Result<()> {
661                    match errmod {
662                        ErrorModel::Additive { .. } => {
663                            writer.write_field(format!("{:.5}", errmod.scalar()?))?;
664                        }
665                        ErrorModel::Proportional { .. } => {
666                            writer.write_field(format!("{:.5}", errmod.scalar()?))?;
667                        }
668                        ErrorModel::None { .. } => {}
669                    }
670                    Ok(())
671                },
672            )?;
673
674            for param in cycle.theta.matrix().col_iter() {
675                let param_values: Vec<f64> = param.iter().cloned().collect();
676
677                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
678                let median = median(param_values.clone());
679                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
680                    / (param_values.len() as f64 - 1.0);
681
682                writer.write_field(format!("{}", mean))?;
683                writer.write_field(format!("{}", median))?;
684                writer.write_field(format!("{}", std))?;
685            }
686            writer.write_record(None::<&[u8]>)?;
687        }
688        writer.flush()?;
689        tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
690        Ok(())
691    }
692}
693
694impl Default for CycleLog {
695    fn default() -> Self {
696        Self::new()
697    }
698}
699
700pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
701    if psi.matrix().ncols() != w.nrows() {
702        bail!(
703            "Number of rows in psi ({}) and number of weights ({}) do not match.",
704            psi.matrix().nrows(),
705            w.nrows()
706        );
707    }
708
709    let psi_matrix = psi.matrix();
710    let py = psi_matrix * w;
711
712    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
713        psi_matrix.get(i, j) * w.get(j) / py.get(i)
714    });
715
716    Ok(posterior)
717}
718
719pub fn median(data: Vec<f64>) -> f64 {
720    let mut data = data.clone();
721    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
722
723    let size = data.len();
724    match size {
725        even if even % 2 == 0 => {
726            let fst = data.get(even / 2 - 1).unwrap();
727            let snd = data.get(even / 2).unwrap();
728            (fst + snd) / 2.0
729        }
730        odd => *data.get(odd / 2_usize).unwrap(),
731    }
732}
733
734fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
735    // Ensure the data and weights arrays have the same length
736    assert_eq!(
737        data.len(),
738        weights.len(),
739        "The length of data and weights must be the same"
740    );
741    assert!(
742        weights.iter().all(|&x| x >= 0.0),
743        "Weights must be non-negative, weights: {:?}",
744        weights
745    );
746
747    // Create a vector of tuples (data, weight)
748    let mut weighted_data: Vec<(f64, f64)> = data
749        .iter()
750        .zip(weights.iter())
751        .map(|(&d, &w)| (d, w))
752        .collect();
753
754    // Sort the vector by the data values
755    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
756
757    // Calculate the cumulative sum of weights
758    let total_weight: f64 = weights.sum();
759    let mut cumulative_sum = 0.0;
760
761    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
762        cumulative_sum += weight;
763
764        if cumulative_sum == total_weight / 2.0 {
765            // If the cumulative sum equals half the total weight, average this value with the next
766            if i + 1 < weighted_data.len() {
767                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
768            } else {
769                return weighted_data[i].0;
770            }
771        } else if cumulative_sum > total_weight / 2.0 {
772            return weighted_data[i].0;
773        }
774    }
775
776    unreachable!("The function should have returned a value before reaching this point.");
777}
778
779pub fn population_mean_median(
780    theta: &Array2<f64>,
781    w: &Array1<f64>,
782) -> Result<(Array1<f64>, Array1<f64>)> {
783    let w = if w.is_empty() {
784        tracing::warn!("w.len() == 0, setting all weights to 1/n");
785        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
786    } else {
787        w.clone()
788    };
789    // Check for compatible sizes
790    if theta.nrows() != w.len() {
791        bail!(
792            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
793            theta.nrows(),
794            w.len()
795        );
796    }
797
798    let mut mean = Array1::zeros(theta.ncols());
799    let mut median = Array1::zeros(theta.ncols());
800
801    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
802        // Calculate the weighted mean
803        let col = theta.column(i).to_owned() * w.to_owned();
804        *mn = col.sum();
805
806        // Calculate the median
807        let ct = theta.column(i);
808        let mut params = vec![];
809        let mut weights = vec![];
810        for (ti, wi) in ct.iter().zip(w.clone()) {
811            params.push(*ti);
812            weights.push(wi);
813        }
814
815        *mdn = weighted_median(&Array::from(params), &Array::from(weights));
816    }
817
818    Ok((mean, median))
819}
820
821pub fn posterior_mean_median(
822    theta: &Array2<f64>,
823    psi: &Array2<f64>,
824    w: &Array1<f64>,
825) -> Result<(Array2<f64>, Array2<f64>)> {
826    let mut mean = Array2::zeros((0, theta.ncols()));
827    let mut median = Array2::zeros((0, theta.ncols()));
828
829    let w = if w.is_empty() {
830        tracing::warn!("w is empty, setting all weights to 1/n");
831        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
832    } else {
833        w.clone()
834    };
835
836    // Check for compatible sizes
837    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
838        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
839    }
840
841    // Normalize psi to get probabilities of each spp for each id
842    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
843    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
844        let row_w = row.to_owned() * w.to_owned();
845        let row_sum = row_w.sum();
846        let row_norm = if row_sum == 0.0 {
847            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
848            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
849        } else {
850            &row_w / row_sum
851        };
852        psi_norm.push_row(row_norm.view())?;
853    }
854    if psi_norm.iter().any(|&x| x.is_nan()) {
855        dbg!(&psi);
856        bail!("NaN values found in psi_norm");
857    };
858
859    // Transpose normalized psi to get ID (col) by prob (row)
860    // let psi_norm_transposed = psi_norm.t();
861
862    // For each subject..
863    for probs in psi_norm.axis_iter(Axis(0)) {
864        let mut post_mean: Vec<f64> = Vec::new();
865        let mut post_median: Vec<f64> = Vec::new();
866
867        // For each parameter
868        for pars in theta.axis_iter(Axis(1)) {
869            // Calculate the mean
870            let weighted_par = &probs * &pars;
871            let the_mean = weighted_par.sum();
872            post_mean.push(the_mean);
873
874            // Calculate the median
875            let median = weighted_median(&pars.to_owned(), &probs.to_owned());
876            post_median.push(median);
877        }
878
879        mean.push_row(Array::from(post_mean.clone()).view())?;
880        median.push_row(Array::from(post_median.clone()).view())?;
881    }
882
883    Ok((mean, median))
884}
885
886/// Contains all the necessary information of an output file
887#[derive(Debug)]
888pub struct OutputFile {
889    pub file: File,
890    pub relative_path: PathBuf,
891}
892
893impl OutputFile {
894    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
895        let relative_path = Path::new(&folder).join(file_name);
896
897        if let Some(parent) = relative_path.parent() {
898            create_dir_all(parent)
899                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
900        }
901
902        let file = OpenOptions::new()
903            .write(true)
904            .create(true)
905            .truncate(true)
906            .open(&relative_path)
907            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
908
909        Ok(OutputFile {
910            file,
911            relative_path,
912        })
913    }
914
915    pub fn get_relative_path(&self) -> &Path {
916        &self.relative_path
917    }
918}
919
920pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
921    let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
922
923    writer.write_record(["id", "block", "time", "out", "outeq"])?;
924    for subject in data.get_subjects() {
925        for occasion in subject.occasions() {
926            for event in occasion.get_events(&None, &None, false) {
927                if let Event::Observation(event) = event {
928                    writer.write_record([
929                        subject.id(),
930                        &occasion.index().to_string(),
931                        &event.time().to_string(),
932                        &event.value().to_string(),
933                        &event.outeq().to_string(),
934                    ])?;
935                }
936            }
937        }
938    }
939    Ok(())
940}
941
942#[cfg(test)]
943mod tests {
944    use super::median;
945
946    #[test]
947    fn test_median_odd() {
948        let data = vec![1.0, 3.0, 2.0];
949        assert_eq!(median(data), 2.0);
950    }
951
952    #[test]
953    fn test_median_even() {
954        let data = vec![1.0, 2.0, 3.0, 4.0];
955        assert_eq!(median(data), 2.5);
956    }
957
958    #[test]
959    fn test_median_single() {
960        let data = vec![42.0];
961        assert_eq!(median(data), 42.0);
962    }
963
964    #[test]
965    fn test_median_sorted() {
966        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
967        assert_eq!(median(data), 15.0);
968    }
969
970    #[test]
971    fn test_median_unsorted() {
972        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
973        assert_eq!(median(data), 30.0);
974    }
975
976    #[test]
977    fn test_median_with_duplicates() {
978        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
979        assert_eq!(median(data), 2.0);
980    }
981
982    use super::weighted_median;
983    use ndarray::Array1;
984
985    #[test]
986    fn test_weighted_median_simple() {
987        let data = Array1::from(vec![1.0, 2.0, 3.0]);
988        let weights = Array1::from(vec![0.2, 0.5, 0.3]);
989        assert_eq!(weighted_median(&data, &weights), 2.0);
990    }
991
992    #[test]
993    fn test_weighted_median_even_weights() {
994        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
995        let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
996        assert_eq!(weighted_median(&data, &weights), 2.5);
997    }
998
999    #[test]
1000    fn test_weighted_median_single_element() {
1001        let data = Array1::from(vec![42.0]);
1002        let weights = Array1::from(vec![1.0]);
1003        assert_eq!(weighted_median(&data, &weights), 42.0);
1004    }
1005
1006    #[test]
1007    #[should_panic(expected = "The length of data and weights must be the same")]
1008    fn test_weighted_median_mismatched_lengths() {
1009        let data = Array1::from(vec![1.0, 2.0, 3.0]);
1010        let weights = Array1::from(vec![0.1, 0.2]);
1011        weighted_median(&data, &weights);
1012    }
1013
1014    #[test]
1015    fn test_weighted_median_all_same_elements() {
1016        let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
1017        let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
1018        assert_eq!(weighted_median(&data, &weights), 5.0);
1019    }
1020
1021    #[test]
1022    #[should_panic(expected = "Weights must be non-negative")]
1023    fn test_weighted_median_negative_weights() {
1024        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1025        let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
1026        assert_eq!(weighted_median(&data, &weights), 4.0);
1027    }
1028
1029    #[test]
1030    fn test_weighted_median_unsorted_data() {
1031        let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1032        let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1033        assert_eq!(weighted_median(&data, &weights), 2.5);
1034    }
1035
1036    #[test]
1037    fn test_weighted_median_with_zero_weights() {
1038        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1039        let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1040        assert_eq!(weighted_median(&data, &weights), 3.0);
1041    }
1042}