pmcore/routines/initialization/
sobol.rs

1use crate::structs::theta::Theta;
2use anyhow::Result;
3use faer::Mat;
4
5use sobol_burley::sample;
6
7use crate::prelude::Parameters;
8
9/// Generates an instance of [Theta] from a Sobol sequence.
10///
11/// The sequence samples [0, 1), and the values are scaled to the parameter ranges.
12///
13/// # Arguments
14///
15/// * `parameters` - The [Parameters] struct, which contains the parameters to be sampled.
16/// * `points` - The number of points to generate, i.e. the number of rows in the matrix.
17/// * `seed` - The seed for the Sobol sequence generator.
18///
19/// # Returns
20///
21/// [Theta], a structure that holds the support point matrix
22///
23pub fn generate(parameters: &Parameters, points: usize, seed: usize) -> Result<Theta> {
24    let seed = seed as u32;
25    let params: Vec<(String, f64, f64, bool)> = parameters
26        .iter()
27        .map(|p| (p.name.clone(), p.lower, p.upper, p.fixed))
28        .collect();
29
30    // Random parameters are sampled from the Sobol sequence
31    let random_params: Vec<(String, f64, f64)> = params
32        .iter()
33        .filter(|(_, _, _, fixed)| !fixed)
34        .map(|(name, lower, upper, _)| (name.clone(), *lower, *upper))
35        .collect();
36
37    let rand_matrix = Mat::from_fn(points, random_params.len(), |i, j| {
38        let unscaled = sample((i).try_into().unwrap(), j.try_into().unwrap(), seed) as f64;
39        let (_name, lower, upper) = random_params.get(j).unwrap();
40        lower + unscaled * (upper - lower)
41    });
42
43    // Fixed parameters are initialized to the middle of their range
44    let fixed_params: Vec<(String, f64)> = params
45        .iter()
46        .filter(|(_, _, _, fixed)| *fixed)
47        .map(|(name, lower, upper, _)| (name.clone(), (upper - lower) / 2.0))
48        .collect();
49
50    let theta = Theta::from_parts(rand_matrix, random_params, fixed_params);
51    Ok(theta)
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::prelude::Parameters;
58
59    #[test]
60    fn test_sobol() {
61        let params = Parameters::new()
62            .add("a", 0.0, 1.0, false)
63            .add("b", 0.0, 1.0, false)
64            .add("c", 0.0, 1.0, false);
65
66        let theta = generate(&params, 10, 22).unwrap();
67
68        assert_eq!(theta.nspp(), 10);
69        assert_eq!(theta.matrix().ncols(), 3);
70    }
71
72    #[test]
73    fn test_sobol_with_fixed() {
74        let params = Parameters::new()
75            .add("a", 0.0, 1.0, false)
76            .add("b", 0.0, 1.0, false)
77            .add("c", 0.0, 1.0, true);
78
79        let theta = generate(&params, 10, 22).unwrap();
80
81        assert_eq!(theta.nspp(), 10);
82        assert_eq!(theta.matrix().ncols(), 2);
83    }
84
85    #[test]
86    fn test_sobol_ranges() {
87        let params = Parameters::new()
88            .add("a", 0.0, 1.0, false)
89            .add("b", 0.0, 1.0, false)
90            .add("c", 0.0, 1.0, false);
91
92        let theta = generate(&params, 10, 22).unwrap();
93
94        theta.matrix().row_iter().for_each(|row| {
95            row.iter().for_each(|&value| {
96                assert!(value >= 0.0 && value <= 1.0);
97            });
98        });
99    }
100
101    #[test]
102    fn test_sobol_values() {
103        use faer::mat;
104        let params = Parameters::new()
105            .add("a", 0.0, 1.0, false)
106            .add("b", 0.0, 1.0, false)
107            .add("c", 0.0, 1.0, false);
108
109        let theta = generate(&params, 10, 22).unwrap();
110
111        let expected = mat![
112            [0.05276215076446533, 0.609707236289978, 0.29471302032470703], //
113            [0.6993427276611328, 0.4142681360244751, 0.6447571516036987],  //
114            [0.860404372215271, 0.769607663154602, 0.1742185354232788],    //
115            [0.3863574266433716, 0.07018685340881348, 0.9825305938720703], //
116            [0.989533543586731, 0.19934570789337158, 0.4716176986694336],  //
117            [0.29962968826293945, 0.899970293045044, 0.5400241613388062],  //
118            [0.5577576160430908, 0.6990838050842285, 0.859503984451294],   //
119            [
120                0.19194257259368896,
121                0.31645333766937256,
122                0.042426824569702150
123            ], //
124            [0.8874167203903198, 0.5214653015136719, 0.5899909734725952],  //
125            [0.35627472400665283, 0.4780532121658325, 0.42954015731811523]  //
126        ];
127
128        assert_eq!(theta.matrix().to_owned(), expected);
129    }
130}