pmcore/routines/output/
mod.rs

1use crate::algorithms::{Status, StopReason};
2use crate::prelude::*;
3use crate::routines::output::cycles::CycleLog;
4use crate::routines::output::predictions::NPPredictions;
5use crate::routines::settings::Settings;
6use crate::structs::psi::Psi;
7use crate::structs::theta::Theta;
8use crate::structs::weights::Weights;
9use anyhow::{bail, Context, Result};
10use csv::WriterBuilder;
11use faer::linalg::zip::IntoView;
12use faer_ext::IntoNdarray;
13use ndarray::{Array, Array1, Array2, Axis};
14use pharmsol::prelude::data::*;
15use pharmsol::prelude::simulator::Equation;
16use serde::Serialize;
17use std::fs::{create_dir_all, File, OpenOptions};
18use std::path::{Path, PathBuf};
19
20pub mod cycles;
21pub mod posterior;
22pub mod predictions;
23
24use posterior::posterior;
25
26/// Defines the result objects from an NPAG run
27/// An [NPResult] contains the necessary information to generate predictions and summary statistics
28#[derive(Debug, Serialize)]
29pub struct NPResult<E: Equation> {
30    #[serde(skip)]
31    equation: E,
32    data: Data,
33    theta: Theta,
34    psi: Psi,
35    w: Weights,
36    objf: f64,
37    cycles: usize,
38    status: Status,
39    par_names: Vec<String>,
40    settings: Settings,
41    cyclelog: CycleLog,
42}
43
44#[allow(clippy::too_many_arguments)]
45impl<E: Equation> NPResult<E> {
46    /// Create a new NPResult object
47    pub fn new(
48        equation: E,
49        data: Data,
50        theta: Theta,
51        psi: Psi,
52        w: Weights,
53        objf: f64,
54        cycles: usize,
55        status: Status,
56        settings: Settings,
57        cyclelog: CycleLog,
58    ) -> Self {
59        // TODO: Add support for fixed and constant parameters
60
61        let par_names = settings.parameters().names();
62
63        Self {
64            equation,
65            data,
66            theta,
67            psi,
68            w,
69            objf,
70            cycles,
71            status,
72            par_names,
73            settings,
74            cyclelog,
75        }
76    }
77
78    pub fn cycles(&self) -> usize {
79        self.cycles
80    }
81
82    pub fn objf(&self) -> f64 {
83        self.objf
84    }
85
86    pub fn converged(&self) -> bool {
87        self.status == Status::Stop(StopReason::Converged)
88    }
89
90    pub fn get_theta(&self) -> &Theta {
91        &self.theta
92    }
93
94    /// Get the [Psi] structure
95    pub fn psi(&self) -> &Psi {
96        &self.psi
97    }
98
99    /// Get the weights (probabilities) of the support points
100    pub fn weights(&self) -> &Weights {
101        &self.w
102    }
103
104    pub fn write_outputs(&self) -> Result<()> {
105        if self.settings.output().write {
106            tracing::debug!("Writing outputs to {:?}", self.settings.output().path);
107            self.settings.write()?;
108            let idelta: f64 = self.settings.predictions().idelta;
109            let tad = self.settings.predictions().tad;
110            self.cyclelog.write(&self.settings)?;
111            self.write_theta().context("Failed to write theta")?;
112            self.write_covs().context("Failed to write covariates")?;
113            self.write_predictions(idelta, tad)
114                .context("Failed to write predictions")?;
115            self.write_posterior()
116                .context("Failed to write posterior")?;
117        }
118        Ok(())
119    }
120
121    /// Writes the observations and predictions to a single file
122    pub fn write_obspred(&self) -> Result<()> {
123        tracing::debug!("Writing observations and predictions...");
124
125        #[derive(Debug, Clone, Serialize)]
126        struct Row {
127            id: String,
128            time: f64,
129            outeq: usize,
130            block: usize,
131            obs: Option<f64>,
132            pop_mean: f64,
133            pop_median: f64,
134            post_mean: f64,
135            post_median: f64,
136        }
137
138        let theta: Array2<f64> = self
139            .theta
140            .matrix()
141            .clone()
142            .as_mut()
143            .into_ndarray()
144            .to_owned();
145        let w: Array1<f64> = self
146            .w
147            .weights()
148            .clone()
149            .into_view()
150            .iter()
151            .cloned()
152            .collect();
153        let psi: Array2<f64> = self.psi.matrix().as_ref().into_ndarray().to_owned();
154
155        let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w)
156            .context("Failed to calculate posterior mean and median")?;
157
158        let (pop_mean, pop_median) = population_mean_median(&theta, &w)
159            .context("Failed to calculate posterior mean and median")?;
160
161        let subjects = self.data.subjects();
162        if subjects.len() != post_mean.nrows() {
163            bail!(
164                "Number of subjects: {} and number of posterior means: {} do not match",
165                subjects.len(),
166                post_mean.nrows()
167            );
168        }
169
170        let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?;
171        let mut writer = WriterBuilder::new()
172            .has_headers(true)
173            .from_writer(&outputfile.file);
174
175        for (i, subject) in subjects.iter().enumerate() {
176            for occasion in subject.occasions() {
177                let id = subject.id();
178                let occ = occasion.index();
179
180                let subject = Subject::from_occasions(id.clone(), vec![occasion.clone()]);
181
182                // Population predictions
183                let pop_mean_pred = self
184                    .equation
185                    .simulate_subject(&subject, &pop_mean.to_vec(), None)?
186                    .0
187                    .get_predictions()
188                    .clone();
189
190                let pop_median_pred = self
191                    .equation
192                    .simulate_subject(&subject, &pop_median.to_vec(), None)?
193                    .0
194                    .get_predictions()
195                    .clone();
196
197                // Posterior predictions
198                let post_mean_spp: Vec<f64> = post_mean.row(i).to_vec();
199                let post_mean_pred = self
200                    .equation
201                    .simulate_subject(&subject, &post_mean_spp, None)?
202                    .0
203                    .get_predictions()
204                    .clone();
205                let post_median_spp: Vec<f64> = post_median.row(i).to_vec();
206                let post_median_pred = self
207                    .equation
208                    .simulate_subject(&subject, &post_median_spp, None)?
209                    .0
210                    .get_predictions()
211                    .clone();
212                assert_eq!(
213                    pop_mean_pred.len(),
214                    pop_median_pred.len(),
215                    "The number of predictions do not match (pop_mean vs pop_median)"
216                );
217
218                assert_eq!(
219                    post_mean_pred.len(),
220                    post_median_pred.len(),
221                    "The number of predictions do not match (post_mean vs post_median)"
222                );
223
224                assert_eq!(
225                    pop_mean_pred.len(),
226                    post_mean_pred.len(),
227                    "The number of predictions do not match (pop_mean vs post_mean)"
228                );
229
230                for (((pop_mean_pred, pop_median_pred), post_mean_pred), post_median_pred) in
231                    pop_mean_pred
232                        .iter()
233                        .zip(pop_median_pred.iter())
234                        .zip(post_mean_pred.iter())
235                        .zip(post_median_pred.iter())
236                {
237                    let row = Row {
238                        id: id.clone(),
239                        time: pop_mean_pred.time(),
240                        outeq: pop_mean_pred.outeq(),
241                        block: occ,
242                        obs: pop_mean_pred.observation(),
243                        pop_mean: pop_mean_pred.prediction(),
244                        pop_median: pop_median_pred.prediction(),
245                        post_mean: post_mean_pred.prediction(),
246                        post_median: post_median_pred.prediction(),
247                    };
248                    writer.serialize(row)?;
249                }
250            }
251        }
252        writer.flush()?;
253        tracing::debug!(
254            "Observations with predictions written to {:?}",
255            &outputfile.relative_path()
256        );
257        Ok(())
258    }
259
260    /// Writes theta, which contains the population support points and their associated probabilities
261    /// Each row is one support point, the last column being probability
262    pub fn write_theta(&self) -> Result<()> {
263        tracing::debug!("Writing population parameter distribution...");
264
265        let theta = &self.theta;
266        let w: Vec<f64> = self
267            .w
268            .weights()
269            .clone()
270            .into_view()
271            .iter()
272            .cloned()
273            .collect();
274
275        if w.len() != theta.matrix().nrows() {
276            bail!(
277                "Number of weights ({}) and number of support points ({}) do not match.",
278                w.len(),
279                theta.matrix().nrows()
280            );
281        }
282
283        let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv")
284            .context("Failed to create output file for theta")?;
285
286        let mut writer = WriterBuilder::new()
287            .has_headers(true)
288            .from_writer(&outputfile.file);
289
290        // Create the headers
291        let mut theta_header = self.par_names.clone();
292        theta_header.push("prob".to_string());
293        writer.write_record(&theta_header)?;
294
295        // Write contents
296        for (theta_row, &w_val) in theta.matrix().row_iter().zip(w.iter()) {
297            let mut row: Vec<String> = theta_row.iter().map(|&val| val.to_string()).collect();
298            row.push(w_val.to_string());
299            writer.write_record(&row)?;
300        }
301        writer.flush()?;
302        tracing::debug!(
303            "Population parameter distribution written to {:?}",
304            &outputfile.relative_path()
305        );
306        Ok(())
307    }
308
309    /// Writes the posterior support points for each individual
310    pub fn write_posterior(&self) -> Result<()> {
311        tracing::debug!("Writing posterior parameter probabilities...");
312        let theta = &self.theta;
313        let w = &self.w;
314        let psi = &self.psi;
315
316        // Calculate the posterior probabilities
317        let posterior = posterior(psi, w)?;
318
319        // Create the output folder if it doesn't exist
320        let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") {
321            Ok(of) => of,
322            Err(e) => {
323                tracing::error!("Failed to create output file: {}", e);
324                return Err(e.context("Failed to create output file"));
325            }
326        };
327
328        // Create a new writer
329        let mut writer = WriterBuilder::new()
330            .has_headers(true)
331            .from_writer(&outputfile.file);
332
333        // Create the headers
334        writer.write_field("id")?;
335        writer.write_field("point")?;
336        theta.param_names().iter().for_each(|name| {
337            writer.write_field(name).unwrap();
338        });
339        writer.write_field("prob")?;
340        writer.write_record(None::<&[u8]>)?;
341
342        // Write contents
343        let subjects = self.data.subjects();
344        posterior
345            .matrix()
346            .row_iter()
347            .enumerate()
348            .for_each(|(i, row)| {
349                let subject = subjects.get(i).unwrap();
350                let id = subject.id();
351
352                row.iter().enumerate().for_each(|(spp, prob)| {
353                    writer.write_field(id.clone()).unwrap();
354                    writer.write_field(spp.to_string()).unwrap();
355
356                    theta.matrix().row(spp).iter().for_each(|val| {
357                        writer.write_field(val.to_string()).unwrap();
358                    });
359
360                    writer.write_field(prob.to_string()).unwrap();
361                    writer.write_record(None::<&[u8]>).unwrap();
362                });
363            });
364
365        writer.flush()?;
366        tracing::debug!(
367            "Posterior parameters written to {:?}",
368            &outputfile.relative_path()
369        );
370
371        Ok(())
372    }
373
374    /// Writes the predictions
375    pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> {
376        tracing::debug!("Writing predictions...");
377
378        let posterior = posterior(&self.psi, &self.w)?;
379
380        // Calculate the predictions
381        let predictions = NPPredictions::calculate(
382            &self.equation,
383            &self.data,
384            self.theta.clone(),
385            &self.w,
386            &posterior,
387            idelta,
388            tad,
389        )?;
390
391        // Write (full) predictions to pred.csv
392        let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?;
393        let mut writer = WriterBuilder::new()
394            .has_headers(true)
395            .from_writer(&outputfile_pred.file);
396
397        // Write each prediction row
398        for row in predictions.predictions() {
399            writer.serialize(row)?;
400        }
401
402        writer.flush()?;
403        tracing::debug!(
404            "Predictions written to {:?}",
405            &outputfile_pred.relative_path()
406        );
407
408        Ok(())
409    }
410
411    /// Writes the covariates
412    pub fn write_covs(&self) -> Result<()> {
413        tracing::debug!("Writing covariates...");
414        let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?;
415        let mut writer = WriterBuilder::new()
416            .has_headers(true)
417            .from_writer(&outputfile.file);
418
419        // Collect all unique covariate names
420        let mut covariate_names = std::collections::HashSet::new();
421        for subject in self.data.subjects() {
422            for occasion in subject.occasions() {
423                let cov = occasion.covariates();
424                let covmap = cov.covariates();
425                for cov_name in covmap.keys() {
426                    covariate_names.insert(cov_name.clone());
427                }
428            }
429        }
430        let mut covariate_names: Vec<String> = covariate_names.into_iter().collect();
431        covariate_names.sort(); // Ensure consistent order
432
433        // Write the header row: id, time, block, covariate names
434        let mut headers = vec!["id", "time", "block"];
435        headers.extend(covariate_names.iter().map(|s| s.as_str()));
436        writer.write_record(&headers)?;
437
438        // Write the data rows
439        for subject in self.data.subjects() {
440            for occasion in subject.occasions() {
441                let cov = occasion.covariates();
442                let covmap = cov.covariates();
443
444                for event in occasion.iter() {
445                    let time = match event {
446                        Event::Bolus(bolus) => bolus.time(),
447                        Event::Infusion(infusion) => infusion.time(),
448                        Event::Observation(observation) => observation.time(),
449                    };
450
451                    let mut row: Vec<String> = Vec::new();
452                    row.push(subject.id().clone());
453                    row.push(time.to_string());
454                    row.push(occasion.index().to_string());
455
456                    // Add covariate values to the row
457                    for cov_name in &covariate_names {
458                        if let Some(cov) = covmap.get(cov_name) {
459                            if let Ok(value) = cov.interpolate(time) {
460                                row.push(value.to_string());
461                            } else {
462                                row.push(String::new());
463                            }
464                        } else {
465                            row.push(String::new());
466                        }
467                    }
468
469                    writer.write_record(&row)?;
470                }
471            }
472        }
473
474        writer.flush()?;
475        tracing::debug!("Covariates written to {:?}", &outputfile.relative_path());
476        Ok(())
477    }
478}
479
480pub(crate) fn median(data: &[f64]) -> f64 {
481    let mut data: Vec<f64> = data.to_vec();
482    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
483
484    let size = data.len();
485    match size {
486        even if even % 2 == 0 => {
487            let fst = data.get(even / 2 - 1).unwrap();
488            let snd = data.get(even / 2).unwrap();
489            (fst + snd) / 2.0
490        }
491        odd => *data.get(odd / 2_usize).unwrap(),
492    }
493}
494
495fn weighted_median(data: &[f64], weights: &[f64]) -> f64 {
496    // Ensure the data and weights arrays have the same length
497    assert_eq!(
498        data.len(),
499        weights.len(),
500        "The length of data and weights must be the same"
501    );
502    assert!(
503        weights.iter().all(|&x| x >= 0.0),
504        "Weights must be non-negative, weights: {:?}",
505        weights
506    );
507
508    // Create a vector of tuples (data, weight)
509    let mut weighted_data: Vec<(f64, f64)> = data
510        .iter()
511        .zip(weights.iter())
512        .map(|(&d, &w)| (d, w))
513        .collect();
514
515    // Sort the vector by the data values
516    weighted_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
517
518    // Calculate the cumulative sum of weights
519    let total_weight: f64 = weights.iter().sum();
520    let mut cumulative_sum = 0.0;
521
522    for (i, &(_, weight)) in weighted_data.iter().enumerate() {
523        cumulative_sum += weight;
524
525        if cumulative_sum == total_weight / 2.0 {
526            // If the cumulative sum equals half the total weight, average this value with the next
527            if i + 1 < weighted_data.len() {
528                return (weighted_data[i].0 + weighted_data[i + 1].0) / 2.0;
529            } else {
530                return weighted_data[i].0;
531            }
532        } else if cumulative_sum > total_weight / 2.0 {
533            return weighted_data[i].0;
534        }
535    }
536
537    unreachable!("The function should have returned a value before reaching this point.");
538}
539
540pub fn population_mean_median(
541    theta: &Array2<f64>,
542    w: &Array1<f64>,
543) -> Result<(Array1<f64>, Array1<f64>)> {
544    let w = if w.is_empty() {
545        tracing::warn!("w.len() == 0, setting all weights to 1/n");
546        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
547    } else {
548        w.clone()
549    };
550    // Check for compatible sizes
551    if theta.nrows() != w.len() {
552        bail!(
553            "Number of parameters and number of weights do not match. Theta: {}, w: {}",
554            theta.nrows(),
555            w.len()
556        );
557    }
558
559    let mut mean = Array1::zeros(theta.ncols());
560    let mut median = Array1::zeros(theta.ncols());
561
562    for (i, (mn, mdn)) in mean.iter_mut().zip(&mut median).enumerate() {
563        // Calculate the weighted mean
564        let col = theta.column(i).to_owned() * w.to_owned();
565        *mn = col.sum();
566
567        // Calculate the median
568        let ct = theta.column(i);
569        let mut params = vec![];
570        let mut weights = vec![];
571        for (ti, wi) in ct.iter().zip(w.clone()) {
572            params.push(*ti);
573            weights.push(wi);
574        }
575
576        *mdn = weighted_median(&params, &weights);
577    }
578
579    Ok((mean, median))
580}
581
582pub fn posterior_mean_median(
583    theta: &Array2<f64>,
584    psi: &Array2<f64>,
585    w: &Array1<f64>,
586) -> Result<(Array2<f64>, Array2<f64>)> {
587    let mut mean = Array2::zeros((0, theta.ncols()));
588    let mut median = Array2::zeros((0, theta.ncols()));
589
590    let w = if w.is_empty() {
591        tracing::warn!("w is empty, setting all weights to 1/n");
592        Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64)
593    } else {
594        w.clone()
595    };
596
597    // Check for compatible sizes
598    if theta.nrows() != w.len() || theta.nrows() != psi.ncols() || psi.ncols() != w.len() {
599        bail!("Number of parameters and number of weights do not match, theta.nrows(): {}, w.len(): {}, psi.ncols(): {}", theta.nrows(), w.len(), psi.ncols());
600    }
601
602    // Normalize psi to get probabilities of each spp for each id
603    let mut psi_norm: Array2<f64> = Array2::zeros((0, psi.ncols()));
604    for (i, row) in psi.axis_iter(Axis(0)).enumerate() {
605        let row_w = row.to_owned() * w.to_owned();
606        let row_sum = row_w.sum();
607        let row_norm = if row_sum == 0.0 {
608            tracing::warn!("Sum of row {} of psi is 0.0, setting that row to 1/n", i);
609            Array1::from_elem(psi.ncols(), 1.0 / psi.ncols() as f64)
610        } else {
611            &row_w / row_sum
612        };
613        psi_norm.push_row(row_norm.view())?;
614    }
615    if psi_norm.iter().any(|&x| x.is_nan()) {
616        dbg!(&psi);
617        bail!("NaN values found in psi_norm");
618    };
619
620    // Transpose normalized psi to get ID (col) by prob (row)
621    // let psi_norm_transposed = psi_norm.t();
622
623    // For each subject..
624    for probs in psi_norm.axis_iter(Axis(0)) {
625        let mut post_mean: Vec<f64> = Vec::new();
626        let mut post_median: Vec<f64> = Vec::new();
627
628        // For each parameter
629        for pars in theta.axis_iter(Axis(1)) {
630            // Calculate the mean
631            let weighted_par = &probs * &pars;
632            let the_mean = weighted_par.sum();
633            post_mean.push(the_mean);
634
635            // Calculate the median
636            let median = weighted_median(&pars.to_vec(), &probs.to_vec());
637            post_median.push(median);
638        }
639
640        mean.push_row(Array::from(post_mean.clone()).view())?;
641        median.push_row(Array::from(post_median.clone()).view())?;
642    }
643
644    Ok((mean, median))
645}
646
647/// Contains all the necessary information of an output file
648#[derive(Debug)]
649pub struct OutputFile {
650    file: File,
651    relative_path: PathBuf,
652}
653
654impl OutputFile {
655    pub fn new(folder: &str, file_name: &str) -> Result<Self> {
656        let relative_path = Path::new(&folder).join(file_name);
657
658        if let Some(parent) = relative_path.parent() {
659            create_dir_all(parent)
660                .with_context(|| format!("Failed to create directories for {:?}", parent))?;
661        }
662
663        let file = OpenOptions::new()
664            .write(true)
665            .create(true)
666            .truncate(true)
667            .open(&relative_path)
668            .with_context(|| format!("Failed to open file: {:?}", relative_path))?;
669
670        Ok(OutputFile {
671            file,
672            relative_path,
673        })
674    }
675
676    pub fn file(&self) -> &File {
677        &self.file
678    }
679
680    pub fn file_owned(self) -> File {
681        self.file
682    }
683
684    pub fn relative_path(&self) -> &Path {
685        &self.relative_path
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::median;
692
693    #[test]
694    fn test_median_odd() {
695        let data = vec![1.0, 3.0, 2.0];
696        assert_eq!(median(&data), 2.0);
697    }
698
699    #[test]
700    fn test_median_even() {
701        let data = vec![1.0, 2.0, 3.0, 4.0];
702        assert_eq!(median(&data), 2.5);
703    }
704
705    #[test]
706    fn test_median_single() {
707        let data = vec![42.0];
708        assert_eq!(median(&data), 42.0);
709    }
710
711    #[test]
712    fn test_median_sorted() {
713        let data = vec![5.0, 10.0, 15.0, 20.0, 25.0];
714        assert_eq!(median(&data), 15.0);
715    }
716
717    #[test]
718    fn test_median_unsorted() {
719        let data = vec![10.0, 30.0, 20.0, 50.0, 40.0];
720        assert_eq!(median(&data), 30.0);
721    }
722
723    #[test]
724    fn test_median_with_duplicates() {
725        let data = vec![1.0, 2.0, 2.0, 3.0, 4.0];
726        assert_eq!(median(&data), 2.0);
727    }
728
729    use super::weighted_median;
730
731    #[test]
732    fn test_weighted_median_simple() {
733        let data = vec![1.0, 2.0, 3.0];
734        let weights = vec![0.2, 0.5, 0.3];
735        assert_eq!(weighted_median(&data, &weights), 2.0);
736    }
737
738    #[test]
739    fn test_weighted_median_even_weights() {
740        let data = vec![1.0, 2.0, 3.0, 4.0];
741        let weights = vec![0.25, 0.25, 0.25, 0.25];
742        assert_eq!(weighted_median(&data, &weights), 2.5);
743    }
744
745    #[test]
746    fn test_weighted_median_single_element() {
747        let data = vec![42.0];
748        let weights = vec![1.0];
749        assert_eq!(weighted_median(&data, &weights), 42.0);
750    }
751
752    #[test]
753    #[should_panic(expected = "The length of data and weights must be the same")]
754    fn test_weighted_median_mismatched_lengths() {
755        let data = vec![1.0, 2.0, 3.0];
756        let weights = vec![0.1, 0.2];
757        weighted_median(&data, &weights);
758    }
759
760    #[test]
761    fn test_weighted_median_all_same_elements() {
762        let data = vec![5.0, 5.0, 5.0, 5.0];
763        let weights = vec![0.1, 0.2, 0.3, 0.4];
764        assert_eq!(weighted_median(&data, &weights), 5.0);
765    }
766
767    #[test]
768    #[should_panic(expected = "Weights must be non-negative")]
769    fn test_weighted_median_negative_weights() {
770        let data = vec![1.0, 2.0, 3.0, 4.0];
771        let weights = vec![0.2, -0.5, 0.5, 0.8];
772        assert_eq!(weighted_median(&data, &weights), 4.0);
773    }
774
775    #[test]
776    fn test_weighted_median_unsorted_data() {
777        let data = vec![3.0, 1.0, 4.0, 2.0];
778        let weights = vec![0.1, 0.3, 0.4, 0.2];
779        assert_eq!(weighted_median(&data, &weights), 2.5);
780    }
781
782    #[test]
783    fn test_weighted_median_with_zero_weights() {
784        let data = vec![1.0, 2.0, 3.0, 4.0];
785        let weights = vec![0.0, 0.0, 1.0, 0.0];
786        assert_eq!(weighted_median(&data, &weights), 3.0);
787    }
788}