pmcore/routines/initialization/
latin.rs1use anyhow::Result;
2use faer::Mat;
3use rand::prelude::*;
4use rand::rngs::StdRng;
5
6use crate::prelude::Parameters;
7use crate::structs::theta::Theta;
8
9pub fn generate(parameters: &Parameters, points: usize, seed: usize) -> Result<Theta> {
22 let params: Vec<(String, f64, f64)> = parameters
23 .iter()
24 .map(|p| (p.name.clone(), p.lower, p.upper))
25 .collect();
26
27 let mut rng = StdRng::seed_from_u64(seed as u64);
29
30 let mut intervals = Vec::new();
32 for _ in 0..params.len() {
33 let mut param_intervals: Vec<f64> = (0..points).map(|i| i as f64).collect();
34 param_intervals.shuffle(&mut rng);
35 intervals.push(param_intervals);
36 }
37
38 let rand_matrix = Mat::from_fn(points, params.len(), |i, j| {
39 let interval = intervals[j][i];
41 let random_offset = rng.random::<f64>();
42 let unscaled = (interval + random_offset) / points as f64;
44 let (_name, lower, upper) = params.get(j).unwrap(); lower + unscaled * (upper - lower)
47 });
48
49 let theta = Theta::from_parts(rand_matrix, parameters.clone())?;
50
51 Ok(theta)
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::prelude::Parameters;
58 use faer::mat;
59
60 #[test]
61 fn test_latin_hypercube() {
62 let params = Parameters::new()
63 .add("a", 0.0, 1.0)
64 .add("b", 0.0, 1.0)
65 .add("c", 0.0, 1.0);
66
67 let theta = generate(¶ms, 10, 22).unwrap();
68
69 assert_eq!(theta.nspp(), 10);
70 assert_eq!(theta.matrix().ncols(), 3);
71 }
72
73 #[test]
74 fn test_latin_hypercube_values() {
75 let params = Parameters::new()
76 .add("a", 0.0, 1.0)
77 .add("b", 0.0, 1.0)
78 .add("c", 0.0, 1.0);
79
80 let theta = generate(¶ms, 10, 22).unwrap();
81
82 let expected = mat![
83 [0.9318592685623417, 0.5609665425179973, 0.3351914901515939], [0.5470144220416706, 0.13513808559222779, 0.1067962439473777], [0.34525902829190547, 0.4636722699673962, 0.9142146621998218], [0.24828355387285125, 0.8638104433695395, 0.41653980640777954], [0.7642037770085612, 0.6806932027789437, 0.5608053599272136], [0.19409389824004936, 0.9378790633419902, 0.6039530631991072], [0.04886813284275151, 0.7140428162864041, 0.7855069414226704], [0.6987026842780971, 0.32378779989236495, 0.8888807957183007], [0.4221279608793599, 0.08001464382386277, 0.20689573661666943], [0.8310112718320113, 0.29390050406905127, 0.04806137233953963], ];
94
95 assert_eq!(theta.matrix().to_owned(), expected);
96 }
97}