pmcore/routines/initialization/
mod.rs

1use std::fs::File;
2
3use crate::structs::{theta::Theta, weights::Weights};
4use anyhow::{bail, Context, Result};
5use faer::Mat;
6use serde::{Deserialize, Serialize};
7
8use crate::routines::settings::Settings;
9
10pub mod latin;
11pub mod sobol;
12
13/// The sampler used to generate the grid of support points
14///
15/// The sampler can be one of the following:
16///
17/// - `Sobol`: Generates a Sobol sequence
18/// - `Latin`: Generates a Latin hypercube
19/// - `File`: Reads the prior distribution from a CSV file
20#[derive(Debug, Deserialize, Clone, Serialize)]
21pub enum Prior {
22    Sobol(usize, usize),
23    Latin(usize, usize),
24    File(String),
25    #[serde(skip)]
26    Theta(Theta),
27}
28
29impl Prior {
30    pub fn sobol(points: usize, seed: usize) -> Prior {
31        Prior::Sobol(points, seed)
32    }
33
34    /// Get the number of initial support points
35    ///
36    /// This function returns the number of points for Sobol and Latin samplers,
37    /// and returns `None` for file-based priors since they do not have a fixed number of points.
38    /// For custom priors ([Prior::Theta]), it returns the number of support points in the original [Theta] structure.
39    pub fn points(&self) -> Option<usize> {
40        match self {
41            Prior::Sobol(points, _) => Some(*points),
42            Prior::Latin(points, _) => Some(*points),
43            Prior::File(_) => None, // File-based prior does not have a fixed number of points
44            Prior::Theta(theta) => Some(theta.nspp()),
45        }
46    }
47
48    /// Get the seed used for the random number generator
49    ///
50    /// This function returns the seed for Sobol and Latin samplers,
51    /// and returns `None` for file-based priors since they do not have a fixed seed.
52    /// For custom priors ([Prior::Theta]), it returns `None` as they do not have a fixed seed.
53    pub fn seed(&self) -> Option<usize> {
54        match self {
55            Prior::Sobol(_, seed) => Some(*seed),
56            Prior::Latin(_, seed) => Some(*seed),
57            Prior::File(_) => None, // "File-based prior does not have a fixed seed"
58            Prior::Theta(_) => None, // Custom prior does not have a fixed seed
59        }
60    }
61}
62
63impl Default for Prior {
64    fn default() -> Self {
65        Prior::Sobol(2028, 22)
66    }
67}
68
69/// This function generates the grid of support points according to the sampler specified in the [Settings]
70pub fn sample_space(settings: &Settings) -> Result<Theta> {
71    // Ensure that the parameter ranges are not infinite
72    for param in settings.parameters().iter() {
73        if param.lower.is_infinite() || param.upper.is_infinite() {
74            bail!(
75                "Parameter '{}' has infinite bounds: [{}, {}]",
76                param.name,
77                param.lower,
78                param.upper
79            );
80        }
81
82        // Ensure that the lower bound is less than the upper bound
83        if param.lower >= param.upper {
84            bail!(
85                "Parameter '{}' has invalid bounds: [{}, {}]. Lower bound must be less than upper bound.",
86                param.name,
87                param.lower,
88                param.upper
89            );
90        }
91    }
92
93    // Otherwise, parse the sampler type and generate the grid
94    let prior = match settings.prior() {
95        Prior::Sobol(points, seed) => sobol::generate(settings.parameters(), *points, *seed)?,
96        Prior::Latin(points, seed) => latin::generate(settings.parameters(), *points, *seed)?,
97        Prior::File(ref path) => parse_prior(path, settings)?.0,
98        Prior::Theta(ref theta) => {
99            // If a custom prior is provided, return it directly
100            return Ok(theta.clone());
101        }
102    };
103    Ok(prior)
104}
105
106/// This function reads the prior distribution from a file
107pub fn parse_prior(path: &String, settings: &Settings) -> Result<(Theta, Option<Weights>)> {
108    tracing::info!("Reading prior from {}", path);
109    let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?;
110    let mut reader = csv::ReaderBuilder::new()
111        .has_headers(true)
112        .from_reader(file);
113
114    let mut parameter_names: Vec<String> = reader
115        .headers()?
116        .clone()
117        .into_iter()
118        .map(|s| s.trim().to_owned())
119        .collect();
120
121    // Check if "prob" column is present and get its index
122    let prob_index = parameter_names.iter().position(|name| name == "prob");
123
124    // Remove "prob" column from parameter_names if present
125    if let Some(index) = prob_index {
126        parameter_names.remove(index);
127    }
128
129    // Check and reorder parameters to match names in settings.parsed.random
130    let random_names: Vec<String> = settings.parameters().names();
131
132    let mut reordered_indices: Vec<usize> = Vec::new();
133    for random_name in &random_names {
134        match parameter_names.iter().position(|name| name == random_name) {
135            Some(index) => {
136                // Adjust index if prob column was present and came before this parameter
137                let adjusted_index = if let Some(prob_idx) = prob_index {
138                    if index >= prob_idx {
139                        index + 1 // Add 1 back since we removed prob from parameter_names
140                    } else {
141                        index
142                    }
143                } else {
144                    index
145                };
146                reordered_indices.push(adjusted_index);
147            }
148            None => {
149                bail!("Parameter {} is not present in the CSV file.", random_name);
150            }
151        }
152    }
153
154    // Check if there are remaining parameters not present in settings.parsed.random
155    if parameter_names.len() > random_names.len() {
156        let extra_parameters: Vec<&String> = parameter_names.iter().collect();
157        bail!(
158            "Found parameters in the prior not present in configuration: {:?}",
159            extra_parameters
160        );
161    }
162
163    // Read parameter values and probabilities row by row
164    let mut theta_values = Vec::new();
165    let mut prob_values = Vec::new();
166
167    for result in reader.records() {
168        let record = result.unwrap();
169
170        // Extract parameter values using reordered indices
171        let values: Vec<f64> = reordered_indices
172            .iter()
173            .map(|&i| record[i].parse::<f64>().unwrap())
174            .collect();
175        theta_values.push(values);
176
177        // Extract probability value if prob column exists
178        if let Some(prob_idx) = prob_index {
179            let prob_value: f64 = record[prob_idx].parse::<f64>().unwrap();
180            prob_values.push(prob_value);
181        }
182    }
183
184    let n_points = theta_values.len();
185    let n_params = random_names.len();
186
187    // Convert nested Vec into a single Vec
188    let theta_values: Vec<f64> = theta_values.into_iter().flatten().collect();
189
190    let theta_matrix: Mat<f64> =
191        Mat::from_fn(n_points, n_params, |i, j| theta_values[i * n_params + j]);
192
193    let theta = Theta::from_parts(theta_matrix, settings.parameters().clone())?;
194
195    // Create weights if prob column was present
196    let weights = if !prob_values.is_empty() {
197        Some(Weights::from_vec(prob_values))
198    } else {
199        None
200    };
201
202    Ok((theta, weights))
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::prelude::*;
209    use pharmsol::{ErrorModel, ErrorModels, ErrorPoly};
210    use std::fs;
211
212    fn create_test_settings() -> Settings {
213        let parameters = Parameters::new().add("ke", 0.1, 1.0).add("v", 5.0, 50.0);
214
215        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
216        let ems = ErrorModels::new().add(0, em).unwrap();
217
218        Settings::builder()
219            .set_algorithm(Algorithm::NPAG)
220            .set_parameters(parameters)
221            .set_error_models(ems)
222            .build()
223    }
224
225    fn create_temp_csv_file(content: &str) -> String {
226        let temp_path = format!("test_temp_{}.csv", rand::random::<u32>());
227        fs::write(&temp_path, content).unwrap();
228        temp_path
229    }
230
231    fn cleanup_temp_file(path: &str) {
232        let _ = fs::remove_file(path);
233    }
234
235    #[test]
236    fn test_prior_sobol_creation() {
237        let prior = Prior::sobol(100, 42);
238        assert_eq!(prior.points(), Some(100));
239        assert_eq!(prior.seed(), Some(42));
240    }
241
242    #[test]
243    fn test_prior_latin_creation() {
244        let prior = Prior::Latin(50, 123);
245        assert_eq!(prior.points(), Some(50));
246        assert_eq!(prior.seed(), Some(123));
247    }
248
249    #[test]
250    fn test_prior_default() {
251        let prior = Prior::default();
252        assert_eq!(prior.points(), Some(2028));
253        assert_eq!(prior.seed(), Some(22));
254    }
255
256    #[test]
257    fn test_prior_file_points() {
258        let prior = Prior::File("test.csv".to_string());
259        assert_eq!(prior.points(), None);
260    }
261
262    #[test]
263    fn test_prior_file_seed() {
264        let prior = Prior::File("test.csv".to_string());
265        assert_eq!(prior.seed(), None);
266    }
267
268    #[test]
269    fn test_sample_space_sobol() {
270        let mut settings = create_test_settings();
271        settings.set_prior(Prior::sobol(10, 42));
272
273        let result = sample_space(&settings);
274        assert!(result.is_ok());
275
276        let theta = result.unwrap();
277        assert_eq!(theta.nspp(), 10);
278        assert_eq!(theta.matrix().ncols(), 2);
279    }
280
281    #[test]
282    fn test_sample_space_latin() {
283        let mut settings = create_test_settings();
284        settings.set_prior(Prior::Latin(15, 123));
285
286        let result = sample_space(&settings);
287        assert!(result.is_ok());
288
289        let theta = result.unwrap();
290        assert_eq!(theta.nspp(), 15);
291        assert_eq!(theta.matrix().ncols(), 2);
292    }
293
294    #[test]
295    fn test_sample_space_custom_theta() {
296        let mut settings = create_test_settings();
297
298        // Create a custom theta
299        let parameters = settings.parameters().clone();
300        let matrix = faer::Mat::from_fn(3, 2, |i, j| (i + j) as f64);
301        let custom_theta = Theta::from_parts(matrix, parameters).unwrap();
302
303        let prior = Prior::Theta(custom_theta.clone());
304        settings.set_prior(Prior::Theta(custom_theta.clone()));
305
306        let result = sample_space(&settings);
307        assert!(result.is_ok());
308
309        let theta = result.unwrap();
310        assert_eq!(theta.nspp(), 3);
311        assert_eq!(theta.matrix().ncols(), 2);
312        assert_eq!(theta, custom_theta);
313        assert!(prior.points() == Some(3));
314    }
315
316    #[test]
317    fn test_sample_space_infinite_bounds_error() {
318        let parameters = Parameters::new()
319            .add("ke", f64::NEG_INFINITY, 1.0) // Invalid: infinite lower bound
320            .add("v", 5.0, 50.0);
321
322        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
323        let ems = ErrorModels::new().add(0, em).unwrap();
324
325        let mut settings = Settings::builder()
326            .set_algorithm(Algorithm::NPAG)
327            .set_parameters(parameters)
328            .set_error_models(ems)
329            .build();
330
331        settings.set_prior(Prior::sobol(10, 42));
332
333        let result = sample_space(&settings);
334        assert!(result.is_err());
335        assert!(result.unwrap_err().to_string().contains("infinite bounds"));
336    }
337
338    #[test]
339    fn test_sample_space_invalid_bounds_error() {
340        let parameters = Parameters::new()
341            .add("ke", 1.0, 0.5) // Invalid: lower bound >= upper bound
342            .add("v", 5.0, 50.0);
343
344        let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0);
345        let ems = ErrorModels::new().add(0, em).unwrap();
346
347        let mut settings = Settings::builder()
348            .set_algorithm(Algorithm::NPAG)
349            .set_parameters(parameters)
350            .set_error_models(ems)
351            .build();
352
353        settings.set_prior(Prior::sobol(10, 42));
354
355        let result = sample_space(&settings);
356        assert!(result.is_err());
357        assert!(result.unwrap_err().to_string().contains("invalid bounds"));
358    }
359
360    #[test]
361    fn test_parse_prior_valid_file() {
362        let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
363        let temp_path = create_temp_csv_file(csv_content);
364
365        let settings = create_test_settings();
366
367        let result = parse_prior(&temp_path, &settings);
368        assert!(result.is_ok());
369
370        let (theta, weights) = result.unwrap();
371        assert_eq!(theta.nspp(), 3);
372        assert_eq!(theta.matrix().ncols(), 2);
373        assert!(weights.is_none()); // No prob column, so no weights
374
375        cleanup_temp_file(&temp_path);
376    }
377
378    #[test]
379    fn test_parse_prior_with_prob_column() {
380        let csv_content = "ke,v,prob\n0.1,10.0,0.5\n0.2,15.0,0.3\n0.3,20.0,0.2\n";
381        let temp_path = create_temp_csv_file(csv_content);
382
383        let settings = create_test_settings();
384
385        let result = parse_prior(&temp_path, &settings);
386        assert!(result.is_ok());
387
388        let (theta, weights) = result.unwrap();
389        assert_eq!(theta.nspp(), 3);
390        assert_eq!(theta.matrix().ncols(), 2);
391
392        // Verify that weights were read correctly
393        assert!(weights.is_some());
394        let weights = weights.unwrap();
395        assert_eq!(weights.len(), 3);
396        assert!((weights[0] - 0.5).abs() < 1e-10);
397        assert!((weights[1] - 0.3).abs() < 1e-10);
398        assert!((weights[2] - 0.2).abs() < 1e-10);
399
400        cleanup_temp_file(&temp_path);
401    }
402
403    #[test]
404    fn test_parse_prior_missing_parameter() {
405        let csv_content = "ke\n0.1\n0.2\n0.3\n";
406        let temp_path = create_temp_csv_file(csv_content);
407
408        let settings = create_test_settings();
409
410        let result = parse_prior(&temp_path, &settings);
411        assert!(result.is_err());
412        assert!(result
413            .unwrap_err()
414            .to_string()
415            .contains("Parameter v is not present"));
416
417        cleanup_temp_file(&temp_path);
418    }
419
420    #[test]
421    fn test_parse_prior_extra_parameters() {
422        let csv_content = "ke,v,extra_param\n0.1,10.0,1.0\n0.2,15.0,2.0\n0.3,20.0,3.0\n";
423        let temp_path = create_temp_csv_file(csv_content);
424
425        let settings = create_test_settings();
426
427        let result = parse_prior(&temp_path, &settings);
428        assert!(result.is_err());
429        assert!(result
430            .unwrap_err()
431            .to_string()
432            .contains("Found parameters in the prior not present in configuration"));
433
434        cleanup_temp_file(&temp_path);
435    }
436
437    #[test]
438    fn test_parse_prior_nonexistent_file() {
439        let settings = create_test_settings();
440        let file_path = "nonexistent_file.csv".to_string();
441
442        let result = parse_prior(&file_path, &settings);
443        assert!(result.is_err());
444        assert!(result
445            .unwrap_err()
446            .to_string()
447            .contains("Unable to open the prior file"));
448    }
449
450    #[test]
451    fn test_parse_prior_reordered_columns() {
452        let csv_content = "v,ke\n10.0,0.1\n15.0,0.2\n20.0,0.3\n";
453        let temp_path = create_temp_csv_file(csv_content);
454
455        let settings = create_test_settings();
456
457        let result = parse_prior(&temp_path, &settings);
458        assert!(result.is_ok());
459
460        let (theta, weights) = result.unwrap();
461        assert_eq!(theta.nspp(), 3);
462        assert_eq!(theta.matrix().ncols(), 2);
463        assert!(weights.is_none()); // No prob column, so no weights
464
465        // Verify the values are correctly reordered (ke should be first, v second)
466        let matrix = theta.matrix();
467        assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); // First row, ke value
468        assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); // First row, v value
469
470        cleanup_temp_file(&temp_path);
471    }
472
473    #[test]
474    fn test_parse_prior_with_prob_column_reordered() {
475        let csv_content = "prob,v,ke\n0.5,10.0,0.1\n0.3,15.0,0.2\n0.2,20.0,0.3\n";
476        let temp_path = create_temp_csv_file(csv_content);
477
478        let settings = create_test_settings();
479
480        let result = parse_prior(&temp_path, &settings);
481        assert!(result.is_ok());
482
483        let (theta, weights) = result.unwrap();
484        assert_eq!(theta.nspp(), 3);
485        assert_eq!(theta.matrix().ncols(), 2);
486
487        // Verify that weights were read correctly
488        assert!(weights.is_some());
489        let weights = weights.unwrap();
490        assert_eq!(weights.len(), 3);
491        assert!((weights[0] - 0.5).abs() < 1e-10);
492        assert!((weights[1] - 0.3).abs() < 1e-10);
493        assert!((weights[2] - 0.2).abs() < 1e-10);
494
495        // Verify the parameter values are correctly reordered (ke should be first, v second)
496        let matrix = theta.matrix();
497        assert!((matrix[(0, 0)] - 0.1).abs() < 1e-10); // First row, ke value
498        assert!((matrix[(0, 1)] - 10.0).abs() < 1e-10); // First row, v value
499
500        cleanup_temp_file(&temp_path);
501    }
502
503    #[test]
504    fn test_sample_space_file_based() {
505        let csv_content = "ke,v\n0.1,10.0\n0.2,15.0\n0.3,20.0\n";
506        let temp_path = create_temp_csv_file(csv_content);
507
508        let mut settings = create_test_settings();
509        settings.set_prior(Prior::File(temp_path.clone()));
510
511        let result = sample_space(&settings);
512        assert!(result.is_ok());
513
514        let theta = result.unwrap();
515        assert_eq!(theta.nspp(), 3);
516        assert_eq!(theta.matrix().ncols(), 2);
517
518        cleanup_temp_file(&temp_path);
519    }
520
521    #[test]
522    fn test_prior_theta_no_seed_panic() {
523        let parameters = Parameters::new().add("ke", 0.1, 1.0);
524        let matrix = faer::Mat::from_fn(1, 1, |_, _| 0.5);
525        let theta = Theta::from_parts(matrix, parameters).unwrap();
526        let prior = Prior::Theta(theta);
527
528        assert_eq!(prior.seed(), None, "Theta prior should not have a seed");
529    }
530}