1use std::fs::File;
2
3use crate::structs::{theta::Theta, weights::Weights};
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22 Sobol(usize, usize),
23 Latin(usize, usize),
24 File(String),
25 #[serde(skip)]
26 Theta(Theta),
27}
28
29impl Prior {
30 pub fn sobol(points: usize, seed: usize) -> Prior {
31 Prior::Sobol(points, seed)
32 }
33
34 pub fn points(&self) -> Option<usize> {
40 match self {
41 Prior::Sobol(points, _) => Some(*points),
42 Prior::Latin(points, _) => Some(*points),
43 Prior::File(_) => None, Prior::Theta(theta) => Some(theta.nspp()),
45 }
46 }
47
48 pub fn seed(&self) -> Option<usize> {
54 match self {
55 Prior::Sobol(_, seed) => Some(*seed),
56 Prior::Latin(_, seed) => Some(*seed),
57 Prior::File(_) => None, Prior::Theta(_) => None, }
60 }
61}
62
63impl Default for Prior {
64 fn default() -> Self {
65 Prior::Sobol(2028, 22)
66 }
67}
68
69pub fn sample_space(settings: &Settings) -> Result<Theta> {
71 for param in settings.parameters().iter() {
73 if param.lower.is_infinite() || param.upper.is_infinite() {
74 bail!(
75 "Parameter '{}' has infinite bounds: [{}, {}]",
76 param.name,
77 param.lower,
78 param.upper
79 );
80 }
81
82 if param.lower >= param.upper {
84 bail!(
85 "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
86 param.name,
87 param.lower,
88 param.upper
89 );
90 }
91 }
92
93 let prior = match settings.prior() {
95 Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
96 Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
97 Prior::File(ref path) => parse_prior(path, settings)?.0,
98 Prior::Theta(ref theta) => {
99 return Ok(theta.clone());
101 }
102 };
103 Ok(prior)
104}
105
106pub fn parse_prior(path: &String, settings: &Settings) -> Result<(Theta, Option<Weights>)> {
108 tracing::info!("Reading prior from {}", path);
109 let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
110 let mut reader = csv::ReaderBuilder::new()
111 .has_headers(true)
112 .from_reader(file);
113
114 let mut parameter_names: Vec<String> = reader
115 .headers()?
116 .clone()
117 .into_iter()
118 .map(|s| s.trim().to_owned())
119 .collect();
120
121 let prob_index = parameter_names.iter().position(|name| name == "prob");
123
124 if let Some(index) = prob_index {
126 parameter_names.remove(index);
127 }
128
129 let random_names: Vec<String> = settings.parameters().names();
131
132 let mut reordered_indices: Vec<usize> = Vec::new();
133 for random_name in &random_names {
134 match parameter_names.iter().position(|name| name == random_name) {
135 Some(index) => {
136 let adjusted_index = if let Some(prob_idx) = prob_index {
138 if index >= prob_idx {
139 index + 1 } else {
141 index
142 }
143 } else {
144 index
145 };
146 reordered_indices.push(adjusted_index);
147 }
148 None => {
149 bail!("Parameter {} is not present in the CSV file.", random_name);
150 }
151 }
152 }
153
154 if parameter_names.len() > random_names.len() {
156 let extra_parameters: Vec<&String> = parameter_names.iter().collect();
157 bail!(
158 "Found parameters in the prior not present in configuration: {:?}",
159 extra_parameters
160 );
161 }
162
163 let mut theta_values = Vec::new();
165 let mut prob_values = Vec::new();
166
167 for result in reader.records() {
168 let record = result.unwrap();
169
170 let values: Vec<f64> = reordered_indices
172 .iter()
173 .map(|&i| record[i].parse::<f64>().unwrap())
174 .collect();
175 theta_values.push(values);
176
177 if let Some(prob_idx) = prob_index {
179 let prob_value: f64 = record[prob_idx].parse::<f64>().unwrap();
180 prob_values.push(prob_value);
181 }
182 }
183
184 let n_points = theta_values.len();
185 let n_params = random_names.len();
186
187 let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
189
190 let theta_matrix: Mat<f64> =
191 Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
192
193 let theta = Theta::from_parts(theta_matrix, settings.parameters().clone())?;
194
195 let weights = if !prob_values.is_empty() {
197 Some(Weights::from_vec(prob_values))
198 } else {
199 None
200 };
201
202 Ok((theta, weights))
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::prelude::*;
209 use pharmsol::{ErrorModel, ErrorModels, ErrorPoly};
210 use std::fs;
211
212 fn create_test_settings() -> Settings {
213 let parameters = Parameters::new().add("ke", 0.1, 1.0).add("v", 5.0, 50.0);
214
215 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
216 let ems = ErrorModels::new().add(0, em).unwrap();
217
218 Settings::builder()
219 .set_algorithm(Algorithm::NPAG)
220 .set_parameters(parameters)
221 .set_error_models(ems)
222 .build()
223 }
224
225 fn create_temp_csv_file(content: &str) -> String {
226 let temp_path = format!("test_temp_{}.csv", rand::random::<u32>());
227 fs::write(&temp_path, content).unwrap();
228 temp_path
229 }
230
231 fn cleanup_temp_file(path: &str) {
232 let _ = fs::remove_file(path);
233 }
234
235 #[test]
236 fn test_prior_sobol_creation() {
237 let prior = Prior::sobol(100, 42);
238 assert_eq!(prior.points(), Some(100));
239 assert_eq!(prior.seed(), Some(42));
240 }
241
242 #[test]
243 fn test_prior_latin_creation() {
244 let prior = Prior::Latin(50, 123);
245 assert_eq!(prior.points(), Some(50));
246 assert_eq!(prior.seed(), Some(123));
247 }
248
249 #[test]
250 fn test_prior_default() {
251 let prior = Prior::default();
252 assert_eq!(prior.points(), Some(2028));
253 assert_eq!(prior.seed(), Some(22));
254 }
255
256 #[test]
257 fn test_prior_file_points() {
258 let prior = Prior::File("test.csv".to_string());
259 assert_eq!(prior.points(), None);
260 }
261
262 #[test]
263 fn test_prior_file_seed() {
264 let prior = Prior::File("test.csv".to_string());
265 assert_eq!(prior.seed(), None);
266 }
267
268 #[test]
269 fn test_sample_space_sobol() {
270 let mut settings = create_test_settings();
271 settings.set_prior(Prior::sobol(10, 42));
272
273 let result = sample_space(&settings);
274 assert!(result.is_ok());
275
276 let theta = result.unwrap();
277 assert_eq!(theta.nspp(), 10);
278 assert_eq!(theta.matrix().ncols(), 2);
279 }
280
281 #[test]
282 fn test_sample_space_latin() {
283 let mut settings = create_test_settings();
284 settings.set_prior(Prior::Latin(15, 123));
285
286 let result = sample_space(&settings);
287 assert!(result.is_ok());
288
289 let theta = result.unwrap();
290 assert_eq!(theta.nspp(), 15);
291 assert_eq!(theta.matrix().ncols(), 2);
292 }
293
294 #[test]
295 fn test_sample_space_custom_theta() {
296 let mut settings = create_test_settings();
297
298 let parameters = settings.parameters().clone();
300 let matrix = faer::Mat::from_fn(3, 2, |i, j| (i + j) as f64);
301 let custom_theta = Theta::from_parts(matrix, parameters).unwrap();
302
303 let prior = Prior::Theta(custom_theta.clone());
304 settings.set_prior(Prior::Theta(custom_theta.clone()));
305
306 let result = sample_space(&settings);
307 assert!(result.is_ok());
308
309 let theta = result.unwrap();
310 assert_eq!(theta.nspp(), 3);
311 assert_eq!(theta.matrix().ncols(), 2);
312 assert_eq!(theta, custom_theta);
313 assert!(prior.points() == Some(3));
314 }
315
316 #[test]
317 fn test_sample_space_infinite_bounds_error() {
318 let parameters = Parameters::new()
319 .add("ke", f64::NEG_INFINITY, 1.0) .add("v", 5.0, 50.0);
321
322 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
323 let ems = ErrorModels::new().add(0, em).unwrap();
324
325 let mut settings = Settings::builder()
326 .set_algorithm(Algorithm::NPAG)
327 .set_parameters(parameters)
328 .set_error_models(ems)
329 .build();
330
331 settings.set_prior(Prior::sobol(10, 42));
332
333 let result = sample_space(&settings);
334 assert!(result.is_err());
335 assert!(result.unwrap_err().to_string().contains("infinite bounds"));
336 }
337
338 #[test]
339 fn test_sample_space_invalid_bounds_error() {
340 let parameters = Parameters::new()
341 .add("ke", 1.0, 0.5) .add("v", 5.0, 50.0);
343
344 let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
345 let ems = ErrorModels::new().add(0, em).unwrap();
346
347 let mut settings = Settings::builder()
348 .set_algorithm(Algorithm::NPAG)
349 .set_parameters(parameters)
350 .set_error_models(ems)
351 .build();
352
353 settings.set_prior(Prior::sobol(10, 42));
354
355 let result = sample_space(&settings);
356 assert!(result.is_err());
357 assert!(result.unwrap_err().to_string().contains("invalid bounds"));
358 }
359
360 #[test]
361 fn test_parse_prior_valid_file() {
362 let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
363 let temp_path = create_temp_csv_file(csv_content);
364
365 let settings = create_test_settings();
366
367 let result = parse_prior(&temp_path, &settings);
368 assert!(result.is_ok());
369
370 let (theta, weights) = result.unwrap();
371 assert_eq!(theta.nspp(), 3);
372 assert_eq!(theta.matrix().ncols(), 2);
373 assert!(weights.is_none()); cleanup_temp_file(&temp_path);
376 }
377
378 #[test]
379 fn test_parse_prior_with_prob_column() {
380 let csv_content = "ke,v,prob\n0.1,10.0,0.5\n0.2,15.0,0.3\n0.3,20.0,0.2\n";
381 let temp_path = create_temp_csv_file(csv_content);
382
383 let settings = create_test_settings();
384
385 let result = parse_prior(&temp_path, &settings);
386 assert!(result.is_ok());
387
388 let (theta, weights) = result.unwrap();
389 assert_eq!(theta.nspp(), 3);
390 assert_eq!(theta.matrix().ncols(), 2);
391
392 assert!(weights.is_some());
394 let weights = weights.unwrap();
395 assert_eq!(weights.len(), 3);
396 assert!((weights[0] - 0.5).abs() < 1e-10);
397 assert!((weights[1] - 0.3).abs() < 1e-10);
398 assert!((weights[2] - 0.2).abs() < 1e-10);
399
400 cleanup_temp_file(&temp_path);
401 }
402
403 #[test]
404 fn test_parse_prior_missing_parameter() {
405 let csv_content = "ke\n0.1\n0.2\n0.3\n";
406 let temp_path = create_temp_csv_file(csv_content);
407
408 let settings = create_test_settings();
409
410 let result = parse_prior(&temp_path, &settings);
411 assert!(result.is_err());
412 assert!(result
413 .unwrap_err()
414 .to_string()
415 .contains("Parameter v is not present"));
416
417 cleanup_temp_file(&temp_path);
418 }
419
420 #[test]
421 fn test_parse_prior_extra_parameters() {
422 let csv_content = "ke,v,extra_param\n0.1,10.0,1.0\n0.2,15.0,2.0\n0.3,20.0,3.0\n";
423 let temp_path = create_temp_csv_file(csv_content);
424
425 let settings = create_test_settings();
426
427 let result = parse_prior(&temp_path, &settings);
428 assert!(result.is_err());
429 assert!(result
430 .unwrap_err()
431 .to_string()
432 .contains("Found parameters in the prior not present in configuration"));
433
434 cleanup_temp_file(&temp_path);
435 }
436
437 #[test]
438 fn test_parse_prior_nonexistent_file() {
439 let settings = create_test_settings();
440 let file_path = "nonexistent_file.csv".to_string();
441
442 let result = parse_prior(&file_path, &settings);
443 assert!(result.is_err());
444 assert!(result
445 .unwrap_err()
446 .to_string()
447 .contains("Unable to open the prior file"));
448 }
449
450 #[test]
451 fn test_parse_prior_reordered_columns() {
452 let csv_content = "v,ke\n10.0,0.1\n15.0,0.2\n20.0,0.3\n";
453 let temp_path = create_temp_csv_file(csv_content);
454
455 let settings = create_test_settings();
456
457 let result = parse_prior(&temp_path, &settings);
458 assert!(result.is_ok());
459
460 let (theta, weights) = result.unwrap();
461 assert_eq!(theta.nspp(), 3);
462 assert_eq!(theta.matrix().ncols(), 2);
463 assert!(weights.is_none()); let matrix = theta.matrix();
467 assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); cleanup_temp_file(&temp_path);
471 }
472
473 #[test]
474 fn test_parse_prior_with_prob_column_reordered() {
475 let csv_content = "prob,v,ke\n0.5,10.0,0.1\n0.3,15.0,0.2\n0.2,20.0,0.3\n";
476 let temp_path = create_temp_csv_file(csv_content);
477
478 let settings = create_test_settings();
479
480 let result = parse_prior(&temp_path, &settings);
481 assert!(result.is_ok());
482
483 let (theta, weights) = result.unwrap();
484 assert_eq!(theta.nspp(), 3);
485 assert_eq!(theta.matrix().ncols(), 2);
486
487 assert!(weights.is_some());
489 let weights = weights.unwrap();
490 assert_eq!(weights.len(), 3);
491 assert!((weights[0] - 0.5).abs() < 1e-10);
492 assert!((weights[1] - 0.3).abs() < 1e-10);
493 assert!((weights[2] - 0.2).abs() < 1e-10);
494
495 let matrix = theta.matrix();
497 assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); cleanup_temp_file(&temp_path);
501 }
502
503 #[test]
504 fn test_sample_space_file_based() {
505 let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
506 let temp_path = create_temp_csv_file(csv_content);
507
508 let mut settings = create_test_settings();
509 settings.set_prior(Prior::File(temp_path.clone()));
510
511 let result = sample_space(&settings);
512 assert!(result.is_ok());
513
514 let theta = result.unwrap();
515 assert_eq!(theta.nspp(), 3);
516 assert_eq!(theta.matrix().ncols(), 2);
517
518 cleanup_temp_file(&temp_path);
519 }
520
521 #[test]
522 fn test_prior_theta_no_seed_panic() {
523 let parameters = Parameters::new().add("ke", 0.1, 1.0);
524 let matrix = faer::Mat::from_fn(1, 1, |_, _| 0.5);
525 let theta = Theta::from_parts(matrix, parameters).unwrap();
526 let prior = Prior::Theta(theta);
527
528 assert_eq!(prior.seed(), None, "Theta prior should not have a seed");
529 }
530}