diff --git a/russell_sparse/src/stats_lin_sol.rs b/russell_sparse/src/stats_lin_sol.rs index 013265c0..98f22e65 100644 --- a/russell_sparse/src/stats_lin_sol.rs +++ b/russell_sparse/src/stats_lin_sol.rs @@ -4,7 +4,7 @@ use russell_lab::{format_nanoseconds, get_num_threads, using_intel_mkl}; use serde::{Deserialize, Serialize}; use serde_json; use std::ffi::OsStr; -use std::fs::File; +use std::fs::{self, File}; use std::io::BufReader; use std::path::Path; @@ -179,15 +179,7 @@ impl StatsLinSol { /// Gets a JSON representation of the stats structure pub fn get_json(&mut self) -> String { - self.output.openmp_num_threads = get_num_threads(); - self.time_nanoseconds.total_ifs = - self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve; - self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix); - self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize); - self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize); - self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve); - self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs); - self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify); + self.compute_derived_values(); serde_json::to_string_pretty(&self).unwrap() } @@ -206,6 +198,38 @@ impl StatsLinSol { let stat = serde_json::from_reader(buffered).map_err(|_| "cannot parse JSON file")?; Ok(stat) } + + /// Writes a JSON file with the results + /// + /// # Input + /// + /// * `full_path` -- may be a String, &str, or Path + pub fn write_json

(&mut self, full_path: &P) -> Result<(), StrError> + where + P: AsRef + ?Sized, + { + self.compute_derived_values(); + let path = Path::new(full_path).to_path_buf(); + if let Some(p) = path.parent() { + fs::create_dir_all(p).map_err(|_| "cannot create directory")?; + } + let mut file = File::create(&path).map_err(|_| "cannot create file")?; + serde_json::to_writer_pretty(&mut file, &self).map_err(|_| "cannot write file")?; + Ok(()) + } + + /// Computes derived values + fn compute_derived_values(&mut self) { + self.output.openmp_num_threads = get_num_threads(); + self.time_nanoseconds.total_ifs = + self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve; + self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix); + self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize); + self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize); + self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve); + self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs); + self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify); + } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -275,4 +299,32 @@ mod tests { assert_eq!(stats.matrix.name, "pre2"); assert_eq!(stats.matrix.symmetry, "None"); } + + #[test] + fn write_json_works() { + let mut stats = StatsLinSol::new(); + const ONE_SECOND: u128 = 1000000000; + stats.time_nanoseconds.read_matrix = ONE_SECOND; + stats.time_nanoseconds.initialize = ONE_SECOND; + stats.time_nanoseconds.factorize = ONE_SECOND * 2; + stats.time_nanoseconds.solve = ONE_SECOND * 3; + stats.time_nanoseconds.verify = ONE_SECOND * 4; + let path = "/tmp/russell/write_json_works.json"; + stats.write_json(path).unwrap(); + let res = StatsLinSol::read_json(path).unwrap(); + assert!(res.output.openmp_num_threads > 0); + assert_eq!(res.time_nanoseconds.read_matrix, ONE_SECOND); + assert_eq!(res.time_nanoseconds.initialize, ONE_SECOND); + assert_eq!(res.time_nanoseconds.factorize, ONE_SECOND * 2); + assert_eq!(res.time_nanoseconds.solve, ONE_SECOND * 3); + assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6); + assert_eq!(res.time_nanoseconds.verify, ONE_SECOND * 4); + assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6); + assert_eq!(res.time_human.read_matrix, "1s"); + assert_eq!(res.time_human.initialize, "1s"); + assert_eq!(res.time_human.factorize, "2s"); + assert_eq!(res.time_human.solve, "3s"); + assert_eq!(res.time_human.total_ifs, "6s"); + assert_eq!(res.time_human.verify, "4s"); + } }