diff --git a/src/main.rs b/src/main.rs index 49477c7..ea48777 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,17 @@ use std::fmt::format; use std::io::{BufReader, Write}; +use std::str::FromStr; use std::{fs::File, time::Instant}; use bincode::deserialize_from; -use clap::Subcommand; use clap::{Args, Parser}; +use clap::{Subcommand, ValueEnum}; use owo_colors::{OwoColorize, Stream::Stdout, Style}; use webgestalt_lib::methods::gsea::GSEAConfig; +use webgestalt_lib::methods::multiomics::NormalizationMethod; use webgestalt_lib::methods::ora::ORAConfig; +use webgestalt_lib::readers::read_rank_file; /// WebGestalt CLI. /// ORA and GSEA enrichment tool. @@ -30,6 +33,10 @@ enum Commands { Gsea(GseaArgs), /// Run ORA using the provided files Ora(ORAArgs), + /// Run a test + Test, + /// Combine multiple files into a single file + Combine(CombineArgs), } #[derive(Debug, Args)] @@ -65,6 +72,39 @@ struct ORAArgs { reference: Option, } +#[derive(Args)] +struct CombineArgs { + #[command(subcommand)] + combine_type: Option, +} + +#[derive(Subcommand)] +enum CombineType { + Gmt(CombineGmtArgs), + List(CombineListArgs), +} + +#[derive(Args)] +struct CombineGmtArgs { + out: Option, + /// Paths to the files to combine + files: Vec, +} +#[derive(ValueEnum, Clone)] +enum NormMethods { + MedianRank, + MedianValue, + MeanValue, + None, +} + +#[derive(Args)] +struct CombineListArgs { + normalization: Option, + out: Option, + files: Vec, +} + fn main() { let args = CliArgs::parse(); match &args.command { @@ -129,9 +169,13 @@ fn main() { return; } let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone().unwrap()) - .expect(format!("File {} not found", gsea_args.rnk.clone().unwrap()).as_str()); + .unwrap_or_else(|_| { + panic!("File {} not found", gsea_args.rnk.clone().unwrap()); + }); let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone().unwrap()) - .expect(format!("File {} not found", gsea_args.gmt.clone().unwrap()).as_str()); + .unwrap_or_else(|_| { + panic!("File {} not found", gsea_args.gmt.clone().unwrap()); + }); webgestalt_lib::methods::gsea::gsea(gene_list, gmt, GSEAConfig::default(), None); println!("Done with GSEA"); } @@ -172,8 +216,44 @@ fn main() { res.len() ); } + Some(Commands::Test) => { + let list1 = read_rank_file("gene.rnk".to_string()).unwrap(); + let list2 = read_rank_file("protein.rnk".to_string()).unwrap(); + let list3 = read_rank_file("metabolite.rnk".to_string()).unwrap(); + let lists = vec![list1, list2, list3]; + // let gmt1 = webgestalt_lib::readers::read_gmt_file("gene.gmt".to_string()).unwrap(); + // let gmt2 = + // webgestalt_lib::readers::read_gmt_file("metabolite.gmt".to_string()).unwrap(); + // let combined_gmt = webgestalt_lib::methods::multiomics::combine_gmts(&vec![gmt1, gmt2]); + // let mut file = File::create("combined.gmt").unwrap(); + // for row in combined_gmt { + // writeln!(file, "{}\t{}\t{}", row.id, row.url, row.parts.join("\t")).unwrap(); + // } + let mut combined_list = webgestalt_lib::methods::multiomics::combine_lists( + lists, + webgestalt_lib::methods::multiomics::MultiOmicsMethod::Mean, + webgestalt_lib::methods::multiomics::NormalizationMethod::MeanValue, + ); + combined_list.sort_by(|a, b| b.rank.partial_cmp(&a.rank).unwrap()); + let mut file = File::create("combined.rnk").unwrap(); + for row in combined_list { + writeln!(file, "{}\t{}", row.analyte, row.rank).unwrap(); + } + } + Some(Commands::Combine(args)) => match &args.combine_type { + Some(CombineType::Gmt(files)) => {} + Some(CombineType::List(files)) => { + let mut lists = Vec::new(); + for file in files.files.iter() { + lists.push(read_rank_file(file.clone()).unwrap()); + } + } + _ => { + panic!("Please select a valid combine type"); + } + }, _ => { - println!("Please select a command. Run --help for options.") + todo!("Please select a valid command. Run --help for options.") } } } diff --git a/webgestalt_lib/src/methods/multiomics.rs b/webgestalt_lib/src/methods/multiomics.rs index 897bd27..eac13be 100644 --- a/webgestalt_lib/src/methods/multiomics.rs +++ b/webgestalt_lib/src/methods/multiomics.rs @@ -47,8 +47,9 @@ pub struct ORAJob<'a> { #[derive(Copy, Clone)] pub enum NormalizationMethod { - MedianRatio, - None, + MedianRank, + MedianValue, + MeanValue, } /// Run a multiomics analysis, using iehter the max/mean median ratio or a typical meta analysis @@ -61,10 +62,10 @@ pub enum NormalizationMethod { /// - `method` - A [`MultiOmicsMethod`] enum detailing the analysis method to combine the runs /// together (meta-analysis, mean median ration, or max median ratio). pub fn multiomic_analysis( - jobs: Vec, - analysis_type: AnalysisType, + _jobs: Vec, + _analysis_type: AnalysisType, method: MultiOmicsMethod, -) -> () { +) { if let MultiOmicsMethod::Meta(meta_method) = method { } else { } @@ -94,7 +95,7 @@ fn max_combine( for list in normalized_lists { for item in list { if let Some(val) = batches.get_mut(&item.analyte) { - if item.rank > *val { + if item.rank.abs() > *val { *val = item.rank; } } else { @@ -143,21 +144,82 @@ fn mean_combine( fn normalize(list: &mut Vec, method: NormalizationMethod) -> Vec { match method { NormalizationMethod::None => list.clone(), - NormalizationMethod::MedianRatio => { + NormalizationMethod::MedianRank => { list.sort_by(|a, b| { b.rank .partial_cmp(&a.rank) - .expect("Invalid float comparison during comparison") + .expect("Invalid float comparison during normalization") }); - let median = list[list.len() / 2].rank; + let median = list.len() as f64 / 2.0; let mut final_list: Vec = Vec::new(); - for item in list { + for (i, item) in list.iter().enumerate() { final_list.push(RankListItem { analyte: item.analyte.clone(), - rank: item.rank / median, + rank: (i as f64 - median) / median, }); } final_list } + NormalizationMethod::MedianValue => { + list.sort_by(|a, b| { + b.rank + .partial_cmp(&a.rank) + .expect("Invalid float comparison during normalization") + }); + let min = list.last().unwrap().rank; + let median = list[list.len() / 2].rank - min; + let mut final_list: Vec = Vec::new(); + for item in list.iter() { + final_list.push(RankListItem { + analyte: item.analyte.clone(), + rank: (item.rank - min) / median, + }); + } + final_list + } + NormalizationMethod::MeanValue => { + list.sort_by(|a, b| { + b.rank + .partial_cmp(&a.rank) + .expect("Invalid float comparison during normalization") + }); + let min = list.last().unwrap().rank; + let mean: f64 = list.iter().map(|x| x.rank).sum::() / (list.len() as f64) - min; + let mut final_list: Vec = Vec::new(); + for item in list.iter() { + final_list.push(RankListItem { + analyte: item.analyte.clone(), + rank: (item.rank - min) / mean, + }); + } + final_list + } + } +} + +pub fn combine_gmts(gmts: &Vec>) -> Vec { + let mut combined_parts: AHashMap> = AHashMap::default(); + let mut combined_urls: AHashMap = AHashMap::default(); + for gmt in gmts { + for item in gmt { + if combined_parts.contains_key(&item.id) { + combined_parts + .get_mut(&item.id) + .unwrap() + .extend(item.parts.clone()); + } else { + combined_parts.insert(item.id.clone(), item.parts.clone()); + combined_urls.insert(item.id.clone(), item.url.clone()); + } + } + } + let mut final_gmt: Vec = Vec::new(); + for (key, parts) in combined_parts { + final_gmt.push(Item { + id: key.clone(), + parts, + url: combined_urls[&key].clone(), + }) } + final_gmt } diff --git a/webgestalt_lib/src/methods/ora.rs b/webgestalt_lib/src/methods/ora.rs index 063ca24..6cdef35 100644 --- a/webgestalt_lib/src/methods/ora.rs +++ b/webgestalt_lib/src/methods/ora.rs @@ -8,6 +8,7 @@ pub struct ORAConfig { pub min_overlap: i64, pub min_set_size: usize, pub max_set_size: usize, + pub fdr_method: stat::AdjustmentMethod, } impl Default for ORAConfig { @@ -16,6 +17,7 @@ impl Default for ORAConfig { min_overlap: 5, min_set_size: 5, max_set_size: 500, + fdr_method: stat::AdjustmentMethod::BH, } } } @@ -90,7 +92,7 @@ pub fn get_ora( }); let partials = res.lock().unwrap(); let p_vals: Vec = partials.iter().map(|x| x.p).collect(); - let fdrs: Vec = stat::adjust(&p_vals); + let fdrs: Vec = stat::adjust(&p_vals, config.fdr_method); let mut final_res = Vec::new(); for (i, row) in partials.clone().into_iter().enumerate() { final_res.push(ORAResult { diff --git a/webgestalt_lib/src/stat.rs b/webgestalt_lib/src/stat.rs index 720a2dc..c429a81 100644 --- a/webgestalt_lib/src/stat.rs +++ b/webgestalt_lib/src/stat.rs @@ -3,7 +3,19 @@ struct Carrier { original_order: usize, } -pub fn adjust(p_vals: &[f64]) -> Vec { +pub enum AdjustmentMethod { + BH, + None, +} + +pub fn adjust(p_vals: &[f64], method: AdjustmentMethod) -> Vec { + match method { + AdjustmentMethod::BH => benjamini_hochberg(p_vals), + AdjustmentMethod::None => p_vals.to_vec(), + } +} + +fn benjamini_hochberg(p_vals: &[f64]) -> Vec { let mut carriers: Vec = p_vals .iter() .enumerate()