pmcore/routines/output/
mod.rs

1use crate::algorithms::Status;
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::predictions::NPPredictions;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use crate::structs::weights::Weights;
9use anyhow::{bail, Context, Result};
10use csv::WriterBuilder;
11use faer::linalg::zip::IntoView;
12use faer_ext::IntoNdarray;
13use ndarray::{Array, Array1, Array2, Axis};
14use pharmsol::prelude::data::*;
15use pharmsol::prelude::simulator::Equation;
16use serde::Serialize;
17use std::fs::{create_dir_all, File, OpenOptions};
18use std::path::{Path, PathBuf};
19
20pub mod cycles;
21pub mod posterior;
22pub mod predictions;
23
24use posterior::posterior;
25
26/// Defines the result objects from an NPAG run
27/// An [NPResult] contains the necessary information to generate predictions and summary statistics
28#[derive(Debug)]
29pub struct NPResult<E: Equation> {
30    equation: E,
31    data: Data,
32    theta: Theta,
33    psi: Psi,
34    w: Weights,
35    objf: f64,
36    cycles: usize,
37    status: Status,
38    par_names: Vec<String>,
39    settings: Settings,
40    cyclelog: CycleLog,
41}
42
43#[allow(clippy::too_many_arguments)]
44impl<E: Equation> NPResult<E> {
45    /// Create a new NPResult object
46    pub fn new(
47        equation: E,
48        data: Data,
49        theta: Theta,
50        psi: Psi,
51        w: Weights,
52        objf: f64,
53        cycles: usize,
54        status: Status,
55        settings: Settings,
56        cyclelog: CycleLog,
57    ) -> Self {
58        // TODO: Add support for fixed and constant parameters
59
60        let par_names = settings.parameters().names();
61
62        Self {
63            equation,
64            data,
65            theta,
66            psi,
67            w,
68            objf,
69            cycles,
70            status,
71            par_names,
72            settings,
73            cyclelog,
74        }
75    }
76
77    pub fn cycles(&self) -> usize {
78        self.cycles
79    }
80
81    pub fn objf(&self) -> f64 {
82        self.objf
83    }
84
85    pub fn converged(&self) -> bool {
86        self.status == Status::Converged
87    }
88
89    pub fn get_theta(&self) -> &Theta {
90        &self.theta
91    }
92
93    /// Get the [Psi] structure
94    pub fn psi(&self) -> &Psi {
95        &self.psi
96    }
97
98    /// Get the weights (probabilities) of the support points
99    pub fn weights(&self) -> &Weights {
100        &self.w
101    }
102
103    pub fn write_outputs(&self) -> Result<()> {
104        if self.settings.output().write {
105            tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
106            self.settings.write()?;
107            let idelta: f64 = self.settings.predictions().idelta;
108            let tad = self.settings.predictions().tad;
109            self.cyclelog.write(&self.settings)?;
110            self.write_theta().context("Failed to write theta")?;
111            self.write_covs().context("Failed to write covariates")?;
112            self.write_predictions(idelta, tad)
113                .context("Failed to write predictions")?;
114            self.write_posterior()
115                .context("Failed to write posterior")?;
116        }
117        Ok(())
118    }
119
120    /// Writes the observations and predictions to a single file
121    pub fn write_obspred(&self) -> Result<()> {
122        tracing::debug!("Writing observations and predictions...");
123
124        #[derive(Debug, Clone, Serialize)]
125        struct Row {
126            id: String,
127            time: f64,
128            outeq: usize,
129            block: usize,
130            obs: Option<f64>,
131            pop_mean: f64,
132            pop_median: f64,
133            post_mean: f64,
134            post_median: f64,
135        }
136
137        let theta: Array2<f64> = self
138            .theta
139            .matrix()
140            .clone()
141            .as_mut()
142            .into_ndarray()
143            .to_owned();
144        let w: Array1<f64> = self
145            .w
146            .weights()
147            .clone()
148            .into_view()
149            .iter()
150            .cloned()
151            .collect();
152        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
153
154        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
155            .context("Failed to calculate posterior mean and median")?;
156
157        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
158            .context("Failed to calculate posterior mean and median")?;
159
160        let subjects = self.data.subjects();
161        if subjects.len() != post_mean.nrows() {
162            bail!(
163                "Number of subjects: {} and number of posterior means: {} do not match",
164                subjects.len(),
165                post_mean.nrows()
166            );
167        }
168
169        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
170        let mut writer = WriterBuilder::new()
171            .has_headers(true)
172            .from_writer(&outputfile.file);
173
174        for (i, subject) in subjects.iter().enumerate() {
175            for occasion in subject.occasions() {
176                let id = subject.id();
177                let occ = occasion.index();
178
179                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
180
181                // Population predictions
182                let pop_mean_pred = self
183                    .equation
184                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
185                    .0
186                    .get_predictions()
187                    .clone();
188
189                let pop_median_pred = self
190                    .equation
191                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
192                    .0
193                    .get_predictions()
194                    .clone();
195
196                // Posterior predictions
197                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
198                let post_mean_pred = self
199                    .equation
200                    .simulate_subject(&subject, &post_mean_spp, None)?
201                    .0
202                    .get_predictions()
203                    .clone();
204                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
205                let post_median_pred = self
206                    .equation
207                    .simulate_subject(&subject, &post_median_spp, None)?
208                    .0
209                    .get_predictions()
210                    .clone();
211                assert_eq!(
212                    pop_mean_pred.len(),
213                    pop_median_pred.len(),
214                    "The number of predictions do not match (pop_mean vs pop_median)"
215                );
216
217                assert_eq!(
218                    post_mean_pred.len(),
219                    post_median_pred.len(),
220                    "The number of predictions do not match (post_mean vs post_median)"
221                );
222
223                assert_eq!(
224                    pop_mean_pred.len(),
225                    post_mean_pred.len(),
226                    "The number of predictions do not match (pop_mean vs post_mean)"
227                );
228
229                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
230                    pop_mean_pred
231                        .iter()
232                        .zip(pop_median_pred.iter())
233                        .zip(post_mean_pred.iter())
234                        .zip(post_median_pred.iter())
235                {
236                    let row = Row {
237                        id: id.clone(),
238                        time: pop_mean_pred.time(),
239                        outeq: pop_mean_pred.outeq(),
240                        block: occ,
241                        obs: pop_mean_pred.observation(),
242                        pop_mean: pop_mean_pred.prediction(),
243                        pop_median: pop_median_pred.prediction(),
244                        post_mean: post_mean_pred.prediction(),
245                        post_median: post_median_pred.prediction(),
246                    };
247                    writer.serialize(row)?;
248                }
249            }
250        }
251        writer.flush()?;
252        tracing::debug!(
253            "Observations with predictions written to {:?}",
254            &outputfile.relative_path()
255        );
256        Ok(())
257    }
258
259    /// Writes theta, which contains the population support points and their associated probabilities
260    /// Each row is one support point, the last column being probability
261    pub fn write_theta(&self) -> Result<()> {
262        tracing::debug!("Writing population parameter distribution...");
263
264        let theta = &self.theta;
265        let w: Vec<f64> = self
266            .w
267            .weights()
268            .clone()
269            .into_view()
270            .iter()
271            .cloned()
272            .collect();
273
274        if w.len() != theta.matrix().nrows() {
275            bail!(
276                "Number of weights ({}) and number of support points ({}) do not match.",
277                w.len(),
278                theta.matrix().nrows()
279            );
280        }
281
282        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
283            .context("Failed to create output file for theta")?;
284
285        let mut writer = WriterBuilder::new()
286            .has_headers(true)
287            .from_writer(&outputfile.file);
288
289        // Create the headers
290        let mut theta_header = self.par_names.clone();
291        theta_header.push("prob".to_string());
292        writer.write_record(&theta_header)?;
293
294        // Write contents
295        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
296            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
297            row.push(w_val.to_string());
298            writer.write_record(&row)?;
299        }
300        writer.flush()?;
301        tracing::debug!(
302            "Population parameter distribution written to {:?}",
303            &outputfile.relative_path()
304        );
305        Ok(())
306    }
307
308    /// Writes the posterior support points for each individual
309    pub fn write_posterior(&self) -> Result<()> {
310        tracing::debug!("Writing posterior parameter probabilities...");
311        let theta = &self.theta;
312        let w = &self.w;
313        let psi = &self.psi;
314
315        // Calculate the posterior probabilities
316        let posterior = posterior(psi, w)?;
317
318        // Create the output folder if it doesn't exist
319        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
320            Ok(of) => of,
321            Err(e) => {
322                tracing::error!("Failed to create output file: {}", e);
323                return Err(e.context("Failed to create output file"));
324            }
325        };
326
327        // Create a new writer
328        let mut writer = WriterBuilder::new()
329            .has_headers(true)
330            .from_writer(&outputfile.file);
331
332        // Create the headers
333        writer.write_field("id")?;
334        writer.write_field("point")?;
335        theta.param_names().iter().for_each(|name| {
336            writer.write_field(name).unwrap();
337        });
338        writer.write_field("prob")?;
339        writer.write_record(None::<&[u8]>)?;
340
341        // Write contents
342        let subjects = self.data.subjects();
343        posterior
344            .matrix()
345            .row_iter()
346            .enumerate()
347            .for_each(|(i, row)| {
348                let subject = subjects.get(i).unwrap();
349                let id = subject.id();
350
351                row.iter().enumerate().for_each(|(spp, prob)| {
352                    writer.write_field(id.clone()).unwrap();
353                    writer.write_field(spp.to_string()).unwrap();
354
355                    theta.matrix().row(spp).iter().for_each(|val| {
356                        writer.write_field(val.to_string()).unwrap();
357                    });
358
359                    writer.write_field(prob.to_string()).unwrap();
360                    writer.write_record(None::<&[u8]>).unwrap();
361                });
362            });
363
364        writer.flush()?;
365        tracing::debug!(
366            "Posterior parameters written to {:?}",
367            &outputfile.relative_path()
368        );
369
370        Ok(())
371    }
372
373    /// Writes the predictions
374    pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> {
375        tracing::debug!("Writing predictions...");
376
377        let posterior = posterior(&self.psi, &self.w)?;
378
379        // Calculate the predictions
380        let predictions = NPPredictions::calculate(
381            &self.equation,
382            &self.data,
383            self.theta.clone(),
384            &self.w,
385            &posterior,
386            idelta,
387            tad,
388        )?;
389
390        // Write (full) predictions to pred.csv
391        let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
392        let mut writer = WriterBuilder::new()
393            .has_headers(true)
394            .from_writer(&outputfile_pred.file);
395
396        // Write each prediction row
397        for row in predictions.predictions() {
398            writer.serialize(row)?;
399        }
400
401        writer.flush()?;
402        tracing::debug!(
403            "Predictions written to {:?}",
404            &outputfile_pred.relative_path()
405        );
406
407        // Write observations and predictions to op.csv
408        let outputfile_op = OutputFile::new(&self.settings.output().path, "op.csv")?;
409        let mut writer = WriterBuilder::new()
410            .has_headers(true)
411            .from_writer(&outputfile_op.file);
412
413        // Write each prediction row
414        for row in predictions
415            .predictions()
416            .iter()
417            .filter(|r| r.obs().is_some())
418        {
419            writer.serialize(row)?;
420        }
421
422        writer.flush()?;
423        tracing::debug!(
424            "Observed-predicted values written to {:?}",
425            &outputfile_op.relative_path()
426        );
427
428        Ok(())
429    }
430
431    /// Writes the covariates
432    pub fn write_covs(&self) -> Result<()> {
433        tracing::debug!("Writing covariates...");
434        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
435        let mut writer = WriterBuilder::new()
436            .has_headers(true)
437            .from_writer(&outputfile.file);
438
439        // Collect all unique covariate names
440        let mut covariate_names = std::collections::HashSet::new();
441        for subject in self.data.subjects() {
442            for occasion in subject.occasions() {
443                let cov = occasion.covariates();
444                let covmap = cov.covariates();
445                for cov_name in covmap.keys() {
446                    covariate_names.insert(cov_name.clone());
447                }
448            }
449        }
450        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
451        covariate_names.sort(); // Ensure consistent order
452
453        // Write the header row: id, time, block, covariate names
454        let mut headers = vec!["id", "time", "block"];
455        headers.extend(covariate_names.iter().map(|s| s.as_str()));
456        writer.write_record(&headers)?;
457
458        // Write the data rows
459        for subject in self.data.subjects() {
460            for occasion in subject.occasions() {
461                let cov = occasion.covariates();
462                let covmap = cov.covariates();
463
464                for event in occasion.iter() {
465                    let time = match event {
466                        Event::Bolus(bolus) => bolus.time(),
467                        Event::Infusion(infusion) => infusion.time(),
468                        Event::Observation(observation) => observation.time(),
469                    };
470
471                    let mut row: Vec<String> = Vec::new();
472                    row.push(subject.id().clone());
473                    row.push(time.to_string());
474                    row.push(occasion.index().to_string());
475
476                    // Add covariate values to the row
477                    for cov_name in &covariate_names {
478                        if let Some(cov) = covmap.get(cov_name) {
479                            if let Ok(value) = cov.interpolate(time) {
480                                row.push(value.to_string());
481                            } else {
482                                row.push(String::new());
483                            }
484                        } else {
485                            row.push(String::new());
486                        }
487                    }
488
489                    writer.write_record(&row)?;
490                }
491            }
492        }
493
494        writer.flush()?;
495        tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
496        Ok(())
497    }
498}
499
500pub(crate) fn median(data: &[f64]) -> f64 {
501    let mut data: Vec<f64> = data.to_vec();
502    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
503
504    let size = data.len();
505    match size {
506        even if even % 2 == 0 => {
507            let fst = data.get(even / 2 - 1).unwrap();
508            let snd = data.get(even / 2).unwrap();
509            (fst + snd) / 2.0
510        }
511        odd => *data.get(odd / 2_usize).unwrap(),
512    }
513}
514
515fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
516    // Ensure the data and weights arrays have the same length
517    assert_eq!(
518        data.len(),
519        weights.len(),
520        "The length of data and weights must be the same"
521    );
522    assert!(
523        weights.iter().all(|&x| x >= 0.0),
524        "Weights must be non-negative, weights: {:?}",
525        weights
526    );
527
528    // Create a vector of tuples (data, weight)
529    let mut weighted_data: Vec<(f64, f64)> = data
530        .iter()
531        .zip(weights.iter())
532        .map(|(&d, &w)| (d, w))
533        .collect();
534
535    // Sort the vector by the data values
536    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
537
538    // Calculate the cumulative sum of weights
539    let total_weight: f64 = weights.iter().sum();
540    let mut cumulative_sum = 0.0;
541
542    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
543        cumulative_sum += weight;
544
545        if cumulative_sum == total_weight / 2.0 {
546            // If the cumulative sum equals half the total weight, average this value with the next
547            if i + 1 < weighted_data.len() {
548                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
549            } else {
550                return weighted_data[i].0;
551            }
552        } else if cumulative_sum > total_weight / 2.0 {
553            return weighted_data[i].0;
554        }
555    }
556
557    unreachable!("The function should have returned a value before reaching this point.");
558}
559
560pub fn population_mean_median(
561    theta: &Array2<f64>,
562    w: &Array1<f64>,
563) -> Result<(Array1<f64>, Array1<f64>)> {
564    let w = if w.is_empty() {
565        tracing::warn!("w.len() == 0, setting all weights to 1/n");
566        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
567    } else {
568        w.clone()
569    };
570    // Check for compatible sizes
571    if theta.nrows() != w.len() {
572        bail!(
573            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
574            theta.nrows(),
575            w.len()
576        );
577    }
578
579    let mut mean = Array1::zeros(theta.ncols());
580    let mut median = Array1::zeros(theta.ncols());
581
582    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
583        // Calculate the weighted mean
584        let col = theta.column(i).to_owned() * w.to_owned();
585        *mn = col.sum();
586
587        // Calculate the median
588        let ct = theta.column(i);
589        let mut params = vec![];
590        let mut weights = vec![];
591        for (ti, wi) in ct.iter().zip(w.clone()) {
592            params.push(*ti);
593            weights.push(wi);
594        }
595
596        *mdn = weighted_median(&params, &weights);
597    }
598
599    Ok((mean, median))
600}
601
602pub fn posterior_mean_median(
603    theta: &Array2<f64>,
604    psi: &Array2<f64>,
605    w: &Array1<f64>,
606) -> Result<(Array2<f64>, Array2<f64>)> {
607    let mut mean = Array2::zeros((0, theta.ncols()));
608    let mut median = Array2::zeros((0, theta.ncols()));
609
610    let w = if w.is_empty() {
611        tracing::warn!("w is empty, setting all weights to 1/n");
612        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
613    } else {
614        w.clone()
615    };
616
617    // Check for compatible sizes
618    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
619        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
620    }
621
622    // Normalize psi to get probabilities of each spp for each id
623    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
624    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
625        let row_w = row.to_owned() * w.to_owned();
626        let row_sum = row_w.sum();
627        let row_norm = if row_sum == 0.0 {
628            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
629            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
630        } else {
631            &row_w / row_sum
632        };
633        psi_norm.push_row(row_norm.view())?;
634    }
635    if psi_norm.iter().any(|&x| x.is_nan()) {
636        dbg!(&psi);
637        bail!("NaN values found in psi_norm");
638    };
639
640    // Transpose normalized psi to get ID (col) by prob (row)
641    // let psi_norm_transposed = psi_norm.t();
642
643    // For each subject..
644    for probs in psi_norm.axis_iter(Axis(0)) {
645        let mut post_mean: Vec<f64> = Vec::new();
646        let mut post_median: Vec<f64> = Vec::new();
647
648        // For each parameter
649        for pars in theta.axis_iter(Axis(1)) {
650            // Calculate the mean
651            let weighted_par = &probs * &pars;
652            let the_mean = weighted_par.sum();
653            post_mean.push(the_mean);
654
655            // Calculate the median
656            let median = weighted_median(&pars.to_vec(), &probs.to_vec());
657            post_median.push(median);
658        }
659
660        mean.push_row(Array::from(post_mean.clone()).view())?;
661        median.push_row(Array::from(post_median.clone()).view())?;
662    }
663
664    Ok((mean, median))
665}
666
667/// Contains all the necessary information of an output file
668#[derive(Debug)]
669pub struct OutputFile {
670    file: File,
671    relative_path: PathBuf,
672}
673
674impl OutputFile {
675    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
676        let relative_path = Path::new(&folder).join(file_name);
677
678        if let Some(parent) = relative_path.parent() {
679            create_dir_all(parent)
680                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
681        }
682
683        let file = OpenOptions::new()
684            .write(true)
685            .create(true)
686            .truncate(true)
687            .open(&relative_path)
688            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
689
690        Ok(OutputFile {
691            file,
692            relative_path,
693        })
694    }
695
696    pub fn file(&self) -> &File {
697        &self.file
698    }
699
700    pub fn file_owned(self) -> File {
701        self.file
702    }
703
704    pub fn relative_path(&self) -> &Path {
705        &self.relative_path
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::median;
712
713    #[test]
714    fn test_median_odd() {
715        let data = vec![1.0, 3.0, 2.0];
716        assert_eq!(median(&data), 2.0);
717    }
718
719    #[test]
720    fn test_median_even() {
721        let data = vec![1.0, 2.0, 3.0, 4.0];
722        assert_eq!(median(&data), 2.5);
723    }
724
725    #[test]
726    fn test_median_single() {
727        let data = vec![42.0];
728        assert_eq!(median(&data), 42.0);
729    }
730
731    #[test]
732    fn test_median_sorted() {
733        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
734        assert_eq!(median(&data), 15.0);
735    }
736
737    #[test]
738    fn test_median_unsorted() {
739        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
740        assert_eq!(median(&data), 30.0);
741    }
742
743    #[test]
744    fn test_median_with_duplicates() {
745        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
746        assert_eq!(median(&data), 2.0);
747    }
748
749    use super::weighted_median;
750
751    #[test]
752    fn test_weighted_median_simple() {
753        let data = vec![1.0, 2.0, 3.0];
754        let weights = vec![0.2, 0.5, 0.3];
755        assert_eq!(weighted_median(&data, &weights), 2.0);
756    }
757
758    #[test]
759    fn test_weighted_median_even_weights() {
760        let data = vec![1.0, 2.0, 3.0, 4.0];
761        let weights = vec![0.25, 0.25, 0.25, 0.25];
762        assert_eq!(weighted_median(&data, &weights), 2.5);
763    }
764
765    #[test]
766    fn test_weighted_median_single_element() {
767        let data = vec![42.0];
768        let weights = vec![1.0];
769        assert_eq!(weighted_median(&data, &weights), 42.0);
770    }
771
772    #[test]
773    #[should_panic(expected = "The length of data and weights must be the same")]
774    fn test_weighted_median_mismatched_lengths() {
775        let data = vec![1.0, 2.0, 3.0];
776        let weights = vec![0.1, 0.2];
777        weighted_median(&data, &weights);
778    }
779
780    #[test]
781    fn test_weighted_median_all_same_elements() {
782        let data = vec![5.0, 5.0, 5.0, 5.0];
783        let weights = vec![0.1, 0.2, 0.3, 0.4];
784        assert_eq!(weighted_median(&data, &weights), 5.0);
785    }
786
787    #[test]
788    #[should_panic(expected = "Weights must be non-negative")]
789    fn test_weighted_median_negative_weights() {
790        let data = vec![1.0, 2.0, 3.0, 4.0];
791        let weights = vec![0.2, -0.5, 0.5, 0.8];
792        assert_eq!(weighted_median(&data, &weights), 4.0);
793    }
794
795    #[test]
796    fn test_weighted_median_unsorted_data() {
797        let data = vec![3.0, 1.0, 4.0, 2.0];
798        let weights = vec![0.1, 0.3, 0.4, 0.2];
799        assert_eq!(weighted_median(&data, &weights), 2.5);
800    }
801
802    #[test]
803    fn test_weighted_median_with_zero_weights() {
804        let data = vec![1.0, 2.0, 3.0, 4.0];
805        let weights = vec![0.0, 0.0, 1.0, 0.0];
806        assert_eq!(weighted_median(&data, &weights), 3.0);
807    }
808}