diff --git a/Cargo.lock b/Cargo.lock index 3cabc79..78c6389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -633,6 +633,7 @@ dependencies = [ "bincode", "clap", "owo-colors", + "serde_json", "webgestalt_lib", ] diff --git a/Cargo.toml b/Cargo.toml index 3bdddfe..fa4ebe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ repository = "https://github.com/bzhanglab/webgestalt_rust" bincode = "1.3.3" clap = { version = "4.4.15", features = ["derive"] } owo-colors = { version = "4.0.0", features = ["supports-colors"] } +serde_json = "1.0.116" webgestalt_lib = { version = "0.3.0", path = "webgestalt_lib" } [profile.release] diff --git a/src/main.rs b/src/main.rs index 99c0af2..51df1bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,7 @@ -use bincode::deserialize_from; use clap::{Args, Parser}; use clap::{Subcommand, ValueEnum}; use owo_colors::{OwoColorize, Stream::Stdout, Style}; -use std::io::{BufReader, Write}; +use std::io::Write; use std::{fs::File, time::Instant}; use webgestalt_lib::methods::gsea::GSEAConfig; use webgestalt_lib::methods::multilist::{combine_gmts, MultiListMethod, NormalizationMethod}; @@ -10,7 +9,6 @@ use webgestalt_lib::methods::nta::NTAConfig; use webgestalt_lib::methods::ora::ORAConfig; use webgestalt_lib::readers::utils::Item; use webgestalt_lib::readers::{read_gmt_file, read_rank_file}; -use webgestalt_lib::{MalformedError, WebGestaltError}; /// WebGestalt CLI. /// ORA and GSEA enrichment tool. @@ -24,8 +22,6 @@ struct CliArgs { #[derive(Subcommand)] enum Commands { - /// Benchmark different file formats for gmt. TODO: Remove later - Benchmark, /// Run provided examples for various types of analyses Example(ExampleArgs), /// Run GSEA on the provided files @@ -34,8 +30,6 @@ enum Commands { Ora(ORAArgs), /// Run NTA on the provided files Nta(NtaArgs), - /// Run a test - Test, /// Combine multiple files into a single file Combine(CombineArgs), } @@ -65,7 +59,7 @@ struct NtaArgs { seeds: String, /// Output path for the results #[arg(short, long)] - out: String, + output: String, /// Probability of random walk resetting #[arg(short, long, default_value = "0.5")] reset_probability: f64, @@ -77,8 +71,8 @@ struct NtaArgs { neighborhood_size: usize, /// Method to use for NTA /// Options: prioritize, expand - #[arg(short, long)] - method: Option, + #[arg(short, long, default_value = "prioritize")] + method: NTAMethodClap, } #[derive(ValueEnum, Clone)] @@ -90,19 +84,29 @@ enum NTAMethodClap { #[derive(Args)] struct GseaArgs { /// Path to the GMT file of interest - gmt: Option, + #[arg(short, long)] + gmt: String, /// Path to the rank file of interest - rnk: Option, + #[arg(short, long)] + rnk: String, + /// Output path for the results + #[arg(short, long, default_value = "out.json")] + output: String, } - -#[derive(Args)] +#[derive(Parser)] struct ORAArgs { /// Path to the GMT file of interest - gmt: Option, + #[arg(short, long)] + gmt: String, /// Path to the file containing the interesting analytes - interest: Option, + #[arg(short, long)] + interest: String, + /// Output path for the results + #[arg(short, long, default_value = "out.json")] + output: String, /// Path the file containing the reference list - reference: Option, + #[arg(short, long)] + reference: String, } #[derive(Args)] @@ -146,13 +150,43 @@ struct CombineListArgs { files: Vec, } +fn prompt_yes_no(question: &str) -> bool { + loop { + print!("{} (y/n): ", question); + std::io::stdout().flush().expect("Could not flush stdout!"); // Ensure the prompt is displayed + + let mut input = String::new(); + std::io::stdin() + .read_line(&mut input) + .expect("Could not read line"); + print!("\x1B[2J\x1B[1;1H"); + std::io::stdout().flush().expect("Could not flush stdout!"); + match input.trim().to_lowercase().as_str() { + "y" => return true, + "n" => return false, + _ => println!("Invalid input. Please enter 'y' or 'n'."), + } + } +} + +fn check_and_overwrite(file_path: &str) { + // Check if the file exists + if std::path::Path::new(file_path).exists() { + // Check if the user wants to overwrite the file + if !prompt_yes_no(&format!( + "File at {} already exists. Do you want to overwrite it?", + file_path + )) { + println!("Stopping analysis."); + std::process::exit(1); + }; + } +} + fn main() { println!("WebGestalt CLI v{}", env!("CARGO_PKG_VERSION")); let args = CliArgs::parse(); match &args.command { - Some(Commands::Benchmark) => { - benchmark(); - } Some(Commands::Example(ex)) => match &ex.commands { Some(ExampleOptions::Gsea) => { let gene_list = webgestalt_lib::readers::read_rank_file( @@ -177,7 +211,7 @@ fn main() { "webgestalt_lib/data/genelist.txt".to_owned(), "webgestalt_lib/data/reference.txt".to_owned(), ); - let gmtcount = gmt.len(); + let gmt_count = gmt.len(); let start = Instant::now(); let x: Vec = webgestalt_lib::methods::ora::get_ora( @@ -187,6 +221,8 @@ fn main() { ORAConfig::default(), ); let mut count = 0; + let output_file = File::create("test.json").expect("Could not create output file!"); + serde_json::to_writer(output_file, &x).expect("Could not create JSON file!"); for i in x { if i.p < 0.05 && i.fdr < 0.05 { println!("{}: {}, {}, {}", i.set, i.p, i.fdr, i.overlap); @@ -196,7 +232,7 @@ fn main() { let duration = start.elapsed(); println!( "ORA\nTime took: {:?}\nFound {} significant pathways out of {} pathways", - duration, count, gmtcount + duration, count, gmt_count ); } _ => { @@ -204,24 +240,21 @@ fn main() { } }, Some(Commands::Gsea(gsea_args)) => { - let style = Style::new().red().bold(); - if gsea_args.gmt.is_none() || gsea_args.rnk.is_none() { - println!( - "{}: DID NOT PROVIDE PATHS FOR GMT AND RANK FILE.", - "ERROR".if_supports_color(Stdout, |text| text.style(style)) - ); - return; - } - let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone().unwrap()) + check_and_overwrite(&gsea_args.output); + let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone()) .unwrap_or_else(|_| { - panic!("File {} not found", gsea_args.rnk.clone().unwrap()); - }); - let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone().unwrap()) - .unwrap_or_else(|_| { - panic!("File {} not found", gsea_args.gmt.clone().unwrap()); + panic!("File {} not found", gsea_args.rnk.clone()); }); + let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone()).unwrap_or_else( + |_| { + panic!("File {} not found", gsea_args.gmt.clone()); + }, + ); let res = webgestalt_lib::methods::gsea::gsea(gene_list, gmt, GSEAConfig::default(), None); + let output_file = + File::create(&gsea_args.output).expect("Could not create output file!"); + serde_json::to_writer(output_file, &res).expect("Could not create JSON file!"); let mut count = 0; for i in res { if i.p < 0.05 && i.fdr < 0.05 { @@ -229,23 +262,18 @@ fn main() { count += 1; } } - println!("Done with GSEA: {}", count); + println!( + "Done with GSEA and found {} significant analyte sets", + count + ); } Some(Commands::Ora(ora_args)) => { - let style = Style::new().red().bold(); - if ora_args.gmt.is_none() || ora_args.interest.is_none() || ora_args.reference.is_none() - { - println!( - "{}: DID NOT PROVIDE PATHS FOR GMT, INTEREST, AND REFERENCE FILE.", - "ERROR".if_supports_color(Stdout, |text| text.style(style)) - ); - return; - } + check_and_overwrite(&ora_args.output); let start = Instant::now(); let (gmt, interest, reference) = webgestalt_lib::readers::read_ora_files( - ora_args.gmt.clone().unwrap(), - ora_args.interest.clone().unwrap(), - ora_args.reference.clone().unwrap(), + ora_args.gmt.clone(), + ora_args.interest.clone(), + ora_args.reference.clone(), ); println!("Reading Took {:?}", start.elapsed()); let start = Instant::now(); @@ -255,6 +283,9 @@ fn main() { gmt, ORAConfig::default(), ); + let output_file = + File::create(&ora_args.output).expect("Could not create output file!"); + serde_json::to_writer(output_file, &res).expect("Could not create JSON file!"); println!("Analysis Took {:?}", start.elapsed()); let mut count = 0; for row in res.iter() { @@ -263,30 +294,22 @@ fn main() { } } println!( - "Found {} significant pathways out of {} pathways", + "Found {} significant analyte sets out of {} sets", count, res.len() ); } - Some(Commands::Test) => will_err(1).unwrap_or_else(|x| println!("{}", x)), Some(Commands::Nta(nta_args)) => { - let style = Style::new().fg_rgb::<255, 179, 71>().bold(); + check_and_overwrite(&nta_args.output); let network = webgestalt_lib::readers::read_edge_list(nta_args.network.clone()); let start = Instant::now(); - if nta_args.method.is_none() { - println!( - "{}: DID NOT PROVIDE A METHOD FOR NTA. USING DEFAULT EXPAND METHOD.", - "WARNING".if_supports_color(Stdout, |text| text.style(style)) - ); - }; let nta_method = match nta_args.method { - Some(NTAMethodClap::Prioritize) => webgestalt_lib::methods::nta::NTAMethod::Prioritize( - nta_args.neighborhood_size, - ), - Some(NTAMethodClap::Expand) => webgestalt_lib::methods::nta::NTAMethod::Expand( - nta_args.neighborhood_size, - ), - None => webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size), + NTAMethodClap::Prioritize => { + webgestalt_lib::methods::nta::NTAMethod::Prioritize(nta_args.neighborhood_size) + } + NTAMethodClap::Expand => { + webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size) + } }; let config: NTAConfig = NTAConfig { edge_list: network, @@ -294,11 +317,10 @@ fn main() { reset_probability: nta_args.reset_probability, tolerance: nta_args.tolerance, method: Some(nta_method), - }; let res = webgestalt_lib::methods::nta::get_nta(config); println!("Analysis Took {:?}", start.elapsed()); - webgestalt_lib::writers::save_nta(nta_args.out.clone(), res).unwrap(); + webgestalt_lib::writers::save_nta(nta_args.output.clone(), res).unwrap(); } Some(Commands::Combine(args)) => match &args.combine_type { Some(CombineType::Gmt(gmt_args)) => { @@ -374,53 +396,3 @@ fn main() { } } } - -fn benchmark() { - let mut bin_durations: Vec = Vec::new(); - for _i in 0..1000 { - let start = Instant::now(); - let mut r = BufReader::new(File::open("test.gmt.wga").unwrap()); - let _x: Vec = deserialize_from(&mut r).unwrap(); - let duration = start.elapsed(); - bin_durations.push(duration.as_secs_f64()) - } - let mut gmt_durations: Vec = Vec::new(); - for _i in 0..1000 { - let start = Instant::now(); - let _x = webgestalt_lib::readers::read_gmt_file("webgestalt_lib/data/ktest.gmt".to_owned()) - .unwrap(); - let duration = start.elapsed(); - gmt_durations.push(duration.as_secs_f64()) - } - let gmt_avg: f64 = gmt_durations.iter().sum::() / gmt_durations.len() as f64; - let bin_avg: f64 = bin_durations.iter().sum::() / bin_durations.len() as f64; - let improvement: f64 = 100.0 * (gmt_avg - bin_avg) / gmt_avg; - println!( - " GMT time: {}\tGMT.WGA time: {}\n Improvement: {:.1}%", - gmt_avg, bin_avg, improvement - ); - let mut whole_file: Vec = Vec::new(); - whole_file.push("type\ttime".to_string()); - for line in bin_durations { - whole_file.push(format!("bin\t{:?}", line)); - } - for line in gmt_durations { - whole_file.push(format!("gmt\t{:?}", line)); - } - let mut ftsv = File::create("format_benchmarks.tsv").unwrap(); - writeln!(ftsv, "{}", whole_file.join("\n")).unwrap(); -} - -fn will_err(x: i32) -> Result<(), WebGestaltError> { - if x == 0 { - Ok(()) - } else { - Err(WebGestaltError::MalformedFile(MalformedError { - path: String::from("ExamplePath.txt"), - kind: webgestalt_lib::MalformedErrorType::WrongFormat { - found: String::from("GMT"), - expected: String::from("rank"), - }, - })) - } -} diff --git a/webgestalt_lib/src/readers.rs b/webgestalt_lib/src/readers.rs index 6cc71d5..575f9c4 100644 --- a/webgestalt_lib/src/readers.rs +++ b/webgestalt_lib/src/readers.rs @@ -132,10 +132,10 @@ pub fn read_intersection_list(path: String, ref_list: &AHashSet) -> AHas } /// Read edge list from specified path. Separated by whitespace with no support for weights -/// +/// /// # Parameters /// path - A [`String`] of the path of the edge list to read. -/// +/// /// # Returns /// A [`Vec>`] containing the edge list pub fn read_edge_list(path: String) -> Vec> {