1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22 Sobol(usize, usize),
23 Latin(usize, usize),
24 File(String),
25 #[serde(skip)]
26 Theta(Theta),
27}
28
29impl Prior {
30 pub fn sobol(points: usize, seed: usize) -> Prior {
31 Prior::Sobol(points, seed)
32 }
33
34 pub fn points(&self) -> Option<usize> {
40 match self {
41 Prior::Sobol(points, _) => Some(*points),
42 Prior::Latin(points, _) => Some(*points),
43 Prior::File(_) => None, Prior::Theta(theta) => Some(theta.nspp()),
45 }
46 }
47
48 pub fn seed(&self) -> Option<usize> {
54 match self {
55 Prior::Sobol(_, seed) => Some(*seed),
56 Prior::Latin(_, seed) => Some(*seed),
57 Prior::File(_) => None, Prior::Theta(_) => None, }
60 }
61}
62
63impl Default for Prior {
64 fn default() -> Self {
65 Prior::Sobol(2028, 22)
66 }
67}
68
69pub fn sample_space(settings: &Settings) -> Result<Theta> {
71 for param in settings.parameters().iter() {
73 if param.lower.is_infinite() || param.upper.is_infinite() {
74 bail!(
75 "Parameter '{}' has infinite bounds: [{}, {}]",
76 param.name,
77 param.lower,
78 param.upper
79 );
80 }
81
82 if param.lower >= param.upper {
84 bail!(
85 "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
86 param.name,
87 param.lower,
88 param.upper
89 );
90 }
91 }
92
93 let prior = match settings.prior() {
95 Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
96 Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
97 Prior::File(ref path) => parse_prior(path, settings)?,
98 Prior::Theta(ref theta) => {
99 return Ok(theta.clone());
101 }
102 };
103 Ok(prior)
104}
105
106pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
108 tracing::info!("Reading prior from {}", path);
109 let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
110 let mut reader = csv::ReaderBuilder::new()
111 .has_headers(true)
112 .from_reader(file);
113
114 let mut parameter_names: Vec<String> = reader
115 .headers()?
116 .clone()
117 .into_iter()
118 .map(|s| s.trim().to_owned())
119 .collect();
120
121 if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
123 parameter_names.remove(index);
124 }
125
126 let random_names: Vec<String> = settings.parameters().names();
128
129 let mut reordered_indices: Vec<usize> = Vec::new();
130 for random_name in &random_names {
131 match parameter_names.iter().position(|name| name == random_name) {
132 Some(index) => {
133 reordered_indices.push(index);
134 }
135 None => {
136 bail!("Parameter {} is not present in the CSV file.", random_name);
137 }
138 }
139 }
140
141 if parameter_names.len() > random_names.len() {
143 let extra_parameters: Vec<&String> = parameter_names.iter().collect();
144 bail!(
145 "Found parameters in the prior not present in configuration: {:?}",
146 extra_parameters
147 );
148 }
149
150 let mut theta_values = Vec::new();
152 for result in reader.records() {
153 let record = result.unwrap();
154 let values: Vec<f64> = reordered_indices
155 .iter()
156 .map(|&i| record[i].parse::<f64>().unwrap())
157 .collect();
158 theta_values.push(values);
159 }
160
161 let n_points = theta_values.len();
162 let n_params = random_names.len();
163
164 let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
166
167 let theta_matrix: Mat<f64> =
168 Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
169
170 let theta = Theta::from_parts(theta_matrix, settings.parameters().clone());
171
172 Ok(theta)
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::prelude::*;
179 use pharmsol::{ErrorModel, ErrorModels, ErrorPoly};
180 use std::fs;
181
182 fn create_test_settings() -> Settings {
183 let parameters = Parameters::new().add("ke", 0.1, 1.0).add("v", 5.0, 50.0);
184
185 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
186 let ems = ErrorModels::new().add(0, em).unwrap();
187
188 Settings::builder()
189 .set_algorithm(Algorithm::NPAG)
190 .set_parameters(parameters)
191 .set_error_models(ems)
192 .build()
193 }
194
195 fn create_temp_csv_file(content: &str) -> String {
196 let temp_path = format!("test_temp_{}.csv", rand::random::<u32>());
197 fs::write(&temp_path, content).unwrap();
198 temp_path
199 }
200
201 fn cleanup_temp_file(path: &str) {
202 let _ = fs::remove_file(path);
203 }
204
205 #[test]
206 fn test_prior_sobol_creation() {
207 let prior = Prior::sobol(100, 42);
208 assert_eq!(prior.points(), Some(100));
209 assert_eq!(prior.seed(), Some(42));
210 }
211
212 #[test]
213 fn test_prior_latin_creation() {
214 let prior = Prior::Latin(50, 123);
215 assert_eq!(prior.points(), Some(50));
216 assert_eq!(prior.seed(), Some(123));
217 }
218
219 #[test]
220 fn test_prior_default() {
221 let prior = Prior::default();
222 assert_eq!(prior.points(), Some(2028));
223 assert_eq!(prior.seed(), Some(22));
224 }
225
226 #[test]
227 fn test_prior_file_points() {
228 let prior = Prior::File("test.csv".to_string());
229 assert_eq!(prior.points(), None);
230 }
231
232 #[test]
233 fn test_prior_file_seed() {
234 let prior = Prior::File("test.csv".to_string());
235 assert_eq!(prior.seed(), None);
236 }
237
238 #[test]
239 fn test_sample_space_sobol() {
240 let mut settings = create_test_settings();
241 settings.set_prior(Prior::sobol(10, 42));
242
243 let result = sample_space(&settings);
244 assert!(result.is_ok());
245
246 let theta = result.unwrap();
247 assert_eq!(theta.nspp(), 10);
248 assert_eq!(theta.matrix().ncols(), 2);
249 }
250
251 #[test]
252 fn test_sample_space_latin() {
253 let mut settings = create_test_settings();
254 settings.set_prior(Prior::Latin(15, 123));
255
256 let result = sample_space(&settings);
257 assert!(result.is_ok());
258
259 let theta = result.unwrap();
260 assert_eq!(theta.nspp(), 15);
261 assert_eq!(theta.matrix().ncols(), 2);
262 }
263
264 #[test]
265 fn test_sample_space_custom_theta() {
266 let mut settings = create_test_settings();
267
268 let parameters = settings.parameters().clone();
270 let matrix = faer::Mat::from_fn(3, 2, |i, j| (i + j) as f64);
271 let custom_theta = Theta::from_parts(matrix, parameters);
272
273 let prior = Prior::Theta(custom_theta.clone());
274 settings.set_prior(Prior::Theta(custom_theta.clone()));
275
276 let result = sample_space(&settings);
277 assert!(result.is_ok());
278
279 let theta = result.unwrap();
280 assert_eq!(theta.nspp(), 3);
281 assert_eq!(theta.matrix().ncols(), 2);
282 assert_eq!(theta, custom_theta);
283 assert!(prior.points() == Some(3));
284 }
285
286 #[test]
287 fn test_sample_space_infinite_bounds_error() {
288 let parameters = Parameters::new()
289 .add("ke", f64::NEG_INFINITY, 1.0) .add("v", 5.0, 50.0);
291
292 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
293 let ems = ErrorModels::new().add(0, em).unwrap();
294
295 let mut settings = Settings::builder()
296 .set_algorithm(Algorithm::NPAG)
297 .set_parameters(parameters)
298 .set_error_models(ems)
299 .build();
300
301 settings.set_prior(Prior::sobol(10, 42));
302
303 let result = sample_space(&settings);
304 assert!(result.is_err());
305 assert!(result.unwrap_err().to_string().contains("infinite bounds"));
306 }
307
308 #[test]
309 fn test_sample_space_invalid_bounds_error() {
310 let parameters = Parameters::new()
311 .add("ke", 1.0, 0.5) .add("v", 5.0, 50.0);
313
314 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
315 let ems = ErrorModels::new().add(0, em).unwrap();
316
317 let mut settings = Settings::builder()
318 .set_algorithm(Algorithm::NPAG)
319 .set_parameters(parameters)
320 .set_error_models(ems)
321 .build();
322
323 settings.set_prior(Prior::sobol(10, 42));
324
325 let result = sample_space(&settings);
326 assert!(result.is_err());
327 assert!(result.unwrap_err().to_string().contains("invalid bounds"));
328 }
329
330 #[test]
331 fn test_parse_prior_valid_file() {
332 let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
333 let temp_path = create_temp_csv_file(csv_content);
334
335 let settings = create_test_settings();
336
337 let result = parse_prior(&temp_path, &settings);
338 assert!(result.is_ok());
339
340 let theta = result.unwrap();
341 assert_eq!(theta.nspp(), 3);
342 assert_eq!(theta.matrix().ncols(), 2);
343
344 cleanup_temp_file(&temp_path);
345 }
346
347 #[test]
348 fn test_parse_prior_with_prob_column() {
349 let csv_content = "ke,v,prob\n0.1,10.0,0.5\n0.2,15.0,0.3\n0.3,20.0,0.2\n";
350 let temp_path = create_temp_csv_file(csv_content);
351
352 let settings = create_test_settings();
353
354 let result = parse_prior(&temp_path, &settings);
355 assert!(result.is_ok());
356
357 let theta = result.unwrap();
358 assert_eq!(theta.nspp(), 3);
359 assert_eq!(theta.matrix().ncols(), 2);
360
361 cleanup_temp_file(&temp_path);
362 }
363
364 #[test]
365 fn test_parse_prior_missing_parameter() {
366 let csv_content = "ke\n0.1\n0.2\n0.3\n";
367 let temp_path = create_temp_csv_file(csv_content);
368
369 let settings = create_test_settings();
370
371 let result = parse_prior(&temp_path, &settings);
372 assert!(result.is_err());
373 assert!(result
374 .unwrap_err()
375 .to_string()
376 .contains("Parameter v is not present"));
377
378 cleanup_temp_file(&temp_path);
379 }
380
381 #[test]
382 fn test_parse_prior_extra_parameters() {
383 let csv_content = "ke,v,extra_param\n0.1,10.0,1.0\n0.2,15.0,2.0\n0.3,20.0,3.0\n";
384 let temp_path = create_temp_csv_file(csv_content);
385
386 let settings = create_test_settings();
387
388 let result = parse_prior(&temp_path, &settings);
389 assert!(result.is_err());
390 assert!(result
391 .unwrap_err()
392 .to_string()
393 .contains("Found parameters in the prior not present in configuration"));
394
395 cleanup_temp_file(&temp_path);
396 }
397
398 #[test]
399 fn test_parse_prior_nonexistent_file() {
400 let settings = create_test_settings();
401 let file_path = "nonexistent_file.csv".to_string();
402
403 let result = parse_prior(&file_path, &settings);
404 assert!(result.is_err());
405 assert!(result
406 .unwrap_err()
407 .to_string()
408 .contains("Unable to open the prior file"));
409 }
410
411 #[test]
412 fn test_parse_prior_reordered_columns() {
413 let csv_content = "v,ke\n10.0,0.1\n15.0,0.2\n20.0,0.3\n";
414 let temp_path = create_temp_csv_file(csv_content);
415
416 let settings = create_test_settings();
417
418 let result = parse_prior(&temp_path, &settings);
419 assert!(result.is_ok());
420
421 let theta = result.unwrap();
422 assert_eq!(theta.nspp(), 3);
423 assert_eq!(theta.matrix().ncols(), 2);
424
425 let matrix = theta.matrix();
427 assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); cleanup_temp_file(&temp_path);
431 }
432
433 #[test]
434 fn test_sample_space_file_based() {
435 let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
436 let temp_path = create_temp_csv_file(csv_content);
437
438 let mut settings = create_test_settings();
439 settings.set_prior(Prior::File(temp_path.clone()));
440
441 let result = sample_space(&settings);
442 assert!(result.is_ok());
443
444 let theta = result.unwrap();
445 assert_eq!(theta.nspp(), 3);
446 assert_eq!(theta.matrix().ncols(), 2);
447
448 cleanup_temp_file(&temp_path);
449 }
450
451 #[test]
452 fn test_prior_theta_no_seed_panic() {
453 let parameters = Parameters::new().add("ke", 0.1, 1.0);
454 let matrix = faer::Mat::from_fn(1, 1, |_, _| 0.5);
455 let theta = Theta::from_parts(matrix, parameters);
456 let prior = Prior::Theta(theta);
457
458 assert_eq!(prior.seed(), None, "Theta prior should not have a seed");
459 }
460}