pmcore/routines/initialization/
mod.rs

1use std::fs::File;
2
3use crate::structs::theta::Theta;
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13/// The sampler used to generate the grid of support points
14///
15/// The sampler can be one of the following:
16///
17/// - `Sobol`: Generates a Sobol sequence
18/// - `Latin`: Generates a Latin hypercube
19/// - `File`: Reads the prior distribution from a CSV file
20#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22    Sobol(usize, usize),
23    Latin(usize, usize),
24    File(String),
25    #[serde(skip)]
26    Theta(Theta),
27}
28
29impl Prior {
30    pub fn sobol(points: usize, seed: usize) -> Prior {
31        Prior::Sobol(points, seed)
32    }
33
34    /// Get the number of initial support points
35    ///
36    /// This function returns the number of points for Sobol and Latin samplers,
37    /// and returns `None` for file-based priors since they do not have a fixed number of points.
38    /// For custom priors ([Prior::Theta]), it returns the number of support points in the original [Theta] structure.
39    pub fn points(&self) -> Option<usize> {
40        match self {
41            Prior::Sobol(points, _) => Some(*points),
42            Prior::Latin(points, _) => Some(*points),
43            Prior::File(_) => None, // File-based prior does not have a fixed number of points
44            Prior::Theta(theta) => Some(theta.nspp()),
45        }
46    }
47
48    /// Get the seed used for the random number generator
49    ///
50    /// This function returns the seed for Sobol and Latin samplers,
51    /// and returns `None` for file-based priors since they do not have a fixed seed.
52    /// For custom priors ([Prior::Theta]), it returns `None` as they do not have a fixed seed.
53    pub fn seed(&self) -> Option<usize> {
54        match self {
55            Prior::Sobol(_, seed) => Some(*seed),
56            Prior::Latin(_, seed) => Some(*seed),
57            Prior::File(_) => None, // "File-based prior does not have a fixed seed"
58            Prior::Theta(_) => None, // Custom prior does not have a fixed seed
59        }
60    }
61}
62
63impl Default for Prior {
64    fn default() -> Self {
65        Prior::Sobol(2028, 22)
66    }
67}
68
69/// This function generates the grid of support points according to the sampler specified in the [Settings]
70pub fn sample_space(settings: &Settings) -> Result<Theta> {
71    // Ensure that the parameter ranges are not infinite
72    for param in settings.parameters().iter() {
73        if param.lower.is_infinite() || param.upper.is_infinite() {
74            bail!(
75                "Parameter '{}' has infinite bounds: [{}, {}]",
76                param.name,
77                param.lower,
78                param.upper
79            );
80        }
81
82        // Ensure that the lower bound is less than the upper bound
83        if param.lower >= param.upper {
84            bail!(
85                "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
86                param.name,
87                param.lower,
88                param.upper
89            );
90        }
91    }
92
93    // Otherwise, parse the sampler type and generate the grid
94    let prior = match settings.prior() {
95        Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
96        Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
97        Prior::File(ref path) => parse_prior(path, settings)?,
98        Prior::Theta(ref theta) => {
99            // If a custom prior is provided, return it directly
100            return Ok(theta.clone());
101        }
102    };
103    Ok(prior)
104}
105
106/// This function reads the prior distribution from a file
107pub fn parse_prior(path: &String, settings: &Settings) -> Result<Theta> {
108    tracing::info!("Reading prior from {}", path);
109    let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
110    let mut reader = csv::ReaderBuilder::new()
111        .has_headers(true)
112        .from_reader(file);
113
114    let mut parameter_names: Vec<String> = reader
115        .headers()?
116        .clone()
117        .into_iter()
118        .map(|s| s.trim().to_owned())
119        .collect();
120
121    // Remove "prob" column if present
122    if let Some(index) = parameter_names.iter().position(|name| name == "prob") {
123        parameter_names.remove(index);
124    }
125
126    // Check and reorder parameters to match names in settings.parsed.random
127    let random_names: Vec<String> = settings.parameters().names();
128
129    let mut reordered_indices: Vec<usize> = Vec::new();
130    for random_name in &random_names {
131        match parameter_names.iter().position(|name| name == random_name) {
132            Some(index) => {
133                reordered_indices.push(index);
134            }
135            None => {
136                bail!("Parameter {} is not present in the CSV file.", random_name);
137            }
138        }
139    }
140
141    // Check if there are remaining parameters not present in settings.parsed.random
142    if parameter_names.len() > random_names.len() {
143        let extra_parameters: Vec<&String> = parameter_names.iter().collect();
144        bail!(
145            "Found parameters in the prior not present in configuration: {:?}",
146            extra_parameters
147        );
148    }
149
150    // Read parameter values row by row, keeping only those associated with the reordered parameters
151    let mut theta_values = Vec::new();
152    for result in reader.records() {
153        let record = result.unwrap();
154        let values: Vec<f64> = reordered_indices
155            .iter()
156            .map(|&i| record[i].parse::<f64>().unwrap())
157            .collect();
158        theta_values.push(values);
159    }
160
161    let n_points = theta_values.len();
162    let n_params = random_names.len();
163
164    // Convert nested Vec into a single Vec
165    let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
166
167    let theta_matrix: Mat<f64> =
168        Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
169
170    let theta = Theta::from_parts(theta_matrix, settings.parameters().clone());
171
172    Ok(theta)
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::prelude::*;
179    use pharmsol::{ErrorModel, ErrorModels, ErrorPoly};
180    use std::fs;
181
182    fn create_test_settings() -> Settings {
183        let parameters = Parameters::new().add("ke", 0.1, 1.0).add("v", 5.0, 50.0);
184
185        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
186        let ems = ErrorModels::new().add(0, em).unwrap();
187
188        Settings::builder()
189            .set_algorithm(Algorithm::NPAG)
190            .set_parameters(parameters)
191            .set_error_models(ems)
192            .build()
193    }
194
195    fn create_temp_csv_file(content: &str) -> String {
196        let temp_path = format!("test_temp_{}.csv", rand::random::<u32>());
197        fs::write(&temp_path, content).unwrap();
198        temp_path
199    }
200
201    fn cleanup_temp_file(path: &str) {
202        let _ = fs::remove_file(path);
203    }
204
205    #[test]
206    fn test_prior_sobol_creation() {
207        let prior = Prior::sobol(100, 42);
208        assert_eq!(prior.points(), Some(100));
209        assert_eq!(prior.seed(), Some(42));
210    }
211
212    #[test]
213    fn test_prior_latin_creation() {
214        let prior = Prior::Latin(50, 123);
215        assert_eq!(prior.points(), Some(50));
216        assert_eq!(prior.seed(), Some(123));
217    }
218
219    #[test]
220    fn test_prior_default() {
221        let prior = Prior::default();
222        assert_eq!(prior.points(), Some(2028));
223        assert_eq!(prior.seed(), Some(22));
224    }
225
226    #[test]
227    fn test_prior_file_points() {
228        let prior = Prior::File("test.csv".to_string());
229        assert_eq!(prior.points(), None);
230    }
231
232    #[test]
233    fn test_prior_file_seed() {
234        let prior = Prior::File("test.csv".to_string());
235        assert_eq!(prior.seed(), None);
236    }
237
238    #[test]
239    fn test_sample_space_sobol() {
240        let mut settings = create_test_settings();
241        settings.set_prior(Prior::sobol(10, 42));
242
243        let result = sample_space(&settings);
244        assert!(result.is_ok());
245
246        let theta = result.unwrap();
247        assert_eq!(theta.nspp(), 10);
248        assert_eq!(theta.matrix().ncols(), 2);
249    }
250
251    #[test]
252    fn test_sample_space_latin() {
253        let mut settings = create_test_settings();
254        settings.set_prior(Prior::Latin(15, 123));
255
256        let result = sample_space(&settings);
257        assert!(result.is_ok());
258
259        let theta = result.unwrap();
260        assert_eq!(theta.nspp(), 15);
261        assert_eq!(theta.matrix().ncols(), 2);
262    }
263
264    #[test]
265    fn test_sample_space_custom_theta() {
266        let mut settings = create_test_settings();
267
268        // Create a custom theta
269        let parameters = settings.parameters().clone();
270        let matrix = faer::Mat::from_fn(3, 2, |i, j| (i + j) as f64);
271        let custom_theta = Theta::from_parts(matrix, parameters);
272
273        let prior = Prior::Theta(custom_theta.clone());
274        settings.set_prior(Prior::Theta(custom_theta.clone()));
275
276        let result = sample_space(&settings);
277        assert!(result.is_ok());
278
279        let theta = result.unwrap();
280        assert_eq!(theta.nspp(), 3);
281        assert_eq!(theta.matrix().ncols(), 2);
282        assert_eq!(theta, custom_theta);
283        assert!(prior.points() == Some(3));
284    }
285
286    #[test]
287    fn test_sample_space_infinite_bounds_error() {
288        let parameters = Parameters::new()
289            .add("ke", f64::NEG_INFINITY, 1.0) // Invalid: infinite lower bound
290            .add("v", 5.0, 50.0);
291
292        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
293        let ems = ErrorModels::new().add(0, em).unwrap();
294
295        let mut settings = Settings::builder()
296            .set_algorithm(Algorithm::NPAG)
297            .set_parameters(parameters)
298            .set_error_models(ems)
299            .build();
300
301        settings.set_prior(Prior::sobol(10, 42));
302
303        let result = sample_space(&settings);
304        assert!(result.is_err());
305        assert!(result.unwrap_err().to_string().contains("infinite bounds"));
306    }
307
308    #[test]
309    fn test_sample_space_invalid_bounds_error() {
310        let parameters = Parameters::new()
311            .add("ke", 1.0, 0.5) // Invalid: lower bound >= upper bound
312            .add("v", 5.0, 50.0);
313
314        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0, None);
315        let ems = ErrorModels::new().add(0, em).unwrap();
316
317        let mut settings = Settings::builder()
318            .set_algorithm(Algorithm::NPAG)
319            .set_parameters(parameters)
320            .set_error_models(ems)
321            .build();
322
323        settings.set_prior(Prior::sobol(10, 42));
324
325        let result = sample_space(&settings);
326        assert!(result.is_err());
327        assert!(result.unwrap_err().to_string().contains("invalid bounds"));
328    }
329
330    #[test]
331    fn test_parse_prior_valid_file() {
332        let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
333        let temp_path = create_temp_csv_file(csv_content);
334
335        let settings = create_test_settings();
336
337        let result = parse_prior(&temp_path, &settings);
338        assert!(result.is_ok());
339
340        let theta = result.unwrap();
341        assert_eq!(theta.nspp(), 3);
342        assert_eq!(theta.matrix().ncols(), 2);
343
344        cleanup_temp_file(&temp_path);
345    }
346
347    #[test]
348    fn test_parse_prior_with_prob_column() {
349        let csv_content = "ke,v,prob\n0.1,10.0,0.5\n0.2,15.0,0.3\n0.3,20.0,0.2\n";
350        let temp_path = create_temp_csv_file(csv_content);
351
352        let settings = create_test_settings();
353
354        let result = parse_prior(&temp_path, &settings);
355        assert!(result.is_ok());
356
357        let theta = result.unwrap();
358        assert_eq!(theta.nspp(), 3);
359        assert_eq!(theta.matrix().ncols(), 2);
360
361        cleanup_temp_file(&temp_path);
362    }
363
364    #[test]
365    fn test_parse_prior_missing_parameter() {
366        let csv_content = "ke\n0.1\n0.2\n0.3\n";
367        let temp_path = create_temp_csv_file(csv_content);
368
369        let settings = create_test_settings();
370
371        let result = parse_prior(&temp_path, &settings);
372        assert!(result.is_err());
373        assert!(result
374            .unwrap_err()
375            .to_string()
376            .contains("Parameter v is not present"));
377
378        cleanup_temp_file(&temp_path);
379    }
380
381    #[test]
382    fn test_parse_prior_extra_parameters() {
383        let csv_content = "ke,v,extra_param\n0.1,10.0,1.0\n0.2,15.0,2.0\n0.3,20.0,3.0\n";
384        let temp_path = create_temp_csv_file(csv_content);
385
386        let settings = create_test_settings();
387
388        let result = parse_prior(&temp_path, &settings);
389        assert!(result.is_err());
390        assert!(result
391            .unwrap_err()
392            .to_string()
393            .contains("Found parameters in the prior not present in configuration"));
394
395        cleanup_temp_file(&temp_path);
396    }
397
398    #[test]
399    fn test_parse_prior_nonexistent_file() {
400        let settings = create_test_settings();
401        let file_path = "nonexistent_file.csv".to_string();
402
403        let result = parse_prior(&file_path, &settings);
404        assert!(result.is_err());
405        assert!(result
406            .unwrap_err()
407            .to_string()
408            .contains("Unable to open the prior file"));
409    }
410
411    #[test]
412    fn test_parse_prior_reordered_columns() {
413        let csv_content = "v,ke\n10.0,0.1\n15.0,0.2\n20.0,0.3\n";
414        let temp_path = create_temp_csv_file(csv_content);
415
416        let settings = create_test_settings();
417
418        let result = parse_prior(&temp_path, &settings);
419        assert!(result.is_ok());
420
421        let theta = result.unwrap();
422        assert_eq!(theta.nspp(), 3);
423        assert_eq!(theta.matrix().ncols(), 2);
424
425        // Verify the values are correctly reordered (ke should be first, v second)
426        let matrix = theta.matrix();
427        assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); // First row, ke value
428        assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); // First row, v value
429
430        cleanup_temp_file(&temp_path);
431    }
432
433    #[test]
434    fn test_sample_space_file_based() {
435        let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
436        let temp_path = create_temp_csv_file(csv_content);
437
438        let mut settings = create_test_settings();
439        settings.set_prior(Prior::File(temp_path.clone()));
440
441        let result = sample_space(&settings);
442        assert!(result.is_ok());
443
444        let theta = result.unwrap();
445        assert_eq!(theta.nspp(), 3);
446        assert_eq!(theta.matrix().ncols(), 2);
447
448        cleanup_temp_file(&temp_path);
449    }
450
451    #[test]
452    fn test_prior_theta_no_seed_panic() {
453        let parameters = Parameters::new().add("ke", 0.1, 1.0);
454        let matrix = faer::Mat::from_fn(1, 1, |_, _| 0.5);
455        let theta = Theta::from_parts(matrix, parameters);
456        let prior = Prior::Theta(theta);
457
458        assert_eq!(prior.seed(), None, "Theta prior should not have a seed");
459    }
460}