diff --git a/russell_sparse/src/stats_lin_sol.rs b/russell_sparse/src/stats_lin_sol.rs
index 013265c0..98f22e65 100644
--- a/russell_sparse/src/stats_lin_sol.rs
+++ b/russell_sparse/src/stats_lin_sol.rs
@@ -4,7 +4,7 @@ use russell_lab::{format_nanoseconds, get_num_threads, using_intel_mkl};
use serde::{Deserialize, Serialize};
use serde_json;
use std::ffi::OsStr;
-use std::fs::File;
+use std::fs::{self, File};
use std::io::BufReader;
use std::path::Path;
@@ -179,15 +179,7 @@ impl StatsLinSol {
/// Gets a JSON representation of the stats structure
pub fn get_json(&mut self) -> String {
- self.output.openmp_num_threads = get_num_threads();
- self.time_nanoseconds.total_ifs =
- self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve;
- self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix);
- self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize);
- self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize);
- self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve);
- self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs);
- self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify);
+ self.compute_derived_values();
serde_json::to_string_pretty(&self).unwrap()
}
@@ -206,6 +198,38 @@ impl StatsLinSol {
let stat = serde_json::from_reader(buffered).map_err(|_| "cannot parse JSON file")?;
Ok(stat)
}
+
+ /// Writes a JSON file with the results
+ ///
+ /// # Input
+ ///
+ /// * `full_path` -- may be a String, &str, or Path
+ pub fn write_json
(&mut self, full_path: &P) -> Result<(), StrError>
+ where
+ P: AsRef + ?Sized,
+ {
+ self.compute_derived_values();
+ let path = Path::new(full_path).to_path_buf();
+ if let Some(p) = path.parent() {
+ fs::create_dir_all(p).map_err(|_| "cannot create directory")?;
+ }
+ let mut file = File::create(&path).map_err(|_| "cannot create file")?;
+ serde_json::to_writer_pretty(&mut file, &self).map_err(|_| "cannot write file")?;
+ Ok(())
+ }
+
+ /// Computes derived values
+ fn compute_derived_values(&mut self) {
+ self.output.openmp_num_threads = get_num_threads();
+ self.time_nanoseconds.total_ifs =
+ self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve;
+ self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix);
+ self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize);
+ self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize);
+ self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve);
+ self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs);
+ self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify);
+ }
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -275,4 +299,32 @@ mod tests {
assert_eq!(stats.matrix.name, "pre2");
assert_eq!(stats.matrix.symmetry, "None");
}
+
+ #[test]
+ fn write_json_works() {
+ let mut stats = StatsLinSol::new();
+ const ONE_SECOND: u128 = 1000000000;
+ stats.time_nanoseconds.read_matrix = ONE_SECOND;
+ stats.time_nanoseconds.initialize = ONE_SECOND;
+ stats.time_nanoseconds.factorize = ONE_SECOND * 2;
+ stats.time_nanoseconds.solve = ONE_SECOND * 3;
+ stats.time_nanoseconds.verify = ONE_SECOND * 4;
+ let path = "/tmp/russell/write_json_works.json";
+ stats.write_json(path).unwrap();
+ let res = StatsLinSol::read_json(path).unwrap();
+ assert!(res.output.openmp_num_threads > 0);
+ assert_eq!(res.time_nanoseconds.read_matrix, ONE_SECOND);
+ assert_eq!(res.time_nanoseconds.initialize, ONE_SECOND);
+ assert_eq!(res.time_nanoseconds.factorize, ONE_SECOND * 2);
+ assert_eq!(res.time_nanoseconds.solve, ONE_SECOND * 3);
+ assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6);
+ assert_eq!(res.time_nanoseconds.verify, ONE_SECOND * 4);
+ assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6);
+ assert_eq!(res.time_human.read_matrix, "1s");
+ assert_eq!(res.time_human.initialize, "1s");
+ assert_eq!(res.time_human.factorize, "2s");
+ assert_eq!(res.time_human.solve, "3s");
+ assert_eq!(res.time_human.total_ifs, "6s");
+ assert_eq!(res.time_human.verify, "4s");
+ }
}