pmcore/routines/initialization/
latin.rs1use anyhow::Result;
2use faer::Mat;
3use rand::prelude::*;
4use rand::rngs::StdRng;
5use rand::Rng;
6
7use crate::prelude::Parameters;
8use crate::structs::theta::Theta;
9
10pub fn generate(parameters: &Parameters, points: usize, seed: usize) -> Result<Theta> {
23 let params: Vec<(String, f64, f64)> = parameters
24 .iter()
25 .map(|p| (p.name.clone(), p.lower, p.upper))
26 .collect();
27
28 let mut rng = StdRng::seed_from_u64(seed as u64);
30
31 let mut intervals = Vec::new();
33 for _ in 0..params.len() {
34 let mut param_intervals: Vec<f64> = (0..points).map(|i| i as f64).collect();
35 param_intervals.shuffle(&mut rng);
36 intervals.push(param_intervals);
37 }
38
39 let rand_matrix = Mat::from_fn(points, params.len(), |i, j| {
40 let interval = intervals[j][i];
42 let random_offset = rng.random::<f64>();
43 let unscaled = (interval + random_offset) / points as f64;
45 let (_name, lower, upper) = params.get(j).unwrap(); lower + unscaled * (upper - lower)
48 });
49
50 let theta = Theta::from_parts(rand_matrix, parameters.clone());
51
52 Ok(theta)
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use crate::prelude::Parameters;
59 use faer::mat;
60
61 #[test]
62 fn test_latin_hypercube() {
63 let params = Parameters::new()
64 .add("a", 0.0, 1.0)
65 .add("b", 0.0, 1.0)
66 .add("c", 0.0, 1.0);
67
68 let theta = generate(¶ms, 10, 22).unwrap();
69
70 assert_eq!(theta.nspp(), 10);
71 assert_eq!(theta.matrix().ncols(), 3);
72 }
73
74 #[test]
75 fn test_latin_hypercube_values() {
76 let params = Parameters::new()
77 .add("a", 0.0, 1.0)
78 .add("b", 0.0, 1.0)
79 .add("c", 0.0, 1.0);
80
81 let theta = generate(¶ms, 10, 22).unwrap();
82
83 let expected = mat![
84 [0.9318592685623417, 0.5609665425179973, 0.3351914901515939], [0.5470144220416706, 0.13513808559222779, 0.1067962439473777], [0.34525902829190547, 0.4636722699673962, 0.9142146621998218], [0.24828355387285125, 0.8638104433695395, 0.41653980640777954], [0.7642037770085612, 0.6806932027789437, 0.5608053599272136], [0.19409389824004936, 0.9378790633419902, 0.6039530631991072], [0.04886813284275151, 0.7140428162864041, 0.7855069414226704], [0.6987026842780971, 0.32378779989236495, 0.8888807957183007], [0.4221279608793599, 0.08001464382386277, 0.20689573661666943], [0.8310112718320113, 0.29390050406905127, 0.04806137233953963], ];
95
96 assert_eq!(theta.matrix().to_owned(), expected);
97 }
98}