Skip to content

Commit

Permalink
batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
dagou committed Aug 12, 2024
1 parent 999916c commit 609a086
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 115 deletions.
10 changes: 7 additions & 3 deletions kr2r/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct Build {
pub threads: usize,
}

const BATCH_SIZE: usize = 16 * 1024 * 1024;
const BUFFER_SIZE: usize = 16 * 1024 * 1024;

/// Command line arguments for the classify program.
///
Expand Down Expand Up @@ -84,8 +84,12 @@ pub struct ClassifyArgs {
#[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())]
pub num_threads: usize,

#[clap(long, default_value_t = BATCH_SIZE)]
pub batch_size: usize,
#[clap(long, default_value_t = BUFFER_SIZE)]
pub buffer_size: usize,

/// The size of each batch for processing taxid match results, used to control memory usage
#[clap(long, default_value_t = 16)]
pub batch_size: u32,

/// Confidence score threshold
#[clap(
Expand Down
101 changes: 76 additions & 25 deletions kr2r/src/bin/annotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::path::Path;
use std::path::PathBuf;
use std::time::Instant;
// 定义每批次处理的 Slot 数量
pub const BATCH_SIZE: usize = 8 * 1024 * 1024;
pub const BUFFER_SIZE: usize = 8 * 1024 * 1024;

/// Command line arguments for the splitr program.
///
Expand All @@ -30,8 +30,12 @@ pub struct Args {
#[clap(long)]
pub chunk_dir: PathBuf,

#[clap(long, default_value_t = BATCH_SIZE)]
pub batch_size: usize,
#[clap(long, default_value_t = BUFFER_SIZE)]
pub buffer_size: usize,

/// The size of each batch for processing taxid match results, used to control memory usage
#[clap(long, default_value_t = 16)]
pub batch_size: u32,

/// The number of threads to use.
#[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())]
Expand All @@ -57,7 +61,7 @@ fn read_chunk_header<R: Read>(reader: &mut R) -> io::Result<(usize, usize)> {
Ok((index as usize, chunk_size as usize))
}

fn write_to_file(
fn _write_to_file(
file_index: u64,
bytes: &[u8],
last_file_index: &mut Option<u64>,
Expand Down Expand Up @@ -87,21 +91,65 @@ fn write_to_file(
Ok(())
}

fn write_to_file(
file_index: u64,
seq_id_mod: u32,
bytes: &[u8],
writers: &mut HashMap<(u64, u32), BufWriter<File>>,
chunk_dir: &PathBuf,
) -> io::Result<()> {
// 检查是否已经有该文件的 writer,没有则创建一个新的
let writer = writers.entry((file_index, seq_id_mod)).or_insert_with(|| {
let file_name = format!("sample_file_{}_{}.bin", file_index, seq_id_mod);
let file_path = chunk_dir.join(file_name);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&file_path)
.expect("failed to open file");
BufWriter::new(file)
});

writer.write_all(bytes)?;

Ok(())
}

fn clean_up_writers(
writers: &mut HashMap<(u64, u32), BufWriter<File>>,
current_file_index: u64,
) -> io::Result<()> {
let keys_to_remove: Vec<(u64, u32)> = writers
.keys()
.cloned()
.filter(|(idx, _)| *idx != current_file_index)
.collect();

for key in keys_to_remove {
if let Some(mut writer) = writers.remove(&key) {
writer.flush()?; // 刷新并清理
}
}

Ok(())
}

fn process_batch<R>(
reader: &mut R,
hash_config: &HashConfig,
chtm: &CHTable,
chunk_dir: PathBuf,
batch_size: usize,
buffer_size: usize,
bin_threads: u32,
page_index: usize,
num_threads: usize,
) -> std::io::Result<()>
where
R: Read + Send,
{
let row_size = std::mem::size_of::<Row>();
let mut last_file_index: Option<u64> = None;
let mut writer: Option<BufWriter<File>> = None;
let mut writers: HashMap<(u64, u32), BufWriter<File>> = HashMap::new();
let mut current_file_index: Option<u64> = None;

let value_mask = hash_config.get_value_mask();
let value_bits = hash_config.get_value_bits();
Expand All @@ -111,9 +159,9 @@ where
buffer_read_parallel(
reader,
num_threads,
batch_size,
buffer_size,
|dataset: Vec<Slot<u64>>| {
let mut results: HashMap<u64, Vec<u8>> = HashMap::new();
let mut results: HashMap<(u64, u32), Vec<u8>> = HashMap::new();
for slot in dataset {
let indx = slot.idx & idx_mask;
let compacted = slot.value.left(value_bits) as u32;
Expand All @@ -127,9 +175,10 @@ where
let high = u32::combined(left, taxid, value_bits);
let row = Row::new(high, seq_id, kmer_id as u32);
let value_bytes = row.as_slice(row_size);
let seq_id_mod = seq_id % bin_threads;

results
.entry(file_index)
.entry((file_index, seq_id_mod))
.or_insert_with(Vec::new)
.extend(value_bytes);
}
Expand All @@ -138,28 +187,29 @@ where
},
|result| {
while let Some(Some(res)) = result.next() {
let mut file_indices: Vec<_> = res.keys().cloned().collect();
file_indices.sort_unstable(); // 对file_index进行排序

for file_index in file_indices {
if let Some(bytes) = res.get(&file_index) {
write_to_file(
file_index,
bytes,
&mut last_file_index,
&mut writer,
&chunk_dir,
)
.expect("write to file error");
let mut file_keys: Vec<_> = res.keys().cloned().collect();
file_keys.sort_unstable(); // 对 (file_index, seq_id_mod) 进行排序

for (file_index, seq_id_mod) in file_keys {
if let Some(bytes) = res.get(&(file_index, seq_id_mod)) {
// 如果当前处理的 file_index 改变了,清理非当前的 writers
if current_file_index != Some(file_index) {
clean_up_writers(&mut writers, file_index).expect("clean writer");
current_file_index = Some(file_index);
}

write_to_file(file_index, seq_id_mod, bytes, &mut writers, &chunk_dir)
.expect("write to file error");
}
}
}
},
)
.expect("failed");

if let Some(w) = writer.as_mut() {
w.flush()?;
// 最终批次处理完成后,刷新所有的 writer
for writer in writers.values_mut() {
writer.flush()?;
}

Ok(())
Expand Down Expand Up @@ -190,6 +240,7 @@ fn process_chunk_file<P: AsRef<Path>>(
&config,
&chtm,
args.chunk_dir.clone(),
args.buffer_size,
args.batch_size,
page_index,
args.num_threads,
Expand Down
2 changes: 1 addition & 1 deletion kr2r/src/bin/kun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl From<ClassifyArgs> for annotate::Args {
database: item.database,
chunk_dir: item.chunk_dir,
batch_size: item.batch_size,
buffer_size: item.buffer_size,
num_threads: item.num_threads,
}
}
Expand All @@ -91,7 +92,6 @@ impl From<ClassifyArgs> for resolve::Args {
Self {
database: item.database,
chunk_dir: item.chunk_dir,
batch_size: item.batch_size,
confidence_threshold: item.confidence_threshold,
minimum_hit_groups: item.minimum_hit_groups,
kraken_output_dir: item.kraken_output_dir,
Expand Down
2 changes: 1 addition & 1 deletion kr2r/src/bin/merge_fna.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::time::Instant;
#[clap(version, about = "A tool for processing genomic files")]
pub struct Args {
/// Directory to store downloaded files
#[arg(short, long, default_value = "lib")]
#[arg(short, long, required = true)]
pub download_dir: PathBuf,

/// ncbi library fna database directory
Expand Down
Loading

0 comments on commit 609a086

Please sign in to comment.