pmcore/routines/initialization/
sobol.rs1use crate::structs::theta::Theta;
2use anyhow::Result;
3use faer::Mat;
4
5use sobol_burley::sample;
6
7use crate::prelude::Parameters;
8
9pub fn generate(parameters: &Parameters, points: usize, seed: usize) -> Result<Theta> {
24 let seed = seed as u32;
25 let params: Vec<(String, f64, f64)> = parameters
26 .iter()
27 .map(|p| (p.name.clone(), p.lower, p.upper))
28 .collect();
29
30 let rand_matrix = Mat::from_fn(points, params.len(), |i, j| {
31 let unscaled = sample((i).try_into().unwrap(), j.try_into().unwrap(), seed) as f64;
32 let (_name, lower, upper) = params.get(j).unwrap();
33 lower + unscaled * (upper - lower)
34 });
35
36 let theta = Theta::from_parts(rand_matrix, parameters.clone());
37 Ok(theta)
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use crate::prelude::Parameters;
44
45 #[test]
46 fn test_sobol() {
47 let params = Parameters::new()
48 .add("a", 0.0, 1.0)
49 .add("b", 0.0, 1.0)
50 .add("c", 0.0, 1.0);
51
52 let theta = generate(¶ms, 10, 22).unwrap();
53
54 assert_eq!(theta.nspp(), 10);
55 assert_eq!(theta.matrix().ncols(), 3);
56 }
57
58 #[test]
59 fn test_sobol_ranges() {
60 let params = Parameters::new()
61 .add("a", 0.0, 1.0)
62 .add("b", 0.0, 1.0)
63 .add("c", 0.0, 1.0);
64
65 let theta = generate(¶ms, 10, 22).unwrap();
66
67 theta.matrix().row_iter().for_each(|row| {
68 row.iter().for_each(|&value| {
69 assert!(value >= 0.0 && value <= 1.0);
70 });
71 });
72 }
73
74 #[test]
75 fn test_sobol_values() {
76 use faer::mat;
77 let params = Parameters::new()
78 .add("a", 0.0, 1.0)
79 .add("b", 0.0, 1.0)
80 .add("c", 0.0, 1.0);
81
82 let theta = generate(¶ms, 10, 22).unwrap();
83
84 let expected = mat![
85 [0.05276215076446533, 0.609707236289978, 0.29471302032470703], [0.6993427276611328, 0.4142681360244751, 0.6447571516036987], [0.860404372215271, 0.769607663154602, 0.1742185354232788], [0.3863574266433716, 0.07018685340881348, 0.9825305938720703], [0.989533543586731, 0.19934570789337158, 0.4716176986694336], [0.29962968826293945, 0.899970293045044, 0.5400241613388062], [0.5577576160430908, 0.6990838050842285, 0.859503984451294], [
93 0.19194257259368896,
94 0.31645333766937256,
95 0.042426824569702150
96 ], [0.8874167203903198, 0.5214653015136719, 0.5899909734725952], [0.35627472400665283, 0.4780532121658325, 0.42954015731811523] ];
100
101 assert_eq!(theta.matrix().to_owned(), expected);
102 }
103}