pmcore/routines/output/
mod.rs

1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::posterior::Posterior;
5use crate::routines::output::predictions::NPPredictions;
6use crate::routines::settings::Settings;
7use crate::structs::psi::Psi;
8use crate::structs::theta::Theta;
9use crate::structs::weights::Weights;
10use anyhow::{bail, Context, Result};
11use csv::WriterBuilder;
12use faer::linalg::zip::IntoView;
13use faer_ext::IntoNdarray;
14use ndarray::{Array, Array1, Array2, Axis};
15use pharmsol::prelude::data::*;
16use pharmsol::prelude::simulator::Equation;
17use serde::Serialize;
18use std::fs::{create_dir_all, File, OpenOptions};
19use std::path::{Path, PathBuf};
20
21pub mod cycles;
22pub mod posterior;
23pub mod predictions;
24
25use posterior::posterior;
26
27/// Defines the result objects from an NPAG run
28/// An [NPResult] contains the necessary information to generate predictions and summary statistics
29#[derive(Debug, Serialize)]
30pub struct NPResult<E: Equation> {
31    #[serde(skip)]
32    equation: E,
33    data: Data,
34    theta: Theta,
35    psi: Psi,
36    w: Weights,
37    objf: f64,
38    cycles: usize,
39    status: Status,
40    settings: Settings,
41    cyclelog: CycleLog,
42    predictions: Option<NPPredictions>,
43    posterior: Posterior,
44}
45
46#[allow(clippy::too_many_arguments)]
47impl<E: Equation> NPResult<E> {
48    /// Create a new NPResult object
49    ///
50    /// This will also calculate the [Posterior] structure and add it to the NPResult
51    pub(crate) fn new(
52        equation: E,
53        data: Data,
54        theta: Theta,
55        psi: Psi,
56        w: Weights,
57        objf: f64,
58        cycles: usize,
59        status: Status,
60        settings: Settings,
61        cyclelog: CycleLog,
62    ) -> Result<Self> {
63        // Calculate the posterior probabilities
64        let posterior = posterior(&psi, &w)
65            .context("Failed to calculate posterior during initialization of NPResult")?;
66
67        let result = Self {
68            equation,
69            data,
70            theta,
71            psi,
72            w,
73            objf,
74            cycles,
75            status,
76            settings,
77            cyclelog,
78            predictions: None,
79            posterior,
80        };
81
82        Ok(result)
83    }
84
85    pub fn cycles(&self) -> usize {
86        self.cycles
87    }
88
89    pub fn objf(&self) -> f64 {
90        self.objf
91    }
92
93    pub fn converged(&self) -> bool {
94        self.status == Status::Stop(StopReason::Converged)
95    }
96
97    pub fn get_theta(&self) -> &Theta {
98        &self.theta
99    }
100
101    pub fn data(&self) -> &Data {
102        &self.data
103    }
104
105    pub fn cycle_log(&self) -> &CycleLog {
106        &self.cyclelog
107    }
108
109    pub fn settings(&self) -> &Settings {
110        &self.settings
111    }
112
113    /// Get the [Psi] structure
114    pub fn psi(&self) -> &Psi {
115        &self.psi
116    }
117
118    /// Get the weights (probabilities) of the support points
119    pub fn weights(&self) -> &Weights {
120        &self.w
121    }
122
123    /// Calculate and store the [NPPredictions] in the [NPResult]
124    ///
125    /// This will overwrite any existing predictions stored in the result!
126    pub fn calculate_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
127        let predictions = NPPredictions::calculate(
128            &self.equation,
129            &self.data,
130            &self.theta,
131            &self.w,
132            &self.posterior,
133            idelta,
134            tad,
135        )?;
136        self.predictions = Some(predictions);
137        Ok(())
138    }
139
140    pub fn write_outputs(&mut self) -> Result<()> {
141        if self.settings.output().write {
142            tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
143            self.settings.write()?;
144            let idelta: f64 = self.settings.predictions().idelta;
145            let tad = self.settings.predictions().tad;
146            self.cyclelog.write(&self.settings)?;
147            self.write_theta().context("Failed to write theta")?;
148            self.write_covs().context("Failed to write covariates")?;
149            self.write_predictions(idelta, tad)
150                .context("Failed to write predictions")?;
151            self.write_posterior()
152                .context("Failed to write posterior")?;
153        }
154        Ok(())
155    }
156
157    /// Writes the observations and predictions to a single file
158    pub fn write_obspred(&self) -> Result<()> {
159        tracing::debug!("Writing observations and predictions...");
160
161        #[derive(Debug, Clone, Serialize)]
162        struct Row {
163            id: String,
164            time: f64,
165            outeq: usize,
166            block: usize,
167            obs: Option<f64>,
168            pop_mean: f64,
169            pop_median: f64,
170            post_mean: f64,
171            post_median: f64,
172        }
173
174        let theta: Array2<f64> = self
175            .theta
176            .matrix()
177            .clone()
178            .as_mut()
179            .into_ndarray()
180            .to_owned();
181        let w: Array1<f64> = self
182            .w
183            .weights()
184            .clone()
185            .into_view()
186            .iter()
187            .cloned()
188            .collect();
189        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
190
191        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
192            .context("Failed to calculate posterior mean and median")?;
193
194        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
195            .context("Failed to calculate posterior mean and median")?;
196
197        let subjects = self.data.subjects();
198        if subjects.len() != post_mean.nrows() {
199            bail!(
200                "Number of subjects: {} and number of posterior means: {} do not match",
201                subjects.len(),
202                post_mean.nrows()
203            );
204        }
205
206        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
207        let mut writer = WriterBuilder::new()
208            .has_headers(true)
209            .from_writer(&outputfile.file);
210
211        for (i, subject) in subjects.iter().enumerate() {
212            for occasion in subject.occasions() {
213                let id = subject.id();
214                let occ = occasion.index();
215
216                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
217
218                // Population predictions
219                let pop_mean_pred = self
220                    .equation
221                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
222                    .0
223                    .get_predictions()
224                    .clone();
225
226                let pop_median_pred = self
227                    .equation
228                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
229                    .0
230                    .get_predictions()
231                    .clone();
232
233                // Posterior predictions
234                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
235                let post_mean_pred = self
236                    .equation
237                    .simulate_subject(&subject, &post_mean_spp, None)?
238                    .0
239                    .get_predictions()
240                    .clone();
241                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
242                let post_median_pred = self
243                    .equation
244                    .simulate_subject(&subject, &post_median_spp, None)?
245                    .0
246                    .get_predictions()
247                    .clone();
248                assert_eq!(
249                    pop_mean_pred.len(),
250                    pop_median_pred.len(),
251                    "The number of predictions do not match (pop_mean vs pop_median)"
252                );
253
254                assert_eq!(
255                    post_mean_pred.len(),
256                    post_median_pred.len(),
257                    "The number of predictions do not match (post_mean vs post_median)"
258                );
259
260                assert_eq!(
261                    pop_mean_pred.len(),
262                    post_mean_pred.len(),
263                    "The number of predictions do not match (pop_mean vs post_mean)"
264                );
265
266                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
267                    pop_mean_pred
268                        .iter()
269                        .zip(pop_median_pred.iter())
270                        .zip(post_mean_pred.iter())
271                        .zip(post_median_pred.iter())
272                {
273                    let row = Row {
274                        id: id.clone(),
275                        time: pop_mean_pred.time(),
276                        outeq: pop_mean_pred.outeq(),
277                        block: occ,
278                        obs: pop_mean_pred.observation(),
279                        pop_mean: pop_mean_pred.prediction(),
280                        pop_median: pop_median_pred.prediction(),
281                        post_mean: post_mean_pred.prediction(),
282                        post_median: post_median_pred.prediction(),
283                    };
284                    writer.serialize(row)?;
285                }
286            }
287        }
288        writer.flush()?;
289        tracing::debug!(
290            "Observations with predictions written to {:?}",
291            &outputfile.relative_path()
292        );
293        Ok(())
294    }
295
296    /// Writes theta, which contains the population support points and their associated probabilities
297    /// Each row is one support point, the last column being probability
298    pub fn write_theta(&self) -> Result<()> {
299        tracing::debug!("Writing population parameter distribution...");
300
301        let theta = &self.theta;
302        let w: Vec<f64> = self
303            .w
304            .weights()
305            .clone()
306            .into_view()
307            .iter()
308            .cloned()
309            .collect();
310
311        if w.len() != theta.matrix().nrows() {
312            bail!(
313                "Number of weights ({}) and number of support points ({}) do not match.",
314                w.len(),
315                theta.matrix().nrows()
316            );
317        }
318
319        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
320            .context("Failed to create output file for theta")?;
321
322        let mut writer = WriterBuilder::new()
323            .has_headers(true)
324            .from_writer(&outputfile.file);
325
326        // Create the headers
327        let mut theta_header = self.settings.parameters().names();
328        theta_header.push("prob".to_string());
329        writer.write_record(&theta_header)?;
330
331        // Write contents
332        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
333            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
334            row.push(w_val.to_string());
335            writer.write_record(&row)?;
336        }
337        writer.flush()?;
338        tracing::debug!(
339            "Population parameter distribution written to {:?}",
340            &outputfile.relative_path()
341        );
342        Ok(())
343    }
344
345    /// Writes the posterior support points for each individual
346    pub fn write_posterior(&self) -> Result<()> {
347        tracing::debug!("Writing posterior parameter probabilities...");
348        let theta = &self.theta;
349
350        // Calculate the posterior probabilities
351        let posterior = self.posterior.clone();
352
353        // Create the output folder if it doesn't exist
354        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
355            Ok(of) => of,
356            Err(e) => {
357                tracing::error!("Failed to create output file: {}", e);
358                return Err(e.context("Failed to create output file"));
359            }
360        };
361
362        // Create a new writer
363        let mut writer = WriterBuilder::new()
364            .has_headers(true)
365            .from_writer(&outputfile.file);
366
367        // Create the headers
368        writer.write_field("id")?;
369        writer.write_field("point")?;
370        theta.param_names().iter().for_each(|name| {
371            writer.write_field(name).unwrap();
372        });
373        writer.write_field("prob")?;
374        writer.write_record(None::<&[u8]>)?;
375
376        // Write contents
377        let subjects = self.data.subjects();
378        posterior
379            .matrix()
380            .row_iter()
381            .enumerate()
382            .for_each(|(i, row)| {
383                let subject = subjects.get(i).unwrap();
384                let id = subject.id();
385
386                row.iter().enumerate().for_each(|(spp, prob)| {
387                    writer.write_field(id.clone()).unwrap();
388                    writer.write_field(spp.to_string()).unwrap();
389
390                    theta.matrix().row(spp).iter().for_each(|val| {
391                        writer.write_field(val.to_string()).unwrap();
392                    });
393
394                    writer.write_field(prob.to_string()).unwrap();
395                    writer.write_record(None::<&[u8]>).unwrap();
396                });
397            });
398
399        writer.flush()?;
400        tracing::debug!(
401            "Posterior parameters written to {:?}",
402            &outputfile.relative_path()
403        );
404
405        Ok(())
406    }
407
408    /// Writes the predictions
409    pub fn write_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> {
410        tracing::debug!("Writing predictions...");
411
412        self.calculate_predictions(idelta, tad)?;
413
414        let predictions = self
415            .predictions
416            .as_ref()
417            .expect("Predictions should have been calculated, but are of type None.");
418
419        // Write (full) predictions to pred.csv
420        let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
421        let mut writer = WriterBuilder::new()
422            .has_headers(true)
423            .from_writer(&outputfile_pred.file);
424
425        // Write each prediction row
426        for row in predictions.predictions() {
427            writer.serialize(row)?;
428        }
429
430        writer.flush()?;
431        tracing::debug!(
432            "Predictions written to {:?}",
433            &outputfile_pred.relative_path()
434        );
435
436        Ok(())
437    }
438
439    /// Writes the covariates
440    pub fn write_covs(&self) -> Result<()> {
441        tracing::debug!("Writing covariates...");
442        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
443        let mut writer = WriterBuilder::new()
444            .has_headers(true)
445            .from_writer(&outputfile.file);
446
447        // Collect all unique covariate names
448        let mut covariate_names = std::collections::HashSet::new();
449        for subject in self.data.subjects() {
450            for occasion in subject.occasions() {
451                let cov = occasion.covariates();
452                let covmap = cov.covariates();
453                for cov_name in covmap.keys() {
454                    covariate_names.insert(cov_name.clone());
455                }
456            }
457        }
458        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
459        covariate_names.sort(); // Ensure consistent order
460
461        // Write the header row: id, time, block, covariate names
462        let mut headers = vec!["id", "time", "block"];
463        headers.extend(covariate_names.iter().map(|s| s.as_str()));
464        writer.write_record(&headers)?;
465
466        // Write the data rows
467        for subject in self.data.subjects() {
468            for occasion in subject.occasions() {
469                let cov = occasion.covariates();
470                let covmap = cov.covariates();
471
472                for event in occasion.iter() {
473                    let time = match event {
474                        Event::Bolus(bolus) => bolus.time(),
475                        Event::Infusion(infusion) => infusion.time(),
476                        Event::Observation(observation) => observation.time(),
477                    };
478
479                    let mut row: Vec<String> = Vec::new();
480                    row.push(subject.id().clone());
481                    row.push(time.to_string());
482                    row.push(occasion.index().to_string());
483
484                    // Add covariate values to the row
485                    for cov_name in &covariate_names {
486                        if let Some(cov) = covmap.get(cov_name) {
487                            if let Ok(value) = cov.interpolate(time) {
488                                row.push(value.to_string());
489                            } else {
490                                row.push(String::new());
491                            }
492                        } else {
493                            row.push(String::new());
494                        }
495                    }
496
497                    writer.write_record(&row)?;
498                }
499            }
500        }
501
502        writer.flush()?;
503        tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
504        Ok(())
505    }
506}
507
508pub(crate) fn median(data: &[f64]) -> f64 {
509    let mut data: Vec<f64> = data.to_vec();
510    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
511
512    let size = data.len();
513    match size {
514        even if even % 2 == 0 => {
515            let fst = data.get(even / 2 - 1).unwrap();
516            let snd = data.get(even / 2).unwrap();
517            (fst + snd) / 2.0
518        }
519        odd => *data.get(odd / 2_usize).unwrap(),
520    }
521}
522
523fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
524    // Ensure the data and weights arrays have the same length
525    assert_eq!(
526        data.len(),
527        weights.len(),
528        "The length of data and weights must be the same"
529    );
530    assert!(
531        weights.iter().all(|&x| x >= 0.0),
532        "Weights must be non-negative, weights: {:?}",
533        weights
534    );
535
536    // Create a vector of tuples (data, weight)
537    let mut weighted_data: Vec<(f64, f64)> = data
538        .iter()
539        .zip(weights.iter())
540        .map(|(&d, &w)| (d, w))
541        .collect();
542
543    // Sort the vector by the data values
544    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
545
546    // Calculate the cumulative sum of weights
547    let total_weight: f64 = weights.iter().sum();
548    let mut cumulative_sum = 0.0;
549
550    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
551        cumulative_sum += weight;
552
553        if cumulative_sum == total_weight / 2.0 {
554            // If the cumulative sum equals half the total weight, average this value with the next
555            if i + 1 < weighted_data.len() {
556                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
557            } else {
558                return weighted_data[i].0;
559            }
560        } else if cumulative_sum > total_weight / 2.0 {
561            return weighted_data[i].0;
562        }
563    }
564
565    unreachable!("The function should have returned a value before reaching this point.");
566}
567
568pub fn population_mean_median(
569    theta: &Array2<f64>,
570    w: &Array1<f64>,
571) -> Result<(Array1<f64>, Array1<f64>)> {
572    let w = if w.is_empty() {
573        tracing::warn!("w.len() == 0, setting all weights to 1/n");
574        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
575    } else {
576        w.clone()
577    };
578    // Check for compatible sizes
579    if theta.nrows() != w.len() {
580        bail!(
581            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
582            theta.nrows(),
583            w.len()
584        );
585    }
586
587    let mut mean = Array1::zeros(theta.ncols());
588    let mut median = Array1::zeros(theta.ncols());
589
590    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
591        // Calculate the weighted mean
592        let col = theta.column(i).to_owned() * w.to_owned();
593        *mn = col.sum();
594
595        // Calculate the median
596        let ct = theta.column(i);
597        let mut params = vec![];
598        let mut weights = vec![];
599        for (ti, wi) in ct.iter().zip(w.clone()) {
600            params.push(*ti);
601            weights.push(wi);
602        }
603
604        *mdn = weighted_median(&params, &weights);
605    }
606
607    Ok((mean, median))
608}
609
610pub fn posterior_mean_median(
611    theta: &Array2<f64>,
612    psi: &Array2<f64>,
613    w: &Array1<f64>,
614) -> Result<(Array2<f64>, Array2<f64>)> {
615    let mut mean = Array2::zeros((0, theta.ncols()));
616    let mut median = Array2::zeros((0, theta.ncols()));
617
618    let w = if w.is_empty() {
619        tracing::warn!("w is empty, setting all weights to 1/n");
620        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
621    } else {
622        w.clone()
623    };
624
625    // Check for compatible sizes
626    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
627        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
628    }
629
630    // Normalize psi to get probabilities of each spp for each id
631    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
632    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
633        let row_w = row.to_owned() * w.to_owned();
634        let row_sum = row_w.sum();
635        let row_norm = if row_sum == 0.0 {
636            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
637            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
638        } else {
639            &row_w / row_sum
640        };
641        psi_norm.push_row(row_norm.view())?;
642    }
643    if psi_norm.iter().any(|&x| x.is_nan()) {
644        dbg!(&psi);
645        bail!("NaN values found in psi_norm");
646    };
647
648    // Transpose normalized psi to get ID (col) by prob (row)
649    // let psi_norm_transposed = psi_norm.t();
650
651    // For each subject..
652    for probs in psi_norm.axis_iter(Axis(0)) {
653        let mut post_mean: Vec<f64> = Vec::new();
654        let mut post_median: Vec<f64> = Vec::new();
655
656        // For each parameter
657        for pars in theta.axis_iter(Axis(1)) {
658            // Calculate the mean
659            let weighted_par = &probs * &pars;
660            let the_mean = weighted_par.sum();
661            post_mean.push(the_mean);
662
663            // Calculate the median
664            let median = weighted_median(&pars.to_vec(), &probs.to_vec());
665            post_median.push(median);
666        }
667
668        mean.push_row(Array::from(post_mean.clone()).view())?;
669        median.push_row(Array::from(post_median.clone()).view())?;
670    }
671
672    Ok((mean, median))
673}
674
675/// Contains all the necessary information of an output file
676#[derive(Debug)]
677pub struct OutputFile {
678    file: File,
679    relative_path: PathBuf,
680}
681
682impl OutputFile {
683    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
684        let relative_path = Path::new(&folder).join(file_name);
685
686        if let Some(parent) = relative_path.parent() {
687            create_dir_all(parent)
688                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
689        }
690
691        let file = OpenOptions::new()
692            .write(true)
693            .create(true)
694            .truncate(true)
695            .open(&relative_path)
696            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
697
698        Ok(OutputFile {
699            file,
700            relative_path,
701        })
702    }
703
704    pub fn file(&self) -> &File {
705        &self.file
706    }
707
708    pub fn file_owned(self) -> File {
709        self.file
710    }
711
712    pub fn relative_path(&self) -> &Path {
713        &self.relative_path
714    }
715}
716
717#[cfg(test)]
718mod tests {
719    use super::median;
720
721    #[test]
722    fn test_median_odd() {
723        let data = vec![1.0, 3.0, 2.0];
724        assert_eq!(median(&data), 2.0);
725    }
726
727    #[test]
728    fn test_median_even() {
729        let data = vec![1.0, 2.0, 3.0, 4.0];
730        assert_eq!(median(&data), 2.5);
731    }
732
733    #[test]
734    fn test_median_single() {
735        let data = vec![42.0];
736        assert_eq!(median(&data), 42.0);
737    }
738
739    #[test]
740    fn test_median_sorted() {
741        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
742        assert_eq!(median(&data), 15.0);
743    }
744
745    #[test]
746    fn test_median_unsorted() {
747        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
748        assert_eq!(median(&data), 30.0);
749    }
750
751    #[test]
752    fn test_median_with_duplicates() {
753        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
754        assert_eq!(median(&data), 2.0);
755    }
756
757    use super::weighted_median;
758
759    #[test]
760    fn test_weighted_median_simple() {
761        let data = vec![1.0, 2.0, 3.0];
762        let weights = vec![0.2, 0.5, 0.3];
763        assert_eq!(weighted_median(&data, &weights), 2.0);
764    }
765
766    #[test]
767    fn test_weighted_median_even_weights() {
768        let data = vec![1.0, 2.0, 3.0, 4.0];
769        let weights = vec![0.25, 0.25, 0.25, 0.25];
770        assert_eq!(weighted_median(&data, &weights), 2.5);
771    }
772
773    #[test]
774    fn test_weighted_median_single_element() {
775        let data = vec![42.0];
776        let weights = vec![1.0];
777        assert_eq!(weighted_median(&data, &weights), 42.0);
778    }
779
780    #[test]
781    #[should_panic(expected = "The length of data and weights must be the same")]
782    fn test_weighted_median_mismatched_lengths() {
783        let data = vec![1.0, 2.0, 3.0];
784        let weights = vec![0.1, 0.2];
785        weighted_median(&data, &weights);
786    }
787
788    #[test]
789    fn test_weighted_median_all_same_elements() {
790        let data = vec![5.0, 5.0, 5.0, 5.0];
791        let weights = vec![0.1, 0.2, 0.3, 0.4];
792        assert_eq!(weighted_median(&data, &weights), 5.0);
793    }
794
795    #[test]
796    #[should_panic(expected = "Weights must be non-negative")]
797    fn test_weighted_median_negative_weights() {
798        let data = vec![1.0, 2.0, 3.0, 4.0];
799        let weights = vec![0.2, -0.5, 0.5, 0.8];
800        assert_eq!(weighted_median(&data, &weights), 4.0);
801    }
802
803    #[test]
804    fn test_weighted_median_unsorted_data() {
805        let data = vec![3.0, 1.0, 4.0, 2.0];
806        let weights = vec![0.1, 0.3, 0.4, 0.2];
807        assert_eq!(weighted_median(&data, &weights), 2.5);
808    }
809
810    #[test]
811    fn test_weighted_median_with_zero_weights() {
812        let data = vec![1.0, 2.0, 3.0, 4.0];
813        let weights = vec![0.0, 0.0, 1.0, 0.0];
814        assert_eq!(weighted_median(&data, &weights), 3.0);
815    }
816}