pmcore/routines/
output.rs

1use crate::prelude::*;
2use crate::structs::psi::Psi;
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use csv::WriterBuilder;
6use faer::linalg::zip::IntoView;
7use faer_ext::IntoNdarray;
8use ndarray::{Array, Array1, Array2, Axis};
9use pharmsol::prelude::data::*;
10use pharmsol::prelude::simulator::Equation;
11use serde::Serialize;
12// use pharmsol::Cache;
13use crate::routines::settings::Settings;
14use faer::{Col, Mat};
15use std::fs::{create_dir_all, File, OpenOptions};
16use std::path::{Path, PathBuf};
17
18/// Defines the result objects from an NPAG run
19/// An [NPResult] contains the necessary information to generate predictions and summary statistics
20#[derive(Debug)]
21pub struct NPResult<E: Equation> {
22    equation: E,
23    data: Data,
24    theta: Theta,
25    psi: Psi,
26    w: Col<f64>,
27    objf: f64,
28    cycles: usize,
29    converged: bool,
30    par_names: Vec<String>,
31    settings: Settings,
32    cyclelog: CycleLog,
33}
34
35#[allow(clippy::too_many_arguments)]
36impl<E: Equation> NPResult<E> {
37    /// Create a new NPResult object
38    pub fn new(
39        equation: E,
40        data: Data,
41        theta: Theta,
42        psi: Psi,
43        w: Col<f64>,
44        objf: f64,
45        cycles: usize,
46        converged: bool,
47        settings: Settings,
48        cyclelog: CycleLog,
49    ) -> Self {
50        // TODO: Add support for fixed and constant parameters
51
52        let par_names = settings.parameters().names();
53
54        Self {
55            equation,
56            data,
57            theta,
58            psi,
59            w,
60            objf,
61            cycles,
62            converged,
63            par_names,
64            settings,
65            cyclelog,
66        }
67    }
68
69    pub fn cycles(&self) -> usize {
70        self.cycles
71    }
72
73    pub fn objf(&self) -> f64 {
74        self.objf
75    }
76
77    pub fn converged(&self) -> bool {
78        self.converged
79    }
80
81    pub fn get_theta(&self) -> &Theta {
82        &self.theta
83    }
84
85    pub fn get_psi(&self) -> &Psi {
86        &self.psi
87    }
88
89    pub fn get_w(&self) -> &Col<f64> {
90        &self.w
91    }
92
93    pub fn write_outputs(&self) -> Result<()> {
94        if self.settings.output().write {
95            self.settings.write()?;
96            let idelta: f64 = self.settings.predictions().idelta;
97            let tad = self.settings.predictions().tad;
98            self.cyclelog.write(&self.settings)?;
99            self.write_obs().context("Failed to write observations")?;
100            self.write_theta().context("Failed to write theta")?;
101            self.write_obspred()
102                .context("Failed to write observed-predicted file")?;
103            self.write_pred(idelta, tad)
104                .context("Failed to write predictions")?;
105            self.write_covs().context("Failed to write covariates")?;
106            self.write_posterior()
107                .context("Failed to write posterior")?;
108        }
109        Ok(())
110    }
111
112    /// Writes the observations and predictions to a single file
113    pub fn write_obspred(&self) -> Result<()> {
114        tracing::debug!("Writing observations and predictions...");
115
116        #[derive(Debug, Clone, Serialize)]
117        struct Row {
118            id: String,
119            time: f64,
120            outeq: usize,
121            block: usize,
122            obs: f64,
123            pop_mean: f64,
124            pop_median: f64,
125            post_mean: f64,
126            post_median: f64,
127        }
128
129        let theta: Array2<f64> = self
130            .theta
131            .matrix()
132            .clone()
133            .as_mut()
134            .into_ndarray()
135            .to_owned();
136        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
137        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
138
139        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
140            .context("Failed to calculate posterior mean and median")?;
141
142        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
143            .context("Failed to calculate posterior mean and median")?;
144
145        let subjects = self.data.get_subjects();
146        if subjects.len() != post_mean.nrows() {
147            bail!(
148                "Number of subjects: {} and number of posterior means: {} do not match",
149                subjects.len(),
150                post_mean.nrows()
151            );
152        }
153
154        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
155        let mut writer = WriterBuilder::new()
156            .has_headers(true)
157            .from_writer(&outputfile.file);
158
159        for (i, subject) in subjects.iter().enumerate() {
160            for occasion in subject.occasions() {
161                let id = subject.id();
162                let occ = occasion.index();
163
164                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
165
166                // Population predictions
167                let pop_mean_pred = self
168                    .equation
169                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
170                    .0
171                    .get_predictions()
172                    .clone();
173
174                let pop_median_pred = self
175                    .equation
176                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
177                    .0
178                    .get_predictions()
179                    .clone();
180
181                // Posterior predictions
182                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
183                let post_mean_pred = self
184                    .equation
185                    .simulate_subject(&subject, &post_mean_spp, None)?
186                    .0
187                    .get_predictions()
188                    .clone();
189                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
190                let post_median_pred = self
191                    .equation
192                    .simulate_subject(&subject, &post_median_spp, None)?
193                    .0
194                    .get_predictions()
195                    .clone();
196                assert_eq!(
197                    pop_mean_pred.len(),
198                    pop_median_pred.len(),
199                    "The number of predictions do not match (pop_mean vs pop_median)"
200                );
201
202                assert_eq!(
203                    post_mean_pred.len(),
204                    post_median_pred.len(),
205                    "The number of predictions do not match (post_mean vs post_median)"
206                );
207
208                assert_eq!(
209                    pop_mean_pred.len(),
210                    post_mean_pred.len(),
211                    "The number of predictions do not match (pop_mean vs post_mean)"
212                );
213
214                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
215                    pop_mean_pred
216                        .iter()
217                        .zip(pop_median_pred.iter())
218                        .zip(post_mean_pred.iter())
219                        .zip(post_median_pred.iter())
220                {
221                    let row = Row {
222                        id: id.clone(),
223                        time: pop_mean_pred.time(),
224                        outeq: pop_mean_pred.outeq(),
225                        block: occ,
226                        obs: pop_mean_pred.observation(),
227                        pop_mean: pop_mean_pred.prediction(),
228                        pop_median: pop_median_pred.prediction(),
229                        post_mean: post_mean_pred.prediction(),
230                        post_median: post_median_pred.prediction(),
231                    };
232                    writer.serialize(row)?;
233                }
234            }
235        }
236        writer.flush()?;
237        tracing::info!(
238            "Observations with predictions written to {:?}",
239            &outputfile.get_relative_path()
240        );
241        Ok(())
242    }
243
244    /// Writes theta, which contains the population support points and their associated probabilities
245    /// Each row is one support point, the last column being probability
246    pub fn write_theta(&self) -> Result<()> {
247        tracing::debug!("Writing population parameter distribution...");
248
249        let theta = &self.theta;
250        let w: Vec<f64> = self.w.clone().into_view().iter().cloned().collect();
251        /* let w = if self.w.len() != theta.matrix().nrows() {
252                   tracing::warn!("Number of weights and number of support points do not match. Setting all weights to 0.");
253                   Array1::zeros(theta.matrix().nrows())
254               } else {
255                   self.w.clone()
256               };
257        */
258        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
259            .context("Failed to create output file for theta")?;
260
261        let mut writer = WriterBuilder::new()
262            .has_headers(true)
263            .from_writer(&outputfile.file);
264
265        // Create the headers
266        let mut theta_header = self.par_names.clone();
267        theta_header.push("prob".to_string());
268        writer.write_record(&theta_header)?;
269
270        // Write contents
271        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
272            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
273            row.push(w_val.to_string());
274            writer.write_record(&row)?;
275        }
276        writer.flush()?;
277        tracing::info!(
278            "Population parameter distribution written to {:?}",
279            &outputfile.get_relative_path()
280        );
281        Ok(())
282    }
283
284    /// Writes the posterior support points for each individual
285    pub fn write_posterior(&self) -> Result<()> {
286        tracing::debug!("Writing posterior parameter probabilities...");
287        let theta = &self.theta;
288        let w = &self.w;
289        let psi = &self.psi;
290
291        // Calculate the posterior probabilities
292        let posterior = posterior(psi, w)?;
293
294        // Create the output folder if it doesn't exist
295        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
296            Ok(of) => of,
297            Err(e) => {
298                tracing::error!("Failed to create output file: {}", e);
299                return Err(e.context("Failed to create output file"));
300            }
301        };
302
303        // Create a new writer
304        let mut writer = WriterBuilder::new()
305            .has_headers(true)
306            .from_writer(&outputfile.file);
307
308        // Create the headers
309        writer.write_field("id")?;
310        writer.write_field("point")?;
311        theta.param_names().iter().for_each(|name| {
312            writer.write_field(name).unwrap();
313        });
314        writer.write_field("prob")?;
315        writer.write_record(None::<&[u8]>)?;
316
317        // Write contents
318        let subjects = self.data.get_subjects();
319        posterior.row_iter().enumerate().for_each(|(i, row)| {
320            let subject = subjects.get(i).unwrap();
321            let id = subject.id();
322
323            row.iter().enumerate().for_each(|(spp, prob)| {
324                writer.write_field(id.clone()).unwrap();
325                writer.write_field(i.to_string()).unwrap();
326
327                theta.matrix().row(spp).iter().for_each(|val| {
328                    writer.write_field(val.to_string()).unwrap();
329                });
330
331                writer.write_field(prob.to_string()).unwrap();
332                writer.write_record(None::<&[u8]>).unwrap();
333            });
334        });
335
336        writer.flush()?;
337        tracing::info!(
338            "Posterior parameters written to {:?}",
339            &outputfile.get_relative_path()
340        );
341
342        Ok(())
343    }
344
345    /// Write the observations, which is the reformatted input data
346    pub fn write_obs(&self) -> Result<()> {
347        tracing::debug!("Writing observations...");
348        let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?;
349        write_pmetrics_observations(&self.data, &outputfile.file)?;
350        tracing::info!(
351            "Observations written to {:?}",
352            &outputfile.get_relative_path()
353        );
354        Ok(())
355    }
356
357    /// Writes the predictions
358    pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> {
359        tracing::debug!("Writing predictions...");
360        let data = self.data.expand(idelta, tad);
361
362        let theta: Array2<f64> = self
363            .theta
364            .matrix()
365            .clone()
366            .as_mut()
367            .into_ndarray()
368            .to_owned();
369        let w: Array1<f64> = self.w.clone().into_view().iter().cloned().collect();
370        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
371
372        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
373            .context("Failed to calculate posterior mean and median")?;
374
375        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
376            .context("Failed to calculate population mean and median")?;
377
378        let subjects = data.get_subjects();
379        if subjects.len() != post_mean.nrows() {
380            bail!("Number of subjects and number of posterior means do not match");
381        }
382
383        let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?;
384        let mut writer = WriterBuilder::new()
385            .has_headers(true)
386            .from_writer(&outputfile.file);
387
388        #[derive(Debug, Clone, Serialize)]
389        struct Row {
390            id: String,
391            time: f64,
392            outeq: usize,
393            block: usize,
394            pop_mean: f64,
395            pop_median: f64,
396            post_mean: f64,
397            post_median: f64,
398        }
399
400        for (i, subject) in subjects.iter().enumerate() {
401            for occasion in subject.occasions() {
402                let id = subject.id();
403                let block = occasion.index();
404
405                // Create a new subject with only the current occasion
406                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
407
408                // Population predictions
409                let pop_mean_pred = self
410                    .equation
411                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
412                    .0
413                    .get_predictions()
414                    .clone();
415                let pop_median_pred = self
416                    .equation
417                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
418                    .0
419                    .get_predictions()
420                    .clone();
421
422                // Posterior predictions
423                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
424                let post_mean_pred = self
425                    .equation
426                    .simulate_subject(&subject, &post_mean_spp, None)?
427                    .0
428                    .get_predictions()
429                    .clone();
430                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
431                let post_median_pred = self
432                    .equation
433                    .simulate_subject(&subject, &post_median_spp, None)?
434                    .0
435                    .get_predictions()
436                    .clone();
437
438                // Write predictions for each time point
439                for (((pop_mean, pop_median), post_mean), post_median) in pop_mean_pred
440                    .iter()
441                    .zip(pop_median_pred.iter())
442                    .zip(post_mean_pred.iter())
443                    .zip(post_median_pred.iter())
444                {
445                    let row = Row {
446                        id: id.clone(),
447                        time: pop_mean.time(),
448                        outeq: pop_mean.outeq(),
449                        block,
450                        pop_mean: pop_mean.prediction(),
451                        pop_median: pop_median.prediction(),
452                        post_mean: post_mean.prediction(),
453                        post_median: post_median.prediction(),
454                    };
455                    writer.serialize(row)?;
456                }
457            }
458        }
459        writer.flush()?;
460        tracing::info!(
461            "Predictions written to {:?}",
462            &outputfile.get_relative_path()
463        );
464        Ok(())
465    }
466
467    /// Writes the covariates
468    pub fn write_covs(&self) -> Result<()> {
469        tracing::debug!("Writing covariates...");
470        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
471        let mut writer = WriterBuilder::new()
472            .has_headers(true)
473            .from_writer(&outputfile.file);
474
475        // Collect all unique covariate names
476        let mut covariate_names = std::collections::HashSet::new();
477        for subject in self.data.get_subjects() {
478            for occasion in subject.occasions() {
479                if let Some(cov) = occasion.get_covariates() {
480                    let covmap = cov.covariates();
481                    for cov_name in covmap.keys() {
482                        covariate_names.insert(cov_name.clone());
483                    }
484                }
485            }
486        }
487        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
488        covariate_names.sort(); // Ensure consistent order
489
490        // Write the header row: id, time, block, covariate names
491        let mut headers = vec!["id", "time", "block"];
492        headers.extend(covariate_names.iter().map(|s| s.as_str()));
493        writer.write_record(&headers)?;
494
495        // Write the data rows
496        for subject in self.data.get_subjects() {
497            for occasion in subject.occasions() {
498                if let Some(cov) = occasion.get_covariates() {
499                    let covmap = cov.covariates();
500
501                    for event in occasion.get_events(&None, &None, false) {
502                        let time = match event {
503                            Event::Bolus(bolus) => bolus.time(),
504                            Event::Infusion(infusion) => infusion.time(),
505                            Event::Observation(observation) => observation.time(),
506                        };
507
508                        let mut row: Vec<String> = Vec::new();
509                        row.push(subject.id().clone());
510                        row.push(time.to_string());
511                        row.push(occasion.index().to_string());
512
513                        // Add covariate values to the row
514                        for cov_name in &covariate_names {
515                            if let Some(cov) = covmap.get(cov_name) {
516                                if let Some(value) = cov.interpolate(time) {
517                                    row.push(value.to_string());
518                                } else {
519                                    row.push(String::new());
520                                }
521                            } else {
522                                row.push(String::new());
523                            }
524                        }
525
526                        writer.write_record(&row)?;
527                    }
528                }
529            }
530        }
531
532        writer.flush()?;
533        tracing::info!(
534            "Covariates written to {:?}",
535            &outputfile.get_relative_path()
536        );
537        Ok(())
538    }
539}
540
541/// An [NPCycle] object contains the summary of a cycle
542/// It holds the following information:
543/// - `cycle`: The cycle number
544/// - `objf`: The objective function value
545/// - `gamlam`: The assay noise parameter, either gamma or lambda
546/// - `theta`: The support points and their associated probabilities
547/// - `nspp`: The number of support points
548/// - `delta_objf`: The change in objective function value from last cycle
549/// - `converged`: Whether the algorithm has reached convergence
550#[derive(Debug, Clone)]
551pub struct NPCycle {
552    pub cycle: usize,
553    pub objf: f64,
554    pub gamlam: f64,
555    pub theta: Theta,
556    pub nspp: usize,
557    pub delta_objf: f64,
558    pub converged: bool,
559}
560
561impl NPCycle {
562    pub fn new(
563        cycle: usize,
564        objf: f64,
565        gamlam: f64,
566        theta: Theta,
567        nspp: usize,
568        delta_objf: f64,
569        converged: bool,
570    ) -> Self {
571        Self {
572            cycle,
573            objf,
574            gamlam,
575            theta,
576            nspp,
577            delta_objf,
578            converged,
579        }
580    }
581
582    pub fn placeholder() -> Self {
583        Self {
584            cycle: 0,
585            objf: 0.0,
586            gamlam: 0.0,
587            theta: Theta::new(),
588            nspp: 0,
589            delta_objf: 0.0,
590            converged: false,
591        }
592    }
593}
594
595/// This holdes a vector of [NPCycle] objects to provide a more detailed log
596#[derive(Debug, Clone)]
597pub struct CycleLog {
598    pub cycles: Vec<NPCycle>,
599}
600
601impl CycleLog {
602    pub fn new() -> Self {
603        Self { cycles: Vec::new() }
604    }
605
606    pub fn push(&mut self, cycle: NPCycle) {
607        self.cycles.push(cycle);
608    }
609
610    pub fn write(&self, settings: &Settings) -> Result<()> {
611        tracing::debug!("Writing cycles...");
612        let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?;
613        let mut writer = WriterBuilder::new()
614            .has_headers(false)
615            .from_writer(&outputfile.file);
616
617        // Write headers
618        writer.write_field("cycle")?;
619        writer.write_field("converged")?;
620        writer.write_field("neg2ll")?;
621        writer.write_field("gamlam")?;
622        writer.write_field("nspp")?;
623
624        let parameter_names = settings.parameters().names();
625        for param_name in &parameter_names {
626            writer.write_field(format!("{}.mean", param_name))?;
627            writer.write_field(format!("{}.median", param_name))?;
628            writer.write_field(format!("{}.sd", param_name))?;
629        }
630
631        writer.write_record(None::<&[u8]>)?;
632
633        for cycle in &self.cycles {
634            writer.write_field(format!("{}", cycle.cycle))?;
635            writer.write_field(format!("{}", cycle.converged))?;
636            writer.write_field(format!("{}", cycle.objf))?;
637            writer.write_field(format!("{}", cycle.gamlam))?;
638            writer
639                .write_field(format!("{}", cycle.theta.matrix().nrows()))
640                .unwrap();
641
642            for param in cycle.theta.matrix().col_iter() {
643                let param_values: Vec<f64> = param.iter().cloned().collect();
644
645                let mean: f64 = param_values.iter().sum::<f64>() / param_values.len() as f64;
646                let median = median(param_values.clone());
647                let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
648                    / (param_values.len() as f64 - 1.0);
649
650                writer.write_field(format!("{}", mean))?;
651                writer.write_field(format!("{}", median))?;
652                writer.write_field(format!("{}", std))?;
653            }
654            writer.write_record(None::<&[u8]>)?;
655        }
656        writer.flush()?;
657        tracing::info!("Cycles written to {:?}", &outputfile.get_relative_path());
658        Ok(())
659    }
660}
661
662impl Default for CycleLog {
663    fn default() -> Self {
664        Self::new()
665    }
666}
667
668pub fn posterior(psi: &Psi, w: &Col<f64>) -> Result<Mat<f64>> {
669    if psi.matrix().ncols() != w.nrows() {
670        bail!(
671            "Number of rows in psi ({}) and number of weights ({}) do not match.",
672            psi.matrix().nrows(),
673            w.nrows()
674        );
675    }
676
677    let psi_matrix = psi.matrix();
678    let py = psi_matrix * w;
679
680    let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| {
681        psi_matrix.get(i, j) * w.get(j) / py.get(i)
682    });
683
684    Ok(posterior)
685}
686
687pub fn median(data: Vec<f64>) -> f64 {
688    let mut data = data.clone();
689    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
690
691    let size = data.len();
692    match size {
693        even if even % 2 == 0 => {
694            let fst = data.get(even / 2 - 1).unwrap();
695            let snd = data.get(even / 2).unwrap();
696            (fst + snd) / 2.0
697        }
698        odd => *data.get(odd / 2_usize).unwrap(),
699    }
700}
701
702fn weighted_median(data: &Array1<f64>, weights: &Array1<f64>) -> f64 {
703    // Ensure the data and weights arrays have the same length
704    assert_eq!(
705        data.len(),
706        weights.len(),
707        "The length of data and weights must be the same"
708    );
709    assert!(
710        weights.iter().all(|&x| x >= 0.0),
711        "Weights must be non-negative, weights: {:?}",
712        weights
713    );
714
715    // Create a vector of tuples (data, weight)
716    let mut weighted_data: Vec<(f64, f64)> = data
717        .iter()
718        .zip(weights.iter())
719        .map(|(&d, &w)| (d, w))
720        .collect();
721
722    // Sort the vector by the data values
723    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
724
725    // Calculate the cumulative sum of weights
726    let total_weight: f64 = weights.sum();
727    let mut cumulative_sum = 0.0;
728
729    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
730        cumulative_sum += weight;
731
732        if cumulative_sum == total_weight / 2.0 {
733            // If the cumulative sum equals half the total weight, average this value with the next
734            if i + 1 < weighted_data.len() {
735                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
736            } else {
737                return weighted_data[i].0;
738            }
739        } else if cumulative_sum > total_weight / 2.0 {
740            return weighted_data[i].0;
741        }
742    }
743
744    unreachable!("The function should have returned a value before reaching this point.");
745}
746
747pub fn population_mean_median(
748    theta: &Array2<f64>,
749    w: &Array1<f64>,
750) -> Result<(Array1<f64>, Array1<f64>)> {
751    let w = if w.is_empty() {
752        tracing::warn!("w.len() == 0, setting all weights to 1/n");
753        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
754    } else {
755        w.clone()
756    };
757    // Check for compatible sizes
758    if theta.nrows() != w.len() {
759        bail!(
760            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
761            theta.nrows(),
762            w.len()
763        );
764    }
765
766    let mut mean = Array1::zeros(theta.ncols());
767    let mut median = Array1::zeros(theta.ncols());
768
769    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
770        // Calculate the weighted mean
771        let col = theta.column(i).to_owned() * w.to_owned();
772        *mn = col.sum();
773
774        // Calculate the median
775        let ct = theta.column(i);
776        let mut params = vec![];
777        let mut weights = vec![];
778        for (ti, wi) in ct.iter().zip(w.clone()) {
779            params.push(*ti);
780            weights.push(wi);
781        }
782
783        *mdn = weighted_median(&Array::from(params), &Array::from(weights));
784    }
785
786    Ok((mean, median))
787}
788
789pub fn posterior_mean_median(
790    theta: &Array2<f64>,
791    psi: &Array2<f64>,
792    w: &Array1<f64>,
793) -> Result<(Array2<f64>, Array2<f64>)> {
794    let mut mean = Array2::zeros((0, theta.ncols()));
795    let mut median = Array2::zeros((0, theta.ncols()));
796
797    let w = if w.is_empty() {
798        tracing::warn!("w is empty, setting all weights to 1/n");
799        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
800    } else {
801        w.clone()
802    };
803
804    // Check for compatible sizes
805    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
806        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
807    }
808
809    // Normalize psi to get probabilities of each spp for each id
810    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
811    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
812        let row_w = row.to_owned() * w.to_owned();
813        let row_sum = row_w.sum();
814        let row_norm = if row_sum == 0.0 {
815            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
816            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
817        } else {
818            &row_w / row_sum
819        };
820        psi_norm.push_row(row_norm.view())?;
821    }
822    if psi_norm.iter().any(|&x| x.is_nan()) {
823        dbg!(&psi);
824        bail!("NaN values found in psi_norm");
825    };
826
827    // Transpose normalized psi to get ID (col) by prob (row)
828    // let psi_norm_transposed = psi_norm.t();
829
830    // For each subject..
831    for probs in psi_norm.axis_iter(Axis(0)) {
832        let mut post_mean: Vec<f64> = Vec::new();
833        let mut post_median: Vec<f64> = Vec::new();
834
835        // For each parameter
836        for pars in theta.axis_iter(Axis(1)) {
837            // Calculate the mean
838            let weighted_par = &probs * &pars;
839            let the_mean = weighted_par.sum();
840            post_mean.push(the_mean);
841
842            // Calculate the median
843            let median = weighted_median(&pars.to_owned(), &probs.to_owned());
844            post_median.push(median);
845        }
846
847        mean.push_row(Array::from(post_mean.clone()).view())?;
848        median.push_row(Array::from(post_median.clone()).view())?;
849    }
850
851    Ok((mean, median))
852}
853
854/// Contains all the necessary information of an output file
855#[derive(Debug)]
856pub struct OutputFile {
857    pub file: File,
858    pub relative_path: PathBuf,
859}
860
861impl OutputFile {
862    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
863        let relative_path = Path::new(&folder).join(file_name);
864
865        if let Some(parent) = relative_path.parent() {
866            create_dir_all(parent)
867                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
868        }
869
870        let file = OpenOptions::new()
871            .write(true)
872            .create(true)
873            .truncate(true)
874            .open(&relative_path)
875            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
876
877        Ok(OutputFile {
878            file,
879            relative_path,
880        })
881    }
882
883    pub fn get_relative_path(&self) -> &Path {
884        &self.relative_path
885    }
886}
887
888pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> {
889    let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
890
891    writer.write_record(["id", "block", "time", "out", "outeq"])?;
892    for subject in data.get_subjects() {
893        for occasion in subject.occasions() {
894            for event in occasion.get_events(&None, &None, false) {
895                if let Event::Observation(event) = event {
896                    writer.write_record([
897                        subject.id(),
898                        &occasion.index().to_string(),
899                        &event.time().to_string(),
900                        &event.value().to_string(),
901                        &event.outeq().to_string(),
902                    ])?;
903                }
904            }
905        }
906    }
907    Ok(())
908}
909
910#[cfg(test)]
911mod tests {
912    use super::median;
913
914    #[test]
915    fn test_median_odd() {
916        let data = vec![1.0, 3.0, 2.0];
917        assert_eq!(median(data), 2.0);
918    }
919
920    #[test]
921    fn test_median_even() {
922        let data = vec![1.0, 2.0, 3.0, 4.0];
923        assert_eq!(median(data), 2.5);
924    }
925
926    #[test]
927    fn test_median_single() {
928        let data = vec![42.0];
929        assert_eq!(median(data), 42.0);
930    }
931
932    #[test]
933    fn test_median_sorted() {
934        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
935        assert_eq!(median(data), 15.0);
936    }
937
938    #[test]
939    fn test_median_unsorted() {
940        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
941        assert_eq!(median(data), 30.0);
942    }
943
944    #[test]
945    fn test_median_with_duplicates() {
946        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
947        assert_eq!(median(data), 2.0);
948    }
949
950    use super::weighted_median;
951    use ndarray::Array1;
952
953    #[test]
954    fn test_weighted_median_simple() {
955        let data = Array1::from(vec![1.0, 2.0, 3.0]);
956        let weights = Array1::from(vec![0.2, 0.5, 0.3]);
957        assert_eq!(weighted_median(&data, &weights), 2.0);
958    }
959
960    #[test]
961    fn test_weighted_median_even_weights() {
962        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
963        let weights = Array1::from(vec![0.25, 0.25, 0.25, 0.25]);
964        assert_eq!(weighted_median(&data, &weights), 2.5);
965    }
966
967    #[test]
968    fn test_weighted_median_single_element() {
969        let data = Array1::from(vec![42.0]);
970        let weights = Array1::from(vec![1.0]);
971        assert_eq!(weighted_median(&data, &weights), 42.0);
972    }
973
974    #[test]
975    #[should_panic(expected = "The length of data and weights must be the same")]
976    fn test_weighted_median_mismatched_lengths() {
977        let data = Array1::from(vec![1.0, 2.0, 3.0]);
978        let weights = Array1::from(vec![0.1, 0.2]);
979        weighted_median(&data, &weights);
980    }
981
982    #[test]
983    fn test_weighted_median_all_same_elements() {
984        let data = Array1::from(vec![5.0, 5.0, 5.0, 5.0]);
985        let weights = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
986        assert_eq!(weighted_median(&data, &weights), 5.0);
987    }
988
989    #[test]
990    #[should_panic(expected = "Weights must be non-negative")]
991    fn test_weighted_median_negative_weights() {
992        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
993        let weights = Array1::from(vec![0.2, -0.5, 0.5, 0.8]);
994        assert_eq!(weighted_median(&data, &weights), 4.0);
995    }
996
997    #[test]
998    fn test_weighted_median_unsorted_data() {
999        let data = Array1::from(vec![3.0, 1.0, 4.0, 2.0]);
1000        let weights = Array1::from(vec![0.1, 0.3, 0.4, 0.2]);
1001        assert_eq!(weighted_median(&data, &weights), 2.5);
1002    }
1003
1004    #[test]
1005    fn test_weighted_median_with_zero_weights() {
1006        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1007        let weights = Array1::from(vec![0.0, 0.0, 1.0, 0.0]);
1008        assert_eq!(weighted_median(&data, &weights), 3.0);
1009    }
1010}