pmcore/routines/initialization/
mod.rs1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22 Sobol(usize, usize),
23 Latin(usize, usize),
24 File(String),
25 #[serde(skip)]
26 Theta(Theta),
27}
28
29impl Prior {
30 pub fn sobol(points: usize, seed: usize) -> Prior {
31 Prior::Sobol(points, seed)
32 }
33
34 pub fn get_points(&self) -> usize {
35 match self {
36 Prior::Sobol(points, _) => *points,
37 Prior::Latin(points, _) => *points,
38 Prior::File(_) => {
39 unimplemented!("File-based prior does not have a fixed number of points")
40 }
41 Prior::Theta(theta) => theta.nspp(),
42 }
43 }
44
45 pub fn get_seed(&self) -> usize {
46 match self {
47 Prior::Sobol(_, seed) => *seed,
48 Prior::Latin(_, seed) => *seed,
49 Prior::File(_) => unimplemented!("File-based prior does not have a fixed seed"),
50 Prior::Theta(_) => {
51 unimplemented!("Custom prior does not have a fixed seed")
52 }
53 }
54 }
55}
56
57impl Default for Prior {
58 fn default() -> Self {
59 Prior::Sobol(2028, 22)
60 }
61}
62
63pub fn sample_space(settings: &Settings) -> Result<Theta> {
65 let prior = match settings.prior() {
67 Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
68 Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
69 Prior::File(ref path) => parse_prior(path, settings)?,
70 Prior::Theta(ref theta) => {
71 return Ok(theta.clone());
73 }
74 };
75 Ok(prior)
76}
77
78pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
80 tracing::info!("Reading prior from {}", path);
81 let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
82 let mut reader = csv::ReaderBuilder::new()
83 .has_headers(true)
84 .from_reader(file);
85
86 let mut parameter_names: Vec<String> = reader
87 .headers()?
88 .clone()
89 .into_iter()
90 .map(|s| s.trim().to_owned())
91 .collect();
92
93 if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
95 parameter_names.remove(index);
96 }
97
98 let random_names: Vec<String> = settings.parameters().names();
100
101 let mut reordered_indices: Vec<usize> = Vec::new();
102 for random_name in &random_names {
103 match parameter_names.iter().position(|name| name == random_name) {
104 Some(index) => {
105 reordered_indices.push(index);
106 }
107 None => {
108 bail!("Parameter {} is not present in the CSV file.", random_name);
109 }
110 }
111 }
112
113 if parameter_names.len() > random_names.len() {
115 let extra_parameters: Vec<&String> = parameter_names.iter().collect();
116 bail!(
117 "Found parameters in the prior not present in configuration: {:?}",
118 extra_parameters
119 );
120 }
121
122 let mut theta_values = Vec::new();
124 for result in reader.records() {
125 let record = result.unwrap();
126 let values: Vec<f64> = reordered_indices
127 .iter()
128 .map(|&i| record[i].parse::<f64>().unwrap())
129 .collect();
130 theta_values.push(values);
131 }
132
133 let n_points = theta_values.len();
134 let n_params = random_names.len();
135
136 let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
138
139 let theta_matrix: Mat<f64> =
140 Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
141
142 let random = settings
143 .parameters()
144 .iter()
145 .filter(|p| !p.fixed)
146 .collect::<Vec<_>>()
147 .iter()
148 .map(|p| (p.name.clone(), p.lower, p.upper))
149 .collect();
150
151 let theta = Theta::from_parts(theta_matrix, random, Vec::new());
152
153 Ok(theta)
154}