pmcore/routines/initialization/
mod.rs1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22 Sobol(usize, usize),
23 Latin(usize, usize),
24 File(String),
25 #[serde(skip)]
26 Theta(Theta),
27}
28
29impl Prior {
30 pub fn sobol(points: usize, seed: usize) -> Prior {
31 Prior::Sobol(points, seed)
32 }
33
34 pub fn points(&self) -> usize {
36 match self {
37 Prior::Sobol(points, _) => *points,
38 Prior::Latin(points, _) => *points,
39 Prior::File(_) => {
40 unimplemented!("File-based prior does not have a fixed number of points")
41 }
42 Prior::Theta(theta) => theta.nspp(),
43 }
44 }
45
46 pub fn seed(&self) -> usize {
48 match self {
49 Prior::Sobol(_, seed) => *seed,
50 Prior::Latin(_, seed) => *seed,
51 Prior::File(_) => unimplemented!("File-based prior does not have a fixed seed"),
52 Prior::Theta(_) => {
53 unimplemented!("Custom prior does not have a fixed seed")
54 }
55 }
56 }
57}
58
59impl Default for Prior {
60 fn default() -> Self {
61 Prior::Sobol(2028, 22)
62 }
63}
64
65pub fn sample_space(settings: &Settings) -> Result<Theta> {
67 for param in settings.parameters().iter() {
69 if param.lower.is_infinite() || param.upper.is_infinite() {
70 bail!(
71 "Parameter '{}' has infinite bounds: [{}, {}]",
72 param.name,
73 param.lower,
74 param.upper
75 );
76 }
77
78 if param.lower >= param.upper {
80 bail!(
81 "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
82 param.name,
83 param.lower,
84 param.upper
85 );
86 }
87 }
88
89 let prior = match settings.prior() {
91 Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
92 Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
93 Prior::File(ref path) => parse_prior(path, settings)?,
94 Prior::Theta(ref theta) => {
95 return Ok(theta.clone());
97 }
98 };
99 Ok(prior)
100}
101
102pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
104 tracing::info!("Reading prior from {}", path);
105 let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
106 let mut reader = csv::ReaderBuilder::new()
107 .has_headers(true)
108 .from_reader(file);
109
110 let mut parameter_names: Vec<String> = reader
111 .headers()?
112 .clone()
113 .into_iter()
114 .map(|s| s.trim().to_owned())
115 .collect();
116
117 if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
119 parameter_names.remove(index);
120 }
121
122 let random_names: Vec<String> = settings.parameters().names();
124
125 let mut reordered_indices: Vec<usize> = Vec::new();
126 for random_name in &random_names {
127 match parameter_names.iter().position(|name| name == random_name) {
128 Some(index) => {
129 reordered_indices.push(index);
130 }
131 None => {
132 bail!("Parameter {} is not present in the CSV file.", random_name);
133 }
134 }
135 }
136
137 if parameter_names.len() > random_names.len() {
139 let extra_parameters: Vec<&String> = parameter_names.iter().collect();
140 bail!(
141 "Found parameters in the prior not present in configuration: {:?}",
142 extra_parameters
143 );
144 }
145
146 let mut theta_values = Vec::new();
148 for result in reader.records() {
149 let record = result.unwrap();
150 let values: Vec<f64> = reordered_indices
151 .iter()
152 .map(|&i| record[i].parse::<f64>().unwrap())
153 .collect();
154 theta_values.push(values);
155 }
156
157 let n_points = theta_values.len();
158 let n_params = random_names.len();
159
160 let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
162
163 let theta_matrix: Mat<f64> =
164 Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
165
166 let theta = Theta::from_parts(theta_matrix, settings.parameters().clone());
167
168 Ok(theta)
169}