Skip to content

Commit

Permalink
implement parallelized EM iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-p committed Feb 17, 2024
1 parent 8a389ab commit b340c62
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 13 deletions.
15 changes: 11 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 14 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
name = "piscem-infer"
version = "0.5.1"
edition = "2021"
license = "BSD-3-Clause-Attribution"
keywords = ["science", "RNA-seq", "RAD-file"]
readme = "README.md"
license-file = "LICENSE"
keywords = ["science", "RNA-seq", "quantification", "RAD-file"]
categories = ["command-line-utilities", "science"]
repository = "https://github.com/COMBINE-lab/piscem-infer/"
homepage = "https://github.com/COMBINE-lab/piscem-infer/"
readme = "README.md"
homepage = "https://piscem-infer.readthedocs.io/"
description = "A flexible tool to perform target quantification from bulk-sequencing data"
authors = ["Rob Patro", "Rob Patro <rob@cs.umd.edu>"]
include = [
"/src/*.rs",
"/src/utils/*.rs",
"/Cargo.toml",
"/Cargo.lock",
"/README.md",
"/LICENSE"
]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
Expand All @@ -19,7 +27,7 @@ distrs = "0.2.1"
anyhow = "1.0.79"
bincode = "1.3.3"
bstr = "1.9.0"
clap = { version = "4.5.0", features = ["derive", "wrap_help", "cargo", "help", "usage", "error-context"] }
clap = { version = "4.5.1", features = ["derive", "wrap_help", "cargo", "help", "usage", "error-context"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", default-features = true, features = ["env-filter"] }
tabled = "0.15.0"
Expand All @@ -33,6 +41,7 @@ arrow2 = { version = "0.18.0", features = ["io_parquet", "io_parquet_gzip", "io_
scroll = "0.12.0"
snap = "1.1.1"
path-tools = "0.1.0"
atomic_float = "0.1.0"

[[bin]]
name = "piscem-infer"
Expand Down
27 changes: 27 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
site_name: Documentation for piscem-infer

theme:
name: material
palette:

# Palette toggle for light mode
- scheme: default
toggle:
icon: material/brightness-7
name: Switch to dark mode

# Palette toggle for dark mode
- scheme: slate
toggle:
icon: material/brightness-4
name: Switch to light mode

markdown_extensions:
- footnotes
- pymdownx.arithmatex:
generic: true

extra_javascript:
- javascripts/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
11 changes: 8 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tracing::{error, info, warn, Level};

mod utils;
use utils::custom_rad_utils::MetaChunk;
use utils::em::{adjust_ref_lengths, conditional_means, em, EMInfo, EqLabel, EqMap};
use utils::em::{adjust_ref_lengths, conditional_means, em, em_par, EMInfo, EqLabel, EqMap};

use crate::utils::em::conditional_means_from_params;
use crate::utils::em::do_bootstrap;
Expand Down Expand Up @@ -241,7 +241,7 @@ pub struct QuantOpts {
/// number of bootstrap replicates to perform.
#[arg(long, default_value_t = 0)]
pub num_bootstraps: usize,
/// number of threads to use (only used for bootstrapping)
/// number of threads to use (used during the EM and for bootstrapping)
#[arg(long, default_value_t = 16)]
pub num_threads: usize,
}
Expand Down Expand Up @@ -480,7 +480,12 @@ fn main() -> anyhow::Result<()> {
max_iter,
convergence_thresh,
};
let em_res = em(&eminfo);

let em_res = if num_threads > 1 {
em_par(&eminfo, num_threads)
} else {
em(&eminfo)
};

let quant_output = output.with_additional_extension(".quant");
io::write_results(&quant_output, &hdr, &em_res, &ref_lengths, &eff_lengths)?;
Expand Down
135 changes: 134 additions & 1 deletion src/utils/em.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use ahash::AHashMap;
use atomic_float::AtomicF64;
use rand::prelude::*;
use rand::thread_rng;
use rand_distr::WeightedAliasIndex;
use rayon::prelude::*;
use std::sync::atomic::Ordering;
use tracing::info;

pub enum OrientationProperty {
Expand Down Expand Up @@ -116,6 +118,8 @@ pub struct PackedEqLabelIter<'a> {

impl<'a> Iterator for PackedEqLabelIter<'a> {
type Item = &'a [u32];

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let c = self.counter as usize;
if c < self.underlying_packed_map.len() {
Expand All @@ -139,6 +143,8 @@ pub struct PackedEqEntryIter<'a> {

impl<'a> Iterator for PackedEqEntryIter<'a> {
type Item = (&'a [u32], &'a usize);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let c = self.counter as usize;
if c < self.underlying_packed_map.len() {
Expand Down Expand Up @@ -196,6 +202,8 @@ pub struct EqEntryIter<'a> {

impl<'a> Iterator for EqEntryIter<'a> {
type Item = (&'a [u32], &'a usize);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self.underlying_iter.next() {
Some((k, v)) => Some((k.target_labels(self.contains_ori), v)),
Expand Down Expand Up @@ -286,6 +294,38 @@ pub fn adjust_ref_lengths(ref_lens: &[u32], cond_means: &[f64]) -> Vec<f64> {
const ABSENCE_THRESH: f64 = 1e-8;
const RELDIFF_THRESH: f64 = 1e-3;

#[inline]
fn m_step_par(
eq_iterates: &[(&[u32], &usize)],
prev_count: &mut [AtomicF64],
inv_eff_lens: &[f64],
curr_counts: &mut [AtomicF64],
) {
// TODO: is there a better way to set the capacity on
// this Vec?
eq_iterates.par_iter().for_each_with(
(&curr_counts, Vec::with_capacity(64)),
|(curr_counts, weights), (k, v)| {
let count = **v as f64;

let mut denom = 0.0_f64;
for e in k.iter() {
let w = prev_count[*e as usize].load(Ordering::Relaxed) * inv_eff_lens[*e as usize];
weights.push(w);
denom += w;
}
if denom > 1e-8 {
let count_over_denom = count / denom;
for (target_id, w) in k.iter().zip(weights.iter()) {
let inc = count_over_denom * w;
curr_counts[*target_id as usize].fetch_add(inc, Ordering::AcqRel);
}
}
weights.clear();
},
);
}

#[inline]
fn m_step(
eq_map: &PackedEqMap,
Expand All @@ -294,7 +334,8 @@ fn m_step(
inv_eff_lens: &[f64],
curr_counts: &mut [f64],
) {
// TODO: is there a better way to set this capacity?
// TODO: is there a better way to set the capacity on
// this Vec?
let mut weights: Vec<f64> = Vec::with_capacity(64);

for (k, v) in eq_map.iter_labels().zip(eq_counts.iter()) {
Expand Down Expand Up @@ -404,6 +445,98 @@ pub fn do_bootstrap(em_info: &EMInfo, num_boot: usize) -> Vec<Vec<f64>> {
.collect()
}

pub fn em_par(em_info: &EMInfo, nthreads: usize) -> Vec<f64> {
let eq_map = em_info.eq_map;
let eff_lens = em_info.eff_lens;
let inv_eff_lens = eff_lens
.iter()
.map(|x| {
let y = 1.0_f64 / *x;
if y.is_finite() {
y
} else {
0_f64
}
})
.collect::<Vec<f64>>();
let max_iter = em_info.max_iter;
let total_weight: f64 = eq_map.counts.iter().sum::<usize>() as f64;

// init
let avg = total_weight / (eff_lens.len() as f64);
let mut prev_counts: Vec<AtomicF64> = vec![avg; eff_lens.len()]
.iter()
.map(|x| AtomicF64::new(*x))
.collect();
let mut curr_counts: Vec<AtomicF64> = vec![0.0f64; eff_lens.len()]
.iter()
.map(|x| AtomicF64::new(*x))
.collect();

let eq_iterates: Vec<(&[u32], &usize)> = eq_map.iter_labels().zip(&eq_map.counts).collect();

let mut rel_diff = 0.0_f64;
let mut niter = 0_u32;

let pool = rayon::ThreadPoolBuilder::new()
.num_threads(nthreads)
.build()
.unwrap();

pool.install(|| {
while niter < max_iter {
m_step_par(
&eq_iterates,
&mut prev_counts,
&inv_eff_lens,
&mut curr_counts,
);

//std::mem::swap(&)
for i in 0..curr_counts.len() {
let pci = prev_counts[i].load(Ordering::Relaxed);
if pci > ABSENCE_THRESH {
let cci = curr_counts[i].load(Ordering::Relaxed);
let rd = (cci - pci) / pci;
rel_diff = if rel_diff > rd { rel_diff } else { rd };
}
}

std::mem::swap(&mut prev_counts, &mut curr_counts);
// zero out the vector
curr_counts
.par_iter()
.for_each(|x| x.store(0.0f64, Ordering::Relaxed));

if rel_diff < RELDIFF_THRESH {
break;
}
niter += 1;
if niter % 100 == 0 {
info!("iteration {}; rel diff {:.3}", niter, rel_diff);
}
rel_diff = 0.0_f64;
}

prev_counts.iter_mut().for_each(|x| {
if x.load(Ordering::Relaxed) < ABSENCE_THRESH {
x.store(0.0, Ordering::Relaxed);
}
});
m_step_par(
&eq_iterates,
&mut prev_counts,
&inv_eff_lens,
&mut curr_counts,
);
});

curr_counts
.iter()
.map(|x| x.load(Ordering::Relaxed))
.collect::<Vec<f64>>()
}

pub fn em(em_info: &EMInfo) -> Vec<f64> {
let eq_map = em_info.eq_map;
let eff_lens = em_info.eff_lens;
Expand Down

0 comments on commit b340c62

Please sign in to comment.