pmcore/routines/initialization/
mod.rs

1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13/// The sampler used to generate the grid of support points
14///
15/// The sampler can be one of the following:
16///
17/// - `Sobol`: Generates a Sobol sequence
18/// - `Latin`: Generates a Latin hypercube
19/// - `File`: Reads the prior distribution from a CSV file
20#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22    Sobol(usize, usize),
23    Latin(usize, usize),
24    File(String),
25    #[serde(skip)]
26    Theta(Theta),
27}
28
29impl Prior {
30    pub fn sobol(points: usize, seed: usize) -> Prior {
31        Prior::Sobol(points, seed)
32    }
33
34    pub fn get_points(&self) -> usize {
35        match self {
36            Prior::Sobol(points, _) => *points,
37            Prior::Latin(points, _) => *points,
38            Prior::File(_) => {
39                unimplemented!("File-based prior does not have a fixed number of points")
40            }
41            Prior::Theta(theta) => theta.nspp(),
42        }
43    }
44
45    pub fn get_seed(&self) -> usize {
46        match self {
47            Prior::Sobol(_, seed) => *seed,
48            Prior::Latin(_, seed) => *seed,
49            Prior::File(_) => unimplemented!("File-based prior does not have a fixed seed"),
50            Prior::Theta(_) => {
51                unimplemented!("Custom prior does not have a fixed seed")
52            }
53        }
54    }
55}
56
57impl Default for Prior {
58    fn default() -> Self {
59        Prior::Sobol(2028, 22)
60    }
61}
62
63/// This function generates the grid of support points according to the sampler specified in the [Settings]
64pub fn sample_space(settings: &Settings) -> Result<Theta> {
65    // Otherwise, parse the sampler type and generate the grid
66    let prior = match settings.prior() {
67        Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
68        Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
69        Prior::File(ref path) => parse_prior(path, settings)?,
70        Prior::Theta(ref theta) => {
71            // If a custom prior is provided, return it directly
72            return Ok(theta.clone());
73        }
74    };
75    Ok(prior)
76}
77
78/// This function reads the prior distribution from a file
79pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
80    tracing::info!("Reading prior from {}", path);
81    let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
82    let mut reader = csv::ReaderBuilder::new()
83        .has_headers(true)
84        .from_reader(file);
85
86    let mut parameter_names: Vec<String> = reader
87        .headers()?
88        .clone()
89        .into_iter()
90        .map(|s| s.trim().to_owned())
91        .collect();
92
93    // Remove "prob" column if present
94    if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
95        parameter_names.remove(index);
96    }
97
98    // Check and reorder parameters to match names in settings.parsed.random
99    let random_names: Vec<String> = settings.parameters().names();
100
101    let mut reordered_indices: Vec<usize> = Vec::new();
102    for random_name in &random_names {
103        match parameter_names.iter().position(|name| name == random_name) {
104            Some(index) => {
105                reordered_indices.push(index);
106            }
107            None => {
108                bail!("Parameter {} is not present in the CSV file.", random_name);
109            }
110        }
111    }
112
113    // Check if there are remaining parameters not present in settings.parsed.random
114    if parameter_names.len() > random_names.len() {
115        let extra_parameters: Vec<&String> = parameter_names.iter().collect();
116        bail!(
117            "Found parameters in the prior not present in configuration: {:?}",
118            extra_parameters
119        );
120    }
121
122    // Read parameter values row by row, keeping only those associated with the reordered parameters
123    let mut theta_values = Vec::new();
124    for result in reader.records() {
125        let record = result.unwrap();
126        let values: Vec<f64> = reordered_indices
127            .iter()
128            .map(|&i| record[i].parse::<f64>().unwrap())
129            .collect();
130        theta_values.push(values);
131    }
132
133    let n_points = theta_values.len();
134    let n_params = random_names.len();
135
136    // Convert nested Vec into a single Vec
137    let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
138
139    let theta_matrix: Mat<f64> =
140        Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
141
142    let random = settings
143        .parameters()
144        .iter()
145        .filter(|p| !p.fixed)
146        .collect::<Vec<_>>()
147        .iter()
148        .map(|p| (p.name.clone(), p.lower, p.upper))
149        .collect();
150
151    let theta = Theta::from_parts(theta_matrix, random, Vec::new());
152
153    Ok(theta)
154}