pmcore/routines/output/
mod.rs

1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::posterior::Posterior;
5use crate::routines::output::predictions::NPPredictions;
6use crate::routines::settings::Settings;
7use crate::structs::psi::Psi;
8use crate::structs::theta::Theta;
9use crate::structs::weights::Weights;
10use anyhow::{bail, Context, Result};
11use csv::WriterBuilder;
12use ndarray::{Array, Array1, Array2, Axis};
13use pharmsol::prelude::data::*;
14use pharmsol::prelude::simulator::Equation;
15use serde::Serialize;
16use std::fs::{create_dir_all, File, OpenOptions};
17use std::path::{Path, PathBuf};
18
19pub mod cycles;
20pub mod posterior;
21pub mod predictions;
22
23use posterior::posterior;
24
25/// Defines the result objects from an NPAG run
26/// An [NPResult] contains the necessary information to generate predictions and summary statistics
27#[derive(Debug, Serialize)]
28pub struct NPResult<E: Equation> {
29    #[serde(skip)]
30    equation: E,
31    data: Data,
32    theta: Theta,
33    psi: Psi,
34    w: Weights,
35    objf: f64,
36    cycles: usize,
37    status: Status,
38    settings: Settings,
39    cyclelog: CycleLog,
40    predictions: Option<NPPredictions>,
41    posterior: Posterior,
42}
43
44#[allow(clippy::too_many_arguments)]
45impl<E: Equation> NPResult<E> {
46    /// Create a new NPResult object
47    ///
48    /// This will also calculate the [Posterior] structure and add it to the NPResult
49    pub(crate) fn new(
50        equation: E,
51        data: Data,
52        theta: Theta,
53        psi: Psi,
54        w: Weights,
55        objf: f64,
56        cycles: usize,
57        status: Status,
58        settings: Settings,
59        cyclelog: CycleLog,
60    ) -> Result<Self> {
61        // Calculate the posterior probabilities
62        let posterior = posterior(&psi, &w)
63            .context("Failed to calculate posterior during initialization of NPResult")?;
64
65        let result = Self {
66            equation,
67            data,
68            theta,
69            psi,
70            w,
71            objf,
72            cycles,
73            status,
74            settings,
75            cyclelog,
76            predictions: None,
77            posterior,
78        };
79
80        Ok(result)
81    }
82
83    pub fn cycles(&self) -> usize {
84        self.cycles
85    }
86
87    pub fn objf(&self) -> f64 {
88        self.objf
89    }
90
91    pub fn converged(&self) -> bool {
92        self.status == Status::Stop(StopReason::Converged)
93    }
94
95    pub fn get_theta(&self) -> &Theta {
96        &self.theta
97    }
98
99    pub fn data(&self) -> &Data {
100        &self.data
101    }
102
103    pub fn cycle_log(&self) -> &CycleLog {
104        &self.cyclelog
105    }
106
107    pub fn settings(&self) -> &Settings {
108        &self.settings
109    }
110
111    /// Get the [Psi] structure
112    pub fn psi(&self) -> &Psi {
113        &self.psi
114    }
115
116    /// Get the weights (probabilities) of the support points
117    pub fn weights(&self) -> &Weights {
118        &self.w
119    }
120
121    /// Calculate and store the [NPPredictions] in the [NPResult]
122    ///
123    /// This will overwrite any existing predictions stored in the result!
124    pub fn calculate_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
125        let predictions = NPPredictions::calculate(
126            &self.equation,
127            &self.data,
128            &self.theta,
129            &self.w,
130            &self.posterior,
131            idelta,
132            tad,
133        )?;
134        self.predictions = Some(predictions);
135        Ok(())
136    }
137
138    pub fn write_outputs(&mut self) -> Result<()> {
139        if self.settings.output().write {
140            tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
141            self.settings.write()?;
142            let idelta: f64 = self.settings.predictions().idelta;
143            let tad = self.settings.predictions().tad;
144            self.cyclelog.write(&self.settings)?;
145            self.write_theta().context("Failed to write theta")?;
146            self.write_covs().context("Failed to write covariates")?;
147            self.write_predictions(idelta, tad)
148                .context("Failed to write predictions")?;
149            self.write_posterior()
150                .context("Failed to write posterior")?;
151        }
152        Ok(())
153    }
154
155    /// Writes the observations and predictions to a single file
156    pub fn write_obspred(&self) -> Result<()> {
157        tracing::debug!("Writing observations and predictions...");
158
159        #[derive(Debug, Clone, Serialize)]
160        struct Row {
161            id: String,
162            time: f64,
163            outeq: usize,
164            block: usize,
165            obs: Option<f64>,
166            pop_mean: f64,
167            pop_median: f64,
168            post_mean: f64,
169            post_median: f64,
170        }
171
172        let tm = self.theta.matrix();
173        let theta = Array2::from_shape_fn((tm.nrows(), tm.ncols()), |(i, j)| tm[(i, j)]);
174        let w: Array1<f64> = self.w.iter().collect();
175        let pm = self.psi.matrix();
176        let psi = Array2::from_shape_fn((pm.nrows(), pm.ncols()), |(i, j)| pm[(i, j)]);
177
178        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
179            .context("Failed to calculate posterior mean and median")?;
180
181        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
182            .context("Failed to calculate posterior mean and median")?;
183
184        let subjects = self.data.subjects();
185        if subjects.len() != post_mean.nrows() {
186            bail!(
187                "Number of subjects: {} and number of posterior means: {} do not match",
188                subjects.len(),
189                post_mean.nrows()
190            );
191        }
192
193        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
194        let mut writer = WriterBuilder::new()
195            .has_headers(true)
196            .from_writer(&outputfile.file);
197
198        for (i, subject) in subjects.iter().enumerate() {
199            for occasion in subject.occasions() {
200                let id = subject.id();
201                let occ = occasion.index();
202
203                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
204
205                // Population predictions
206                let pop_mean_pred = self
207                    .equation
208                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
209                    .0
210                    .get_predictions()
211                    .clone();
212
213                let pop_median_pred = self
214                    .equation
215                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
216                    .0
217                    .get_predictions()
218                    .clone();
219
220                // Posterior predictions
221                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
222                let post_mean_pred = self
223                    .equation
224                    .simulate_subject(&subject, &post_mean_spp, None)?
225                    .0
226                    .get_predictions()
227                    .clone();
228                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
229                let post_median_pred = self
230                    .equation
231                    .simulate_subject(&subject, &post_median_spp, None)?
232                    .0
233                    .get_predictions()
234                    .clone();
235                assert_eq!(
236                    pop_mean_pred.len(),
237                    pop_median_pred.len(),
238                    "The number of predictions do not match (pop_mean vs pop_median)"
239                );
240
241                assert_eq!(
242                    post_mean_pred.len(),
243                    post_median_pred.len(),
244                    "The number of predictions do not match (post_mean vs post_median)"
245                );
246
247                assert_eq!(
248                    pop_mean_pred.len(),
249                    post_mean_pred.len(),
250                    "The number of predictions do not match (pop_mean vs post_mean)"
251                );
252
253                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
254                    pop_mean_pred
255                        .iter()
256                        .zip(pop_median_pred.iter())
257                        .zip(post_mean_pred.iter())
258                        .zip(post_median_pred.iter())
259                {
260                    let row = Row {
261                        id: id.clone(),
262                        time: pop_mean_pred.time(),
263                        outeq: pop_mean_pred.outeq(),
264                        block: occ,
265                        obs: pop_mean_pred.observation(),
266                        pop_mean: pop_mean_pred.prediction(),
267                        pop_median: pop_median_pred.prediction(),
268                        post_mean: post_mean_pred.prediction(),
269                        post_median: post_median_pred.prediction(),
270                    };
271                    writer.serialize(row)?;
272                }
273            }
274        }
275        writer.flush()?;
276        tracing::debug!(
277            "Observations with predictions written to {:?}",
278            &outputfile.relative_path()
279        );
280        Ok(())
281    }
282
283    /// Writes theta, which contains the population support points and their associated probabilities
284    /// Each row is one support point, the last column being probability
285    pub fn write_theta(&self) -> Result<()> {
286        tracing::debug!("Writing population parameter distribution...");
287
288        let theta = &self.theta;
289        let w: Vec<f64> = self.w.to_vec();
290
291        if w.len() != theta.matrix().nrows() {
292            bail!(
293                "Number of weights ({}) and number of support points ({}) do not match.",
294                w.len(),
295                theta.matrix().nrows()
296            );
297        }
298
299        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
300            .context("Failed to create output file for theta")?;
301
302        let mut writer = WriterBuilder::new()
303            .has_headers(true)
304            .from_writer(&outputfile.file);
305
306        // Create the headers
307        let mut theta_header = self.settings.parameters().names();
308        theta_header.push("prob".to_string());
309        writer.write_record(&theta_header)?;
310
311        // Write contents
312        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
313            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
314            row.push(w_val.to_string());
315            writer.write_record(&row)?;
316        }
317        writer.flush()?;
318        tracing::debug!(
319            "Population parameter distribution written to {:?}",
320            &outputfile.relative_path()
321        );
322        Ok(())
323    }
324
325    /// Writes the posterior support points for each individual
326    pub fn write_posterior(&self) -> Result<()> {
327        tracing::debug!("Writing posterior parameter probabilities...");
328        let theta = &self.theta;
329
330        // Calculate the posterior probabilities
331        let posterior = self.posterior.clone();
332
333        // Create the output folder if it doesn't exist
334        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
335            Ok(of) => of,
336            Err(e) => {
337                tracing::error!("Failed to create output file: {}", e);
338                return Err(e.context("Failed to create output file"));
339            }
340        };
341
342        // Create a new writer
343        let mut writer = WriterBuilder::new()
344            .has_headers(true)
345            .from_writer(&outputfile.file);
346
347        // Create the headers
348        writer.write_field("id")?;
349        writer.write_field("point")?;
350        theta.param_names().iter().for_each(|name| {
351            writer.write_field(name).unwrap();
352        });
353        writer.write_field("prob")?;
354        writer.write_record(None::<&[u8]>)?;
355
356        // Write contents
357        let subjects = self.data.subjects();
358        posterior
359            .matrix()
360            .row_iter()
361            .enumerate()
362            .for_each(|(i, row)| {
363                let subject = subjects.get(i).unwrap();
364                let id = subject.id();
365
366                row.iter().enumerate().for_each(|(spp, prob)| {
367                    writer.write_field(id.clone()).unwrap();
368                    writer.write_field(spp.to_string()).unwrap();
369
370                    theta.matrix().row(spp).iter().for_each(|val| {
371                        writer.write_field(val.to_string()).unwrap();
372                    });
373
374                    writer.write_field(prob.to_string()).unwrap();
375                    writer.write_record(None::<&[u8]>).unwrap();
376                });
377            });
378
379        writer.flush()?;
380        tracing::debug!(
381            "Posterior parameters written to {:?}",
382            &outputfile.relative_path()
383        );
384
385        Ok(())
386    }
387
388    /// Writes the predictions
389    pub fn write_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
390        tracing::debug!("Writing predictions...");
391
392        self.calculate_predictions(idelta, tad)?;
393
394        let predictions = self
395            .predictions
396            .as_ref()
397            .expect("Predictions should have been calculated, but are of type None.");
398
399        // Write (full) predictions to pred.csv
400        let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
401        let mut writer = WriterBuilder::new()
402            .has_headers(true)
403            .from_writer(&outputfile_pred.file);
404
405        // Write each prediction row
406        for row in predictions.predictions() {
407            writer.serialize(row)?;
408        }
409
410        writer.flush()?;
411        tracing::debug!(
412            "Predictions written to {:?}",
413            &outputfile_pred.relative_path()
414        );
415
416        Ok(())
417    }
418
419    /// Writes the covariates
420    pub fn write_covs(&self) -> Result<()> {
421        tracing::debug!("Writing covariates...");
422        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
423        let mut writer = WriterBuilder::new()
424            .has_headers(true)
425            .from_writer(&outputfile.file);
426
427        // Collect all unique covariate names
428        let mut covariate_names = std::collections::HashSet::new();
429        for subject in self.data.subjects() {
430            for occasion in subject.occasions() {
431                let cov = occasion.covariates();
432                let covmap = cov.covariates();
433                for cov_name in covmap.keys() {
434                    covariate_names.insert(cov_name.clone());
435                }
436            }
437        }
438        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
439        covariate_names.sort(); // Ensure consistent order
440
441        // Write the header row: id, time, block, covariate names
442        let mut headers = vec!["id", "time", "block"];
443        headers.extend(covariate_names.iter().map(|s| s.as_str()));
444        writer.write_record(&headers)?;
445
446        // Write the data rows
447        for subject in self.data.subjects() {
448            for occasion in subject.occasions() {
449                let cov = occasion.covariates();
450                let covmap = cov.covariates();
451
452                for event in occasion.iter() {
453                    let time = match event {
454                        Event::Bolus(bolus) => bolus.time(),
455                        Event::Infusion(infusion) => infusion.time(),
456                        Event::Observation(observation) => observation.time(),
457                    };
458
459                    let mut row: Vec<String> = Vec::new();
460                    row.push(subject.id().clone());
461                    row.push(time.to_string());
462                    row.push(occasion.index().to_string());
463
464                    // Add covariate values to the row
465                    for cov_name in &covariate_names {
466                        if let Some(cov) = covmap.get(cov_name) {
467                            if let Ok(value) = cov.interpolate(time) {
468                                row.push(value.to_string());
469                            } else {
470                                row.push(String::new());
471                            }
472                        } else {
473                            row.push(String::new());
474                        }
475                    }
476
477                    writer.write_record(&row)?;
478                }
479            }
480        }
481
482        writer.flush()?;
483        tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
484        Ok(())
485    }
486}
487
488pub(crate) fn median(data: &[f64]) -> f64 {
489    let mut data: Vec<f64> = data.to_vec();
490    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
491
492    let size = data.len();
493    match size {
494        even if even % 2 == 0 => {
495            let fst = data.get(even / 2 - 1).unwrap();
496            let snd = data.get(even / 2).unwrap();
497            (fst + snd) / 2.0
498        }
499        odd => *data.get(odd / 2_usize).unwrap(),
500    }
501}
502
503fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
504    // Ensure the data and weights arrays have the same length
505    assert_eq!(
506        data.len(),
507        weights.len(),
508        "The length of data and weights must be the same"
509    );
510    assert!(
511        weights.iter().all(|&x| x >= 0.0),
512        "Weights must be non-negative, weights: {:?}",
513        weights
514    );
515
516    // Create a vector of tuples (data, weight)
517    let mut weighted_data: Vec<(f64, f64)> = data
518        .iter()
519        .zip(weights.iter())
520        .map(|(&d, &w)| (d, w))
521        .collect();
522
523    // Sort the vector by the data values
524    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
525
526    // Calculate the cumulative sum of weights
527    let total_weight: f64 = weights.iter().sum();
528    let mut cumulative_sum = 0.0;
529
530    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
531        cumulative_sum += weight;
532
533        if cumulative_sum == total_weight / 2.0 {
534            // If the cumulative sum equals half the total weight, average this value with the next
535            if i + 1 < weighted_data.len() {
536                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
537            } else {
538                return weighted_data[i].0;
539            }
540        } else if cumulative_sum > total_weight / 2.0 {
541            return weighted_data[i].0;
542        }
543    }
544
545    unreachable!("The function should have returned a value before reaching this point.");
546}
547
548pub fn population_mean_median(
549    theta: &Array2<f64>,
550    w: &Array1<f64>,
551) -> Result<(Array1<f64>, Array1<f64>)> {
552    let w = if w.is_empty() {
553        tracing::warn!("w.len() == 0, setting all weights to 1/n");
554        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
555    } else {
556        w.clone()
557    };
558    // Check for compatible sizes
559    if theta.nrows() != w.len() {
560        bail!(
561            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
562            theta.nrows(),
563            w.len()
564        );
565    }
566
567    let mut mean = Array1::zeros(theta.ncols());
568    let mut median = Array1::zeros(theta.ncols());
569
570    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
571        // Calculate the weighted mean
572        let col = theta.column(i).to_owned() * w.to_owned();
573        *mn = col.sum();
574
575        // Calculate the median
576        let ct = theta.column(i);
577        let mut params = vec![];
578        let mut weights = vec![];
579        for (ti, wi) in ct.iter().zip(w.clone()) {
580            params.push(*ti);
581            weights.push(wi);
582        }
583
584        *mdn = weighted_median(&params, &weights);
585    }
586
587    Ok((mean, median))
588}
589
590pub fn posterior_mean_median(
591    theta: &Array2<f64>,
592    psi: &Array2<f64>,
593    w: &Array1<f64>,
594) -> Result<(Array2<f64>, Array2<f64>)> {
595    let mut mean = Array2::zeros((0, theta.ncols()));
596    let mut median = Array2::zeros((0, theta.ncols()));
597
598    let w = if w.is_empty() {
599        tracing::warn!("w is empty, setting all weights to 1/n");
600        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
601    } else {
602        w.clone()
603    };
604
605    // Check for compatible sizes
606    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
607        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
608    }
609
610    // Normalize psi to get probabilities of each spp for each id
611    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
612    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
613        let row_w = row.to_owned() * w.to_owned();
614        let row_sum = row_w.sum();
615        let row_norm = if row_sum == 0.0 {
616            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
617            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
618        } else {
619            &row_w / row_sum
620        };
621        psi_norm.push_row(row_norm.view())?;
622    }
623    if psi_norm.iter().any(|&x| x.is_nan()) {
624        dbg!(&psi);
625        bail!("NaN values found in psi_norm");
626    };
627
628    // Transpose normalized psi to get ID (col) by prob (row)
629    // let psi_norm_transposed = psi_norm.t();
630
631    // For each subject..
632    for probs in psi_norm.axis_iter(Axis(0)) {
633        let mut post_mean: Vec<f64> = Vec::new();
634        let mut post_median: Vec<f64> = Vec::new();
635
636        // For each parameter
637        for pars in theta.axis_iter(Axis(1)) {
638            // Calculate the mean
639            let weighted_par = &probs * &pars;
640            let the_mean = weighted_par.sum();
641            post_mean.push(the_mean);
642
643            // Calculate the median
644            let median = weighted_median(&pars.to_vec(), &probs.to_vec());
645            post_median.push(median);
646        }
647
648        mean.push_row(Array::from(post_mean.clone()).view())?;
649        median.push_row(Array::from(post_median.clone()).view())?;
650    }
651
652    Ok((mean, median))
653}
654
655/// Contains all the necessary information of an output file
656#[derive(Debug)]
657pub struct OutputFile {
658    file: File,
659    relative_path: PathBuf,
660}
661
662impl OutputFile {
663    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
664        let relative_path = Path::new(&folder).join(file_name);
665
666        if let Some(parent) = relative_path.parent() {
667            create_dir_all(parent)
668                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
669        }
670
671        let file = OpenOptions::new()
672            .write(true)
673            .create(true)
674            .truncate(true)
675            .open(&relative_path)
676            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
677
678        Ok(OutputFile {
679            file,
680            relative_path,
681        })
682    }
683
684    pub fn file(&self) -> &File {
685        &self.file
686    }
687
688    pub fn file_owned(self) -> File {
689        self.file
690    }
691
692    pub fn relative_path(&self) -> &Path {
693        &self.relative_path
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::median;
700
701    #[test]
702    fn test_median_odd() {
703        let data = vec![1.0, 3.0, 2.0];
704        assert_eq!(median(&data), 2.0);
705    }
706
707    #[test]
708    fn test_median_even() {
709        let data = vec![1.0, 2.0, 3.0, 4.0];
710        assert_eq!(median(&data), 2.5);
711    }
712
713    #[test]
714    fn test_median_single() {
715        let data = vec![42.0];
716        assert_eq!(median(&data), 42.0);
717    }
718
719    #[test]
720    fn test_median_sorted() {
721        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
722        assert_eq!(median(&data), 15.0);
723    }
724
725    #[test]
726    fn test_median_unsorted() {
727        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
728        assert_eq!(median(&data), 30.0);
729    }
730
731    #[test]
732    fn test_median_with_duplicates() {
733        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
734        assert_eq!(median(&data), 2.0);
735    }
736
737    use super::weighted_median;
738
739    #[test]
740    fn test_weighted_median_simple() {
741        let data = vec![1.0, 2.0, 3.0];
742        let weights = vec![0.2, 0.5, 0.3];
743        assert_eq!(weighted_median(&data, &weights), 2.0);
744    }
745
746    #[test]
747    fn test_weighted_median_even_weights() {
748        let data = vec![1.0, 2.0, 3.0, 4.0];
749        let weights = vec![0.25, 0.25, 0.25, 0.25];
750        assert_eq!(weighted_median(&data, &weights), 2.5);
751    }
752
753    #[test]
754    fn test_weighted_median_single_element() {
755        let data = vec![42.0];
756        let weights = vec![1.0];
757        assert_eq!(weighted_median(&data, &weights), 42.0);
758    }
759
760    #[test]
761    #[should_panic(expected = "The length of data and weights must be the same")]
762    fn test_weighted_median_mismatched_lengths() {
763        let data = vec![1.0, 2.0, 3.0];
764        let weights = vec![0.1, 0.2];
765        weighted_median(&data, &weights);
766    }
767
768    #[test]
769    fn test_weighted_median_all_same_elements() {
770        let data = vec![5.0, 5.0, 5.0, 5.0];
771        let weights = vec![0.1, 0.2, 0.3, 0.4];
772        assert_eq!(weighted_median(&data, &weights), 5.0);
773    }
774
775    #[test]
776    #[should_panic(expected = "Weights must be non-negative")]
777    fn test_weighted_median_negative_weights() {
778        let data = vec![1.0, 2.0, 3.0, 4.0];
779        let weights = vec![0.2, -0.5, 0.5, 0.8];
780        assert_eq!(weighted_median(&data, &weights), 4.0);
781    }
782
783    #[test]
784    fn test_weighted_median_unsorted_data() {
785        let data = vec![3.0, 1.0, 4.0, 2.0];
786        let weights = vec![0.1, 0.3, 0.4, 0.2];
787        assert_eq!(weighted_median(&data, &weights), 2.5);
788    }
789
790    #[test]
791    fn test_weighted_median_with_zero_weights() {
792        let data = vec![1.0, 2.0, 3.0, 4.0];
793        let weights = vec![0.0, 0.0, 1.0, 0.0];
794        assert_eq!(weighted_median(&data, &weights), 3.0);
795    }
796}