pmcore/routines/initialization/
mod.rs

1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13/// The sampler used to generate the grid of support points
14///
15/// The sampler can be one of the following:
16///
17/// - `Sobol`: Generates a Sobol sequence
18/// - `Latin`: Generates a Latin hypercube
19/// - `File`: Reads the prior distribution from a CSV file
20#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22    Sobol(usize, usize),
23    Latin(usize, usize),
24    File(String),
25    #[serde(skip)]
26    Theta(Theta),
27}
28
29impl Prior {
30    pub fn sobol(points: usize, seed: usize) -> Prior {
31        Prior::Sobol(points, seed)
32    }
33
34    /// Get the number of initial support points
35    pub fn points(&self) -> usize {
36        match self {
37            Prior::Sobol(points, _) => *points,
38            Prior::Latin(points, _) => *points,
39            Prior::File(_) => {
40                unimplemented!("File-based prior does not have a fixed number of points")
41            }
42            Prior::Theta(theta) => theta.nspp(),
43        }
44    }
45
46    /// Get the seed used for the random number generator
47    pub fn seed(&self) -> usize {
48        match self {
49            Prior::Sobol(_, seed) => *seed,
50            Prior::Latin(_, seed) => *seed,
51            Prior::File(_) => unimplemented!("File-based prior does not have a fixed seed"),
52            Prior::Theta(_) => {
53                unimplemented!("Custom prior does not have a fixed seed")
54            }
55        }
56    }
57}
58
59impl Default for Prior {
60    fn default() -> Self {
61        Prior::Sobol(2028, 22)
62    }
63}
64
65/// This function generates the grid of support points according to the sampler specified in the [Settings]
66pub fn sample_space(settings: &Settings) -> Result<Theta> {
67    // Ensure that the parameter ranges are not infinite
68    for param in settings.parameters().iter() {
69        if param.lower.is_infinite() || param.upper.is_infinite() {
70            bail!(
71                "Parameter '{}' has infinite bounds: [{}, {}]",
72                param.name,
73                param.lower,
74                param.upper
75            );
76        }
77
78        // Ensure that the lower bound is less than the upper bound
79        if param.lower >= param.upper {
80            bail!(
81                "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
82                param.name,
83                param.lower,
84                param.upper
85            );
86        }
87    }
88
89    // Otherwise, parse the sampler type and generate the grid
90    let prior = match settings.prior() {
91        Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
92        Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
93        Prior::File(ref path) => parse_prior(path, settings)?,
94        Prior::Theta(ref theta) => {
95            // If a custom prior is provided, return it directly
96            return Ok(theta.clone());
97        }
98    };
99    Ok(prior)
100}
101
102/// This function reads the prior distribution from a file
103pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
104    tracing::info!("Reading prior from {}", path);
105    let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
106    let mut reader = csv::ReaderBuilder::new()
107        .has_headers(true)
108        .from_reader(file);
109
110    let mut parameter_names: Vec<String> = reader
111        .headers()?
112        .clone()
113        .into_iter()
114        .map(|s| s.trim().to_owned())
115        .collect();
116
117    // Remove "prob" column if present
118    if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
119        parameter_names.remove(index);
120    }
121
122    // Check and reorder parameters to match names in settings.parsed.random
123    let random_names: Vec<String> = settings.parameters().names();
124
125    let mut reordered_indices: Vec<usize> = Vec::new();
126    for random_name in &random_names {
127        match parameter_names.iter().position(|name| name == random_name) {
128            Some(index) => {
129                reordered_indices.push(index);
130            }
131            None => {
132                bail!("Parameter {} is not present in the CSV file.", random_name);
133            }
134        }
135    }
136
137    // Check if there are remaining parameters not present in settings.parsed.random
138    if parameter_names.len() > random_names.len() {
139        let extra_parameters: Vec<&String> = parameter_names.iter().collect();
140        bail!(
141            "Found parameters in the prior not present in configuration: {:?}",
142            extra_parameters
143        );
144    }
145
146    // Read parameter values row by row, keeping only those associated with the reordered parameters
147    let mut theta_values = Vec::new();
148    for result in reader.records() {
149        let record = result.unwrap();
150        let values: Vec<f64> = reordered_indices
151            .iter()
152            .map(|&i| record[i].parse::<f64>().unwrap())
153            .collect();
154        theta_values.push(values);
155    }
156
157    let n_points = theta_values.len();
158    let n_params = random_names.len();
159
160    // Convert nested Vec into a single Vec
161    let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
162
163    let theta_matrix: Mat<f64> =
164        Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
165
166    let theta = Theta::from_parts(theta_matrix, settings.parameters().clone());
167
168    Ok(theta)
169}