From a4a22c454e20a5d914ff3d6bd694132df06cf882 Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Mon, 12 Feb 2024 13:59:45 +0800 Subject: [PATCH 01/26] Update aemb module --- src/aln.cpp | 131 +++++++++++++++++++++++++++++++++++++++++++----- src/aln.hpp | 7 ++- src/cmdline.cpp | 2 + src/cmdline.hpp | 1 + src/main.cpp | 40 +++++++++++---- src/pc.cpp | 10 ++-- src/pc.hpp | 3 +- 7 files changed, 164 insertions(+), 30 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index b5105bb1..7b88d3b1 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -868,12 +868,19 @@ inline void get_best_map_location( std::vector &nams2, InsertSizeDistribution &isize_est, Nam &best_nam1, - Nam &best_nam2 + Nam &best_nam2, + std::vector &abundance, + int read1_len, + int read2_len, + bool output_abundance ) { std::vector nam_pairs = get_best_scoring_nam_pairs(nams1, nams2, isize_est.mu, isize_est.sigma); best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped + std::vector best_ref1; + std::vector best_ref2; + if (nam_pairs.empty()) { return; } @@ -903,6 +910,73 @@ inline void get_best_map_location( if (score_joint > score_indiv) { // joint score is better than individual best_nam1 = n1_joint_max; best_nam2 = n2_joint_max; + + if (output_abundance){ + for (auto &[score, n1, n2] : nam_pairs){ + if ((n1.score + n2.score) == score_joint){ + best_ref1.push_back(n1); + best_ref2.push_back(n2); + }else{ + break; + } + } + + int ref_size1 = best_ref1.size(); + for (auto &t: best_ref1){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read1_len) / float(ref_size1); + } + + int ref_size2 = best_ref2.size(); + for (auto &t: best_ref2){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read2_len) / float(ref_size2); + } + } + + } + else{ + if (output_abundance){ + if (!nams1.empty()){ + for (auto &t : nams1){ + if (t.score == nams1[0].score){ + best_ref1.push_back(t); + }else{ + break; + } + } + + int ref_size1 = best_ref1.size(); + for (auto &t: best_ref1){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read1_len) / float(ref_size1); + } + } + + if (!nams2.empty()){ + for (auto &t : nams2){ + if (t.score == nams2[0].score){ + best_ref2.push_back(t); + }else{ + break; + } + } + + int ref_size2 = best_ref2.size(); + for (auto &t: best_ref2){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read2_len) / float(ref_size2); + } + } + } } if (isize_est.sample_size < 400 && score_joint > score_indiv) { @@ -957,7 +1031,8 @@ void align_or_map_paired( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ) { std::array details; std::array, 2> nams_pair; @@ -991,12 +1066,19 @@ void align_or_map_paired( } Timer extend_timer; + if (map_param.is_abundance_out){ + Nam nam_read1; + Nam nam_read2; + get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundance, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); + } + else{ if (!map_param.is_sam_out) { Nam nam_read1; Nam nam_read2; get_best_map_location(nams_pair[0], nams_pair[1], isize_est, nam_read1, - nam_read2); + nam_read2, abundance, record1.seq.length(), + record2.seq.length(), map_param.is_abundance_out); output_hits_paf_PE(outstring, nam_read1, record1.name, references, record1.seq.length()); @@ -1066,6 +1148,7 @@ void align_or_map_paired( ); } } + } statistics.tot_extend += extend_timer.duration(); statistics += details[0]; statistics += details[1]; @@ -1082,7 +1165,8 @@ void align_or_map_single( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ) { Details details; Timer strobe_timer; @@ -1111,15 +1195,36 @@ void align_or_map_single( Timer extend_timer; - if (!map_param.is_sam_out) { - output_hits_paf(outstring, nams, record.name, references, - record.seq.length()); - } else { - align_single( - aligner, sam, nams, record, index_parameters.syncmer.k, - references, details, map_param.dropoff_threshold, map_param.max_tries, - map_param.max_secondary, random_engine - ); + std::vector best_ref; + if (map_param.is_abundance_out){ + if (!nams.empty()){ + for (auto &t : nams){ + if (t.score == nams[0].score){ + best_ref.push_back(t); + }else{ + break; + } + } + int ref_size = best_ref.size(); + for (auto &t: best_ref){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(record.seq.length()) / float(ref_size); + } + } + } + else{ + if (!map_param.is_sam_out) { + output_hits_paf(outstring, nams, record.name, references, + record.seq.length()); + } else { + align_single( + aligner, sam, nams, record, index_parameters.syncmer.k, + references, details, map_param.dropoff_threshold, map_param.max_tries, + map_param.max_secondary, random_engine + ); + } } statistics.tot_extend += extend_timer.duration(); statistics += details; diff --git a/src/aln.hpp b/src/aln.hpp index f8bb69bf..99bcfedb 100644 --- a/src/aln.hpp +++ b/src/aln.hpp @@ -64,6 +64,7 @@ struct MappingParameters { int max_tries { 20 }; int rescue_cutoff; bool is_sam_out { true }; + bool is_abundance_out {false}; CigarOps cigar_ops{CigarOps::M}; bool output_unmapped { true }; bool details{false}; @@ -88,7 +89,8 @@ void align_or_map_paired( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ); void align_or_map_single( @@ -101,7 +103,8 @@ void align_or_map_single( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ); // Private declarations, only needed for tests diff --git a/src/cmdline.cpp b/src/cmdline.cpp index c00c5479..510821b6 100644 --- a/src/cmdline.cpp +++ b/src/cmdline.cpp @@ -31,6 +31,7 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { args::ValueFlag index_statistics(parser, "PATH", "Print statistics of indexing to PATH", {"index-statistics"}); args::Flag i(parser, "index", "Do not map reads; only generate the strobemer index and write it to disk. If read files are provided, they are used to estimate read length", {"create-index", 'i'}); args::Flag use_index(parser, "use_index", "Use a pre-generated index previously written with --create-index.", { "use-index" }); + args::Flag aemb(parser, "aemb", "Only output abundance value of contigs for metagenomic binning", {"aemb"}); args::Group sam(parser, "SAM output:"); args::Flag eqx(parser, "eqx", "Emit =/X instead of M CIGAR operations", {"eqx"}); @@ -97,6 +98,7 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { if (index_statistics) { opt.logfile_name = args::get(index_statistics); } if (i) { opt.only_gen_index = true; } if (use_index) { opt.use_index = true; } + if (aemb) {opt.is_abundance_out = true; } // SAM output if (eqx) { opt.cigar_eqx = true; } diff --git a/src/cmdline.hpp b/src/cmdline.hpp index 2ae3186e..9f5b4a05 100644 --- a/src/cmdline.hpp +++ b/src/cmdline.hpp @@ -18,6 +18,7 @@ struct CommandLineOptions { bool only_gen_index { false }; bool use_index { false }; bool is_sam_out { true }; + bool is_abundance_out {false}; // SAM output bool cigar_eqx { false }; diff --git a/src/main.cpp b/src/main.cpp index dcc96774..662798e1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -188,6 +188,7 @@ int run_strobealign(int argc, char **argv) { map_param.output_unmapped = opt.output_unmapped; map_param.details = opt.details; map_param.fastq_comments = opt.fastq_comments; + map_param.is_abundance_out = opt.is_abundance_out; map_param.verify(); log_parameters(index_parameters, map_param, aln_params); @@ -284,31 +285,33 @@ int run_strobealign(int argc, char **argv) { std::ostream out(buf); - if (map_param.is_sam_out) { - std::stringstream cmd_line; - for(int i = 0; i < argc; ++i) { - cmd_line << argv[i] << " "; - } + if (!map_param.is_abundance_out){ + if (map_param.is_sam_out) { + std::stringstream cmd_line; + for(int i = 0; i < argc; ++i) { + cmd_line << argv[i] << " "; + } - out << sam_header(references, opt.read_group_id, opt.read_group_fields); - if (opt.pg_header) { - out << pg_header(cmd_line.str()); + out << sam_header(references, opt.read_group_id, opt.read_group_fields); + if (opt.pg_header) { + out << pg_header(cmd_line.str()); + } } } std::vector log_stats_vec(opt.n_threads); - + logger.info() << "Running in " << (opt.is_SE ? "single-end" : "paired-end") << " mode" << std::endl; OutputBuffer output_buffer(out); - std::vector workers; std::vector worker_done(opt.n_threads); // each thread sets its entry to 1 when it’s done + std::vector> align_abundances(opt.n_threads, std::vector(references.size(), 0)); for (int i = 0; i < opt.n_threads; ++i) { std::thread consumer(perform_task, std::ref(input_buffer), std::ref(output_buffer), std::ref(log_stats_vec[i]), std::ref(worker_done[i]), std::ref(aln_params), std::ref(map_param), std::ref(index_parameters), std::ref(references), - std::ref(index), std::ref(opt.read_group_id)); + std::ref(index), std::ref(opt.read_group_id), std::ref(align_abundances[i])); workers.push_back(std::move(consumer)); } if (opt.show_progress && isatty(2)) { @@ -324,6 +327,21 @@ int run_strobealign(int argc, char **argv) { tot_statistics += it; } + if (map_param.is_abundance_out){ + std::vector abundances(references.size(), 0); + std::vector abundances_norm(references.size(), 0); + for (size_t i = 0; i < align_abundances.size(); ++i) { + for (size_t j = 0; j < align_abundances[i].size(); ++j) { + abundances[j] += align_abundances[i][j]; + } + } + + for (size_t i = 0; i < references.size(); ++i) { + std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / float(references.sequences[i].size()) << std::endl; + } + } + + logger.info() << "Total mapping sites tried: " << tot_statistics.tot_all_tried << std::endl << "Total calls to ssw: " << tot_statistics.tot_aligner_calls << std::endl << "Inconsistent NAM ends: " << tot_statistics.inconsistent_nams << std::endl diff --git a/src/pc.cpp b/src/pc.cpp index 5011bf21..34da9c97 100644 --- a/src/pc.cpp +++ b/src/pc.cpp @@ -139,7 +139,8 @@ void perform_task( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - const std::string& read_group_id + const std::string& read_group_id, + std::vector &abundance ) { bool eof = false; Aligner aligner{aln_params}; @@ -170,16 +171,19 @@ void perform_task( to_uppercase(record1.seq); to_uppercase(record2.seq); align_or_map_paired(record1, record2, sam, sam_out, statistics, isize_est, aligner, - map_param, index_parameters, references, index, random_engine); + map_param, index_parameters, references, index, random_engine, abundance); statistics.n_reads += 2; } for (size_t i = 0; i < records3.size(); ++i) { auto record = records3[i]; - align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine); + align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundance); statistics.n_reads++; } + + if (!map_param.is_abundance_out){ output_buffer.output_records(std::move(sam_out), chunk_index); assert(sam_out == ""); + } } statistics.tot_aligner_calls += aligner.calls_count(); done = true; diff --git a/src/pc.hpp b/src/pc.hpp index 703de209..b0fa6ad8 100644 --- a/src/pc.hpp +++ b/src/pc.hpp @@ -64,7 +64,8 @@ class OutputBuffer { void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer, AlignmentStatistics& statistics, int& done, const AlignmentParameters &aln_params, - const MappingParameters &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id); + const MappingParameters &map_param, const IndexParameters& index_parameters, + const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundance); bool same_name(const std::string& n1, const std::string& n2); From f113da2f3e94bbb06579391d2cfd7526ac701b21 Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 22:48:33 +0800 Subject: [PATCH 02/26] Code improve; documentation update; testing update --- README.md | 6 +++ src/aln.cpp | 86 +++++++++++++++++++++--------------------- src/aln.hpp | 4 +- src/cmdline.cpp | 2 +- src/main.cpp | 36 ++++++++++-------- src/pc.cpp | 6 +-- src/pc.hpp | 2 +- tests/phix.abun.pe.txt | 1 + tests/phix.abun.se.txt | 1 + tests/run.sh | 10 +++++ 10 files changed, 90 insertions(+), 64 deletions(-) create mode 100644 tests/phix.abun.pe.txt create mode 100644 tests/phix.abun.se.txt diff --git a/README.md b/README.md index f7ec6e05..804e8e7c 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ strobealign ref.fa reads.1.fastq.gz reads.2.fastq.gz | samtools sort -o sorted.b This is usually faster than doing the two steps separately because fewer intermediate files are created. +To output the estimated abundance of every contig, the format of output file is: contig_id \t abundance_value: +``` +strobealign ref.fa reads.fq --aemb > abundance.txt # Single-end reads +strobealign ref.fa reads1.fq reads2.fq --aemb > abundance.txt # Paired-end reads +``` ## Command-line options @@ -127,6 +132,7 @@ options. Some important ones are: * `--eqx`: Emit `=` and `X` CIGAR operations instead of `M`. * `-x`: Only map reads, do not do no base-level alignment. This switches the output format from SAM to [PAF](https://github.com/lh3/miniasm/blob/master/PAF.md). +* `--aemb`: Output the estimated abundance value of every contig, the format of output file is: contig_id \t abundance_value. * `--rg-id=ID`: Add RG tag to each SAM record. * `--rg=TAG:VALUE`: Add read group metadata to the SAM header. This can be specified multiple times. Example: `--rg-id=1 --rg=SM:mysamle --rg=LB:mylibrary`. diff --git a/src/aln.cpp b/src/aln.cpp index 7b88d3b1..e66f9039 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -869,7 +869,7 @@ inline void get_best_map_location( InsertSizeDistribution &isize_est, Nam &best_nam1, Nam &best_nam2, - std::vector &abundance, + std::vector &abundances, int read1_len, int read2_len, bool output_abundance @@ -878,8 +878,8 @@ inline void get_best_map_location( best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped - std::vector best_ref1; - std::vector best_ref2; + std::vector best_contig1; + std::vector best_contig2; if (nam_pairs.empty()) { return; @@ -912,68 +912,70 @@ inline void get_best_map_location( best_nam2 = n2_joint_max; if (output_abundance){ + // find all NAM pairs that have the same best score for (auto &[score, n1, n2] : nam_pairs){ if ((n1.score + n2.score) == score_joint){ - best_ref1.push_back(n1); - best_ref2.push_back(n2); + best_contig1.push_back(n1); + best_contig2.push_back(n2); }else{ break; } } - int ref_size1 = best_ref1.size(); - for (auto &t: best_ref1){ + size_t contig_size1 = best_contig1.size(); + for (auto &t: best_contig1){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read1_len) / float(ref_size1); + abundances[t.ref_id] += float(read1_len) / float(contig_size1); } - int ref_size2 = best_ref2.size(); - for (auto &t: best_ref2){ + int contig_size2 = best_contig2.size(); + for (auto &t: best_contig2){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read2_len) / float(ref_size2); + abundances[t.ref_id] += float(read2_len) / float(contig_size2); } - } - + } } else{ if (output_abundance){ if (!nams1.empty()){ + // find all NAM1 that have the same best score for (auto &t : nams1){ if (t.score == nams1[0].score){ - best_ref1.push_back(t); + best_contig1.push_back(t); }else{ break; } } - int ref_size1 = best_ref1.size(); - for (auto &t: best_ref1){ + size_t contig_size1 = best_contig1.size(); + for (auto &t: best_contig1){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read1_len) / float(ref_size1); + abundances[t.ref_id] += float(read1_len) / float(contig_size1); } } if (!nams2.empty()){ + // find all NAM2 that have the same best score for (auto &t : nams2){ if (t.score == nams2[0].score){ - best_ref2.push_back(t); + best_contig2.push_back(t); }else{ break; } } - int ref_size2 = best_ref2.size(); - for (auto &t: best_ref2){ + int contig_size2 = best_contig2.size(); + for (auto &t: best_contig2){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read2_len) / float(ref_size2); + abundances[t.ref_id] += float(read2_len) / float(contig_size2); } } } @@ -1032,7 +1034,7 @@ void align_or_map_paired( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ) { std::array details; std::array, 2> nams_pair; @@ -1069,22 +1071,22 @@ void align_or_map_paired( if (map_param.is_abundance_out){ Nam nam_read1; Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundance, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); + get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundances, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); } else{ - if (!map_param.is_sam_out) { - Nam nam_read1; - Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est, - nam_read1, - nam_read2, abundance, record1.seq.length(), - record2.seq.length(), map_param.is_abundance_out); - output_hits_paf_PE(outstring, nam_read1, record1.name, - references, - record1.seq.length()); - output_hits_paf_PE(outstring, nam_read2, record2.name, - references, - record2.seq.length()); + if (!map_param.is_sam_out) { + Nam nam_read1; + Nam nam_read2; + get_best_map_location(nams_pair[0], nams_pair[1], isize_est, + nam_read1, + nam_read2, abundances, record1.seq.length(), + record2.seq.length(), map_param.is_abundance_out); + output_hits_paf_PE(outstring, nam_read1, record1.name, + references, + record1.seq.length()); + output_hits_paf_PE(outstring, nam_read2, record2.name, + references, + record2.seq.length()); } else { Read read1(record1.seq); Read read2(record2.seq); @@ -1166,7 +1168,7 @@ void align_or_map_single( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ) { Details details; Timer strobe_timer; @@ -1195,22 +1197,22 @@ void align_or_map_single( Timer extend_timer; - std::vector best_ref; + std::vector best_contig; if (map_param.is_abundance_out){ if (!nams.empty()){ for (auto &t : nams){ if (t.score == nams[0].score){ - best_ref.push_back(t); + best_contig.push_back(t); }else{ break; } } - int ref_size = best_ref.size(); - for (auto &t: best_ref){ + int contig_size = best_contig.size(); + for (auto &t: best_contig){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(record.seq.length()) / float(ref_size); + abundances[t.ref_id] += float(record.seq.length()) / float(contig_size); } } } diff --git a/src/aln.hpp b/src/aln.hpp index 99bcfedb..a9f24e27 100644 --- a/src/aln.hpp +++ b/src/aln.hpp @@ -90,7 +90,7 @@ void align_or_map_paired( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ); void align_or_map_single( @@ -104,7 +104,7 @@ void align_or_map_single( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ); // Private declarations, only needed for tests diff --git a/src/cmdline.cpp b/src/cmdline.cpp index 510821b6..5ad83821 100644 --- a/src/cmdline.cpp +++ b/src/cmdline.cpp @@ -27,11 +27,11 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { args::Flag v(parser, "v", "Verbose output", {'v'}); args::Flag no_progress(parser, "no-progress", "Disable progress report (enabled by default if output is a terminal)", {"no-progress"}); args::Flag x(parser, "x", "Only map reads, no base level alignment (produces PAF file)", {'x'}); + args::Flag aemb(parser, "aemb", "Output the estimated abundance value of contigs, the format of output file is: contig_id \t abundance_value", {"aemb"}); args::Flag interleaved(parser, "interleaved", "Interleaved input", {"interleaved"}); args::ValueFlag index_statistics(parser, "PATH", "Print statistics of indexing to PATH", {"index-statistics"}); args::Flag i(parser, "index", "Do not map reads; only generate the strobemer index and write it to disk. If read files are provided, they are used to estimate read length", {"create-index", 'i'}); args::Flag use_index(parser, "use_index", "Use a pre-generated index previously written with --create-index.", { "use-index" }); - args::Flag aemb(parser, "aemb", "Only output abundance value of contigs for metagenomic binning", {"aemb"}); args::Group sam(parser, "SAM output:"); args::Flag eqx(parser, "eqx", "Emit =/X instead of M CIGAR operations", {"eqx"}); diff --git a/src/main.cpp b/src/main.cpp index 662798e1..964d223f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -104,6 +104,12 @@ InputBuffer get_input_buffer(const CommandLineOptions& opt) { } } +void output_abundance(std::vector abundances, References references){ + for (size_t i = 0; i < references.size(); ++i) { + std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / double(references.sequences[i].size()) << std::endl; + } +} + void show_progress_until_done(std::vector& worker_done, std::vector& stats) { Timer timer; bool reported = false; @@ -154,6 +160,11 @@ int run_strobealign(int argc, char **argv) { if (opt.c >= 64 || opt.c <= 0) { throw BadParameter("c must be greater than 0 and less than 64"); } + + if (!opt.is_sam_out && opt.is_abundance_out){ + throw BadParameter("Can not use -x and --aemb at the same time"); + } + InputBuffer input_buffer = get_input_buffer(opt); if (!opt.r_set && !opt.reads_filename1.empty()) { opt.r = estimate_read_length(input_buffer); @@ -285,18 +296,15 @@ int run_strobealign(int argc, char **argv) { std::ostream out(buf); - if (!map_param.is_abundance_out){ - if (map_param.is_sam_out) { + if (map_param.is_sam_out && !map_param.is_abundance_out){ std::stringstream cmd_line; for(int i = 0; i < argc; ++i) { cmd_line << argv[i] << " "; } - out << sam_header(references, opt.read_group_id, opt.read_group_fields); if (opt.pg_header) { out << pg_header(cmd_line.str()); } - } } std::vector log_stats_vec(opt.n_threads); @@ -306,12 +314,12 @@ int run_strobealign(int argc, char **argv) { OutputBuffer output_buffer(out); std::vector workers; std::vector worker_done(opt.n_threads); // each thread sets its entry to 1 when it’s done - std::vector> align_abundances(opt.n_threads, std::vector(references.size(), 0)); + std::vector> worker_abundances(opt.n_threads, std::vector(references.size(), 0)); for (int i = 0; i < opt.n_threads; ++i) { std::thread consumer(perform_task, std::ref(input_buffer), std::ref(output_buffer), std::ref(log_stats_vec[i]), std::ref(worker_done[i]), std::ref(aln_params), std::ref(map_param), std::ref(index_parameters), std::ref(references), - std::ref(index), std::ref(opt.read_group_id), std::ref(align_abundances[i])); + std::ref(index), std::ref(opt.read_group_id), std::ref(worker_abundances[i])); workers.push_back(std::move(consumer)); } if (opt.show_progress && isatty(2)) { @@ -328,17 +336,15 @@ int run_strobealign(int argc, char **argv) { } if (map_param.is_abundance_out){ - std::vector abundances(references.size(), 0); - std::vector abundances_norm(references.size(), 0); - for (size_t i = 0; i < align_abundances.size(); ++i) { - for (size_t j = 0; j < align_abundances[i].size(); ++j) { - abundances[j] += align_abundances[i][j]; - } + std::vector abundances(references.size(), 0); + for (size_t i = 0; i < worker_abundances.size(); ++i) { + for (size_t j = 0; j < worker_abundances[i].size(); ++j) { + abundances[j] += worker_abundances[i][j]; } + } - for (size_t i = 0; i < references.size(); ++i) { - std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / float(references.sequences[i].size()) << std::endl; - } + // output the abundance file + output_abundance(abundances, references); } diff --git a/src/pc.cpp b/src/pc.cpp index 34da9c97..417717ea 100644 --- a/src/pc.cpp +++ b/src/pc.cpp @@ -140,7 +140,7 @@ void perform_task( const References& references, const StrobemerIndex& index, const std::string& read_group_id, - std::vector &abundance + std::vector &abundances ) { bool eof = false; Aligner aligner{aln_params}; @@ -171,12 +171,12 @@ void perform_task( to_uppercase(record1.seq); to_uppercase(record2.seq); align_or_map_paired(record1, record2, sam, sam_out, statistics, isize_est, aligner, - map_param, index_parameters, references, index, random_engine, abundance); + map_param, index_parameters, references, index, random_engine, abundances); statistics.n_reads += 2; } for (size_t i = 0; i < records3.size(); ++i) { auto record = records3[i]; - align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundance); + align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundances); statistics.n_reads++; } diff --git a/src/pc.hpp b/src/pc.hpp index b0fa6ad8..87e4e873 100644 --- a/src/pc.hpp +++ b/src/pc.hpp @@ -65,7 +65,7 @@ class OutputBuffer { void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer, AlignmentStatistics& statistics, int& done, const AlignmentParameters &aln_params, const MappingParameters &map_param, const IndexParameters& index_parameters, - const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundance); + const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundances); bool same_name(const std::string& n1, const std::string& n2); diff --git a/tests/phix.abun.pe.txt b/tests/phix.abun.pe.txt new file mode 100644 index 00000000..7e86894b --- /dev/null +++ b/tests/phix.abun.pe.txt @@ -0,0 +1 @@ +NC_001422.1 4.572291 diff --git a/tests/phix.abun.se.txt b/tests/phix.abun.se.txt new file mode 100644 index 00000000..62b75067 --- /dev/null +++ b/tests/phix.abun.se.txt @@ -0,0 +1 @@ +NC_001422.1 2.313690 diff --git a/tests/run.sh b/tests/run.sh index f82dbb5b..d7d6f064 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -58,6 +58,16 @@ strobealign -x tests/phix.fasta tests/phix.1.fastq tests/phix.2.fastq | tail -n diff tests/phix.pe.paf phix.pe.paf rm phix.pe.paf +# Single-end abundance estimation +strobealign --aemb tests/phix.fasta tests/phix.1.fastq > phix.abun.se.txt +diff tests/phix.abun.se.txt phix.abun.se.txt +rm phix.abun.se.txt + +# Paired-end abundance estimation +strobealign --aemb tests/phix.fasta tests/phix.1.fastq tests/phix.2.fastq > phix.abun.pe.txt +diff tests/phix.abun.pe.txt phix.abun.pe.txt +rm phix.abun.pe.txt + # Build a separate index strobealign --no-PG -r 150 tests/phix.fasta tests/phix.1.fastq > without-sti.sam strobealign -r 150 -i tests/phix.fasta From c09d3678460d76124b4382343deb6154fa6c104d Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 23:29:13 +0800 Subject: [PATCH 03/26] Update single-end abundance testing file --- tests/phix.abun.se.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/phix.abun.se.txt b/tests/phix.abun.se.txt index 62b75067..55e31ba6 100644 --- a/tests/phix.abun.se.txt +++ b/tests/phix.abun.se.txt @@ -1 +1 @@ -NC_001422.1 2.313690 +NC_001422.1 2.347196 From 0719ec5d396e086d4e9279fa37d823cd532bddc2 Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 23:31:32 +0800 Subject: [PATCH 04/26] Update paired-end abundance testing file --- tests/phix.abun.pe.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/phix.abun.pe.txt b/tests/phix.abun.pe.txt index 7e86894b..900b9de5 100644 --- a/tests/phix.abun.pe.txt +++ b/tests/phix.abun.pe.txt @@ -1 +1 @@ -NC_001422.1 4.572291 +NC_001422.1 4.638507 From 7418bb60ae3ac439c63ce76bc24da423a231d609 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Mon, 22 Jan 2024 15:08:14 +0100 Subject: [PATCH 05/26] Vendor zstr instead of fetching it from Git at build time This way, the strobealign sources are more self-contained and it should not be necessary to have internet access at build time. --- CHANGES.md | 3 + CMakeLists.txt | 6 +- ext/README.md | 7 + ext/zstr/CMakeLists.txt | 46 +++ ext/zstr/LICENSE | 21 ++ ext/zstr/README.org | 103 +++++++ ext/zstr/src/strict_fstream.hpp | 237 +++++++++++++++ ext/zstr/src/zstr.hpp | 502 ++++++++++++++++++++++++++++++++ 8 files changed, 920 insertions(+), 5 deletions(-) create mode 100644 ext/zstr/CMakeLists.txt create mode 100644 ext/zstr/LICENSE create mode 100644 ext/zstr/README.org create mode 100644 ext/zstr/src/strict_fstream.hpp create mode 100644 ext/zstr/src/zstr.hpp diff --git a/CHANGES.md b/CHANGES.md index 5e653588..266b5832 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,9 @@ * #378: Added `-C` option for appending the FASTA or FASTQ comment to SAM output. (Idea and name of the option taken from BWA-MEM.) * #371: Added `--no-PG` option for not outputting the PG SAM header +* Include [ZStr](https://github.com/mateidavid/zstr/) in our own repository + instead of downloading it at build time. This should make it possible to + build strobealign without internet access. ## v0.12.0 (2023-11-23) diff --git a/CMakeLists.txt b/CMakeLists.txt index 19ace7a2..988b1235 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,11 +25,7 @@ endif() message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") add_compile_options(-Wall -Wextra -Werror=maybe-uninitialized) -FetchContent_Declare(ZStrGitRepo - GIT_REPOSITORY "https://github.com/mateidavid/zstr" - GIT_TAG "755da7890ea22478a702e3139092e6c964fab1f5" -) -FetchContent_MakeAvailable(ZStrGitRepo) +add_subdirectory(ext/zstr) # Obtain version from Git or fall back to PROJECT_VERSION if not building # from a Git repository diff --git a/ext/README.md b/ext/README.md index e473fdf5..e8316d26 100644 --- a/ext/README.md +++ b/ext/README.md @@ -47,3 +47,10 @@ License: See ssw/README.md Homepage: https://www.xxhash.com Version: ? License: See xxhash.c + + +## zstr + +Homepage: https://github.com/mateidavid/zstr +Commit used: 755da7890ea22478a702e3139092e6c964fab1f5 +License: See zstr/LICENSE diff --git a/ext/zstr/CMakeLists.txt b/ext/zstr/CMakeLists.txt new file mode 100644 index 00000000..8a015618 --- /dev/null +++ b/ext/zstr/CMakeLists.txt @@ -0,0 +1,46 @@ +cmake_minimum_required(VERSION 3.10 FATAL_ERROR) + +project(zstr LANGUAGES CXX) + +if (${CMAKE_VERSION} VERSION_GREATER_EQUAL 3.12) + cmake_policy(SET CMP0074 NEW) # find_package uses _ROOT variables +endif() + +if(${CMAKE_VERSION} VERSION_LESS 3.13) + message(WARNING + "Interface library targets are not well supported before cmake 3.13 .... " + "You may need to add \${ZSTR_INCLUDE_DIRS} to your include directories\n" + "target_include_directories(YourTarget PRIVATE \${ZSTR_INCLUDE_DIRS}) " + ) +endif() + +# -- locate zlib + +find_package(ZLIB 1.2.3 REQUIRED) # defines imported target ZLIB::ZLIB +message(STATUS "zstr - found ZLIB (version: ${ZLIB_VERSION_STRING})") + +# -- add target + +add_library(zstr INTERFACE) +add_library(zstr::zstr ALIAS zstr) + +# -- set target properties + +target_include_directories(zstr INTERFACE "${PROJECT_SOURCE_DIR}/src") +target_link_libraries(zstr INTERFACE ZLIB::ZLIB) +target_compile_features(zstr INTERFACE cxx_std_11) # require c++11 flag + +# -- set cache variables + +# NOTE: these vars are mostly useful to people using cmake < 3.13 +set(ZSTR_INCLUDE_DIRS "${PROJECT_SOURCE_DIR}/src;${ZLIB_INCLUDE_DIRS}" CACHE PATH "" FORCE) +set(ZSTR_LIBRARIES "${ZLIB_LIBRARIES}" CACHE PATH "" FORCE) + +# -- print target summary + +message(STATUS + "zstr - added INTERFACE target 'zstr::zstr' + includes : ${ZSTR_INCLUDE_DIRS} + libraries: ZLIB::ZLIB + features : cxx_std_11" +) diff --git a/ext/zstr/LICENSE b/ext/zstr/LICENSE new file mode 100644 index 00000000..841c7214 --- /dev/null +++ b/ext/zstr/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2015 Matei David, Ontario Institute for Cancer Research + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ext/zstr/README.org b/ext/zstr/README.org new file mode 100644 index 00000000..dea53589 --- /dev/null +++ b/ext/zstr/README.org @@ -0,0 +1,103 @@ +# -*- mode:org; mode:visual-line; coding:utf-8; -*- + +** A C++ ZLib wrapper + +[[http://travis-ci.org/mateidavid/zstr][http://travis-ci.org/mateidavid/zstr.svg?branch=master]] [[https://tldrlegal.com/license/mit-license][http://img.shields.io/:license-mit-blue.svg]] + +This C++ header-only library enables the use of C++ standard iostreams to access ZLib-compressed streams. + +For input access (decompression), the compression format is auto-detected, and multiple concatenated compressed streams are decompressed seamlessly. + +For output access (compression), the only parameter exposed by this API is the compression level. + +Alternatives to this library include: + +- The original [[http://www.zlib.net/][ZLib]], through its [[http://www.zlib.net/manual.html][C API]]. This does not interact nicely with C++ iostreams. + +- The [[http://www.cs.unc.edu/Research/compgeom/gzstream/][GZStream]] library. This library does not auto-detect input compression, and it cannot wrap streams (only files). + +- The [[http://www.boost.org/doc/libs/release/libs/iostreams/][Boost IOStreams]] library. The library does not auto-detect input compression (by default, though that can be easily implemented with filters), and more importantly, it is not a header-only Boost library. + +- The [[https://github.com/tmaklin/bxzstr][bxzstr]] library, if you want support for BZ2 and/or LZMA as well. + +For an example usage, see [[examples/ztxtpipe.cpp]] and [[examples/zc.cpp]]. + +It is compatible with [[https://github.com/richgel999/miniz][miniz]] in case you don't want to get frustrated with zlib e. g. on Windows. + +**** Input Auto-detection + +For input access, the library seamlessly auto-detects whether the source stream is compressed or not. The following compressed streams are detected: + +- GZip header, when stream starts with =1F 8B=. See [[http://en.wikipedia.org/wiki/Gzip][GZip format]]. + +- ZLib header, when stream starts with =78 01=, =78 9C=, and =78 DA=. See [[http://stackoverflow.com/a/17176881][answer here]]. + +If none of these formats are detected, the library assumes the input is not compressed, and it produces a plain copy of the source stream. + +**** Classes + +The package provides 6 classes for accessing ZLib streams: + +- =zstr::istreambuf= is the core decompression class. This is constructed from an existing =std::streambuf= that contains source data. The =zstr::istreambuf= constructor accepts explicit settings for the internal buffer size (default: 1 MB) and the auto-detection option (default: on). ZLib errors cause exceptions to be thrown. + +- =zstr::ostreambuf= is the core compression class. This is constructed from an existing =std::streambuf= that contains sink data. The =zstr::ostreambuf= constructor accepts explicit settings for the internal buffer size (default: 1 MB) and the compression option (default: ZLib default). ZLib errors cause exceptions to be thrown. + +- =zstr::istream= is a wrapper for a =zstr::istreambuf= that accesses an /external/ =std::streambuf=. It can be constructed from an existing =std::istream= (such as =std::cin=) or =std::streambuf=. + +- =zstr::ostream= is a wrapper for a =zstr::ostreambuf= that accesses an /external/ =std::streambuf=. It can be constructed from an existing =std::ostream= (such as =std::cout=) or =std::streambuf=. + +- =zstr::ifstream= is a wrapper for a =zstr::istreambuf= that accesses an /internal/ =std::ifstream=. This can be used to open a file and read decompressed data from it. + +- =zstr::ofstream= is a wrapper for a =zstr::ostreambuf= that accesses an /internal/ =std::ofstream=. This can be used to open a file and write compressed data to it. + +For all stream objects, the =badbit= of their exception mask is turned on in order to propagate exceptions. + +**** CMake + +There are three simple ways to add zstr to a CMake project. + +Method 1. Add zstr as a subdirectory and link to the =zstr::zstr= target + + #+BEGIN_SRC cmake + add_subdirectory(zstr) # defines INTERFACE target 'zstr::zstr' + + add_executable(YourTarget main.cpp) + target_link_libraries(YourTarget PRIVATE zstr::zstr) + # if using cmake < 3.13 you may also need the following line + # target_include_directories(YourTarget PRIVATE ${ZSTR_INCLUDE_DIRS}) + #+END_SRC + +Method 2. Fetch a copy of zstr from an external repository and link to the =zstr::zstr= target + + /NOTE: The FetchContent functions shown here were introduced in CMake 3.14/ + + #+BEGIN_SRC cmake + include(FetchContent) + FetchContent_Declare(ZStrGitRepo + GIT_REPOSITORY "https://github.com/mateidavid/zstr" # can also be a local filesystem path! + GIT_TAG "master" + ) + FetchContent_MakeAvailable(ZStrGitRepo) # defines INTERFACE target 'zstr::zstr' + + add_executable(YourTarget main.cpp) + target_link_libraries(YourTarget PRIVATE zstr::zstr) + #+END_SRC + +Method 3. Add path containing 'zstr.hpp' to your target's include directories + + /NOTE: With this method you're responsible for finding and linking to ZLIB !/ + + #+BEGIN_SRC cmake + find_package(ZLIB REQUIRED) + add_executable(YourTarget main.cpp) + target_link_libraries(YourTarget PRIVATE ZLIB::ZLIB) + target_include_directories(YourTarget PRIVATE /path/to/zstr/src) + #+END_SRC + +**** Requisites + +If you use GCC and want to use the `fs.open()` function, you need to deploy at least GCC version 5.1. + +**** License + +Released under the [[file:LICENSE][MIT license]]. diff --git a/ext/zstr/src/strict_fstream.hpp b/ext/zstr/src/strict_fstream.hpp new file mode 100644 index 00000000..7d03ea66 --- /dev/null +++ b/ext/zstr/src/strict_fstream.hpp @@ -0,0 +1,237 @@ +#pragma once + +#include +#include +#include +#include +#include + +/** + * This namespace defines wrappers for std::ifstream, std::ofstream, and + * std::fstream objects. The wrappers perform the following steps: + * - check the open modes make sense + * - check that the call to open() is successful + * - (for input streams) check that the opened file is peek-able + * - turn on the badbit in the exception mask + */ +namespace strict_fstream +{ + +// Help people out a bit, it seems like this is a common recommenation since +// musl breaks all over the place. +#if defined(__NEED_size_t) && !defined(__MUSL__) +#warning "It seems to be recommended to patch in a define for __MUSL__ if you use musl globally: https://www.openwall.com/lists/musl/2013/02/10/5" +#define __MUSL__ +#endif + +// Workaround for broken musl implementation +// Since musl insists that they are perfectly compatible, ironically enough, +// they don't officially have a __musl__ or similar. But __NEED_size_t is defined in their +// relevant header (and not in working implementations), so we can use that. +#ifdef __MUSL__ +#warning "Working around broken strerror_r() implementation in musl, remove when musl is fixed" +#endif + +// Non-gnu variants of strerror_* don't necessarily null-terminate if +// truncating, so we have to do things manually. +inline std::string trim_to_null(const std::vector &buff) +{ + std::string ret(buff.begin(), buff.end()); + + const std::string::size_type pos = ret.find('\0'); + if (pos == std::string::npos) { + ret += " [...]"; // it has been truncated + } else { + ret.resize(pos); + } + return ret; +} + +/// Overload of error-reporting function, to enable use with VS and non-GNU +/// POSIX libc's +/// Ref: +/// - http://stackoverflow.com/a/901316/717706 +static std::string strerror() +{ + // Can't use std::string since we're pre-C++17 + std::vector buff(256, '\0'); + +#ifdef _WIN32 + // Since strerror_s might set errno itself, we need to store it. + const int err_num = errno; + if (strerror_s(buff.data(), buff.size(), err_num) != 0) { + return trim_to_null(buff); + } else { + return "Unknown error (" + std::to_string(err_num) + ")"; + } +#elif ((_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || defined(__APPLE__) || defined(__FreeBSD__)) && ! _GNU_SOURCE) || defined(__MUSL__) +// XSI-compliant strerror_r() + const int err_num = errno; // See above + if (strerror_r(err_num, buff.data(), buff.size()) == 0) { + return trim_to_null(buff); + } else { + return "Unknown error (" + std::to_string(err_num) + ")"; + } +#else +// GNU-specific strerror_r() + char * p = strerror_r(errno, &buff[0], buff.size()); + return std::string(p, std::strlen(p)); +#endif +} + +/// Exception class thrown by failed operations. +class Exception + : public std::exception +{ +public: + Exception(const std::string& msg) : _msg(msg) {} + const char * what() const noexcept { return _msg.c_str(); } +private: + std::string _msg; +}; // class Exception + +namespace detail +{ + +struct static_method_holder +{ + static std::string mode_to_string(std::ios_base::openmode mode) + { + static const int n_modes = 6; + static const std::ios_base::openmode mode_val_v[n_modes] = + { + std::ios_base::in, + std::ios_base::out, + std::ios_base::app, + std::ios_base::ate, + std::ios_base::trunc, + std::ios_base::binary + }; + + static const char * mode_name_v[n_modes] = + { + "in", + "out", + "app", + "ate", + "trunc", + "binary" + }; + std::string res; + for (int i = 0; i < n_modes; ++i) + { + if (mode & mode_val_v[i]) + { + res += (! res.empty()? "|" : ""); + res += mode_name_v[i]; + } + } + if (res.empty()) res = "none"; + return res; + } + static void check_mode(const std::string& filename, std::ios_base::openmode mode) + { + if ((mode & std::ios_base::trunc) && ! (mode & std::ios_base::out)) + { + throw Exception(std::string("strict_fstream: open('") + filename + "'): mode error: trunc and not out"); + } + else if ((mode & std::ios_base::app) && ! (mode & std::ios_base::out)) + { + throw Exception(std::string("strict_fstream: open('") + filename + "'): mode error: app and not out"); + } + else if ((mode & std::ios_base::trunc) && (mode & std::ios_base::app)) + { + throw Exception(std::string("strict_fstream: open('") + filename + "'): mode error: trunc and app"); + } + } + static void check_open(std::ios * s_p, const std::string& filename, std::ios_base::openmode mode) + { + if (s_p->fail()) + { + throw Exception(std::string("strict_fstream: open('") + + filename + "'," + mode_to_string(mode) + "): open failed: " + + strerror()); + } + } + static void check_peek(std::istream * is_p, const std::string& filename, std::ios_base::openmode mode) + { + bool peek_failed = true; + try + { + is_p->peek(); + peek_failed = is_p->fail(); + } + catch (const std::ios_base::failure &) {} + if (peek_failed) + { + throw Exception(std::string("strict_fstream: open('") + + filename + "'," + mode_to_string(mode) + "): peek failed: " + + strerror()); + } + is_p->clear(); + } +}; // struct static_method_holder + +} // namespace detail + +class ifstream + : public std::ifstream +{ +public: + ifstream() = default; + ifstream(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in) + { + open(filename, mode); + } + void open(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in) + { + mode |= std::ios_base::in; + exceptions(std::ios_base::badbit); + detail::static_method_holder::check_mode(filename, mode); + std::ifstream::open(filename, mode); + detail::static_method_holder::check_open(this, filename, mode); + detail::static_method_holder::check_peek(this, filename, mode); + } +}; // class ifstream + +class ofstream + : public std::ofstream +{ +public: + ofstream() = default; + ofstream(const std::string& filename, std::ios_base::openmode mode = std::ios_base::out) + { + open(filename, mode); + } + void open(const std::string& filename, std::ios_base::openmode mode = std::ios_base::out) + { + mode |= std::ios_base::out; + exceptions(std::ios_base::badbit); + detail::static_method_holder::check_mode(filename, mode); + std::ofstream::open(filename, mode); + detail::static_method_holder::check_open(this, filename, mode); + } +}; // class ofstream + +class fstream + : public std::fstream +{ +public: + fstream() = default; + fstream(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in) + { + open(filename, mode); + } + void open(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in) + { + if (! (mode & std::ios_base::out)) mode |= std::ios_base::in; + exceptions(std::ios_base::badbit); + detail::static_method_holder::check_mode(filename, mode); + std::fstream::open(filename, mode); + detail::static_method_holder::check_open(this, filename, mode); + detail::static_method_holder::check_peek(this, filename, mode); + } +}; // class fstream + +} // namespace strict_fstream + diff --git a/ext/zstr/src/zstr.hpp b/ext/zstr/src/zstr.hpp new file mode 100644 index 00000000..bd330ea1 --- /dev/null +++ b/ext/zstr/src/zstr.hpp @@ -0,0 +1,502 @@ +//--------------------------------------------------------- +// Copyright 2015 Ontario Institute for Cancer Research +// Written by Matei David (matei@cs.toronto.edu) +//--------------------------------------------------------- + +// Reference: +// http://stackoverflow.com/questions/14086417/how-to-write-custom-input-stream-in-c + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "strict_fstream.hpp" + +#if defined(__GNUC__) && !defined(__clang__) +#if (__GNUC__ > 5) || (__GNUC__ == 5 && __GNUC_MINOR__>0) +#define CAN_MOVE_IOSTREAM +#endif +#else +#define CAN_MOVE_IOSTREAM +#endif + +namespace zstr +{ + +static const std::size_t default_buff_size = static_cast(1 << 20); + +/// Exception class thrown by failed zlib operations. +class Exception + : public std::ios_base::failure +{ +public: + static std::string error_to_message(z_stream * zstrm_p, int ret) + { + std::string msg = "zlib: "; + switch (ret) + { + case Z_STREAM_ERROR: + msg += "Z_STREAM_ERROR: "; + break; + case Z_DATA_ERROR: + msg += "Z_DATA_ERROR: "; + break; + case Z_MEM_ERROR: + msg += "Z_MEM_ERROR: "; + break; + case Z_VERSION_ERROR: + msg += "Z_VERSION_ERROR: "; + break; + case Z_BUF_ERROR: + msg += "Z_BUF_ERROR: "; + break; + default: + std::ostringstream oss; + oss << ret; + msg += "[" + oss.str() + "]: "; + break; + } + if (zstrm_p->msg) { + msg += zstrm_p->msg; + } + msg += " (" + "next_in: " + + std::to_string(uintptr_t(zstrm_p->next_in)) + + ", avail_in: " + + std::to_string(uintptr_t(zstrm_p->avail_in)) + + ", next_out: " + + std::to_string(uintptr_t(zstrm_p->next_out)) + + ", avail_out: " + + std::to_string(uintptr_t(zstrm_p->avail_out)) + + ")"; + return msg; + } + + Exception(z_stream * zstrm_p, int ret) + : std::ios_base::failure(error_to_message(zstrm_p, ret)) + { + } +}; // class Exception + +namespace detail +{ + +class z_stream_wrapper + : public z_stream +{ +public: + z_stream_wrapper(bool _is_input, int _level, int _window_bits) + : is_input(_is_input) + { + this->zalloc = nullptr;//Z_NULL + this->zfree = nullptr;//Z_NULL + this->opaque = nullptr;//Z_NULL + int ret; + if (is_input) + { + this->avail_in = 0; + this->next_in = nullptr;//Z_NULL + ret = inflateInit2(this, _window_bits ? _window_bits : 15+32); + } + else + { + ret = deflateInit2(this, _level, Z_DEFLATED, _window_bits ? _window_bits : 15+16, 8, Z_DEFAULT_STRATEGY); + } + if (ret != Z_OK) throw Exception(this, ret); + } + ~z_stream_wrapper() + { + if (is_input) + { + inflateEnd(this); + } + else + { + deflateEnd(this); + } + } +private: + bool is_input; +}; // class z_stream_wrapper + +} // namespace detail + +class istreambuf + : public std::streambuf +{ +public: + istreambuf(std::streambuf * _sbuf_p, + std::size_t _buff_size = default_buff_size, bool _auto_detect = true, int _window_bits = 0) + : sbuf_p(_sbuf_p), + in_buff(), + in_buff_start(nullptr), + in_buff_end(nullptr), + out_buff(), + zstrm_p(nullptr), + buff_size(_buff_size), + auto_detect(_auto_detect), + auto_detect_run(false), + is_text(false), + window_bits(_window_bits) + { + assert(sbuf_p); + in_buff = std::unique_ptr(new char[buff_size]); + in_buff_start = in_buff.get(); + in_buff_end = in_buff.get(); + out_buff = std::unique_ptr(new char[buff_size]); + setg(out_buff.get(), out_buff.get(), out_buff.get()); + } + + istreambuf(const istreambuf &) = delete; + istreambuf & operator = (const istreambuf &) = delete; + + pos_type seekoff(off_type off, std::ios_base::seekdir dir, + std::ios_base::openmode which) override + { + if (off != 0 || dir != std::ios_base::cur) { + return std::streambuf::seekoff(off, dir, which); + } + + if (!zstrm_p) { + return 0; + } + + return static_cast(zstrm_p->total_out - static_cast(in_avail())); + } + + std::streambuf::int_type underflow() override + { + if (this->gptr() == this->egptr()) + { + // pointers for free region in output buffer + char * out_buff_free_start = out_buff.get(); + int tries = 0; + do + { + if (++tries > 1000) { + throw std::ios_base::failure("Failed to fill buffer after 1000 tries"); + } + + // read more input if none available + if (in_buff_start == in_buff_end) + { + // empty input buffer: refill from the start + in_buff_start = in_buff.get(); + std::streamsize sz = sbuf_p->sgetn(in_buff.get(), static_cast(buff_size)); + in_buff_end = in_buff_start + sz; + if (in_buff_end == in_buff_start) break; // end of input + } + // auto detect if the stream contains text or deflate data + if (auto_detect && ! auto_detect_run) + { + auto_detect_run = true; + unsigned char b0 = *reinterpret_cast< unsigned char * >(in_buff_start); + unsigned char b1 = *reinterpret_cast< unsigned char * >(in_buff_start + 1); + // Ref: + // http://en.wikipedia.org/wiki/Gzip + // http://stackoverflow.com/questions/9050260/what-does-a-zlib-header-look-like + is_text = ! (in_buff_start + 2 <= in_buff_end + && ((b0 == 0x1F && b1 == 0x8B) // gzip header + || (b0 == 0x78 && (b1 == 0x01 // zlib header + || b1 == 0x9C + || b1 == 0xDA)))); + } + if (is_text) + { + // simply swap in_buff and out_buff, and adjust pointers + assert(in_buff_start == in_buff.get()); + std::swap(in_buff, out_buff); + out_buff_free_start = in_buff_end; + in_buff_start = in_buff.get(); + in_buff_end = in_buff.get(); + } + else + { + // run inflate() on input + if (! zstrm_p) zstrm_p = std::unique_ptr(new detail::z_stream_wrapper(true, Z_DEFAULT_COMPRESSION, window_bits)); + zstrm_p->next_in = reinterpret_cast< decltype(zstrm_p->next_in) >(in_buff_start); + zstrm_p->avail_in = uint32_t(in_buff_end - in_buff_start); + zstrm_p->next_out = reinterpret_cast< decltype(zstrm_p->next_out) >(out_buff_free_start); + zstrm_p->avail_out = uint32_t((out_buff.get() + buff_size) - out_buff_free_start); + int ret = inflate(zstrm_p.get(), Z_NO_FLUSH); + // process return code + if (ret != Z_OK && ret != Z_STREAM_END) throw Exception(zstrm_p.get(), ret); + // update in&out pointers following inflate() + in_buff_start = reinterpret_cast< decltype(in_buff_start) >(zstrm_p->next_in); + in_buff_end = in_buff_start + zstrm_p->avail_in; + out_buff_free_start = reinterpret_cast< decltype(out_buff_free_start) >(zstrm_p->next_out); + assert(out_buff_free_start + zstrm_p->avail_out == out_buff.get() + buff_size); + + if (ret == Z_STREAM_END) { + // if stream ended, deallocate inflator + zstrm_p.reset(); + } + } + } while (out_buff_free_start == out_buff.get()); + // 2 exit conditions: + // - end of input: there might or might not be output available + // - out_buff_free_start != out_buff: output available + this->setg(out_buff.get(), out_buff.get(), out_buff_free_start); + } + return this->gptr() == this->egptr() + ? traits_type::eof() + : traits_type::to_int_type(*this->gptr()); + } +private: + std::streambuf * sbuf_p; + std::unique_ptr in_buff; + char * in_buff_start; + char * in_buff_end; + std::unique_ptr out_buff; + std::unique_ptr zstrm_p; + std::size_t buff_size; + bool auto_detect; + bool auto_detect_run; + bool is_text; + int window_bits; + +}; // class istreambuf + +class ostreambuf + : public std::streambuf +{ +public: + ostreambuf(std::streambuf * _sbuf_p, + std::size_t _buff_size = default_buff_size, int _level = Z_DEFAULT_COMPRESSION, int _window_bits = 0) + : sbuf_p(_sbuf_p), + in_buff(), + out_buff(), + zstrm_p(new detail::z_stream_wrapper(false, _level, _window_bits)), + buff_size(_buff_size) + { + assert(sbuf_p); + in_buff = std::unique_ptr(new char[buff_size]); + out_buff = std::unique_ptr(new char[buff_size]); + setp(in_buff.get(), in_buff.get() + buff_size); + } + + ostreambuf(const ostreambuf &) = delete; + ostreambuf & operator = (const ostreambuf &) = delete; + + int deflate_loop(int flush) + { + while (true) + { + zstrm_p->next_out = reinterpret_cast< decltype(zstrm_p->next_out) >(out_buff.get()); + zstrm_p->avail_out = uint32_t(buff_size); + int ret = deflate(zstrm_p.get(), flush); + if (ret != Z_OK && ret != Z_STREAM_END && ret != Z_BUF_ERROR) { + failed = true; + throw Exception(zstrm_p.get(), ret); + } + std::streamsize sz = sbuf_p->sputn(out_buff.get(), reinterpret_cast< decltype(out_buff.get()) >(zstrm_p->next_out) - out_buff.get()); + if (sz != reinterpret_cast< decltype(out_buff.get()) >(zstrm_p->next_out) - out_buff.get()) + { + // there was an error in the sink stream + return -1; + } + if (ret == Z_STREAM_END || ret == Z_BUF_ERROR || sz == 0) + { + break; + } + } + return 0; + } + + virtual ~ostreambuf() + { + // flush the zlib stream + // + // NOTE: Errors here (sync() return value not 0) are ignored, because we + // cannot throw in a destructor. This mirrors the behaviour of + // std::basic_filebuf::~basic_filebuf(). To see an exception on error, + // close the ofstream with an explicit call to close(), and do not rely + // on the implicit call in the destructor. + // + if (!failed) try { + sync(); + } catch (...) {} + } + std::streambuf::int_type overflow(std::streambuf::int_type c = traits_type::eof()) override + { + zstrm_p->next_in = reinterpret_cast< decltype(zstrm_p->next_in) >(pbase()); + zstrm_p->avail_in = uint32_t(pptr() - pbase()); + while (zstrm_p->avail_in > 0) + { + int r = deflate_loop(Z_NO_FLUSH); + if (r != 0) + { + setp(nullptr, nullptr); + return traits_type::eof(); + } + } + setp(in_buff.get(), in_buff.get() + buff_size); + return traits_type::eq_int_type(c, traits_type::eof()) ? traits_type::eof() : sputc(char_type(c)); + } + int sync() override + { + // first, call overflow to clear in_buff + overflow(); + if (! pptr()) return -1; + // then, call deflate asking to finish the zlib stream + zstrm_p->next_in = nullptr; + zstrm_p->avail_in = 0; + if (deflate_loop(Z_FINISH) != 0) return -1; + deflateReset(zstrm_p.get()); + return 0; + } +private: + std::streambuf * sbuf_p = nullptr; + std::unique_ptr in_buff; + std::unique_ptr out_buff; + std::unique_ptr zstrm_p; + std::size_t buff_size; + bool failed = false; + +}; // class ostreambuf + +class istream + : public std::istream +{ +public: + istream(std::istream & is, + std::size_t _buff_size = default_buff_size, bool _auto_detect = true, int _window_bits = 0) + : std::istream(new istreambuf(is.rdbuf(), _buff_size, _auto_detect, _window_bits)) + { + exceptions(std::ios_base::badbit); + } + explicit istream(std::streambuf * sbuf_p) + : std::istream(new istreambuf(sbuf_p)) + { + exceptions(std::ios_base::badbit); + } + virtual ~istream() + { + delete rdbuf(); + } +}; // class istream + +class ostream + : public std::ostream +{ +public: + ostream(std::ostream & os, + std::size_t _buff_size = default_buff_size, int _level = Z_DEFAULT_COMPRESSION, int _window_bits = 0) + : std::ostream(new ostreambuf(os.rdbuf(), _buff_size, _level, _window_bits)) + { + exceptions(std::ios_base::badbit); + } + explicit ostream(std::streambuf * sbuf_p) + : std::ostream(new ostreambuf(sbuf_p)) + { + exceptions(std::ios_base::badbit); + } + virtual ~ostream() + { + delete rdbuf(); + } +}; // class ostream + +namespace detail +{ + +template < typename FStream_Type > +struct strict_fstream_holder +{ + strict_fstream_holder(const std::string& filename, std::ios_base::openmode mode = std::ios_base::in) + : _fs(filename, mode) + {} + strict_fstream_holder() = default; + FStream_Type _fs {}; +}; // class strict_fstream_holder + +} // namespace detail + +class ifstream + : private detail::strict_fstream_holder< strict_fstream::ifstream >, + public std::istream +{ +public: + explicit ifstream(const std::string filename, std::ios_base::openmode mode = std::ios_base::in, size_t buff_size = default_buff_size) + : detail::strict_fstream_holder< strict_fstream::ifstream >(filename, mode), + std::istream(new istreambuf(_fs.rdbuf(), buff_size)) + { + exceptions(std::ios_base::badbit); + } + explicit ifstream(): detail::strict_fstream_holder< strict_fstream::ifstream >(), std::istream(new istreambuf(_fs.rdbuf())){} + void close() { + _fs.close(); + } + #ifdef CAN_MOVE_IOSTREAM + void open(const std::string filename, std::ios_base::openmode mode = std::ios_base::in) { + _fs.open(filename, mode); + std::istream::operator=(std::istream(new istreambuf(_fs.rdbuf()))); + } + #endif + bool is_open() const { + return _fs.is_open(); + } + virtual ~ifstream() + { + if (_fs.is_open()) close(); + if (rdbuf()) delete rdbuf(); + } + + /// Return the position within the compressed file (wrapped filestream) + std::streampos compressed_tellg() + { + return _fs.tellg(); + } +}; // class ifstream + +class ofstream + : private detail::strict_fstream_holder< strict_fstream::ofstream >, + public std::ostream +{ +public: + explicit ofstream(const std::string filename, std::ios_base::openmode mode = std::ios_base::out, + int level = Z_DEFAULT_COMPRESSION, size_t buff_size = default_buff_size) + : detail::strict_fstream_holder< strict_fstream::ofstream >(filename, mode | std::ios_base::binary), + std::ostream(new ostreambuf(_fs.rdbuf(), buff_size, level)) + { + exceptions(std::ios_base::badbit); + } + explicit ofstream(): detail::strict_fstream_holder< strict_fstream::ofstream >(), std::ostream(new ostreambuf(_fs.rdbuf())){} + void close() { + std::ostream::flush(); + _fs.close(); + } + #ifdef CAN_MOVE_IOSTREAM + void open(const std::string filename, std::ios_base::openmode mode = std::ios_base::out, int level = Z_DEFAULT_COMPRESSION) { + flush(); + _fs.open(filename, mode | std::ios_base::binary); + std::ostream::operator=(std::ostream(new ostreambuf(_fs.rdbuf(), default_buff_size, level))); + } + #endif + bool is_open() const { + return _fs.is_open(); + } + ofstream& flush() { + std::ostream::flush(); + _fs.flush(); + return *this; + } + virtual ~ofstream() + { + if (_fs.is_open()) close(); + if (rdbuf()) delete rdbuf(); + } + + // Return the position within the compressed file (wrapped filestream) + std::streampos compressed_tellp() + { + return _fs.tellp(); + } +}; // class ofstream + +} // namespace zstr + From e9c453ccef9e09e30bee62ceb8585af4f1dfb046 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Thu, 7 Dec 2023 10:49:15 +0100 Subject: [PATCH 06/26] Use poolstl to sort randstrobes in parallel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/alugowski/poolSTL Sorting the randstrobes is currently a bottleneck in index generation as it does not run in parallel. This is an attempt to parallelize it. poolstl’s sort uses regular std::sort under the hood. We currently use pdqsort_branchless, which is about twice as fast as std::sort, so parallel sorting breaks even with pdqsort_branchless at about 2-3 threads. It gets faster with more threads, but not as much as one would perhaps expect. Here are the sorting runtimes for CHM13: - 31 s with pdqsort_branchless - 59 s with std::sort - 34 s with parallel sort, 2 threads - 24 s with parallel sort, 4 threads - 23 s with parallel sort, 8 threads Another issue is that sorting is no longer in place, so memory usage goes up by a couple of gigabytes, which is another reason for me not to make this change. --- ext/poolstl/poolstl.hpp | 1697 +++++++++++++++++++++++++++++++++++++++ src/index.cpp | 12 +- src/index.hpp | 2 +- 3 files changed, 1707 insertions(+), 4 deletions(-) create mode 100644 ext/poolstl/poolstl.hpp diff --git a/ext/poolstl/poolstl.hpp b/ext/poolstl/poolstl.hpp new file mode 100644 index 00000000..ea79146e --- /dev/null +++ b/ext/poolstl/poolstl.hpp @@ -0,0 +1,1697 @@ +// SPDX-License-Identifier: BSD-2-Clause OR MIT OR BSL-1.0 +/** + * @brief Thread pool-based implementation of parallel standard library algorithms. Single-file version. + * @see https://github.com/alugowski/poolSTL + * @author Adam Lugowski + * @copyright Copyright (C) 2023 Adam Lugowski. + * Licensed under any of the following open-source licenses: + * BSD-2-Clause license, MIT license, Boost Software License 1.0 + * + * + * BSD-2-Clause license: + * + * Copyright (C) 2023 Adam Lugowski + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF + * THE POSSIBILITY OF SUCH DAMAGE. + * + * + * + * MIT License: + * + * Copyright (c) 2023 Adam Lugowski + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * + * + * Boost Software License 1.0: + * + * Permission is hereby granted, free of charge, to any person or organization + * obtaining a copy of the software and accompanying documentation covered by + * this license (the "Software") to use, reproduce, display, distribute, execute, + * and transmit the Software, and to prepare derivative works of the Software, + * and to permit third-parties to whom the Software is furnished to do so, + * all subject to the following: + * + * The copyright notices in the Software and this entire statement, including + * the above license grant, this restriction and the following disclaimer, must + * be included in all copies of the Software, in whole or in part, and all + * derivative works of the Software, unless such copies or derivative works + * are solely in the form of machine-executable object code generated by a + * source language processor. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#ifndef POOLSTL_HPP +#define POOLSTL_HPP + + +#ifndef POOLSTL_EXECUTION_HPP +#define POOLSTL_EXECUTION_HPP + +#include +#include +#include + + +#ifndef AL_TASK_THREAD_POOL_HPP +#define AL_TASK_THREAD_POOL_HPP + +// Version macros. +#define TASK_THREAD_POOL_VERSION_MAJOR 1 +#define TASK_THREAD_POOL_VERSION_MINOR 0 +#define TASK_THREAD_POOL_VERSION_PATCH 9 + +#include +#include +#include +#include +#include +#include +#include + +// MSVC does not correctly set the __cplusplus macro by default, so we must read it from _MSVC_LANG +// See https://devblogs.microsoft.com/cppblog/msvc-now-correctly-reports-__cplusplus/ +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define TTP_CXX17 1 +#else +#define TTP_CXX17 0 +#endif + +#if TTP_CXX17 +#define TTP_NODISCARD [[nodiscard]] +#else +#define TTP_NODISCARD +#endif + +namespace task_thread_pool { + +#if !TTP_CXX17 + /** + * A reimplementation of std::decay_t, which is only available since C++14. + */ + template + using decay_t = typename std::decay::type; +#endif + + /** + * A fast and lightweight thread pool that uses C++11 threads. + */ + class task_thread_pool { + public: + /** + * Create a task_thread_pool and start worker threads. + * + * @param num_threads Number of worker threads. If 0 then number of threads is equal to the + * number of physical cores on the machine, as given by std::thread::hardware_concurrency(). + */ + explicit task_thread_pool(unsigned int num_threads = 0) { + if (num_threads < 1) { + num_threads = std::thread::hardware_concurrency(); + if (num_threads < 1) { num_threads = 1; } + } + start_threads(num_threads); + } + + /** + * Finish all tasks left in the queue then shut down worker threads. + * If the pool is currently paused then it is resumed. + */ + ~task_thread_pool() { + unpause(); + wait_for_queued_tasks(); + stop_all_threads(); + } + + /** + * Drop all tasks that have been submitted but not yet started by a worker. + * + * Tasks already in progress continue executing. + */ + void clear_task_queue() { + const std::lock_guard tasks_lock(task_mutex); + tasks = {}; + } + + /** + * Get number of enqueued tasks. + * + * @return Number of tasks that have been enqueued but not yet started. + */ + TTP_NODISCARD size_t get_num_queued_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return tasks.size(); + } + + /** + * Get number of in-progress tasks. + * + * @return Approximate number of tasks currently being processed by worker threads. + */ + TTP_NODISCARD size_t get_num_running_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return num_inflight_tasks; + } + + /** + * Get total number of tasks in the pool. + * + * @return Approximate number of tasks both enqueued and running. + */ + TTP_NODISCARD size_t get_num_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return tasks.size() + num_inflight_tasks; + } + + /** + * Get number of worker threads. + * + * @return Number of worker threads. + */ + TTP_NODISCARD unsigned int get_num_threads() const { + const std::lock_guard threads_lock(thread_mutex); + return static_cast(threads.size()); + } + + /** + * Set number of worker threads. Will start or stop worker threads as necessary. + * + * @param num_threads Number of worker threads. If 0 then number of threads is equal to the + * number of physical cores on the machine, as given by std::thread::hardware_concurrency(). + * @return Previous number of worker threads. + */ + unsigned int set_num_threads(unsigned int num_threads) { + const std::lock_guard threads_lock(thread_mutex); + unsigned int previous_num_threads = get_num_threads(); + + if (num_threads < 1) { + num_threads = std::thread::hardware_concurrency(); + if (num_threads < 1) { num_threads = 1; } + } + + if (previous_num_threads <= num_threads) { + // expanding the thread pool + start_threads(num_threads - previous_num_threads); + } else { + // contracting the thread pool + stop_all_threads(); + { + const std::lock_guard tasks_lock(task_mutex); + pool_running = true; + } + start_threads(num_threads); + } + + return previous_num_threads; + } + + /** + * Stop executing queued tasks. Use `unpause()` to resume. Note: Destroying the pool will implicitly unpause. + * + * Any in-progress tasks continue executing. + */ + void pause() { + const std::lock_guard tasks_lock(task_mutex); + pool_paused = true; + } + + /** + * Resume executing queued tasks. + */ + void unpause() { + const std::lock_guard tasks_lock(task_mutex); + pool_paused = false; + task_cv.notify_all(); + } + + /** + * Check whether the pool is paused. + * + * @return true if pause() has been called without an intervening unpause(). + */ + TTP_NODISCARD bool is_paused() const { + const std::lock_guard tasks_lock(task_mutex); + return pool_paused; + } + + /** + * Submit a Callable for the pool to execute and return a std::future. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + * @param args Arguments for func. Optional. + * @return std::future that can be used to get func's return value or thrown exception. + */ + template , std::decay_t...> +#else + typename R = typename std::result_of(decay_t...)>::type +#endif + > + TTP_NODISCARD std::future submit(F&& func, A&&... args) { + std::shared_ptr> ptask = + std::make_shared>(std::bind(std::forward(func), std::forward(args)...)); + submit_detach([ptask] { (*ptask)(); }); + return ptask->get_future(); + } + + /** + * Submit a zero-argument Callable for the pool to execute. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + */ + template + void submit_detach(F&& func) { + const std::lock_guard tasks_lock(task_mutex); + tasks.emplace(std::forward(func)); + task_cv.notify_one(); + } + + /** + * Submit a Callable with arguments for the pool to execute. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + */ + template + void submit_detach(F&& func, A&&... args) { + const std::lock_guard tasks_lock(task_mutex); + tasks.emplace(std::bind(std::forward(func), std::forward(args)...)); + task_cv.notify_one(); + } + + /** + * Block until the task queue is empty. Some tasks may be in-progress when this method returns. + */ + void wait_for_queued_tasks() { + std::unique_lock tasks_lock(task_mutex); + notify_task_finish = true; + task_finished_cv.wait(tasks_lock, [&] { return tasks.empty(); }); + notify_task_finish = false; + } + + /** + * Block until all tasks have finished. + */ + void wait_for_tasks() { + std::unique_lock tasks_lock(task_mutex); + notify_task_finish = true; + task_finished_cv.wait(tasks_lock, [&] { return tasks.empty() && num_inflight_tasks == 0; }); + notify_task_finish = false; + } + + protected: + + /** + * Main function for worker threads. + */ + void worker_main() { + bool finished_task = false; + + while (true) { + std::unique_lock tasks_lock(task_mutex); + + if (finished_task) { + --num_inflight_tasks; + if (notify_task_finish) { + task_finished_cv.notify_all(); + } + } + + task_cv.wait(tasks_lock, [&]() { return !pool_running || (!pool_paused && !tasks.empty()); }); + + if (!pool_running) { + break; + } + + // Must mean that (!pool_paused && !tasks.empty()) is true + + std::packaged_task task{std::move(tasks.front())}; + tasks.pop(); + ++num_inflight_tasks; + tasks_lock.unlock(); + + try { + task(); + } catch (...) { + // std::packaged_task::operator() may throw in some error conditions, such as if the task + // had already been run. Nothing that the pool can do anything about. + } + + finished_task = true; + } + } + + /** + * Start worker threads. + * + * @param num_threads How many threads to start. + */ + void start_threads(const unsigned int num_threads) { + const std::lock_guard threads_lock(thread_mutex); + + for (unsigned int i = 0; i < num_threads; ++i) { + threads.emplace_back(&task_thread_pool::worker_main, this); + } + } + + /** + * Stop, join, and destroy all worker threads. + */ + void stop_all_threads() { + const std::lock_guard threads_lock(thread_mutex); + + { + const std::lock_guard tasks_lock(task_mutex); + pool_running = false; + task_cv.notify_all(); + } + + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + threads.clear(); + } + + /** + * The worker threads. + * + * Access protected by thread_mutex + */ + std::vector threads; + + /** + * A mutex for methods that start/stop threads. + */ + mutable std::recursive_mutex thread_mutex; + + /** + * The task queue. + * + * Access protected by task_mutex. + */ + std::queue> tasks = {}; + + /** + * A mutex for all variables related to tasks. + */ + mutable std::mutex task_mutex; + + /** + * Used to notify changes to the task queue, such as a new task added, pause/unpause, etc. + */ + std::condition_variable task_cv; + + /** + * Used to notify of finished tasks. + */ + std::condition_variable task_finished_cv; + + /** + * A signal for worker threads that the pool is either running or shutting down. + * + * Access protected by task_mutex. + */ + bool pool_running = true; + + /** + * A signal for worker threads to not pull new tasks from the queue. + * + * Access protected by task_mutex. + */ + bool pool_paused = false; + + /** + * A signal for worker threads that they should notify task_finished_cv when they finish a task. + * + * Access protected by task_mutex. + */ + bool notify_task_finish = false; + + /** + * A counter of the number of tasks in-progress by worker threads. + * Incremented when a task is popped off the task queue and decremented when that task is complete. + * + * Access protected by task_mutex. + */ + int num_inflight_tasks = 0; + }; +} + +// clean up +#undef TTP_NODISCARD +#undef TTP_CXX17 + +#endif + +#ifndef POOLSTL_INTERNAL_UTILS_HPP +#define POOLSTL_INTERNAL_UTILS_HPP + +// Version macros. +#define POOLSTL_VERSION_MAJOR 0 +#define POOLSTL_VERSION_MINOR 3 +#define POOLSTL_VERSION_PATCH 1 + +#include +#include + +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define POOLSTL_HAVE_CXX17 1 +#define POOLSTL_NO_DISCARD [[nodiscard]] +#else +#define POOLSTL_HAVE_CXX17 0 +#define POOLSTL_NO_DISCARD +#endif + +#if POOLSTL_HAVE_CXX17 && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 9) +#define POOLSTL_HAVE_CXX17_LIB 1 +#else +#define POOLSTL_HAVE_CXX17_LIB 0 +#endif + +#if __cplusplus >= 201402L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +#define POOLSTL_HAVE_CXX14 1 +#else +#define POOLSTL_HAVE_CXX14 0 +#endif + +namespace poolstl { + namespace internal { + + inline constexpr std::size_t get_chunk_size(std::size_t num_steps, unsigned int num_threads) { + return (num_steps / num_threads) + ((num_steps % num_threads) > 0 ? 1 : 0); + } + + template + constexpr typename std::iterator_traits::difference_type + get_chunk_size(Iterator first, Iterator last, unsigned int num_threads) { + using diff_t = typename std::iterator_traits::difference_type; + return static_cast(get_chunk_size((std::size_t)std::distance(first, last), num_threads)); + } + + template + constexpr typename std::iterator_traits::difference_type + get_iter_chunk_size(const Iterator& iter, const Iterator& last, + typename std::iterator_traits::difference_type chunk_size) { + return std::min(chunk_size, std::distance(iter, last)); + } + + template + Iterator advanced(Iterator iter, typename std::iterator_traits::difference_type offset) { + Iterator ret = iter; + std::advance(ret, offset); + return ret; + } + + /** + * An iterator wrapper that calls std::future<>::get(). + * @tparam Iterator + */ + template + class getting_iter : public Iterator { + public: + using value_type = decltype((*std::declval()).get()); + using difference_type = typename std::iterator_traits::difference_type; + using pointer = value_type*; + using reference = value_type&; + explicit getting_iter(Iterator iter) : iter(iter) {} + + getting_iter operator++() { ++iter; return *this; } + getting_iter operator++(int) { getting_iter ret(*this); ++iter; return ret; } + + value_type operator*() { return (*iter).get(); } + value_type operator[](difference_type offset) { return iter[offset].get(); } + + bool operator==(const getting_iter &other) const { return iter == other.iter; } + bool operator!=(const getting_iter &other) const { return iter != other.iter; } + + protected: + Iterator iter; + }; + + template + getting_iter get_wrap(Iterator iter) { + return getting_iter(iter); + } + + template + void get_futures(Container& futures) { + for (auto &future: futures) { + future.get(); + } + } + + /* + * Some methods are only available with C++17 and up. Reimplement on older standards. + */ +#if POOLSTL_HAVE_CXX17_LIB + namespace cpp17 = std; +#else + namespace cpp17 { + + // std::reduce + + template + Tp reduce(InputIt first, InputIt last, Tp init, BinOp b) { + for (; first != last; ++first) + init = b(init, *first); + return init; + } + + template + typename std::iterator_traits::value_type reduce(InputIt first, InputIt last) { + return reduce(first, last, + typename std::iterator_traits::value_type{}, + std::plus::value_type>()); + } + + // std::transform + + template + OutputIt transform(InputIt first1, InputIt last1, OutputIt d_first, + UnaryOperation unary_op) { + while (first1 != last1) { + *d_first++ = unary_op(*first1++); + } + + return d_first; + } + + template + OutputIt transform(InputIt1 first1, InputIt1 last1, + InputIt2 first2, OutputIt d_first, + BinaryOperation binary_op) { + while (first1 != last1) { + *d_first++ = binary_op(*first1++, *first2++); + } + + return d_first; + } + } +#endif + } +} + +#endif + +#if POOLSTL_HAVE_CXX17 +#include +#endif + +namespace poolstl { + + namespace ttp = task_thread_pool; + + namespace execution { + namespace internal { + /** + * Holds the thread pool used by par. + */ + inline std::shared_ptr get_default_pool() { + static std::shared_ptr pool; + static std::once_flag flag; + std::call_once(flag, [&](){ pool = std::make_shared(); }); + return pool; + } + } + + /** + * A sequential policy that simply forwards to the non-policy overload. + */ + struct sequenced_policy {}; + + /** + * A parallel policy that can use a user-specified thread pool or a default one. + */ + struct parallel_policy { + parallel_policy() = default; + explicit parallel_policy(ttp::task_thread_pool& on_pool): on_pool(&on_pool) {} + + parallel_policy on(ttp::task_thread_pool& pool) const { + return parallel_policy{pool}; + } + + POOLSTL_NO_DISCARD ttp::task_thread_pool& pool() const { + if (on_pool) { + return *on_pool; + } else { + return *(internal::get_default_pool()); + } + } + + protected: + ttp::task_thread_pool *on_pool = nullptr; + }; + + constexpr sequenced_policy seq{}; + constexpr parallel_policy par{}; + + +#if POOLSTL_HAVE_CXX17 + /** + * A policy that allows selecting a policy at runtime. + * + * @tparam Variant std::variant<> of policy options. + */ + template + struct variant_policy { + explicit variant_policy(const Variant& policy): var(policy) {} + Variant var; + }; + + namespace internal { + using poolstl_policy_variant = std::variant< + poolstl::execution::parallel_policy, + poolstl::execution::sequenced_policy>; + } + + /** + * Choose parallel or sequential at runtime. + * + * @param call_par Whether to use a parallel policy. + * @return `par` if call_par is true, else `seq`. + */ + inline variant_policy par_if(bool call_par) { + if (call_par) { + return variant_policy(internal::poolstl_policy_variant(par)); + } else { + return variant_policy(internal::poolstl_policy_variant(seq)); + } + } + + /** + * Choose parallel or sequential at runtime, with pool selection. + * + * @param call_par Whether to use a parallel policy. + * @return `par.on(pool)` if call_par is true, else `seq`. + */ + inline variant_policy par_if(bool call_par, ttp::task_thread_pool& pool) { + if (call_par) { + return variant_policy(internal::poolstl_policy_variant(par.on(pool))); + } else { + return variant_policy(internal::poolstl_policy_variant(seq)); + } + } +#endif + } + + using execution::seq; + using execution::par; +#if POOLSTL_HAVE_CXX17 + using execution::variant_policy; + using execution::par_if; +#endif + + namespace internal { + /** + * To enable/disable seq overload resolution + */ + template + using enable_if_seq = + typename std::enable_if< + std::is_same::type>::type>::value, + Tp>::type; + + /** + * To enable/disable par overload resolution + */ + template + using enable_if_par = + typename std::enable_if< + std::is_same::type>::type>::value, + Tp>::type; + +#if POOLSTL_HAVE_CXX17 + /** + * Helper for enable_if_poolstl_variant + */ + template struct is_poolstl_variant_policy : std::false_type {}; + template struct is_poolstl_variant_policy< + ::poolstl::execution::variant_policy> :std::true_type {}; + + /** + * To enable/disable variant_policy (for par_if) overload resolution + */ + template + using enable_if_poolstl_variant = + typename std::enable_if< + is_poolstl_variant_policy< + typename std::remove_cv::type>::type>::value, + Tp>::type; +#endif + } +} + +#endif + +#ifndef POOLSTL_ALGORITHM_HPP +#define POOLSTL_ALGORITHM_HPP + +#include + + +#ifndef POOLSTL_INTERNAL_TTP_IMPL_HPP +#define POOLSTL_INTERNAL_TTP_IMPL_HPP + +#include +#include +#include +#include + + +namespace poolstl { + namespace internal { + +#if POOLSTL_HAVE_CXX17_LIB + /** + * Call std::apply in parallel. + */ + template + std::vector> + parallel_apply(ExecPolicy &&policy, Op op, const ArgContainer& args_list) { + std::vector> futures; + auto& task_pool = policy.pool(); + + for (const auto& args : args_list) { + futures.emplace_back(task_pool.submit([op](const auto& args_fwd) { std::apply(op, args_fwd); }, args)); + } + + return futures; + } +#endif + + /** + * Chunk a single range. + */ + template + std::vector()(std::declval(), std::declval()))>> + parallel_chunk_for(ExecPolicy &&policy, RandIt first, RandIt last, Chunk chunk, int extra_split_factor = 1) { + std::vector()(std::declval(), std::declval())) + >> futures; + auto& task_pool = policy.pool(); + auto chunk_size = get_chunk_size(first, last, extra_split_factor * task_pool.get_num_threads()); + + while (first < last) { + auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); + RandIt loop_end = advanced(first, iter_chunk_size); + + futures.emplace_back(task_pool.submit(chunk, first, loop_end)); + + first = loop_end; + } + + return futures; + } + + /** + * Element-wise chunk two ranges. + */ + template + std::vector()( + std::declval(), + std::declval(), + std::declval()))>> + parallel_chunk_for(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, Chunk chunk) { + std::vector()( + std::declval(), + std::declval(), + std::declval())) + >> futures; + auto& task_pool = policy.pool(); + auto chunk_size = get_chunk_size(first1, last1, task_pool.get_num_threads()); + + while (first1 < last1) { + auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); + RandIt1 loop_end = advanced(first1, iter_chunk_size); + + futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2)); + + first1 = loop_end; + std::advance(first2, iter_chunk_size); + } + + return futures; + } + + /** + * Element-wise chunk three ranges. + */ + template + std::vector()( + std::declval(), + std::declval(), + std::declval(), + std::declval()))>> + parallel_chunk_for(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, RandIt3 first3, + Chunk chunk) { + std::vector()( + std::declval(), + std::declval(), + std::declval(), + std::declval())) + >> futures; + auto& task_pool = policy.pool(); + auto chunk_size = get_chunk_size(first1, last1, task_pool.get_num_threads()); + + while (first1 < last1) { + auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); + RandIt1 loop_end = advanced(first1, iter_chunk_size); + + futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2, first3)); + + first1 = loop_end; + std::advance(first2, iter_chunk_size); + std::advance(first3, iter_chunk_size); + } + + return futures; + } + + /** + * Sort a range in parallel. + * + * @param stable Whether to use std::stable_sort or std::sort + */ + template + void parallel_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, bool stable) { + if (first == last) { + return; + } + + // Sort chunks in parallel + auto futures = parallel_chunk_for(std::forward(policy), first, last, + [&comp, stable] (RandIt chunk_first, RandIt chunk_last) { + if (stable) { + std::stable_sort(chunk_first, chunk_last, comp); + } else { + std::sort(chunk_first, chunk_last, comp); + } + return std::make_pair(chunk_first, chunk_last); + }); + + // Merge the sorted ranges + using SortedRange = std::pair; + auto& task_pool = policy.pool(); + std::vector subranges; + do { + for (auto& future : futures) { + subranges.emplace_back(future.get()); + } + futures.clear(); + + for (std::size_t i = 0; i < subranges.size(); ++i) { + if (i + 1 < subranges.size()) { + // pair up and merge + auto& lhs = subranges[i]; + auto& rhs = subranges[i + 1]; + futures.emplace_back(task_pool.submit([&comp] (RandIt chunk_first, RandIt chunk_middle, + RandIt chunk_last) { + std::inplace_merge(chunk_first, chunk_middle, chunk_last, comp); + return std::make_pair(chunk_first, chunk_last); + }, lhs.first, lhs.second, rhs.second)); + ++i; + } else { + // forward the final extra range + std::promise p; + futures.emplace_back(p.get_future()); + p.set_value(subranges[i]); + } + } + + subranges.clear(); + } while (futures.size() > 1); + futures.front().get(); + } + } +} + +#endif + +namespace std { + + /** + * NOTE: Iterators are expected to be random access. + * See std::copy https://en.cppreference.com/w/cpp/algorithm/copy + */ + template + poolstl::internal::enable_if_par + copy(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, dest, + [](RandIt1 chunk_first, RandIt1 chunk_last, RandIt2 chunk_dest) { + std::copy(chunk_first, chunk_last, chunk_dest); + }); + poolstl::internal::get_futures(futures); + return poolstl::internal::advanced(dest, std::distance(first, last)); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::copy_n https://en.cppreference.com/w/cpp/algorithm/copy_n + */ + template + poolstl::internal::enable_if_par + copy_n(ExecPolicy &&policy, RandIt1 first, Size n, RandIt2 dest) { + if (n <= 0) { + return dest; + } + RandIt1 last = poolstl::internal::advanced(first, n); + std::copy(std::forward(policy), first, last, dest); + return poolstl::internal::advanced(dest, n); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::count_if https://en.cppreference.com/w/cpp/algorithm/count_if + */ + template + poolstl::internal::enable_if_par::difference_type> + count_if(ExecPolicy&& policy, RandIt first, RandIt last, UnaryPredicate p) { + using T = typename iterator_traits::difference_type; + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&p](RandIt chunk_first, RandIt chunk_last) { + return std::count_if(chunk_first, chunk_last, p); + }); + + return poolstl::internal::cpp17::reduce( + poolstl::internal::get_wrap(futures.begin()), + poolstl::internal::get_wrap(futures.end()), (T)0, std::plus()); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::count https://en.cppreference.com/w/cpp/algorithm/count + */ + template + poolstl::internal::enable_if_par::difference_type> + count(ExecPolicy&& policy, RandIt first, RandIt last, const T& value) { + return std::count_if(std::forward(policy), first, last, + [&value](const T& test) { return test == value; }); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::fill https://en.cppreference.com/w/cpp/algorithm/fill + */ + template + poolstl::internal::enable_if_par + fill(ExecPolicy &&policy, RandIt first, RandIt last, const Tp& value) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&value](RandIt chunk_first, RandIt chunk_last) { + std::fill(chunk_first, chunk_last, value); + }); + poolstl::internal::get_futures(futures); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::fill_n https://en.cppreference.com/w/cpp/algorithm/fill_n + */ + template + poolstl::internal::enable_if_par + fill_n(ExecPolicy &&policy, RandIt first, Size n, const Tp& value) { + if (n <= 0) { + return first; + } + RandIt last = poolstl::internal::advanced(first, n); + std::fill(std::forward(policy), first, last, value); + return last; + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::find_if https://en.cppreference.com/w/cpp/algorithm/find_if + */ + template + poolstl::internal::enable_if_par + find_if(ExecPolicy &&policy, RandIt first, RandIt last, UnaryPredicate p) { + using diff_t = typename std::iterator_traits::difference_type; + diff_t n = std::distance(first, last); + std::atomic extremum(n); + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&first, &extremum, &p](RandIt chunk_first, RandIt chunk_last) { + if (std::distance(first, chunk_first) > extremum) { + // already found by another task + return; + } + + RandIt chunk_res = std::find_if(chunk_first, chunk_last, p); + if (chunk_res != chunk_last) { + // Found, update exremum using a priority update CAS, as discussed in + // "Reducing Contention Through Priority Updates", PPoPP '13 + const diff_t k = std::distance(first, chunk_res); + for (diff_t old = extremum; k < old; old = extremum) { + extremum.compare_exchange_weak(old, k); + } + } + }, 8); // use small tasks so later ones may exit early if item is already found + poolstl::internal::get_futures(futures); + return extremum == n ? last : first + extremum; + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::find_if_not https://en.cppreference.com/w/cpp/algorithm/find_if_not + */ + template + poolstl::internal::enable_if_par + find_if_not(ExecPolicy &&policy, RandIt first, RandIt last, UnaryPredicate p) { + return std::find_if(std::forward(policy), first, last, + [&p](const typename std::iterator_traits::value_type& test) { return !p(test); }); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::find https://en.cppreference.com/w/cpp/algorithm/find + */ + template + poolstl::internal::enable_if_par + find(ExecPolicy &&policy, RandIt first, RandIt last, const T& value) { + return std::find_if(std::forward(policy), first, last, + [&value](const T& test) { return value == test; }); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::for_each https://en.cppreference.com/w/cpp/algorithm/for_each + */ + template + poolstl::internal::enable_if_par + for_each(ExecPolicy &&policy, RandIt first, RandIt last, UnaryFunction f) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&f](RandIt chunk_first, RandIt chunk_last) { + // std::for_each(chunk_first, chunk_last, f); + for (; chunk_first != chunk_last; ++chunk_first) { + f(*chunk_first); + } + }); + poolstl::internal::get_futures(futures); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::for_each_n https://en.cppreference.com/w/cpp/algorithm/for_each_n + */ + template + poolstl::internal::enable_if_par + for_each_n(ExecPolicy &&policy, RandIt first, Size n, UnaryFunction f) { + RandIt last = poolstl::internal::advanced(first, n); + std::for_each(std::forward(policy), first, last, f); + return last; + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::sort https://en.cppreference.com/w/cpp/algorithm/sort + */ + template + poolstl::internal::enable_if_par + sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp) { + poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, false); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::sort https://en.cppreference.com/w/cpp/algorithm/sort + */ + template + poolstl::internal::enable_if_par + sort(ExecPolicy &&policy, RandIt first, RandIt last) { + using T = typename std::iterator_traits::value_type; + poolstl::internal::parallel_sort(std::forward(policy), first, last, std::less(), false); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::stable_sort https://en.cppreference.com/w/cpp/algorithm/stable_sort + */ + template + poolstl::internal::enable_if_par + stable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp) { + poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, true); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::stable_sort https://en.cppreference.com/w/cpp/algorithm/stable_sort + */ + template + poolstl::internal::enable_if_par + stable_sort(ExecPolicy &&policy, RandIt first, RandIt last) { + using T = typename std::iterator_traits::value_type; + poolstl::internal::parallel_sort(std::forward(policy), first, last, std::less(), true); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::transform https://en.cppreference.com/w/cpp/algorithm/transform + */ + template + poolstl::internal::enable_if_par + transform(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, + RandIt2 dest, UnaryOperation unary_op) { + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, dest, + [&unary_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt2 dest_first) { + return poolstl::internal::cpp17::transform(chunk_first1, chunk_last1, dest_first, unary_op); + }); + poolstl::internal::get_futures(futures); + return dest + std::distance(first1, last1); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::transform https://en.cppreference.com/w/cpp/algorithm/transform + */ + template + poolstl::internal::enable_if_par + transform(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, + RandIt2 first2, RandIt3 dest, BinaryOperation binary_op) { + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, + first2, dest, + [&binary_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt1 chunk_first2, RandIt3 dest_first) { + return poolstl::internal::cpp17::transform(chunk_first1, chunk_last1, + chunk_first2, dest_first, binary_op); + }); + poolstl::internal::get_futures(futures); + return dest + std::distance(first1, last1); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::all_of https://en.cppreference.com/w/cpp/algorithm/all_of + */ + template + poolstl::internal::enable_if_par + all_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { + return last == std::find_if_not(std::forward(policy), first, last, pred); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::none_of https://en.cppreference.com/w/cpp/algorithm/none_of + */ + template + poolstl::internal::enable_if_par + none_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { + return last == std::find_if(std::forward(policy), first, last, pred); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::any_of https://en.cppreference.com/w/cpp/algorithm/any_of + */ + template + poolstl::internal::enable_if_par + any_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { + return !std::none_of(std::forward(policy), first, last, pred); + } +} + +namespace poolstl { + + template + void for_each_chunk(RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) { + if (first == last) { + return; + } + + auto chunk_data = construct(); + for (; first != last; ++first) { + f(*first, chunk_data); + } + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::for_each`, but exposes the chunking. The `construct` method is called once per parallel chunk and + * its output is passed to `f`. + * + * Useful for cases where an expensive workspace can be shared between loop iterations + * but cannot be shared by all parallel iterations. + */ + template + poolstl::internal::enable_if_par + for_each_chunk(ExecPolicy&& policy, RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&construct, &f](RandIt chunk_first, RandIt chunk_last) { + for_each_chunk(chunk_first, chunk_last, construct, f); + }); + poolstl::internal::get_futures(futures); + } +} + +#endif + +#ifndef POOLSTL_NUMERIC_HPP +#define POOLSTL_NUMERIC_HPP + +#include + + +namespace std { + +#if POOLSTL_HAVE_CXX17_LIB + /** + * NOTE: Iterators are expected to be random access. + * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan + */ + template + poolstl::internal::enable_if_par + exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init, BinaryOp binop) { + if (first == last) { + return dest; + } + + // Pass 1: Chunk the input and find the sum of each chunk + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [binop](RandIt1 chunk_first, RandIt1 chunk_last) { + auto sum = std::accumulate(chunk_first, chunk_last, T{}, binop); + return std::make_tuple(std::make_pair(chunk_first, chunk_last), sum); + }); + + std::vector> ranges; + std::vector sums; + + for (auto& future : futures) { + auto res = future.get(); + ranges.push_back(std::get<0>(res)); + sums.push_back(std::get<1>(res)); + } + + // find initial values for each range + std::exclusive_scan(sums.begin(), sums.end(), sums.begin(), init, binop); + + // Pass 2: perform exclusive scan of each chunk, using the sum of previous chunks as init + std::vector> args; + for (std::size_t i = 0; i < sums.size(); ++i) { + auto chunk_first = std::get<0>(ranges[i]); + args.emplace_back(std::make_tuple( + chunk_first, std::get<1>(ranges[i]), + dest + (chunk_first - first), + sums[i])); + } + + auto futures2 = poolstl::internal::parallel_apply(std::forward(policy), + [binop](RandIt1 chunk_first, RandIt1 chunk_last, RandIt2 chunk_dest, T chunk_init){ + std::exclusive_scan(chunk_first, chunk_last, chunk_dest, chunk_init, binop); + }, args); + + poolstl::internal::get_futures(futures2); + return dest + (last - first); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan + */ + template + poolstl::internal::enable_if_par + exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init) { + return std::exclusive_scan(std::forward(policy), first, last, dest, init, std::plus()); + } +#endif + + /** + * NOTE: Iterators are expected to be random access. + * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce + */ + template + poolstl::internal::enable_if_par + reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init, BinaryOp binop) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [init, binop](RandIt chunk_first, RandIt chunk_last) { + return poolstl::internal::cpp17::reduce(chunk_first, chunk_last, init, binop); + }); + + return poolstl::internal::cpp17::reduce( + poolstl::internal::get_wrap(futures.begin()), + poolstl::internal::get_wrap(futures.end()), init, binop); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce + */ + template + poolstl::internal::enable_if_par + reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init) { + return std::reduce(std::forward(policy), first, last, init, std::plus()); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce + */ + template + poolstl::internal::enable_if_par< + ExecPolicy, typename std::iterator_traits::value_type> + reduce(ExecPolicy &&policy, RandIt first, RandIt last) { + return std::reduce(std::forward(policy), first, last, + typename std::iterator_traits::value_type{}); + } + +#if POOLSTL_HAVE_CXX17_LIB + /** + * NOTE: Iterators are expected to be random access. + * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce + */ + template + poolstl::internal::enable_if_par + transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, T init, + BinaryReductionOp reduce_op, UnaryTransformOp transform_op) { + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, + [&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1) { + return std::transform_reduce(chunk_first1, chunk_last1, init, reduce_op, transform_op); + }); + + return poolstl::internal::cpp17::reduce( + poolstl::internal::get_wrap(futures.begin()), + poolstl::internal::get_wrap(futures.end()), init, reduce_op); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce + */ + template + poolstl::internal::enable_if_par + transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init, + BinaryReductionOp reduce_op, BinaryTransformOp transform_op) { + + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, first2, + [&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt2 chunk_first2) { + return std::transform_reduce(chunk_first1, chunk_last1, chunk_first2, init, reduce_op, transform_op); + }); + + return poolstl::internal::cpp17::reduce( + poolstl::internal::get_wrap(futures.begin()), + poolstl::internal::get_wrap(futures.end()), init, reduce_op); + } + + /** + * NOTE: Iterators are expected to be random access. + * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce + */ + template< class ExecPolicy, class RandIt1, class RandIt2, class T > + poolstl::internal::enable_if_par + transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init ) { + return transform_reduce(std::forward(policy), + first1, last1, first2, init, std::plus<>(), std::multiplies<>()); + } +#endif + +} + +#endif + +#ifndef POOLSTL_SEQ_FWD_HPP +#define POOLSTL_SEQ_FWD_HPP + + +/* + * Forward poolstl::seq to the native sequential (no policy) method. + */ + +#define POOLSTL_DEFINE_SEQ_FWD(NS, FNAME) \ + template \ + auto FNAME(EP&&, ARGS&&...args) -> \ + poolstl::internal::enable_if_seq(args)...))> { \ + return NS::FNAME(std::forward(args)...); \ + } + +#define POOLSTL_DEFINE_SEQ_FWD_VOID(NS, FNAME) \ + template \ + poolstl::internal::enable_if_seq FNAME(EP&&, ARGS&&... args) { \ + NS::FNAME(std::forward(args)...); \ + } + +#if POOLSTL_HAVE_CXX17 + +/* + * Dynamically choose policy from a std::variant. + * Useful to choose between parallel and sequential policies at runtime via par_if. + */ + +#define POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) \ + template \ + poolstl::internal::enable_if_poolstl_variant FNAME(EP&& policy, ARGS&&...args) { \ + std::visit([&](auto&& pol) { NS::FNAME(pol, std::forward(args)...); }, policy.var); \ + } + +#define POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) \ + template \ + auto FNAME(EP&& policy, ARGS&&...args) -> \ + poolstl::internal::enable_if_poolstl_variant(args)...))> { \ + return std::visit([&](auto&& pol) { return NS::FNAME(pol, std::forward(args)...); }, policy.var); \ + } + +#else +#define POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) +#define POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) +#endif +/* + * Define both the sequential forward and dynamic chooser. + */ +#define POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(NS, FNAME) \ + POOLSTL_DEFINE_SEQ_FWD(NS, FNAME) \ + POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) + +#define POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(NS, FNAME) \ + POOLSTL_DEFINE_SEQ_FWD_VOID(NS, FNAME) \ + POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) + +namespace std { + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, all_of) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, any_of) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, none_of) + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, count) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, count_if) + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, copy) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, copy_n) + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(std, fill) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, fill_n) + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find_if) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find_if_not) + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(std, for_each) +#if POOLSTL_HAVE_CXX17_LIB + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, for_each_n) +#endif + + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, transform) + +#if POOLSTL_HAVE_CXX17_LIB + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, exclusive_scan) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, reduce) + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, transform_reduce) +#endif +} + +namespace poolstl { + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(poolstl, for_each_chunk) +} + +#endif + +// Note that iota_iter.hpp is self-contained in its own right. + +#ifndef POOLSTL_IOTA_ITER_HPP +#define POOLSTL_IOTA_ITER_HPP + +#include +#include + +namespace poolstl { + + /** + * An iterator over the integers. + * + * Effectively a view on a fictional vector populated by std::iota, but without materializing anything. + * + * Useful to parallelize loops that are not over a container, like this: + * + * \code{.cpp} + * for (int i = 0; i < 10; ++i) { + * } + *\endcode + * + * Becomes: + * \code{.cpp} + * std::for_each(iota_iter(0), iota_iter(10), [](int i) { + * }); + * \endcode + * + * @tparam T A type that acts as an integer. + */ + template + class iota_iter { + public: + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T; + using iterator_category = std::random_access_iterator_tag; + + iota_iter() : value{} {} + explicit iota_iter(T rhs) : value(rhs) {} + iota_iter(const iota_iter &rhs) : value(rhs.value) {} + + iota_iter &operator=(T rhs) { value = rhs; return *this; } + iota_iter &operator=(const iota_iter &rhs) { value = rhs.value; return *this; } + + reference operator*() const { return value; } + reference operator[](difference_type rhs) const { return value + rhs; } + // operator-> has no meaning in this application + + bool operator==(const iota_iter &rhs) const { return value == rhs.value; } + bool operator!=(const iota_iter &rhs) const { return value != rhs.value; } + bool operator<(const iota_iter &rhs) const { return value < rhs.value; } + bool operator>(const iota_iter &rhs) const { return value > rhs.value; } + bool operator<=(const iota_iter &rhs) const { return value <= rhs.value; } + bool operator>=(const iota_iter &rhs) const { return value >= rhs.value; } + + iota_iter &operator+=(difference_type rhs) { value += rhs; return *this; } + iota_iter &operator-=(difference_type rhs) { value -= rhs; return *this; } + + iota_iter &operator++() { ++value; return *this; } + iota_iter &operator--() { --value; return *this; } + iota_iter operator++(int) { iota_iter ret(value); ++value; return ret; } + iota_iter operator--(int) { iota_iter ret(value); --value; return ret; } + + difference_type operator-(const iota_iter &rhs) const { return value - rhs.value; } + iota_iter operator-(difference_type rhs) const { return iota_iter(value - rhs); } + iota_iter operator+(difference_type rhs) const { return iota_iter(value + rhs); } + + friend inline iota_iter operator+(difference_type lhs, const iota_iter &rhs) { + return iota_iter(lhs + rhs.value); + } + + protected: + T value; + }; +} + +namespace std { + /** + * Specialize std::iterator_traits for poolstl::iota_iter. + */ + template + struct iterator_traits> { + using value_type = typename poolstl::iota_iter::value_type; + using difference_type = typename poolstl::iota_iter::difference_type; + using pointer = typename poolstl::iota_iter::pointer; + using reference = typename poolstl::iota_iter::reference; + using iterator_category = typename poolstl::iota_iter::iterator_category; + }; +} + +#endif + +/* + * Optionally alias `poolstl::par` as `std::execution::par` to enable poolSTL to fill in for missing compiler support. + * + * USE AT YOUR OWN RISK! + * + * To use this define POOLSTL_STD_SUPPLEMENT=1 before including poolstl.hpp. + * + * Attempts to autodetect native support by checking for , including it if it exists, and then checking for + * the __cpp_lib_parallel_algorithm feature macro. + * + * If native support is not found then the standard execution policies are declared as forwards to poolSTL. + * + * GCC and Clang: TBB is required if is #included. If you'd like to use the poolSTL supplement in cases + * that TBB is not available, have your build system define POOLSTL_STD_SUPPLEMENT_NO_INCLUDE if TBB is not found. + * PoolSTL will then not include and the supplement will kick in. + * Your code must not #include . + * + * MinGW: the compiler declares support, but actual performance is sequential (see poolSTL benchmark). To use + * the supplement anyway define POOLSTL_STD_SUPPLEMENT_FORCE to override the autodetection. + * Your code must not #include . + * + * Define POOLSTL_ALLOW_SUPPLEMENT=0 to override POOLSTL_STD_SUPPLEMENT and disable this feature. + */ +#ifndef POOLSTL_ALLOW_SUPPLEMENT +#define POOLSTL_ALLOW_SUPPLEMENT 1 +#endif + +#if POOLSTL_ALLOW_SUPPLEMENT && defined(POOLSTL_STD_SUPPLEMENT) + +#if __cplusplus >= 201603L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201603L) +#if __has_include() +#ifndef POOLSTL_STD_SUPPLEMENT_NO_INCLUDE +#endif +#endif +#endif + +#if !defined(__cpp_lib_parallel_algorithm) || defined(POOLSTL_STD_SUPPLEMENT_FORCE) +namespace std { + namespace execution { + using ::poolstl::execution::sequenced_policy; + using ::poolstl::execution::seq; + using ::poolstl::execution::parallel_policy; + using ::poolstl::execution::par; + using parallel_unsequenced_policy = ::poolstl::execution::parallel_policy; + constexpr parallel_unsequenced_policy par_unseq{}; + } +} + +#endif +#endif + +#endif diff --git a/src/index.cpp b/src/index.cpp index 7773e509..9b907257 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -11,6 +11,7 @@ #include #include #include "pdqsort/pdqsort.h" +#include "poolstl/poolstl.hpp" #include #include #include @@ -138,7 +139,7 @@ int StrobemerIndex::pick_bits(size_t size) const { return std::clamp(static_cast(log2(estimated_number_of_randstrobes)) - 1, 8, 31); } -void StrobemerIndex::populate(float f, size_t n_threads) { +void StrobemerIndex::populate(float f, unsigned n_threads) { Timer count_hash; auto randstrobe_counts = count_all_randstrobes(references, parameters, n_threads); stats.elapsed_counting_hashes = count_hash.duration(); @@ -164,8 +165,13 @@ void StrobemerIndex::populate(float f, size_t n_threads) { Timer sorting_timer; logger.debug() << " Sorting ...\n"; - // sort by hash values - pdqsort_branchless(randstrobes.begin(), randstrobes.end()); + if (true) { + task_thread_pool::task_thread_pool pool{n_threads}; + std::sort(poolstl::par.on(pool), randstrobes.begin(), randstrobes.end()); + } else { + // sort by hash values + pdqsort_branchless(randstrobes.begin(), randstrobes.end()); + } stats.elapsed_sorting_seeds = sorting_timer.duration(); Timer hash_index_timer; diff --git a/src/index.hpp b/src/index.hpp index 941db7de..a6b9f003 100644 --- a/src/index.hpp +++ b/src/index.hpp @@ -51,7 +51,7 @@ struct StrobemerIndex { void write(const std::string& filename) const; void read(const std::string& filename); - void populate(float f, size_t n_threads); + void populate(float f, unsigned n_threads); void print_diagnostics(const std::string& logfile_name, int k) const; int pick_bits(size_t size) const; size_t find(randstrobe_hash_t key) const { From c689d895e9bc5ebd1ed55aa13ce31c6711006839 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Sun, 14 Jan 2024 16:00:48 +0100 Subject: [PATCH 07/26] Bump poolSTL and use poolstl::pluggable_sort --- ext/README.md | 6 + ext/poolstl/poolstl.hpp | 1476 +++++++++++++++++++++++++++++++-------- src/index.cpp | 9 +- 3 files changed, 1179 insertions(+), 312 deletions(-) diff --git a/ext/README.md b/ext/README.md index e8316d26..d80e5a2d 100644 --- a/ext/README.md +++ b/ext/README.md @@ -27,6 +27,12 @@ Homepage: https://github.com/orlp/pdqsort Commit used: b1ef26a55cdb60d236a5cb199c4234c704f46726 License: See pdqsort/license.txt +## poolstl + +Homepage: https://github.com/alugowski/poolSTL/ +Downloaded file: https://github.com/alugowski/poolSTL/releases/download/v0.3.3/poolstl.hpp +Version: 0.3.3 +License: See poolstl.hpp ## robin_hood diff --git a/ext/poolstl/poolstl.hpp b/ext/poolstl/poolstl.hpp index ea79146e..d1340a1e 100644 --- a/ext/poolstl/poolstl.hpp +++ b/ext/poolstl/poolstl.hpp @@ -84,6 +84,7 @@ * DEALINGS IN THE SOFTWARE. */ + #ifndef POOLSTL_HPP #define POOLSTL_HPP @@ -93,6 +94,767 @@ #include #include +#include +#include + + +#ifndef AL_TASK_THREAD_POOL_HPP +#define AL_TASK_THREAD_POOL_HPP + +// Version macros. +#define TASK_THREAD_POOL_VERSION_MAJOR 1 +#define TASK_THREAD_POOL_VERSION_MINOR 0 +#define TASK_THREAD_POOL_VERSION_PATCH 10 + +#include +#include +#include +#include +#include +#include +#include + +// MSVC does not correctly set the __cplusplus macro by default, so we must read it from _MSVC_LANG +// See https://devblogs.microsoft.com/cppblog/msvc-now-correctly-reports-__cplusplus/ +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define TTP_CXX17 1 +#else +#define TTP_CXX17 0 +#endif + +#if TTP_CXX17 +#define TTP_NODISCARD [[nodiscard]] +#else +#define TTP_NODISCARD +#endif + +namespace task_thread_pool { + +#if !TTP_CXX17 + /** + * A reimplementation of std::decay_t, which is only available since C++14. + */ + template + using decay_t = typename std::decay::type; +#endif + + /** + * A fast and lightweight thread pool that uses C++11 threads. + */ + class task_thread_pool { + public: + /** + * Create a task_thread_pool and start worker threads. + * + * @param num_threads Number of worker threads. If 0 then number of threads is equal to the + * number of physical cores on the machine, as given by std::thread::hardware_concurrency(). + */ + explicit task_thread_pool(unsigned int num_threads = 0) { + if (num_threads < 1) { + num_threads = std::thread::hardware_concurrency(); + if (num_threads < 1) { num_threads = 1; } + } + start_threads(num_threads); + } + + /** + * Finish all tasks left in the queue then shut down worker threads. + * If the pool is currently paused then it is resumed. + */ + ~task_thread_pool() { + unpause(); + wait_for_queued_tasks(); + stop_all_threads(); + } + + /** + * Drop all tasks that have been submitted but not yet started by a worker. + * + * Tasks already in progress continue executing. + */ + void clear_task_queue() { + const std::lock_guard tasks_lock(task_mutex); + tasks = {}; + } + + /** + * Get number of enqueued tasks. + * + * @return Number of tasks that have been enqueued but not yet started. + */ + TTP_NODISCARD size_t get_num_queued_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return tasks.size(); + } + + /** + * Get number of in-progress tasks. + * + * @return Approximate number of tasks currently being processed by worker threads. + */ + TTP_NODISCARD size_t get_num_running_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return num_inflight_tasks; + } + + /** + * Get total number of tasks in the pool. + * + * @return Approximate number of tasks both enqueued and running. + */ + TTP_NODISCARD size_t get_num_tasks() const { + const std::lock_guard tasks_lock(task_mutex); + return tasks.size() + num_inflight_tasks; + } + + /** + * Get number of worker threads. + * + * @return Number of worker threads. + */ + TTP_NODISCARD unsigned int get_num_threads() const { + const std::lock_guard threads_lock(thread_mutex); + return static_cast(threads.size()); + } + + /** + * Set number of worker threads. Will start or stop worker threads as necessary. + * + * @param num_threads Number of worker threads. If 0 then number of threads is equal to the + * number of physical cores on the machine, as given by std::thread::hardware_concurrency(). + * @return Previous number of worker threads. + */ + unsigned int set_num_threads(unsigned int num_threads) { + const std::lock_guard threads_lock(thread_mutex); + unsigned int previous_num_threads = get_num_threads(); + + if (num_threads < 1) { + num_threads = std::thread::hardware_concurrency(); + if (num_threads < 1) { num_threads = 1; } + } + + if (previous_num_threads <= num_threads) { + // expanding the thread pool + start_threads(num_threads - previous_num_threads); + } else { + // contracting the thread pool + stop_all_threads(); + { + const std::lock_guard tasks_lock(task_mutex); + pool_running = true; + } + start_threads(num_threads); + } + + return previous_num_threads; + } + + /** + * Stop executing queued tasks. Use `unpause()` to resume. Note: Destroying the pool will implicitly unpause. + * + * Any in-progress tasks continue executing. + */ + void pause() { + const std::lock_guard tasks_lock(task_mutex); + pool_paused = true; + } + + /** + * Resume executing queued tasks. + */ + void unpause() { + const std::lock_guard tasks_lock(task_mutex); + pool_paused = false; + task_cv.notify_all(); + } + + /** + * Check whether the pool is paused. + * + * @return true if pause() has been called without an intervening unpause(). + */ + TTP_NODISCARD bool is_paused() const { + const std::lock_guard tasks_lock(task_mutex); + return pool_paused; + } + + /** + * Submit a Callable for the pool to execute and return a std::future. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + * @param args Arguments for func. Optional. + * @return std::future that can be used to get func's return value or thrown exception. + */ + template , std::decay_t...> +#else + typename R = typename std::result_of(decay_t...)>::type +#endif + > + TTP_NODISCARD std::future submit(F&& func, A&&... args) { +#if defined(_MSC_VER) + // MSVC's packaged_task is not movable even though it should be. + // Discussion about this bug and its future fix: + // https://developercommunity.visualstudio.com/t/unable-to-move-stdpackaged-task-into-any-stl-conta/108672 + std::shared_ptr> ptask = + std::make_shared>(std::bind(std::forward(func), std::forward(args)...)); + submit_detach([ptask] { (*ptask)(); }); + return ptask->get_future(); +#else + std::packaged_task task(std::bind(std::forward(func), std::forward(args)...)); + auto ret = task.get_future(); + submit_detach(std::move(task)); + return ret; +#endif + } + + /** + * Submit a zero-argument Callable for the pool to execute. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + */ + template + void submit_detach(F&& func) { + const std::lock_guard tasks_lock(task_mutex); + tasks.emplace(std::forward(func)); + task_cv.notify_one(); + } + + /** + * Submit a Callable with arguments for the pool to execute. + * + * @param func The Callable to execute. Can be a function, a lambda, std::packaged_task, std::function, etc. + */ + template + void submit_detach(F&& func, A&&... args) { + const std::lock_guard tasks_lock(task_mutex); + tasks.emplace(std::bind(std::forward(func), std::forward(args)...)); + task_cv.notify_one(); + } + + /** + * Block until the task queue is empty. Some tasks may be in-progress when this method returns. + */ + void wait_for_queued_tasks() { + std::unique_lock tasks_lock(task_mutex); + notify_task_finish = true; + task_finished_cv.wait(tasks_lock, [&] { return tasks.empty(); }); + notify_task_finish = false; + } + + /** + * Block until all tasks have finished. + */ + void wait_for_tasks() { + std::unique_lock tasks_lock(task_mutex); + notify_task_finish = true; + task_finished_cv.wait(tasks_lock, [&] { return tasks.empty() && num_inflight_tasks == 0; }); + notify_task_finish = false; + } + + protected: + + /** + * Main function for worker threads. + */ + void worker_main() { + bool finished_task = false; + + while (true) { + std::unique_lock tasks_lock(task_mutex); + + if (finished_task) { + --num_inflight_tasks; + if (notify_task_finish) { + task_finished_cv.notify_all(); + } + } + + task_cv.wait(tasks_lock, [&]() { return !pool_running || (!pool_paused && !tasks.empty()); }); + + if (!pool_running) { + break; + } + + // Must mean that (!pool_paused && !tasks.empty()) is true + + std::packaged_task task{std::move(tasks.front())}; + tasks.pop(); + ++num_inflight_tasks; + tasks_lock.unlock(); + + try { + task(); + } catch (...) { + // std::packaged_task::operator() may throw in some error conditions, such as if the task + // had already been run. Nothing that the pool can do anything about. + } + + finished_task = true; + } + } + + /** + * Start worker threads. + * + * @param num_threads How many threads to start. + */ + void start_threads(const unsigned int num_threads) { + const std::lock_guard threads_lock(thread_mutex); + + for (unsigned int i = 0; i < num_threads; ++i) { + threads.emplace_back(&task_thread_pool::worker_main, this); + } + } + + /** + * Stop, join, and destroy all worker threads. + */ + void stop_all_threads() { + const std::lock_guard threads_lock(thread_mutex); + + { + const std::lock_guard tasks_lock(task_mutex); + pool_running = false; + task_cv.notify_all(); + } + + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + threads.clear(); + } + + /** + * The worker threads. + * + * Access protected by thread_mutex + */ + std::vector threads; + + /** + * A mutex for methods that start/stop threads. + */ + mutable std::recursive_mutex thread_mutex; + + /** + * The task queue. + * + * Access protected by task_mutex. + */ + std::queue> tasks = {}; + + /** + * A mutex for all variables related to tasks. + */ + mutable std::mutex task_mutex; + + /** + * Used to notify changes to the task queue, such as a new task added, pause/unpause, etc. + */ + std::condition_variable task_cv; + + /** + * Used to notify of finished tasks. + */ + std::condition_variable task_finished_cv; + + /** + * A signal for worker threads that the pool is either running or shutting down. + * + * Access protected by task_mutex. + */ + bool pool_running = true; + + /** + * A signal for worker threads to not pull new tasks from the queue. + * + * Access protected by task_mutex. + */ + bool pool_paused = false; + + /** + * A signal for worker threads that they should notify task_finished_cv when they finish a task. + * + * Access protected by task_mutex. + */ + bool notify_task_finish = false; + + /** + * A counter of the number of tasks in-progress by worker threads. + * Incremented when a task is popped off the task queue and decremented when that task is complete. + * + * Access protected by task_mutex. + */ + int num_inflight_tasks = 0; + }; +} + +// clean up +#undef TTP_NODISCARD +#undef TTP_CXX17 + +#endif + +#ifndef POOLSTL_INTERNAL_UTILS_HPP +#define POOLSTL_INTERNAL_UTILS_HPP + +// Version macros. +#define POOLSTL_VERSION_MAJOR 0 +#define POOLSTL_VERSION_MINOR 3 +#define POOLSTL_VERSION_PATCH 3 + +#include +#include + +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define POOLSTL_HAVE_CXX17 1 +#define POOLSTL_NO_DISCARD [[nodiscard]] +#else +#define POOLSTL_HAVE_CXX17 0 +#define POOLSTL_NO_DISCARD +#endif + +#if POOLSTL_HAVE_CXX17 && (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE >= 9) +#define POOLSTL_HAVE_CXX17_LIB 1 +#else +#define POOLSTL_HAVE_CXX17_LIB 0 +#endif + +#if __cplusplus >= 201402L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +#define POOLSTL_HAVE_CXX14 1 +#else +#define POOLSTL_HAVE_CXX14 0 +#endif + +namespace poolstl { + namespace internal { + + inline constexpr std::size_t get_chunk_size(std::size_t num_steps, unsigned int num_threads) { + return (num_steps / num_threads) + ((num_steps % num_threads) > 0 ? 1 : 0); + } + + template + constexpr typename std::iterator_traits::difference_type + get_chunk_size(Iterator first, Iterator last, unsigned int num_threads) { + using diff_t = typename std::iterator_traits::difference_type; + return static_cast(get_chunk_size((std::size_t)std::distance(first, last), num_threads)); + } + + template + constexpr typename std::iterator_traits::difference_type + get_iter_chunk_size(const Iterator& iter, const Iterator& last, + typename std::iterator_traits::difference_type chunk_size) { + return std::min(chunk_size, std::distance(iter, last)); + } + + template + Iterator advanced(Iterator iter, typename std::iterator_traits::difference_type offset) { + Iterator ret = iter; + std::advance(ret, offset); + return ret; + } + + /** + * An iterator wrapper that calls std::future<>::get(). + * @tparam Iterator + */ + template + class getting_iter : public Iterator { + public: + using value_type = decltype((*std::declval()).get()); + using difference_type = typename std::iterator_traits::difference_type; + using pointer = value_type*; + using reference = value_type&; + explicit getting_iter(Iterator iter) : iter(iter) {} + + getting_iter operator++() { ++iter; return *this; } + getting_iter operator++(int) { getting_iter ret(*this); ++iter; return ret; } + + value_type operator*() { return (*iter).get(); } + value_type operator[](difference_type offset) { return iter[offset].get(); } + + bool operator==(const getting_iter &other) const { return iter == other.iter; } + bool operator!=(const getting_iter &other) const { return iter != other.iter; } + + protected: + Iterator iter; + }; + + template + getting_iter get_wrap(Iterator iter) { + return getting_iter(iter); + } + + template + void get_futures(Container& futures) { + for (auto &future: futures) { + future.get(); + } + } + + /* + * Some methods are only available with C++17 and up. Reimplement on older standards. + */ +#if POOLSTL_HAVE_CXX17_LIB + namespace cpp17 = std; +#else + namespace cpp17 { + + // std::reduce + + template + Tp reduce(InputIt first, InputIt last, Tp init, BinOp b) { + for (; first != last; ++first) + init = b(init, *first); + return init; + } + + template + typename std::iterator_traits::value_type reduce(InputIt first, InputIt last) { + return reduce(first, last, + typename std::iterator_traits::value_type{}, + std::plus::value_type>()); + } + + // std::transform + + template + OutputIt transform(InputIt first1, InputIt last1, OutputIt d_first, + UnaryOperation unary_op) { + while (first1 != last1) { + *d_first++ = unary_op(*first1++); + } + + return d_first; + } + + template + OutputIt transform(InputIt1 first1, InputIt1 last1, + InputIt2 first2, OutputIt d_first, + BinaryOperation binary_op) { + while (first1 != last1) { + *d_first++ = binary_op(*first1++, *first2++); + } + + return d_first; + } + } +#endif + } +} + +#endif + +namespace poolstl { + + namespace ttp = task_thread_pool; + + namespace execution { + namespace internal { + /** + * Holds the thread pool used by par. + */ + inline std::shared_ptr get_default_pool() { + static std::shared_ptr pool; + static std::once_flag flag; + std::call_once(flag, [&](){ pool = std::make_shared(); }); + return pool; + } + } + + /** + * Base class for all poolSTL policies. + */ + struct poolstl_policy { + }; + + /** + * A sequential policy that simply forwards to the non-policy overload. + */ + struct sequenced_policy : public poolstl_policy { + POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const { + // never called, but must exist for C++11 support + throw std::runtime_error("poolSTL: requested thread pool for seq policy."); + } + + POOLSTL_NO_DISCARD bool par_allowed() const { + return false; + } + }; + + /** + * A parallel policy that can use a user-specified thread pool or a default one. + */ + struct parallel_policy : public poolstl_policy { + parallel_policy() = default; + explicit parallel_policy(ttp::task_thread_pool* on_pool, bool par_ok): on_pool(on_pool), par_ok(par_ok) {} + + parallel_policy on(ttp::task_thread_pool& pool) const { + return parallel_policy{&pool, par_ok}; + } + + parallel_policy par_if(bool call_par) const { + return parallel_policy{on_pool, call_par}; + } + + POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const { + if (on_pool) { + return on_pool; + } else { + return internal::get_default_pool().get(); + } + } + + POOLSTL_NO_DISCARD bool par_allowed() const { + return par_ok; + } + + protected: + ttp::task_thread_pool *on_pool = nullptr; + bool par_ok = true; + }; + + constexpr sequenced_policy seq{}; + constexpr parallel_policy par{}; + + /** + * EXPERIMENTAL: Subject to significant changes or removal. + * Use pure threads for each operation instead of a shared thread pool. + * + * Advantage: + * - Fewer symbols (no packaged_task with its operators, destructors, vtable, etc) means smaller binary + * which can mean a lot when there are many calls. + * - No thread pool to manage. + * + * Disadvantages: + * - Threads are started and joined for every operation, so it is harder to amortize that cost. + * - Barely any algorithms are supported. + */ + struct pure_threads_policy : public poolstl_policy { + explicit pure_threads_policy(unsigned int num_threads, bool par_ok): num_threads(num_threads), + par_ok(par_ok) {} + + POOLSTL_NO_DISCARD unsigned int get_num_threads() const { + if (num_threads == 0) { + return std::thread::hardware_concurrency(); + } + return num_threads; + } + + POOLSTL_NO_DISCARD bool par_allowed() const { + return par_ok; + } + + protected: + unsigned int num_threads = 1; + bool par_ok = true; + }; + + /** + * Choose parallel or sequential at runtime. + * + * @param call_par Whether to use a parallel policy. + * @return `par` if call_par is true, else a sequential policy (like `seq`). + */ + inline parallel_policy par_if(bool call_par) { + return parallel_policy{nullptr, call_par}; + } + + /** + * Choose parallel or sequential at runtime, with pool selection. + * + * @param call_par Whether to use a parallel policy. + * @return `par.on(pool)` if call_par is true, else a sequential policy (like `seq`). + */ + inline parallel_policy par_if(bool call_par, ttp::task_thread_pool& pool) { + return parallel_policy{&pool, call_par}; + } + + /** + * EXPERIMENTAL: Subject to significant changes or removal. See `pure_threads_policy`. + * Choose parallel or sequential at runtime, with thread count selection. + * + * @param call_par Whether to use a parallel policy. + * @return `par.on(pool)` if call_par is true, else `seq`. + */ + inline pure_threads_policy par_if_threads(bool call_par, unsigned int num_threads) { + return pure_threads_policy{num_threads, call_par}; + } + } + + using execution::seq; + using execution::par; + using execution::par_if; + + namespace internal { + /** + * To enable/disable seq overload resolution + */ + template + using enable_if_seq = + typename std::enable_if< + std::is_same::type>::type>::value, + Tp>::type; + + /** + * To enable/disable par overload resolution + */ + template + using enable_if_par = + typename std::enable_if< + std::is_same::type>::type>::value, + Tp>::type; + + /** + * To enable/disable par overload resolution + */ + template + using enable_if_poolstl_policy = + typename std::enable_if< + std::is_base_of::type>::type>::value, + Tp>::type; + + template + bool is_seq(const ExecPolicy& policy) { + return !policy.par_allowed(); + } + + template + using is_pure_threads_policy = std::is_same::type>::type>; + } +} + +#endif + +#ifndef POOLSTL_ALGORITHM_HPP +#define POOLSTL_ALGORITHM_HPP + +#include + + +#ifndef POOLSTL_INTERNAL_TTP_IMPL_HPP +#define POOLSTL_INTERNAL_TTP_IMPL_HPP + +#include +#include +#include +#include + + +#ifndef POOLSTL_EXECUTION_HPP +#define POOLSTL_EXECUTION_HPP + +#include +#include +#include #include @@ -102,7 +864,7 @@ // Version macros. #define TASK_THREAD_POOL_VERSION_MAJOR 1 #define TASK_THREAD_POOL_VERSION_MINOR 0 -#define TASK_THREAD_POOL_VERSION_PATCH 9 +#define TASK_THREAD_POOL_VERSION_PATCH 10 #include #include @@ -289,12 +1051,22 @@ namespace task_thread_pool { #else typename R = typename std::result_of(decay_t...)>::type #endif - > + > TTP_NODISCARD std::future submit(F&& func, A&&... args) { +#if defined(_MSC_VER) + // MSVC's packaged_task is not movable even though it should be. + // Discussion about this bug and its future fix: + // https://developercommunity.visualstudio.com/t/unable-to-move-stdpackaged-task-into-any-stl-conta/108672 std::shared_ptr> ptask = std::make_shared>(std::bind(std::forward(func), std::forward(args)...)); submit_detach([ptask] { (*ptask)(); }); return ptask->get_future(); +#else + std::packaged_task task(std::bind(std::forward(func), std::forward(args)...)); + auto ret = task.get_future(); + submit_detach(std::move(task)); + return ret; +#endif } /** @@ -493,7 +1265,7 @@ namespace task_thread_pool { // Version macros. #define POOLSTL_VERSION_MAJOR 0 #define POOLSTL_VERSION_MINOR 3 -#define POOLSTL_VERSION_PATCH 1 +#define POOLSTL_VERSION_PATCH 3 #include #include @@ -637,10 +1409,6 @@ namespace poolstl { #endif -#if POOLSTL_HAVE_CXX17 -#include -#endif - namespace poolstl { namespace ttp = task_thread_pool; @@ -658,92 +1426,129 @@ namespace poolstl { } } + /** + * Base class for all poolSTL policies. + */ + struct poolstl_policy { + }; + /** * A sequential policy that simply forwards to the non-policy overload. */ - struct sequenced_policy {}; + struct sequenced_policy : public poolstl_policy { + POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const { + // never called, but must exist for C++11 support + throw std::runtime_error("poolSTL: requested thread pool for seq policy."); + } + + POOLSTL_NO_DISCARD bool par_allowed() const { + return false; + } + }; /** * A parallel policy that can use a user-specified thread pool or a default one. */ - struct parallel_policy { + struct parallel_policy : public poolstl_policy { parallel_policy() = default; - explicit parallel_policy(ttp::task_thread_pool& on_pool): on_pool(&on_pool) {} + explicit parallel_policy(ttp::task_thread_pool* on_pool, bool par_ok): on_pool(on_pool), par_ok(par_ok) {} parallel_policy on(ttp::task_thread_pool& pool) const { - return parallel_policy{pool}; + return parallel_policy{&pool, par_ok}; + } + + parallel_policy par_if(bool call_par) const { + return parallel_policy{on_pool, call_par}; } - POOLSTL_NO_DISCARD ttp::task_thread_pool& pool() const { + POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const { if (on_pool) { - return *on_pool; + return on_pool; } else { - return *(internal::get_default_pool()); + return internal::get_default_pool().get(); } } + POOLSTL_NO_DISCARD bool par_allowed() const { + return par_ok; + } + protected: ttp::task_thread_pool *on_pool = nullptr; + bool par_ok = true; }; constexpr sequenced_policy seq{}; constexpr parallel_policy par{}; - -#if POOLSTL_HAVE_CXX17 /** - * A policy that allows selecting a policy at runtime. + * EXPERIMENTAL: Subject to significant changes or removal. + * Use pure threads for each operation instead of a shared thread pool. * - * @tparam Variant std::variant<> of policy options. + * Advantage: + * - Fewer symbols (no packaged_task with its operators, destructors, vtable, etc) means smaller binary + * which can mean a lot when there are many calls. + * - No thread pool to manage. + * + * Disadvantages: + * - Threads are started and joined for every operation, so it is harder to amortize that cost. + * - Barely any algorithms are supported. */ - template - struct variant_policy { - explicit variant_policy(const Variant& policy): var(policy) {} - Variant var; - }; + struct pure_threads_policy : public poolstl_policy { + explicit pure_threads_policy(unsigned int num_threads, bool par_ok): num_threads(num_threads), + par_ok(par_ok) {} - namespace internal { - using poolstl_policy_variant = std::variant< - poolstl::execution::parallel_policy, - poolstl::execution::sequenced_policy>; - } + POOLSTL_NO_DISCARD unsigned int get_num_threads() const { + if (num_threads == 0) { + return std::thread::hardware_concurrency(); + } + return num_threads; + } + + POOLSTL_NO_DISCARD bool par_allowed() const { + return par_ok; + } + + protected: + unsigned int num_threads = 1; + bool par_ok = true; + }; /** * Choose parallel or sequential at runtime. * * @param call_par Whether to use a parallel policy. - * @return `par` if call_par is true, else `seq`. + * @return `par` if call_par is true, else a sequential policy (like `seq`). */ - inline variant_policy par_if(bool call_par) { - if (call_par) { - return variant_policy(internal::poolstl_policy_variant(par)); - } else { - return variant_policy(internal::poolstl_policy_variant(seq)); - } + inline parallel_policy par_if(bool call_par) { + return parallel_policy{nullptr, call_par}; } /** * Choose parallel or sequential at runtime, with pool selection. * * @param call_par Whether to use a parallel policy. + * @return `par.on(pool)` if call_par is true, else a sequential policy (like `seq`). + */ + inline parallel_policy par_if(bool call_par, ttp::task_thread_pool& pool) { + return parallel_policy{&pool, call_par}; + } + + /** + * EXPERIMENTAL: Subject to significant changes or removal. See `pure_threads_policy`. + * Choose parallel or sequential at runtime, with thread count selection. + * + * @param call_par Whether to use a parallel policy. * @return `par.on(pool)` if call_par is true, else `seq`. */ - inline variant_policy par_if(bool call_par, ttp::task_thread_pool& pool) { - if (call_par) { - return variant_policy(internal::poolstl_policy_variant(par.on(pool))); - } else { - return variant_policy(internal::poolstl_policy_variant(seq)); - } + inline pure_threads_policy par_if_threads(bool call_par, unsigned int num_threads) { + return pure_threads_policy{num_threads, call_par}; } -#endif } using execution::seq; using execution::par; -#if POOLSTL_HAVE_CXX17 - using execution::variant_policy; using execution::par_if; -#endif namespace internal { /** @@ -766,44 +1571,29 @@ namespace poolstl { typename std::remove_cv::type>::type>::value, Tp>::type; -#if POOLSTL_HAVE_CXX17 - /** - * Helper for enable_if_poolstl_variant - */ - template struct is_poolstl_variant_policy : std::false_type {}; - template struct is_poolstl_variant_policy< - ::poolstl::execution::variant_policy> :std::true_type {}; - /** - * To enable/disable variant_policy (for par_if) overload resolution + * To enable/disable par overload resolution */ template - using enable_if_poolstl_variant = + using enable_if_poolstl_policy = typename std::enable_if< - is_poolstl_variant_policy< + std::is_base_of::type>::type>::value, Tp>::type; -#endif + + template + bool is_seq(const ExecPolicy& policy) { + return !policy.par_allowed(); + } + + template + using is_pure_threads_policy = std::is_same::type>::type>; } } #endif -#ifndef POOLSTL_ALGORITHM_HPP -#define POOLSTL_ALGORITHM_HPP - -#include - - -#ifndef POOLSTL_INTERNAL_TTP_IMPL_HPP -#define POOLSTL_INTERNAL_TTP_IMPL_HPP - -#include -#include -#include -#include - - namespace poolstl { namespace internal { @@ -815,33 +1605,61 @@ namespace poolstl { std::vector> parallel_apply(ExecPolicy &&policy, Op op, const ArgContainer& args_list) { std::vector> futures; - auto& task_pool = policy.pool(); + auto& task_pool = *policy.pool(); for (const auto& args : args_list) { - futures.emplace_back(task_pool.submit([op](const auto& args_fwd) { std::apply(op, args_fwd); }, args)); + futures.emplace_back(task_pool.submit([](Op op, const auto& args_fwd) { + std::apply(op, args_fwd); + }, op, args)); } return futures; } #endif + /** + * Chunk a single range, with autodetected return types. + */ + template ()(std::declval(), std::declval()))> + std::vector> + parallel_chunk_for_gen(ExecPolicy &&policy, RandIt first, RandIt last, Chunk chunk, + ChunkRet* = (decltype(std::declval()(std::declval(), + std::declval()))*)nullptr, + int extra_split_factor = 1) { + std::vector> futures; + auto& task_pool = *policy.pool(); + auto chunk_size = get_chunk_size(first, last, extra_split_factor * task_pool.get_num_threads()); + + while (first < last) { + auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); + RandIt loop_end = advanced(first, iter_chunk_size); + + futures.emplace_back(task_pool.submit(std::forward(chunk), first, loop_end)); + + first = loop_end; + } + + return futures; + } + /** * Chunk a single range. */ - template - std::vector()(std::declval(), std::declval()))>> - parallel_chunk_for(ExecPolicy &&policy, RandIt first, RandIt last, Chunk chunk, int extra_split_factor = 1) { - std::vector()(std::declval(), std::declval())) - >> futures; - auto& task_pool = policy.pool(); + template + std::vector> + parallel_chunk_for_1(ExecPolicy &&policy, RandIt first, RandIt last, + Chunk chunk, ChunkRet*, int extra_split_factor, A&&... chunk_args) { + std::vector> futures; + auto& task_pool = *policy.pool(); auto chunk_size = get_chunk_size(first, last, extra_split_factor * task_pool.get_num_threads()); while (first < last) { auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); RandIt loop_end = advanced(first, iter_chunk_size); - futures.emplace_back(task_pool.submit(chunk, first, loop_end)); + futures.emplace_back(task_pool.submit(std::forward(chunk), first, loop_end, + std::forward(chunk_args)...)); first = loop_end; } @@ -849,28 +1667,36 @@ namespace poolstl { return futures; } + /** + * Chunk a single range. + */ + template + typename std::enable_if::value, void>::type + parallel_chunk_for_1_wait(ExecPolicy &&policy, RandIt first, RandIt last, + Chunk chunk, ChunkRet* rettype, int extra_split_factor, A&&... chunk_args) { + auto futures = parallel_chunk_for_1(std::forward(policy), first, last, + std::forward(chunk), rettype, extra_split_factor, + std::forward(chunk_args)...); + get_futures(futures); + } + /** * Element-wise chunk two ranges. */ - template - std::vector()( - std::declval(), - std::declval(), - std::declval()))>> - parallel_chunk_for(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, Chunk chunk) { - std::vector()( - std::declval(), - std::declval(), - std::declval())) - >> futures; - auto& task_pool = policy.pool(); + template + std::vector> + parallel_chunk_for_2(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, + Chunk chunk, ChunkRet*, A&&... chunk_args) { + std::vector> futures; + auto& task_pool = *policy.pool(); auto chunk_size = get_chunk_size(first1, last1, task_pool.get_num_threads()); while (first1 < last1) { auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); RandIt1 loop_end = advanced(first1, iter_chunk_size); - futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2)); + futures.emplace_back(task_pool.submit(std::forward(chunk), first1, loop_end, first2, + std::forward(chunk_args)...)); first1 = loop_end; std::advance(first2, iter_chunk_size); @@ -882,28 +1708,21 @@ namespace poolstl { /** * Element-wise chunk three ranges. */ - template - std::vector()( - std::declval(), - std::declval(), - std::declval(), - std::declval()))>> - parallel_chunk_for(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, RandIt3 first3, - Chunk chunk) { - std::vector()( - std::declval(), - std::declval(), - std::declval(), - std::declval())) - >> futures; - auto& task_pool = policy.pool(); + template + std::vector> + parallel_chunk_for_3(ExecPolicy &&policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, RandIt3 first3, + Chunk chunk, ChunkRet*, A&&... chunk_args) { + std::vector> futures; + auto& task_pool = *policy.pool(); auto chunk_size = get_chunk_size(first1, last1, task_pool.get_num_threads()); while (first1 < last1) { auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); RandIt1 loop_end = advanced(first1, iter_chunk_size); - futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2, first3)); + futures.emplace_back(task_pool.submit(std::forward(chunk), first1, loop_end, first2, first3, + std::forward(chunk_args)...)); first1 = loop_end; std::advance(first2, iter_chunk_size); @@ -916,28 +1735,26 @@ namespace poolstl { /** * Sort a range in parallel. * - * @param stable Whether to use std::stable_sort or std::sort + * @param sort_func Sequential sort method, like std::sort or std::stable_sort + * @param merge_func Sequential merge method, like std::inplace_merge */ - template - void parallel_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, bool stable) { + template + void parallel_sort(ExecPolicy &&policy, RandIt first, RandIt last, + Compare comp, SortFunc sort_func, MergeFunc merge_func) { if (first == last) { return; } // Sort chunks in parallel - auto futures = parallel_chunk_for(std::forward(policy), first, last, - [&comp, stable] (RandIt chunk_first, RandIt chunk_last) { - if (stable) { - std::stable_sort(chunk_first, chunk_last, comp); - } else { - std::sort(chunk_first, chunk_last, comp); - } + auto futures = parallel_chunk_for_gen(std::forward(policy), first, last, + [&comp, sort_func] (RandIt chunk_first, RandIt chunk_last) { + sort_func(chunk_first, chunk_last, comp); return std::make_pair(chunk_first, chunk_last); }); // Merge the sorted ranges using SortedRange = std::pair; - auto& task_pool = policy.pool(); + auto& task_pool = *policy.pool(); std::vector subranges; do { for (auto& future : futures) { @@ -950,9 +1767,10 @@ namespace poolstl { // pair up and merge auto& lhs = subranges[i]; auto& rhs = subranges[i + 1]; - futures.emplace_back(task_pool.submit([&comp] (RandIt chunk_first, RandIt chunk_middle, - RandIt chunk_last) { - std::inplace_merge(chunk_first, chunk_middle, chunk_last, comp); + futures.emplace_back(task_pool.submit([&comp, merge_func] (RandIt chunk_first, + RandIt chunk_middle, + RandIt chunk_last) { + merge_func(chunk_first, chunk_middle, chunk_last, comp); return std::make_pair(chunk_first, chunk_last); }, lhs.first, lhs.second, rhs.second)); ++i; @@ -973,6 +1791,56 @@ namespace poolstl { #endif +#ifndef POOLSTL_INTERNAL_THREAD_IMPL_HPP +#define POOLSTL_INTERNAL_THREAD_IMPL_HPP + +/** + * EXPERIMENTAL: Subject to significant changes or removal. + * An implementation using only std::thread and no thread pool at all. + * + * Advantage: + * - Fewer symbols (no packaged_task with its operators, destructors, vtable, etc) means smaller binary + * which can mean a lot when there are many calls like with many templates. + * - No thread pool to manage. + * + * Disadvantages: + * - Threads are started and joined for every operation, so it is harder to amortize that cost. + * - Barely any algorithms are supported. + */ + + + +namespace poolstl { + namespace internal { + + template + typename std::enable_if::value, void>::type + parallel_chunk_for_1_wait(ExecPolicy &&policy, RandIt first, RandIt last, + Chunk chunk, ChunkRet*, int extra_split_factor, A&&... chunk_args) { + std::vector threads; + auto chunk_size = get_chunk_size(first, last, extra_split_factor * policy.get_num_threads()); + + while (first < last) { + auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); + RandIt loop_end = advanced(first, iter_chunk_size); + + threads.emplace_back(std::thread(std::forward(chunk), first, loop_end, + std::forward(chunk_args)...)); + + first = loop_end; + } + + for (auto& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + } + } +} + +#endif + namespace std { /** @@ -980,12 +1848,14 @@ namespace std { * See std::copy https://en.cppreference.com/w/cpp/algorithm/copy */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy copy(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest) { - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, dest, - [](RandIt1 chunk_first, RandIt1 chunk_last, RandIt2 chunk_dest) { - std::copy(chunk_first, chunk_last, chunk_dest); - }); + if (poolstl::internal::is_seq(policy)) { + return std::copy(first, last, dest); + } + + auto futures = poolstl::internal::parallel_chunk_for_2(std::forward(policy), first, last, dest, + std::copy, (RandIt2*)nullptr); poolstl::internal::get_futures(futures); return poolstl::internal::advanced(dest, std::distance(first, last)); } @@ -995,7 +1865,7 @@ namespace std { * See std::copy_n https://en.cppreference.com/w/cpp/algorithm/copy_n */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy copy_n(ExecPolicy &&policy, RandIt1 first, Size n, RandIt2 dest) { if (n <= 0) { return dest; @@ -1010,14 +1880,17 @@ namespace std { * See std::count_if https://en.cppreference.com/w/cpp/algorithm/count_if */ template - poolstl::internal::enable_if_par::difference_type> + poolstl::internal::enable_if_poolstl_policy::difference_type> count_if(ExecPolicy&& policy, RandIt first, RandIt last, UnaryPredicate p) { + if (poolstl::internal::is_seq(policy)) { + return std::count_if(first, last, p); + } + using T = typename iterator_traits::difference_type; - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, - [&p](RandIt chunk_first, RandIt chunk_last) { - return std::count_if(chunk_first, chunk_last, p); - }); + auto futures = poolstl::internal::parallel_chunk_for_1(std::forward(policy), first, last, + std::count_if, + (T*)nullptr, 1, p); return poolstl::internal::cpp17::reduce( poolstl::internal::get_wrap(futures.begin()), @@ -1029,7 +1902,7 @@ namespace std { * See std::count https://en.cppreference.com/w/cpp/algorithm/count */ template - poolstl::internal::enable_if_par::difference_type> + poolstl::internal::enable_if_poolstl_policy::difference_type> count(ExecPolicy&& policy, RandIt first, RandIt last, const T& value) { return std::count_if(std::forward(policy), first, last, [&value](const T& test) { return test == value; }); @@ -1040,13 +1913,15 @@ namespace std { * See std::fill https://en.cppreference.com/w/cpp/algorithm/fill */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy fill(ExecPolicy &&policy, RandIt first, RandIt last, const Tp& value) { - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, - [&value](RandIt chunk_first, RandIt chunk_last) { - std::fill(chunk_first, chunk_last, value); - }); - poolstl::internal::get_futures(futures); + if (poolstl::internal::is_seq(policy)) { + std::fill(first, last, value); + return; + } + + poolstl::internal::parallel_chunk_for_1_wait(std::forward(policy), first, last, + std::fill, (void*)nullptr, 1, value); } /** @@ -1054,7 +1929,7 @@ namespace std { * See std::fill_n https://en.cppreference.com/w/cpp/algorithm/fill_n */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy fill_n(ExecPolicy &&policy, RandIt first, Size n, const Tp& value) { if (n <= 0) { return first; @@ -1069,13 +1944,17 @@ namespace std { * See std::find_if https://en.cppreference.com/w/cpp/algorithm/find_if */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy find_if(ExecPolicy &&policy, RandIt first, RandIt last, UnaryPredicate p) { + if (poolstl::internal::is_seq(policy)) { + return std::find_if(first, last, p); + } + using diff_t = typename std::iterator_traits::difference_type; diff_t n = std::distance(first, last); std::atomic extremum(n); - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + poolstl::internal::parallel_chunk_for_1_wait(std::forward(policy), first, last, [&first, &extremum, &p](RandIt chunk_first, RandIt chunk_last) { if (std::distance(first, chunk_first) > extremum) { // already found by another task @@ -1091,8 +1970,8 @@ namespace std { extremum.compare_exchange_weak(old, k); } } - }, 8); // use small tasks so later ones may exit early if item is already found - poolstl::internal::get_futures(futures); + }, (void*)nullptr, + 8); // use small tasks so later ones may exit early if item is already found return extremum == n ? last : first + extremum; } @@ -1101,10 +1980,15 @@ namespace std { * See std::find_if_not https://en.cppreference.com/w/cpp/algorithm/find_if_not */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy find_if_not(ExecPolicy &&policy, RandIt first, RandIt last, UnaryPredicate p) { return std::find_if(std::forward(policy), first, last, - [&p](const typename std::iterator_traits::value_type& test) { return !p(test); }); +#if POOLSTL_HAVE_CXX17_LIB + std::not_fn(p) +#else + [&p](const typename std::iterator_traits::value_type& test) { return !p(test); } +#endif + ); } /** @@ -1112,7 +1996,7 @@ namespace std { * See std::find https://en.cppreference.com/w/cpp/algorithm/find */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy find(ExecPolicy &&policy, RandIt first, RandIt last, const T& value) { return std::find_if(std::forward(policy), first, last, [&value](const T& test) { return value == test; }); @@ -1123,16 +2007,23 @@ namespace std { * See std::for_each https://en.cppreference.com/w/cpp/algorithm/for_each */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy for_each(ExecPolicy &&policy, RandIt first, RandIt last, UnaryFunction f) { - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, - [&f](RandIt chunk_first, RandIt chunk_last) { - // std::for_each(chunk_first, chunk_last, f); - for (; chunk_first != chunk_last; ++chunk_first) { - f(*chunk_first); - } - }); - poolstl::internal::get_futures(futures); + // Using a lambda instead of just calling the non-policy std::for_each because it appears to + // result in a smaller binary. + auto chunk_func = [&f](RandIt chunk_first, RandIt chunk_last) { + for (; chunk_first != chunk_last; ++chunk_first) { + f(*chunk_first); + } + }; + + if (poolstl::internal::is_seq(policy)) { + chunk_func(first, last); + return; + } + + poolstl::internal::parallel_chunk_for_1_wait(std::forward(policy), first, last, + chunk_func, (void*)nullptr, 1); } /** @@ -1140,7 +2031,7 @@ namespace std { * See std::for_each_n https://en.cppreference.com/w/cpp/algorithm/for_each_n */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy for_each_n(ExecPolicy &&policy, RandIt first, Size n, UnaryFunction f) { RandIt last = poolstl::internal::advanced(first, n); std::for_each(std::forward(policy), first, last, f); @@ -1152,9 +2043,15 @@ namespace std { * See std::sort https://en.cppreference.com/w/cpp/algorithm/sort */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp) { - poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, false); + if (poolstl::internal::is_seq(policy)) { + std::sort(first, last, comp); + return; + } + + poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, + std::sort, std::inplace_merge); } /** @@ -1162,10 +2059,10 @@ namespace std { * See std::sort https://en.cppreference.com/w/cpp/algorithm/sort */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy sort(ExecPolicy &&policy, RandIt first, RandIt last) { using T = typename std::iterator_traits::value_type; - poolstl::internal::parallel_sort(std::forward(policy), first, last, std::less(), false); + std::sort(std::forward(policy), first, last, std::less()); } /** @@ -1173,9 +2070,15 @@ namespace std { * See std::stable_sort https://en.cppreference.com/w/cpp/algorithm/stable_sort */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy stable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp) { - poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, true); + if (poolstl::internal::is_seq(policy)) { + std::stable_sort(first, last, comp); + return; + } + + poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, + std::stable_sort, std::inplace_merge); } /** @@ -1183,10 +2086,10 @@ namespace std { * See std::stable_sort https://en.cppreference.com/w/cpp/algorithm/stable_sort */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy stable_sort(ExecPolicy &&policy, RandIt first, RandIt last) { using T = typename std::iterator_traits::value_type; - poolstl::internal::parallel_sort(std::forward(policy), first, last, std::less(), true); + std::stable_sort(std::forward(policy), first, last, std::less()); } /** @@ -1194,14 +2097,17 @@ namespace std { * See std::transform https://en.cppreference.com/w/cpp/algorithm/transform */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy transform(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 dest, UnaryOperation unary_op) { + if (poolstl::internal::is_seq(policy)) { + return poolstl::internal::cpp17::transform(first1, last1, dest, unary_op); + } - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, dest, - [&unary_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt2 dest_first) { - return poolstl::internal::cpp17::transform(chunk_first1, chunk_last1, dest_first, unary_op); - }); + auto futures = poolstl::internal::parallel_chunk_for_2(std::forward(policy), first1, last1, dest, + poolstl::internal::cpp17::transform, + (RandIt2*)nullptr, unary_op); poolstl::internal::get_futures(futures); return dest + std::distance(first1, last1); } @@ -1211,16 +2117,18 @@ namespace std { * See std::transform https://en.cppreference.com/w/cpp/algorithm/transform */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy transform(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, RandIt3 dest, BinaryOperation binary_op) { + if (poolstl::internal::is_seq(policy)) { + return poolstl::internal::cpp17::transform(first1, last1, first2, dest, binary_op); + } - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, - first2, dest, - [&binary_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt1 chunk_first2, RandIt3 dest_first) { - return poolstl::internal::cpp17::transform(chunk_first1, chunk_last1, - chunk_first2, dest_first, binary_op); - }); + auto futures = poolstl::internal::parallel_chunk_for_3(std::forward(policy), first1, last1, + first2, dest, + poolstl::internal::cpp17::transform, + (RandIt3*)nullptr, binary_op); poolstl::internal::get_futures(futures); return dest + std::distance(first1, last1); } @@ -1230,7 +2138,7 @@ namespace std { * See std::all_of https://en.cppreference.com/w/cpp/algorithm/all_of */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy all_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { return last == std::find_if_not(std::forward(policy), first, last, pred); } @@ -1240,7 +2148,7 @@ namespace std { * See std::none_of https://en.cppreference.com/w/cpp/algorithm/none_of */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy none_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { return last == std::find_if(std::forward(policy), first, last, pred); } @@ -1250,7 +2158,7 @@ namespace std { * See std::any_of https://en.cppreference.com/w/cpp/algorithm/any_of */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy any_of(ExecPolicy&& policy, RandIt first, RandIt last, Predicate pred) { return !std::none_of(std::forward(policy), first, last, pred); } @@ -1280,13 +2188,52 @@ namespace poolstl { * but cannot be shared by all parallel iterations. */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy for_each_chunk(ExecPolicy&& policy, RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) { - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, - [&construct, &f](RandIt chunk_first, RandIt chunk_last) { - for_each_chunk(chunk_first, chunk_last, construct, f); - }); - poolstl::internal::get_futures(futures); + if (poolstl::internal::is_seq(policy)) { + for_each_chunk(first, last, construct, f); + return; + } + + poolstl::internal::parallel_chunk_for_1_wait(std::forward(policy), first, last, + for_each_chunk , + (void*)nullptr, 1, construct, f); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::sort`, but allows specifying the sequential sort and merge methods. These methods must have the + * same signature as the comparator versions of `std::sort` and `std::inplace_merge`, respectively. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, + void (sort_func)(RandIt, RandIt, Compare) = std::sort, + void (merge_func)(RandIt, RandIt, RandIt, Compare) = std::inplace_merge) { + if (poolstl::internal::is_seq(policy)) { + sort_func(first, last, comp); + return; + } + + poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, sort_func, merge_func); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::sort`, but allows specifying the sequential sort and merge methods. These methods must have the + * same signature as the comparator versions of `std::sort` and `std::inplace_merge`, respectively. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, + void (sort_func)(RandIt, RandIt, + std::less::value_type>) = std::sort, + void (merge_func)(RandIt, RandIt, RandIt, + std::less::value_type>) = std::inplace_merge){ + using T = typename std::iterator_traits::value_type; + pluggable_sort(std::forward(policy), first, last, std::less(), sort_func, merge_func); } } @@ -1306,14 +2253,18 @@ namespace std { * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init, BinaryOp binop) { if (first == last) { return dest; } + if (poolstl::internal::is_seq(policy)) { + return std::exclusive_scan(first, last, dest, init, binop); + } + // Pass 1: Chunk the input and find the sum of each chunk - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + auto futures = poolstl::internal::parallel_chunk_for_gen(std::forward(policy), first, last, [binop](RandIt1 chunk_first, RandIt1 chunk_last) { auto sum = std::accumulate(chunk_first, chunk_last, T{}, binop); return std::make_tuple(std::make_pair(chunk_first, chunk_last), sum); @@ -1355,7 +2306,7 @@ namespace std { * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init) { return std::exclusive_scan(std::forward(policy), first, last, dest, init, std::plus()); } @@ -1366,12 +2317,15 @@ namespace std { * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init, BinaryOp binop) { - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, - [init, binop](RandIt chunk_first, RandIt chunk_last) { - return poolstl::internal::cpp17::reduce(chunk_first, chunk_last, init, binop); - }); + if (poolstl::internal::is_seq(policy)) { + return poolstl::internal::cpp17::reduce(first, last, init, binop); + } + + auto futures = poolstl::internal::parallel_chunk_for_1(std::forward(policy), first, last, + poolstl::internal::cpp17::reduce, + (T*)nullptr, 1, init, binop); return poolstl::internal::cpp17::reduce( poolstl::internal::get_wrap(futures.begin()), @@ -1383,7 +2337,7 @@ namespace std { * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init) { return std::reduce(std::forward(policy), first, last, init, std::plus()); } @@ -1393,7 +2347,7 @@ namespace std { * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce */ template - poolstl::internal::enable_if_par< + poolstl::internal::enable_if_poolstl_policy< ExecPolicy, typename std::iterator_traits::value_type> reduce(ExecPolicy &&policy, RandIt first, RandIt last) { return std::reduce(std::forward(policy), first, last, @@ -1406,14 +2360,17 @@ namespace std { * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, T init, BinaryReductionOp reduce_op, UnaryTransformOp transform_op) { + if (poolstl::internal::is_seq(policy)) { + return std::transform_reduce(first1, last1, init, reduce_op, transform_op); + } - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, - [&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1) { - return std::transform_reduce(chunk_first1, chunk_last1, init, reduce_op, transform_op); - }); + auto futures = poolstl::internal::parallel_chunk_for_1(std::forward(policy), first1, last1, + std::transform_reduce, + (T*)nullptr, 1, init, reduce_op, transform_op); return poolstl::internal::cpp17::reduce( poolstl::internal::get_wrap(futures.begin()), @@ -1425,14 +2382,17 @@ namespace std { * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce */ template - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init, BinaryReductionOp reduce_op, BinaryTransformOp transform_op) { + if (poolstl::internal::is_seq(policy)) { + return std::transform_reduce(first1, last1, first2, init, reduce_op, transform_op); + } - auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first1, last1, first2, - [&init, &reduce_op, &transform_op](RandIt1 chunk_first1, RandIt1 chunk_last1, RandIt2 chunk_first2) { - return std::transform_reduce(chunk_first1, chunk_last1, chunk_first2, init, reduce_op, transform_op); - }); + auto futures = poolstl::internal::parallel_chunk_for_2(std::forward(policy), first1, last1, first2, + std::transform_reduce, + (T*)nullptr, init, reduce_op, transform_op); return poolstl::internal::cpp17::reduce( poolstl::internal::get_wrap(futures.begin()), @@ -1444,7 +2404,7 @@ namespace std { * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce */ template< class ExecPolicy, class RandIt1, class RandIt2, class T > - poolstl::internal::enable_if_par + poolstl::internal::enable_if_poolstl_policy transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init ) { return transform_reduce(std::forward(policy), first1, last1, first2, init, std::plus<>(), std::multiplies<>()); @@ -1455,100 +2415,6 @@ namespace std { #endif -#ifndef POOLSTL_SEQ_FWD_HPP -#define POOLSTL_SEQ_FWD_HPP - - -/* - * Forward poolstl::seq to the native sequential (no policy) method. - */ - -#define POOLSTL_DEFINE_SEQ_FWD(NS, FNAME) \ - template \ - auto FNAME(EP&&, ARGS&&...args) -> \ - poolstl::internal::enable_if_seq(args)...))> { \ - return NS::FNAME(std::forward(args)...); \ - } - -#define POOLSTL_DEFINE_SEQ_FWD_VOID(NS, FNAME) \ - template \ - poolstl::internal::enable_if_seq FNAME(EP&&, ARGS&&... args) { \ - NS::FNAME(std::forward(args)...); \ - } - -#if POOLSTL_HAVE_CXX17 - -/* - * Dynamically choose policy from a std::variant. - * Useful to choose between parallel and sequential policies at runtime via par_if. - */ - -#define POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) \ - template \ - poolstl::internal::enable_if_poolstl_variant FNAME(EP&& policy, ARGS&&...args) { \ - std::visit([&](auto&& pol) { NS::FNAME(pol, std::forward(args)...); }, policy.var); \ - } - -#define POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) \ - template \ - auto FNAME(EP&& policy, ARGS&&...args) -> \ - poolstl::internal::enable_if_poolstl_variant(args)...))> { \ - return std::visit([&](auto&& pol) { return NS::FNAME(pol, std::forward(args)...); }, policy.var); \ - } - -#else -#define POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) -#define POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) -#endif -/* - * Define both the sequential forward and dynamic chooser. - */ -#define POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(NS, FNAME) \ - POOLSTL_DEFINE_SEQ_FWD(NS, FNAME) \ - POOLSTL_DEFINE_PAR_IF_FWD(NS, FNAME) - -#define POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(NS, FNAME) \ - POOLSTL_DEFINE_SEQ_FWD_VOID(NS, FNAME) \ - POOLSTL_DEFINE_PAR_IF_FWD_VOID(NS, FNAME) - -namespace std { - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, all_of) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, any_of) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, none_of) - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, count) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, count_if) - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, copy) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, copy_n) - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(std, fill) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, fill_n) - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find_if) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, find_if_not) - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(std, for_each) -#if POOLSTL_HAVE_CXX17_LIB - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, for_each_n) -#endif - - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, transform) - -#if POOLSTL_HAVE_CXX17_LIB - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, exclusive_scan) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, reduce) - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF(std, transform_reduce) -#endif -} - -namespace poolstl { - POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(poolstl, for_each_chunk) -} - -#endif - // Note that iota_iter.hpp is self-contained in its own right. #ifndef POOLSTL_IOTA_ITER_HPP diff --git a/src/index.cpp b/src/index.cpp index 9b907257..9b3fbfab 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -165,13 +165,8 @@ void StrobemerIndex::populate(float f, unsigned n_threads) { Timer sorting_timer; logger.debug() << " Sorting ...\n"; - if (true) { - task_thread_pool::task_thread_pool pool{n_threads}; - std::sort(poolstl::par.on(pool), randstrobes.begin(), randstrobes.end()); - } else { - // sort by hash values - pdqsort_branchless(randstrobes.begin(), randstrobes.end()); - } + task_thread_pool::task_thread_pool pool{n_threads}; + poolstl::pluggable_sort(poolstl::par.on(pool), randstrobes.begin(), randstrobes.end(), pdqsort_branchless); stats.elapsed_sorting_seeds = sorting_timer.duration(); Timer hash_index_timer; From bc96aa77aff9764147e9eb48a4374815dedf14bf Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Mon, 22 Jan 2024 13:46:09 +0100 Subject: [PATCH 08/26] Bump to poolSTL 0.3.4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This version indeed gives us a very nice speed improvement without using extra memory: Sorting-only runtimes: * 1 thread: 32.8 s * 2 threads: 20.3 s * 4 threads: 15.7 s * 8 threads: 14.8 s Overall indexing runtimes (before/after): * 1 thread: 151 s → 153 s * 2 threads: 100 s → 88 s * 4 threads: 73 s → 57 s * 8 threads: 63 s → 47 s --- CHANGES.md | 4 + README.md | 2 +- ext/README.md | 4 +- ext/poolstl/poolstl.hpp | 284 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 275 insertions(+), 19 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 266b5832..644e9644 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## development version +* #386: Parallelize indexing even more by using @alugowski’s + [poolSTL](https://github.com/alugowski/) `pluggable_sort`. + Indexing a human reference (measured on CHM13) now takes only ~45 s on a + recent machine (using 8 threads). * #376: Improve accuracy for read length 50 by optimizing the default indexing parameters. Paired-end accuracy increases by 0.3 percentage points on average. Single-end accuracy increases by 1 percentage point. diff --git a/README.md b/README.md index f7ec6e05..1ef03a51 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Strobealign is a read mapper that is typically significantly faster than other r - Map single-end and paired-end reads - Multithreading support -- Fast indexing (1-2 minutes for a human-sized reference genome using four cores) +- Fast indexing (<1 minute for a human-sized reference genome using four cores) - On-the-fly indexing by default. Optionally create an on-disk index. - Output in standard SAM format or produce even faster results by writing PAF (without alignments) - Strobealign is most suited for read lengths between 100 and 500 bp diff --git a/ext/README.md b/ext/README.md index d80e5a2d..a3480467 100644 --- a/ext/README.md +++ b/ext/README.md @@ -30,8 +30,8 @@ License: See pdqsort/license.txt ## poolstl Homepage: https://github.com/alugowski/poolSTL/ -Downloaded file: https://github.com/alugowski/poolSTL/releases/download/v0.3.3/poolstl.hpp -Version: 0.3.3 +Downloaded file: https://github.com/alugowski/poolSTL/releases/download/v0.3.4/poolstl.hpp +Version: 0.3.4 License: See poolstl.hpp ## robin_hood diff --git a/ext/poolstl/poolstl.hpp b/ext/poolstl/poolstl.hpp index d1340a1e..8d569ecb 100644 --- a/ext/poolstl/poolstl.hpp +++ b/ext/poolstl/poolstl.hpp @@ -505,7 +505,7 @@ namespace task_thread_pool { // Version macros. #define POOLSTL_VERSION_MAJOR 0 #define POOLSTL_VERSION_MINOR 3 -#define POOLSTL_VERSION_PATCH 3 +#define POOLSTL_VERSION_PATCH 4 #include #include @@ -596,6 +596,28 @@ namespace poolstl { } } + /** + * Identify a pivot element for quicksort. Chooses the middle element of the range. + */ + template + typename std::iterator_traits::value_type quicksort_pivot(Iterator first, Iterator last) { + return *(std::next(first, std::distance(first, last) / 2)); + } + + /** + * Predicate for std::partition (for quicksort) + */ + template + struct pivot_predicate { + pivot_predicate(Compare comp, const T& pivot) : comp(comp), pivot(pivot) {} + + bool operator()(const T& em) { + return comp(em, pivot); + } + Compare comp; + const T pivot; + }; + /* * Some methods are only available with C++17 and up. Reimplement on older standards. */ @@ -1265,7 +1287,7 @@ namespace task_thread_pool { // Version macros. #define POOLSTL_VERSION_MAJOR 0 #define POOLSTL_VERSION_MINOR 3 -#define POOLSTL_VERSION_PATCH 3 +#define POOLSTL_VERSION_PATCH 4 #include #include @@ -1356,6 +1378,28 @@ namespace poolstl { } } + /** + * Identify a pivot element for quicksort. Chooses the middle element of the range. + */ + template + typename std::iterator_traits::value_type quicksort_pivot(Iterator first, Iterator last) { + return *(std::next(first, std::distance(first, last) / 2)); + } + + /** + * Predicate for std::partition (for quicksort) + */ + template + struct pivot_predicate { + pivot_predicate(Compare comp, const T& pivot) : comp(comp), pivot(pivot) {} + + bool operator()(const T& em) { + return comp(em, pivot); + } + Compare comp; + const T pivot; + }; + /* * Some methods are only available with C++17 and up. Reimplement on older standards. */ @@ -1739,8 +1783,8 @@ namespace poolstl { * @param merge_func Sequential merge method, like std::inplace_merge */ template - void parallel_sort(ExecPolicy &&policy, RandIt first, RandIt last, - Compare comp, SortFunc sort_func, MergeFunc merge_func) { + void parallel_mergesort(ExecPolicy &&policy, RandIt first, RandIt last, + Compare comp, SortFunc sort_func, MergeFunc merge_func) { if (first == last) { return; } @@ -1786,6 +1830,103 @@ namespace poolstl { } while (futures.size() > 1); futures.front().get(); } + + /** + * Quicksort worker function. + */ + template + void quicksort_impl(task_thread_pool::task_thread_pool* task_pool, const RandIt first, const RandIt last, + Compare comp, SortFunc sort_func, PartFunc part_func, PivotFunc pivot_func, + std::ptrdiff_t target_leaf_size, + std::vector>* futures, std::mutex* mutex, + std::condition_variable* cv, int* inflight_spawns) { + using T = typename std::iterator_traits::value_type; + + auto partition_size = std::distance(first, last); + + if (partition_size > target_leaf_size) { + // partition the range + auto mid = part_func(first, last, pivot_predicate(comp, pivot_func(first, last))); + + if (mid != first && mid != last) { + // was able to partition the range, so recurse + std::lock_guard guard(*mutex); + ++(*inflight_spawns); + + futures->emplace_back(task_pool->submit( + quicksort_impl, + task_pool, first, mid, comp, sort_func, part_func, pivot_func, target_leaf_size, + futures, mutex, cv, inflight_spawns)); + + futures->emplace_back(task_pool->submit( + quicksort_impl, + task_pool, mid, last, comp, sort_func, part_func, pivot_func, target_leaf_size, + futures, mutex, cv, inflight_spawns)); + return; + } + } + + // Range does not need to be subdivided (or was unable to subdivide). Run the sequential sort. + { + // notify main thread that partitioning may be finished + std::lock_guard guard(*mutex); + --(*inflight_spawns); + } + cv->notify_one(); + + sort_func(first, last, comp); + } + + /** + * Sort a range in parallel using quicksort. + * + * @param sort_func Sequential sort method, like std::sort or std::stable_sort + * @param part_func Method that partitions a range, like std::partition or std::stable_partition + * @param pivot_func Method that identifies the pivot + */ + template + void parallel_quicksort(ExecPolicy &&policy, RandIt first, RandIt last, + Compare comp, SortFunc sort_func, PartFunc part_func, PivotFunc pivot_func) { + if (first == last) { + return; + } + + auto& task_pool = *policy.pool(); + + // Target partition size. Range will be recursively partitioned into partitions no bigger than this + // size. Target approximately twice as many partitions as threads to reduce impact of uneven pivot + // selection. + std::ptrdiff_t target_leaf_size = std::max(std::distance(first, last) / (task_pool.get_num_threads() * 2), + (std::ptrdiff_t)5); + + // task_thread_pool does not support creating task DAGs, so organize the code such that + // all parallel tasks are independent. The parallel tasks can spawn additional parallel tasks, and they + // record their "child" task's std::future into a common vector to be waited on by the main thread. + std::mutex mutex; + + // Futures of parallel tasks. Access protected by mutex. + std::vector> futures; + + // For signaling that all partitioning has been completed and futures vector is complete. Uses mutex. + std::condition_variable cv; + + // Number of `quicksort_impl` calls that haven't finished yet. Nonzero value means futures vector may + // still be modified. Access protected by mutex. + int inflight_spawns = 1; + + // Root task. + quicksort_impl(&task_pool, first, last, comp, sort_func, part_func, pivot_func, target_leaf_size, + &futures, &mutex, &cv, &inflight_spawns); + + // Wait for all partitioning to finish. + { + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return inflight_spawns == 0; }); + } + + // Wait on all the parallel tasks. + get_futures(futures); + } } } @@ -2050,8 +2191,11 @@ namespace std { return; } - poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, - std::sort, std::inplace_merge); + poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, + std::sort, + std::partition::value_type>>, + poolstl::internal::quicksort_pivot); } /** @@ -2077,8 +2221,11 @@ namespace std { return; } - poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, - std::stable_sort, std::inplace_merge); + poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, + std::stable_sort, + std::stable_partition::value_type>>, + poolstl::internal::quicksort_pivot); } /** @@ -2203,37 +2350,142 @@ namespace poolstl { /** * NOTE: Iterators are expected to be random access. * - * Like `std::sort`, but allows specifying the sequential sort and merge methods. These methods must have the - * same signature as the comparator versions of `std::sort` and `std::inplace_merge`, respectively. + * Like `std::sort`, but allows specifying the sequential sort method, which must have the + * same signature as the comparator version of `std::sort`. + * + * Implemented as a high-level quicksort that delegates to `sort_func`, in parallel, once the range has been + * sufficiently partitioned. */ template poolstl::internal::enable_if_poolstl_policy pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, - void (sort_func)(RandIt, RandIt, Compare) = std::sort, - void (merge_func)(RandIt, RandIt, RandIt, Compare) = std::inplace_merge) { + void (sort_func)(RandIt, RandIt, Compare) = std::sort) { if (poolstl::internal::is_seq(policy)) { sort_func(first, last, comp); return; } - poolstl::internal::parallel_sort(std::forward(policy), first, last, comp, sort_func, merge_func); + poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, sort_func, + std::partition::value_type>>, + poolstl::internal::quicksort_pivot); } /** * NOTE: Iterators are expected to be random access. * - * Like `std::sort`, but allows specifying the sequential sort and merge methods. These methods must have the - * same signature as the comparator versions of `std::sort` and `std::inplace_merge`, respectively. + * Like `std::sort`, but allows specifying the sequential sort method, which must have the + * same signature as the comparator version of `std::sort`. + * + * Implemented as a parallel high-level quicksort that delegates to `sort_func` once the range has been + * sufficiently partitioned. */ template poolstl::internal::enable_if_poolstl_policy pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, + void (sort_func)(RandIt, RandIt, + std::less::value_type>) = std::sort){ + using T = typename std::iterator_traits::value_type; + pluggable_sort(std::forward(policy), first, last, std::less(), sort_func); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Parallel merge sort. + * + * @param comp Comparator. + * @param sort_func Sequential sort method. Must have the same signature as the comparator version of `std::sort`. + * @param merge_func Sequential merge method. Must have the same signature as `std::inplace_merge`. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_mergesort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, + void (sort_func)(RandIt, RandIt, Compare) = std::sort, + void (merge_func)(RandIt, RandIt, RandIt, Compare) = std::inplace_merge) { + if (poolstl::internal::is_seq(policy)) { + sort_func(first, last, comp); + return; + } + + poolstl::internal::parallel_mergesort(std::forward(policy), + first, last, comp, sort_func, merge_func); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Parallel merge sort. + * + * Uses `std::less` comparator. + * + * @param sort_func Sequential sort method. Must have the same signature as the comparator version of `std::sort`. + * @param merge_func Sequential merge method. Must have the same signature as `std::inplace_merge`. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_mergesort(ExecPolicy &&policy, RandIt first, RandIt last, void (sort_func)(RandIt, RandIt, std::less::value_type>) = std::sort, void (merge_func)(RandIt, RandIt, RandIt, std::less::value_type>) = std::inplace_merge){ using T = typename std::iterator_traits::value_type; - pluggable_sort(std::forward(policy), first, last, std::less(), sort_func, merge_func); + pluggable_mergesort(std::forward(policy), first, last, std::less(), sort_func, merge_func); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Parallel quicksort that allows specifying the sequential sort and partition methods. + * + * @param comp Comparator. + * @param sort_func Sequential sort method to use once range is sufficiently partitioned. Must have the same + * signature as the comparator version of `std::sort`. + * @param part_func Sequential partition method. Must have the same signature as `std::partition`. + * @param pivot_func Method that identifies the pivot element + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_quicksort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, + void (sort_func)(RandIt, RandIt, Compare) = std::sort, + RandIt (part_func)(RandIt, RandIt, poolstl::internal::pivot_predicate::value_type>) = std::partition, + typename std::iterator_traits::value_type (pivot_func)(RandIt, RandIt) = + poolstl::internal::quicksort_pivot) { + if (poolstl::internal::is_seq(policy)) { + sort_func(first, last, comp); + return; + } + + poolstl::internal::parallel_quicksort(std::forward(policy), + first, last, comp, sort_func, part_func, pivot_func); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Parallel quicksort that allows specifying the sequential sort and partition methods. + * + * Uses `std::less` comparator. + * + * @param sort_func Sequential sort method to use once range is sufficiently partitioned. Must have the same + * signature as the comparator version of `std::sort`. + * @param part_func Sequential partition method. Must have the same signature as `std::partition`. + * @param pivot_func Method that identifies the pivot element + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_quicksort(ExecPolicy &&policy, RandIt first, RandIt last, + void (sort_func)(RandIt, RandIt, + std::less::value_type>) = std::sort, + RandIt (part_func)(RandIt, RandIt, poolstl::internal::pivot_predicate< + std::less::value_type>, + typename std::iterator_traits::value_type>) = std::partition, + typename std::iterator_traits::value_type (pivot_func)(RandIt, RandIt) = + poolstl::internal::quicksort_pivot) { + using T = typename std::iterator_traits::value_type; + pluggable_quicksort(std::forward(policy), first, last, std::less(), + sort_func, part_func, pivot_func); } } From 1c6376705abd8c83221f831d8d0f7c6ec6df612a Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Mon, 22 Jan 2024 20:45:06 +0100 Subject: [PATCH 09/26] Update baseline commit --- tests/baseline-commit.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/baseline-commit.txt b/tests/baseline-commit.txt index 9c1e9b61..e1ff8c67 100644 --- a/tests/baseline-commit.txt +++ b/tests/baseline-commit.txt @@ -1 +1 @@ -0ced9903276834e6b9bfe095a255952f0616d330 +a905f7bdd2dcc2e843b0cbac23b51912adadfe7a From 6e07fccf29dd5f1423d594ac40fc3cbc046a6050 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Wed, 31 Jan 2024 09:26:05 +0100 Subject: [PATCH 10/26] Bump poolSTL to 0.3.5 --- ext/README.md | 4 +- ext/poolstl/poolstl.hpp | 185 +++++++++++++++++++++++++++------------- 2 files changed, 126 insertions(+), 63 deletions(-) diff --git a/ext/README.md b/ext/README.md index a3480467..55e874b7 100644 --- a/ext/README.md +++ b/ext/README.md @@ -30,8 +30,8 @@ License: See pdqsort/license.txt ## poolstl Homepage: https://github.com/alugowski/poolSTL/ -Downloaded file: https://github.com/alugowski/poolSTL/releases/download/v0.3.4/poolstl.hpp -Version: 0.3.4 +Downloaded file: https://github.com/alugowski/poolSTL/releases/download/v0.3.5/poolstl.hpp +Version: 0.3.5 License: See poolstl.hpp ## robin_hood diff --git a/ext/poolstl/poolstl.hpp b/ext/poolstl/poolstl.hpp index 8d569ecb..77c3e7a0 100644 --- a/ext/poolstl/poolstl.hpp +++ b/ext/poolstl/poolstl.hpp @@ -505,7 +505,7 @@ namespace task_thread_pool { // Version macros. #define POOLSTL_VERSION_MAJOR 0 #define POOLSTL_VERSION_MINOR 3 -#define POOLSTL_VERSION_PATCH 4 +#define POOLSTL_VERSION_PATCH 5 #include #include @@ -1287,7 +1287,7 @@ namespace task_thread_pool { // Version macros. #define POOLSTL_VERSION_MAJOR 0 #define POOLSTL_VERSION_MINOR 3 -#define POOLSTL_VERSION_PATCH 4 +#define POOLSTL_VERSION_PATCH 5 #include #include @@ -1679,7 +1679,7 @@ namespace poolstl { auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); RandIt loop_end = advanced(first, iter_chunk_size); - futures.emplace_back(task_pool.submit(std::forward(chunk), first, loop_end)); + futures.emplace_back(task_pool.submit(chunk, first, loop_end)); first = loop_end; } @@ -1702,8 +1702,7 @@ namespace poolstl { auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); RandIt loop_end = advanced(first, iter_chunk_size); - futures.emplace_back(task_pool.submit(std::forward(chunk), first, loop_end, - std::forward(chunk_args)...)); + futures.emplace_back(task_pool.submit(chunk, first, loop_end, chunk_args...)); first = loop_end; } @@ -1719,8 +1718,7 @@ namespace poolstl { parallel_chunk_for_1_wait(ExecPolicy &&policy, RandIt first, RandIt last, Chunk chunk, ChunkRet* rettype, int extra_split_factor, A&&... chunk_args) { auto futures = parallel_chunk_for_1(std::forward(policy), first, last, - std::forward(chunk), rettype, extra_split_factor, - std::forward(chunk_args)...); + chunk, rettype, extra_split_factor, chunk_args...); get_futures(futures); } @@ -1739,8 +1737,7 @@ namespace poolstl { auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); RandIt1 loop_end = advanced(first1, iter_chunk_size); - futures.emplace_back(task_pool.submit(std::forward(chunk), first1, loop_end, first2, - std::forward(chunk_args)...)); + futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2, chunk_args...)); first1 = loop_end; std::advance(first2, iter_chunk_size); @@ -1765,8 +1762,7 @@ namespace poolstl { auto iter_chunk_size = get_iter_chunk_size(first1, last1, chunk_size); RandIt1 loop_end = advanced(first1, iter_chunk_size); - futures.emplace_back(task_pool.submit(std::forward(chunk), first1, loop_end, first2, first3, - std::forward(chunk_args)...)); + futures.emplace_back(task_pool.submit(chunk, first1, loop_end, first2, first3, chunk_args...)); first1 = loop_end; std::advance(first2, iter_chunk_size); @@ -1896,9 +1892,14 @@ namespace poolstl { // Target partition size. Range will be recursively partitioned into partitions no bigger than this // size. Target approximately twice as many partitions as threads to reduce impact of uneven pivot // selection. - std::ptrdiff_t target_leaf_size = std::max(std::distance(first, last) / (task_pool.get_num_threads() * 2), + auto num_threads = task_pool.get_num_threads(); + std::ptrdiff_t target_leaf_size = std::max(std::distance(first, last) / (num_threads * 2), (std::ptrdiff_t)5); + if (num_threads == 1) { + target_leaf_size = std::distance(first, last); + } + // task_thread_pool does not support creating task DAGs, so organize the code such that // all parallel tasks are independent. The parallel tasks can spawn additional parallel tasks, and they // record their "child" task's std::future into a common vector to be waited on by the main thread. @@ -1927,6 +1928,39 @@ namespace poolstl { // Wait on all the parallel tasks. get_futures(futures); } + + /** + * Partition range according to predicate. Unstable. + * + * This implementation only parallelizes with p=2; will spawn and wait for only one task. + */ + template + RandIt partition_p2(task_thread_pool::task_thread_pool &task_pool, RandIt first, RandIt last, Predicate pred) { + auto range_size = std::distance(first, last); + if (range_size < 4) { + return std::partition(first, last, pred); + } + + // approach should be generalizable to arbitrary p + + RandIt mid = std::next(first + range_size / 2); + + // partition left and right halves in parallel + auto left_future = task_pool.submit(std::partition, first, mid, pred); + RandIt right_mid = std::partition(mid, last, pred); + RandIt left_mid = left_future.get(); + + // merge the two partitioned halves + auto left_highs_size = std::distance(left_mid, mid); + auto right_lows_size = std::distance(mid, right_mid); + if (left_highs_size <= right_lows_size) { + std::swap_ranges(left_mid, mid, right_mid - left_highs_size); + return right_mid - left_highs_size; + } else { + std::swap_ranges(mid, right_mid, left_mid); + return left_mid + right_lows_size; + } + } } } @@ -1965,8 +1999,7 @@ namespace poolstl { auto iter_chunk_size = get_iter_chunk_size(first, last, chunk_size); RandIt loop_end = advanced(first, iter_chunk_size); - threads.emplace_back(std::thread(std::forward(chunk), first, loop_end, - std::forward(chunk_args)...)); + threads.emplace_back(std::thread(chunk, first, loop_end, chunk_args...)); first = loop_end; } @@ -1982,6 +2015,66 @@ namespace poolstl { #endif +namespace poolstl { + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::sort`, but allows specifying the sequential sort method, which must have the + * same signature as the comparator version of `std::sort`. + * + * Implemented as a high-level quicksort that delegates to `sort_func`, in parallel, once the range has been + * sufficiently partitioned. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, + void (sort_func)(RandIt, RandIt, Compare) = std::sort) { + if (poolstl::internal::is_seq(policy)) { + sort_func(first, last, comp); + return; + } + + // Parallel partition. + // The partition_p2 method spawns and waits for its own child task. A deadlock is possible if all worker + // threads are waiting for tasks that in turn have to workers to execute them. This is only an issue because + // our thread pool does not have the concept of dependencies. + // So ensure + auto& task_pool = *policy.pool(); + std::atomic allowed_parallel_partitions{(int)task_pool.get_num_threads() / 2}; + + auto part_func = [&task_pool, &allowed_parallel_partitions](RandIt chunk_first, RandIt chunk_last, + poolstl::internal::pivot_predicate::value_type> pred) { + if (allowed_parallel_partitions.fetch_sub(1) > 0) { + return poolstl::internal::partition_p2(task_pool, chunk_first, chunk_last, pred); + } else { + return std::partition(chunk_first, chunk_last, pred); + } + }; + + poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, sort_func, part_func, + poolstl::internal::quicksort_pivot); + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::sort`, but allows specifying the sequential sort method, which must have the + * same signature as the comparator version of `std::sort`. + * + * Implemented as a parallel high-level quicksort that delegates to `sort_func` once the range has been + * sufficiently partitioned. + */ + template + poolstl::internal::enable_if_poolstl_policy + pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, + void (sort_func)(RandIt, RandIt, + std::less::value_type>) = std::sort){ + using T = typename std::iterator_traits::value_type; + pluggable_sort(std::forward(policy), first, last, std::less(), sort_func); + } +} + namespace std { /** @@ -2179,6 +2272,22 @@ namespace std { return last; } + /** + * NOTE: Iterators are expected to be random access. + * See std::partition https://en.cppreference.com/w/cpp/algorithm/partition + * + * Current implementation uses at most 2 threads. + */ + template + poolstl::internal::enable_if_poolstl_policy + partition(ExecPolicy &&policy, RandIt first, RandIt last, Predicate pred) { + if (poolstl::internal::is_seq(policy)) { + return std::partition(first, last, pred); + } + + return poolstl::internal::partition_p2(*policy.pool(), first, last, pred); + } + /** * NOTE: Iterators are expected to be random access. * See std::sort https://en.cppreference.com/w/cpp/algorithm/sort @@ -2191,11 +2300,7 @@ namespace std { return; } - poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, - std::sort, - std::partition::value_type>>, - poolstl::internal::quicksort_pivot); + poolstl::pluggable_sort(std::forward(policy), first, last, comp, std::sort); } /** @@ -2347,48 +2452,6 @@ namespace poolstl { (void*)nullptr, 1, construct, f); } - /** - * NOTE: Iterators are expected to be random access. - * - * Like `std::sort`, but allows specifying the sequential sort method, which must have the - * same signature as the comparator version of `std::sort`. - * - * Implemented as a high-level quicksort that delegates to `sort_func`, in parallel, once the range has been - * sufficiently partitioned. - */ - template - poolstl::internal::enable_if_poolstl_policy - pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, Compare comp, - void (sort_func)(RandIt, RandIt, Compare) = std::sort) { - if (poolstl::internal::is_seq(policy)) { - sort_func(first, last, comp); - return; - } - - poolstl::internal::parallel_quicksort(std::forward(policy), first, last, comp, sort_func, - std::partition::value_type>>, - poolstl::internal::quicksort_pivot); - } - - /** - * NOTE: Iterators are expected to be random access. - * - * Like `std::sort`, but allows specifying the sequential sort method, which must have the - * same signature as the comparator version of `std::sort`. - * - * Implemented as a parallel high-level quicksort that delegates to `sort_func` once the range has been - * sufficiently partitioned. - */ - template - poolstl::internal::enable_if_poolstl_policy - pluggable_sort(ExecPolicy &&policy, RandIt first, RandIt last, - void (sort_func)(RandIt, RandIt, - std::less::value_type>) = std::sort){ - using T = typename std::iterator_traits::value_type; - pluggable_sort(std::forward(policy), first, last, std::less(), sort_func); - } - /** * NOTE: Iterators are expected to be random access. * From 3540334a1b53586aeac72430ad260ca2ad21f3fa Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Wed, 31 Jan 2024 09:29:27 +0100 Subject: [PATCH 11/26] Update baseline commit --- tests/baseline-commit.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/baseline-commit.txt b/tests/baseline-commit.txt index e1ff8c67..d2baa623 100644 --- a/tests/baseline-commit.txt +++ b/tests/baseline-commit.txt @@ -1 +1 @@ -a905f7bdd2dcc2e843b0cbac23b51912adadfe7a +cc6928611965881a6b533d05483fb93ff18752b3 From af350a63fe1d44f828a83d6bfae468ee21930e95 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Wed, 14 Feb 2024 20:45:31 +0100 Subject: [PATCH 12/26] Ensure sorting of randstrobes is reproducible MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ... by sorting them also by position. This way, it does not matter in which way the randstrobes vector is partitioned during sort. Otherwise, the order would depend on the number of threads used to create the index, making mapping results not reproducible across runs that do not use the same no. of threads. Note: There is still a tiny chance for collisions/nondeterminism because we ignore RefRandstrobe::m_packed. For efficiency, this uses a branchless comparison function inspired by @alugowski’s comment in PR #386. Runtimes for sorting with one thread 32 s - only by hash 45 s - by hash and position using std::tie 42 s - by hash and position using branchless_compare from the PR 35 s - by hash and position using __uint128_t (this commit) Runtimes for sorting (four cores with hyperthreading) threads | sorting time | index creation time -|-|- 1 | 35 s | 154 s 2 | 25 s | 95 s 4 | 16 s | 60 s 8 | 15 s | 48 s --- src/randstrobes.hpp | 8 +++++++- tests/baseline-commit.txt | 2 +- tests/compare-baseline.sh | 8 ++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/randstrobes.hpp b/src/randstrobes.hpp index 7a117b3c..0c72aaed 100644 --- a/src/randstrobes.hpp +++ b/src/randstrobes.hpp @@ -30,7 +30,13 @@ struct RefRandstrobe { , m_packed(packed) { } bool operator< (const RefRandstrobe& other) const { - return hash < other.hash; + // Compare both hash and position to ensure that the order of the + // RefRandstrobes in the index is reproducible no matter which sorting + // function is used. This branchless comparison is faster than the + // equivalent one using std::tie. + __uint128_t lhs = (static_cast<__uint128_t>(hash) << 64) | position; + __uint128_t rhs = (static_cast<__uint128_t>(other.hash) << 64) | other.position; + return lhs < rhs; } int reference_index() const { diff --git a/tests/baseline-commit.txt b/tests/baseline-commit.txt index d2baa623..9ffa9ec8 100644 --- a/tests/baseline-commit.txt +++ b/tests/baseline-commit.txt @@ -1 +1 @@ -cc6928611965881a6b533d05483fb93ff18752b3 +24995f4168108232528fbe5132441c9f1d0401b3 diff --git a/tests/compare-baseline.sh b/tests/compare-baseline.sh index 57f9faac..4bc2223b 100755 --- a/tests/compare-baseline.sh +++ b/tests/compare-baseline.sh @@ -13,8 +13,12 @@ set -euo pipefail python3 -c 'import pysam' ends="pe" -while getopts "s" opt; do +threads=4 +while getopts "st:" opt; do case "${opt}" in + t) + threads=$OPTARG + ;; s) ends=se # single-end reads ;; @@ -38,7 +42,7 @@ baseline_commit=$(< tests/baseline-commit.txt) baseline_bam=baseline/bam/${baseline_commit}.${ends}.bam baseline_binary=baseline/strobealign-${baseline_commit} cmake_options=-DCMAKE_BUILD_TYPE=RelWithDebInfo -strobealign_options="-t 4" +strobealign_options="-t ${threads}" # Generate the baseline BAM if necessary mkdir -p baseline/bam From f1de1e2613da2670aa99dbf9b00884026e96fe87 Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Wed, 14 Feb 2024 21:41:22 +0100 Subject: [PATCH 13/26] Update baseline commit --- tests/baseline-commit.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/baseline-commit.txt b/tests/baseline-commit.txt index 9ffa9ec8..471fbbbb 100644 --- a/tests/baseline-commit.txt +++ b/tests/baseline-commit.txt @@ -1 +1 @@ -24995f4168108232528fbe5132441c9f1d0401b3 +2e4ff9500e68d6e465735dd276d362cf71851dcd From dad60c84ca3b0891891ea7ee0ffdaddea0d01a1c Mon Sep 17 00:00:00 2001 From: Marcel Martin Date: Fri, 16 Feb 2024 15:10:08 +0100 Subject: [PATCH 14/26] Introduce canonical read length 75 Closes #395 --- CHANGES.md | 4 ++++ README.md | 6 +++--- src/indexparameters.cpp | 3 ++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 644e9644..2d907665 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,10 @@ * #376: Improve accuracy for read length 50 by optimizing the default indexing parameters. Paired-end accuracy increases by 0.3 percentage points on average. Single-end accuracy increases by 1 percentage point. +* #395: Previously, read length 75 used the same indexing parameters as length + 50, but the improved settings for length 50 are not the best for length 75. + To avoid a decrease in accuracy, we introduced a new set of pre-defined + indexing parameters for read length 75 (a new canonical read length). * If `--details` is used, output `X0:i` SAM tag with the number of identically-scored best alignments * #378: Added `-C` option for appending the FASTA or FASTQ comment to SAM diff --git a/README.md b/README.md index 1ef03a51..f2d0f8c1 100644 --- a/README.md +++ b/README.md @@ -145,9 +145,9 @@ options. Some important ones are: Strobealign needs to build an index (strobemer index) of the reference before it can map reads to it. The optimal indexing parameters depend on the length of the input reads. -There are currently seven different pre-defined sets of parameters that are -optimized for different read lengths. These *canonical read lengths* are -50, 100, 125, 150, 250 and 400. When deciding which of the pre-defined +There are pre-defined sets of parameters that are optimized for different read +lengths. These *canonical read lengths* are +50, 75, 100, 125, 150, 250 and 400. When deciding which of the pre-defined indexing parameter sets to use, strobealign chooses one whose canonical read length is close to the average read length of the input. diff --git a/src/indexparameters.cpp b/src/indexparameters.cpp index 2f634e9f..0c655903 100644 --- a/src/indexparameters.cpp +++ b/src/indexparameters.cpp @@ -35,7 +35,8 @@ struct Profile { static auto max{std::numeric_limits::max()}; static std::vector profiles = { - Profile{ 50, 90, 18, -4, -2, 1}, + Profile{ 50, 70, 18, -4, -2, 1}, + Profile{ 75, 90, 20, -4, -3, 2}, Profile{100, 110, 20, -4, -2, 2}, Profile{125, 135, 20, -4, -1, 4}, Profile{150, 175, 20, -4, 1, 7}, From 3eebfe9058ffff04b6451f5e874c15cd4f6294e8 Mon Sep 17 00:00:00 2001 From: Luis Pedro Coelho Date: Fri, 9 Feb 2024 15:19:38 +1000 Subject: [PATCH 15/26] Explicit error if too many sequences are used MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Right now, strobealign only supports up to 2²⁴ sequences. If the user tries more, it would silently accept it, but later crash. This was triggered when trying to map to the Greengenes database https://ftp.microbio.me/greengenes_release/2022.10/ https://ftp.microbio.me/greengenes_release/2022.10/2022.10.seqs.fna.gz Even a single read like the one below trigger a crash ``` @M05314:127:000000000-BWLLJ:1:1101:15267:1654 2:N:0:1 CCTGTTCGCTCCCCACGCTTTCGTCCCTCAGCGTCAATATTGTGCCAGAATGCTGCCTTCGCCATTGGTGTTCCTCCTGATATCTACGCATGTCACCGCTACACCAGGAATTCCACATTCCTCTCACATATTCTATTTTATCAGTTTTGAT + AAA1AF@1>AAAGG1A0EAFGGEHAAEGFCG1AAEE/F2FG2F2FF1CA0FBDED1BGFGFFE?AF1BFFCFHDGFFHB1FFGFGEEFE/?/BF2F@/EGEEB00/0//0BFG1>B1BGFEFHHGGFFD12BGH2FDFFFGG22GDD>@/F ``` --- src/main.cpp | 5 +++++ src/randstrobes.hpp | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/src/main.cpp b/src/main.cpp index dcc96774..f9b8f65a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,6 +25,7 @@ #include "timer.hpp" #include "readlen.hpp" #include "version.hpp" +#include "randstrobes.hpp" #include "buildconfig.hpp" @@ -209,6 +210,10 @@ int run_strobealign(int argc, char **argv) { throw InvalidFasta("No reference sequences found"); } + if (references.size() > RefRandstrobe::max_number_of_references) { + throw InvalidFasta("Too many reference sequences. Current maximum is " + std::to_string(RefRandstrobe::max_number_of_references)); + } + StrobemerIndex index(references, index_parameters, opt.bits); if (opt.use_index) { // Read the index from a file diff --git a/src/randstrobes.hpp b/src/randstrobes.hpp index 0c72aaed..933fc89f 100644 --- a/src/randstrobes.hpp +++ b/src/randstrobes.hpp @@ -47,10 +47,14 @@ struct RefRandstrobe { return m_packed & mask; } + private: static constexpr int bit_alloc = 8; static constexpr int mask = (1 << bit_alloc) - 1; packed_t m_packed; // packed representation of ref_index and strobe offset + +public: + static constexpr uint32_t max_number_of_references = (1 << (32 - bit_alloc)) - 1; }; struct QueryRandstrobe { From 9a79d897c56e2b3edc1e257c18d3875e2456064d Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Mon, 12 Feb 2024 13:59:45 +0800 Subject: [PATCH 16/26] Update aemb module --- src/aln.cpp | 131 +++++++++++++++++++++++++++++++++++++++++++----- src/aln.hpp | 7 ++- src/cmdline.cpp | 2 + src/cmdline.hpp | 1 + src/main.cpp | 40 +++++++++++---- src/pc.cpp | 10 ++-- src/pc.hpp | 3 +- 7 files changed, 164 insertions(+), 30 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index b5105bb1..7b88d3b1 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -868,12 +868,19 @@ inline void get_best_map_location( std::vector &nams2, InsertSizeDistribution &isize_est, Nam &best_nam1, - Nam &best_nam2 + Nam &best_nam2, + std::vector &abundance, + int read1_len, + int read2_len, + bool output_abundance ) { std::vector nam_pairs = get_best_scoring_nam_pairs(nams1, nams2, isize_est.mu, isize_est.sigma); best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped + std::vector best_ref1; + std::vector best_ref2; + if (nam_pairs.empty()) { return; } @@ -903,6 +910,73 @@ inline void get_best_map_location( if (score_joint > score_indiv) { // joint score is better than individual best_nam1 = n1_joint_max; best_nam2 = n2_joint_max; + + if (output_abundance){ + for (auto &[score, n1, n2] : nam_pairs){ + if ((n1.score + n2.score) == score_joint){ + best_ref1.push_back(n1); + best_ref2.push_back(n2); + }else{ + break; + } + } + + int ref_size1 = best_ref1.size(); + for (auto &t: best_ref1){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read1_len) / float(ref_size1); + } + + int ref_size2 = best_ref2.size(); + for (auto &t: best_ref2){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read2_len) / float(ref_size2); + } + } + + } + else{ + if (output_abundance){ + if (!nams1.empty()){ + for (auto &t : nams1){ + if (t.score == nams1[0].score){ + best_ref1.push_back(t); + }else{ + break; + } + } + + int ref_size1 = best_ref1.size(); + for (auto &t: best_ref1){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read1_len) / float(ref_size1); + } + } + + if (!nams2.empty()){ + for (auto &t : nams2){ + if (t.score == nams2[0].score){ + best_ref2.push_back(t); + }else{ + break; + } + } + + int ref_size2 = best_ref2.size(); + for (auto &t: best_ref2){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(read2_len) / float(ref_size2); + } + } + } } if (isize_est.sample_size < 400 && score_joint > score_indiv) { @@ -957,7 +1031,8 @@ void align_or_map_paired( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ) { std::array details; std::array, 2> nams_pair; @@ -991,12 +1066,19 @@ void align_or_map_paired( } Timer extend_timer; + if (map_param.is_abundance_out){ + Nam nam_read1; + Nam nam_read2; + get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundance, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); + } + else{ if (!map_param.is_sam_out) { Nam nam_read1; Nam nam_read2; get_best_map_location(nams_pair[0], nams_pair[1], isize_est, nam_read1, - nam_read2); + nam_read2, abundance, record1.seq.length(), + record2.seq.length(), map_param.is_abundance_out); output_hits_paf_PE(outstring, nam_read1, record1.name, references, record1.seq.length()); @@ -1066,6 +1148,7 @@ void align_or_map_paired( ); } } + } statistics.tot_extend += extend_timer.duration(); statistics += details[0]; statistics += details[1]; @@ -1082,7 +1165,8 @@ void align_or_map_single( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ) { Details details; Timer strobe_timer; @@ -1111,15 +1195,36 @@ void align_or_map_single( Timer extend_timer; - if (!map_param.is_sam_out) { - output_hits_paf(outstring, nams, record.name, references, - record.seq.length()); - } else { - align_single( - aligner, sam, nams, record, index_parameters.syncmer.k, - references, details, map_param.dropoff_threshold, map_param.max_tries, - map_param.max_secondary, random_engine - ); + std::vector best_ref; + if (map_param.is_abundance_out){ + if (!nams.empty()){ + for (auto &t : nams){ + if (t.score == nams[0].score){ + best_ref.push_back(t); + }else{ + break; + } + } + int ref_size = best_ref.size(); + for (auto &t: best_ref){ + if (t.ref_start < 0) { + continue; + } + abundance[t.ref_id] += float(record.seq.length()) / float(ref_size); + } + } + } + else{ + if (!map_param.is_sam_out) { + output_hits_paf(outstring, nams, record.name, references, + record.seq.length()); + } else { + align_single( + aligner, sam, nams, record, index_parameters.syncmer.k, + references, details, map_param.dropoff_threshold, map_param.max_tries, + map_param.max_secondary, random_engine + ); + } } statistics.tot_extend += extend_timer.duration(); statistics += details; diff --git a/src/aln.hpp b/src/aln.hpp index f8bb69bf..99bcfedb 100644 --- a/src/aln.hpp +++ b/src/aln.hpp @@ -64,6 +64,7 @@ struct MappingParameters { int max_tries { 20 }; int rescue_cutoff; bool is_sam_out { true }; + bool is_abundance_out {false}; CigarOps cigar_ops{CigarOps::M}; bool output_unmapped { true }; bool details{false}; @@ -88,7 +89,8 @@ void align_or_map_paired( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ); void align_or_map_single( @@ -101,7 +103,8 @@ void align_or_map_single( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine + std::minstd_rand& random_engine, + std::vector &abundance ); // Private declarations, only needed for tests diff --git a/src/cmdline.cpp b/src/cmdline.cpp index c00c5479..510821b6 100644 --- a/src/cmdline.cpp +++ b/src/cmdline.cpp @@ -31,6 +31,7 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { args::ValueFlag index_statistics(parser, "PATH", "Print statistics of indexing to PATH", {"index-statistics"}); args::Flag i(parser, "index", "Do not map reads; only generate the strobemer index and write it to disk. If read files are provided, they are used to estimate read length", {"create-index", 'i'}); args::Flag use_index(parser, "use_index", "Use a pre-generated index previously written with --create-index.", { "use-index" }); + args::Flag aemb(parser, "aemb", "Only output abundance value of contigs for metagenomic binning", {"aemb"}); args::Group sam(parser, "SAM output:"); args::Flag eqx(parser, "eqx", "Emit =/X instead of M CIGAR operations", {"eqx"}); @@ -97,6 +98,7 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { if (index_statistics) { opt.logfile_name = args::get(index_statistics); } if (i) { opt.only_gen_index = true; } if (use_index) { opt.use_index = true; } + if (aemb) {opt.is_abundance_out = true; } // SAM output if (eqx) { opt.cigar_eqx = true; } diff --git a/src/cmdline.hpp b/src/cmdline.hpp index 2ae3186e..9f5b4a05 100644 --- a/src/cmdline.hpp +++ b/src/cmdline.hpp @@ -18,6 +18,7 @@ struct CommandLineOptions { bool only_gen_index { false }; bool use_index { false }; bool is_sam_out { true }; + bool is_abundance_out {false}; // SAM output bool cigar_eqx { false }; diff --git a/src/main.cpp b/src/main.cpp index f9b8f65a..c012ecfb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -189,6 +189,7 @@ int run_strobealign(int argc, char **argv) { map_param.output_unmapped = opt.output_unmapped; map_param.details = opt.details; map_param.fastq_comments = opt.fastq_comments; + map_param.is_abundance_out = opt.is_abundance_out; map_param.verify(); log_parameters(index_parameters, map_param, aln_params); @@ -289,31 +290,33 @@ int run_strobealign(int argc, char **argv) { std::ostream out(buf); - if (map_param.is_sam_out) { - std::stringstream cmd_line; - for(int i = 0; i < argc; ++i) { - cmd_line << argv[i] << " "; - } + if (!map_param.is_abundance_out){ + if (map_param.is_sam_out) { + std::stringstream cmd_line; + for(int i = 0; i < argc; ++i) { + cmd_line << argv[i] << " "; + } - out << sam_header(references, opt.read_group_id, opt.read_group_fields); - if (opt.pg_header) { - out << pg_header(cmd_line.str()); + out << sam_header(references, opt.read_group_id, opt.read_group_fields); + if (opt.pg_header) { + out << pg_header(cmd_line.str()); + } } } std::vector log_stats_vec(opt.n_threads); - + logger.info() << "Running in " << (opt.is_SE ? "single-end" : "paired-end") << " mode" << std::endl; OutputBuffer output_buffer(out); - std::vector workers; std::vector worker_done(opt.n_threads); // each thread sets its entry to 1 when it’s done + std::vector> align_abundances(opt.n_threads, std::vector(references.size(), 0)); for (int i = 0; i < opt.n_threads; ++i) { std::thread consumer(perform_task, std::ref(input_buffer), std::ref(output_buffer), std::ref(log_stats_vec[i]), std::ref(worker_done[i]), std::ref(aln_params), std::ref(map_param), std::ref(index_parameters), std::ref(references), - std::ref(index), std::ref(opt.read_group_id)); + std::ref(index), std::ref(opt.read_group_id), std::ref(align_abundances[i])); workers.push_back(std::move(consumer)); } if (opt.show_progress && isatty(2)) { @@ -329,6 +332,21 @@ int run_strobealign(int argc, char **argv) { tot_statistics += it; } + if (map_param.is_abundance_out){ + std::vector abundances(references.size(), 0); + std::vector abundances_norm(references.size(), 0); + for (size_t i = 0; i < align_abundances.size(); ++i) { + for (size_t j = 0; j < align_abundances[i].size(); ++j) { + abundances[j] += align_abundances[i][j]; + } + } + + for (size_t i = 0; i < references.size(); ++i) { + std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / float(references.sequences[i].size()) << std::endl; + } + } + + logger.info() << "Total mapping sites tried: " << tot_statistics.tot_all_tried << std::endl << "Total calls to ssw: " << tot_statistics.tot_aligner_calls << std::endl << "Inconsistent NAM ends: " << tot_statistics.inconsistent_nams << std::endl diff --git a/src/pc.cpp b/src/pc.cpp index 5011bf21..34da9c97 100644 --- a/src/pc.cpp +++ b/src/pc.cpp @@ -139,7 +139,8 @@ void perform_task( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - const std::string& read_group_id + const std::string& read_group_id, + std::vector &abundance ) { bool eof = false; Aligner aligner{aln_params}; @@ -170,16 +171,19 @@ void perform_task( to_uppercase(record1.seq); to_uppercase(record2.seq); align_or_map_paired(record1, record2, sam, sam_out, statistics, isize_est, aligner, - map_param, index_parameters, references, index, random_engine); + map_param, index_parameters, references, index, random_engine, abundance); statistics.n_reads += 2; } for (size_t i = 0; i < records3.size(); ++i) { auto record = records3[i]; - align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine); + align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundance); statistics.n_reads++; } + + if (!map_param.is_abundance_out){ output_buffer.output_records(std::move(sam_out), chunk_index); assert(sam_out == ""); + } } statistics.tot_aligner_calls += aligner.calls_count(); done = true; diff --git a/src/pc.hpp b/src/pc.hpp index 703de209..b0fa6ad8 100644 --- a/src/pc.hpp +++ b/src/pc.hpp @@ -64,7 +64,8 @@ class OutputBuffer { void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer, AlignmentStatistics& statistics, int& done, const AlignmentParameters &aln_params, - const MappingParameters &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id); + const MappingParameters &map_param, const IndexParameters& index_parameters, + const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundance); bool same_name(const std::string& n1, const std::string& n2); From 358a5b23e7e146e347cb4d6dc6004714512fbb5c Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 22:48:33 +0800 Subject: [PATCH 17/26] Code improve; documentation update; testing update --- README.md | 6 +++ src/aln.cpp | 86 +++++++++++++++++++++--------------------- src/aln.hpp | 4 +- src/cmdline.cpp | 2 +- src/main.cpp | 36 ++++++++++-------- src/pc.cpp | 6 +-- src/pc.hpp | 2 +- tests/phix.abun.pe.txt | 1 + tests/phix.abun.se.txt | 1 + tests/run.sh | 10 +++++ 10 files changed, 90 insertions(+), 64 deletions(-) create mode 100644 tests/phix.abun.pe.txt create mode 100644 tests/phix.abun.se.txt diff --git a/README.md b/README.md index f2d0f8c1..5d736fa4 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ strobealign ref.fa reads.1.fastq.gz reads.2.fastq.gz | samtools sort -o sorted.b This is usually faster than doing the two steps separately because fewer intermediate files are created. +To output the estimated abundance of every contig, the format of output file is: contig_id \t abundance_value: +``` +strobealign ref.fa reads.fq --aemb > abundance.txt # Single-end reads +strobealign ref.fa reads1.fq reads2.fq --aemb > abundance.txt # Paired-end reads +``` ## Command-line options @@ -127,6 +132,7 @@ options. Some important ones are: * `--eqx`: Emit `=` and `X` CIGAR operations instead of `M`. * `-x`: Only map reads, do not do no base-level alignment. This switches the output format from SAM to [PAF](https://github.com/lh3/miniasm/blob/master/PAF.md). +* `--aemb`: Output the estimated abundance value of every contig, the format of output file is: contig_id \t abundance_value. * `--rg-id=ID`: Add RG tag to each SAM record. * `--rg=TAG:VALUE`: Add read group metadata to the SAM header. This can be specified multiple times. Example: `--rg-id=1 --rg=SM:mysamle --rg=LB:mylibrary`. diff --git a/src/aln.cpp b/src/aln.cpp index 7b88d3b1..e66f9039 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -869,7 +869,7 @@ inline void get_best_map_location( InsertSizeDistribution &isize_est, Nam &best_nam1, Nam &best_nam2, - std::vector &abundance, + std::vector &abundances, int read1_len, int read2_len, bool output_abundance @@ -878,8 +878,8 @@ inline void get_best_map_location( best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped - std::vector best_ref1; - std::vector best_ref2; + std::vector best_contig1; + std::vector best_contig2; if (nam_pairs.empty()) { return; @@ -912,68 +912,70 @@ inline void get_best_map_location( best_nam2 = n2_joint_max; if (output_abundance){ + // find all NAM pairs that have the same best score for (auto &[score, n1, n2] : nam_pairs){ if ((n1.score + n2.score) == score_joint){ - best_ref1.push_back(n1); - best_ref2.push_back(n2); + best_contig1.push_back(n1); + best_contig2.push_back(n2); }else{ break; } } - int ref_size1 = best_ref1.size(); - for (auto &t: best_ref1){ + size_t contig_size1 = best_contig1.size(); + for (auto &t: best_contig1){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read1_len) / float(ref_size1); + abundances[t.ref_id] += float(read1_len) / float(contig_size1); } - int ref_size2 = best_ref2.size(); - for (auto &t: best_ref2){ + int contig_size2 = best_contig2.size(); + for (auto &t: best_contig2){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read2_len) / float(ref_size2); + abundances[t.ref_id] += float(read2_len) / float(contig_size2); } - } - + } } else{ if (output_abundance){ if (!nams1.empty()){ + // find all NAM1 that have the same best score for (auto &t : nams1){ if (t.score == nams1[0].score){ - best_ref1.push_back(t); + best_contig1.push_back(t); }else{ break; } } - int ref_size1 = best_ref1.size(); - for (auto &t: best_ref1){ + size_t contig_size1 = best_contig1.size(); + for (auto &t: best_contig1){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read1_len) / float(ref_size1); + abundances[t.ref_id] += float(read1_len) / float(contig_size1); } } if (!nams2.empty()){ + // find all NAM2 that have the same best score for (auto &t : nams2){ if (t.score == nams2[0].score){ - best_ref2.push_back(t); + best_contig2.push_back(t); }else{ break; } } - int ref_size2 = best_ref2.size(); - for (auto &t: best_ref2){ + int contig_size2 = best_contig2.size(); + for (auto &t: best_contig2){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(read2_len) / float(ref_size2); + abundances[t.ref_id] += float(read2_len) / float(contig_size2); } } } @@ -1032,7 +1034,7 @@ void align_or_map_paired( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ) { std::array details; std::array, 2> nams_pair; @@ -1069,22 +1071,22 @@ void align_or_map_paired( if (map_param.is_abundance_out){ Nam nam_read1; Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundance, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); + get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundances, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); } else{ - if (!map_param.is_sam_out) { - Nam nam_read1; - Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est, - nam_read1, - nam_read2, abundance, record1.seq.length(), - record2.seq.length(), map_param.is_abundance_out); - output_hits_paf_PE(outstring, nam_read1, record1.name, - references, - record1.seq.length()); - output_hits_paf_PE(outstring, nam_read2, record2.name, - references, - record2.seq.length()); + if (!map_param.is_sam_out) { + Nam nam_read1; + Nam nam_read2; + get_best_map_location(nams_pair[0], nams_pair[1], isize_est, + nam_read1, + nam_read2, abundances, record1.seq.length(), + record2.seq.length(), map_param.is_abundance_out); + output_hits_paf_PE(outstring, nam_read1, record1.name, + references, + record1.seq.length()); + output_hits_paf_PE(outstring, nam_read2, record2.name, + references, + record2.seq.length()); } else { Read read1(record1.seq); Read read2(record2.seq); @@ -1166,7 +1168,7 @@ void align_or_map_single( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ) { Details details; Timer strobe_timer; @@ -1195,22 +1197,22 @@ void align_or_map_single( Timer extend_timer; - std::vector best_ref; + std::vector best_contig; if (map_param.is_abundance_out){ if (!nams.empty()){ for (auto &t : nams){ if (t.score == nams[0].score){ - best_ref.push_back(t); + best_contig.push_back(t); }else{ break; } } - int ref_size = best_ref.size(); - for (auto &t: best_ref){ + int contig_size = best_contig.size(); + for (auto &t: best_contig){ if (t.ref_start < 0) { continue; } - abundance[t.ref_id] += float(record.seq.length()) / float(ref_size); + abundances[t.ref_id] += float(record.seq.length()) / float(contig_size); } } } diff --git a/src/aln.hpp b/src/aln.hpp index 99bcfedb..a9f24e27 100644 --- a/src/aln.hpp +++ b/src/aln.hpp @@ -90,7 +90,7 @@ void align_or_map_paired( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ); void align_or_map_single( @@ -104,7 +104,7 @@ void align_or_map_single( const References& references, const StrobemerIndex& index, std::minstd_rand& random_engine, - std::vector &abundance + std::vector &abundances ); // Private declarations, only needed for tests diff --git a/src/cmdline.cpp b/src/cmdline.cpp index 510821b6..5ad83821 100644 --- a/src/cmdline.cpp +++ b/src/cmdline.cpp @@ -27,11 +27,11 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) { args::Flag v(parser, "v", "Verbose output", {'v'}); args::Flag no_progress(parser, "no-progress", "Disable progress report (enabled by default if output is a terminal)", {"no-progress"}); args::Flag x(parser, "x", "Only map reads, no base level alignment (produces PAF file)", {'x'}); + args::Flag aemb(parser, "aemb", "Output the estimated abundance value of contigs, the format of output file is: contig_id \t abundance_value", {"aemb"}); args::Flag interleaved(parser, "interleaved", "Interleaved input", {"interleaved"}); args::ValueFlag index_statistics(parser, "PATH", "Print statistics of indexing to PATH", {"index-statistics"}); args::Flag i(parser, "index", "Do not map reads; only generate the strobemer index and write it to disk. If read files are provided, they are used to estimate read length", {"create-index", 'i'}); args::Flag use_index(parser, "use_index", "Use a pre-generated index previously written with --create-index.", { "use-index" }); - args::Flag aemb(parser, "aemb", "Only output abundance value of contigs for metagenomic binning", {"aemb"}); args::Group sam(parser, "SAM output:"); args::Flag eqx(parser, "eqx", "Emit =/X instead of M CIGAR operations", {"eqx"}); diff --git a/src/main.cpp b/src/main.cpp index c012ecfb..811420d6 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -105,6 +105,12 @@ InputBuffer get_input_buffer(const CommandLineOptions& opt) { } } +void output_abundance(std::vector abundances, References references){ + for (size_t i = 0; i < references.size(); ++i) { + std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / double(references.sequences[i].size()) << std::endl; + } +} + void show_progress_until_done(std::vector& worker_done, std::vector& stats) { Timer timer; bool reported = false; @@ -155,6 +161,11 @@ int run_strobealign(int argc, char **argv) { if (opt.c >= 64 || opt.c <= 0) { throw BadParameter("c must be greater than 0 and less than 64"); } + + if (!opt.is_sam_out && opt.is_abundance_out){ + throw BadParameter("Can not use -x and --aemb at the same time"); + } + InputBuffer input_buffer = get_input_buffer(opt); if (!opt.r_set && !opt.reads_filename1.empty()) { opt.r = estimate_read_length(input_buffer); @@ -290,18 +301,15 @@ int run_strobealign(int argc, char **argv) { std::ostream out(buf); - if (!map_param.is_abundance_out){ - if (map_param.is_sam_out) { + if (map_param.is_sam_out && !map_param.is_abundance_out){ std::stringstream cmd_line; for(int i = 0; i < argc; ++i) { cmd_line << argv[i] << " "; } - out << sam_header(references, opt.read_group_id, opt.read_group_fields); if (opt.pg_header) { out << pg_header(cmd_line.str()); } - } } std::vector log_stats_vec(opt.n_threads); @@ -311,12 +319,12 @@ int run_strobealign(int argc, char **argv) { OutputBuffer output_buffer(out); std::vector workers; std::vector worker_done(opt.n_threads); // each thread sets its entry to 1 when it’s done - std::vector> align_abundances(opt.n_threads, std::vector(references.size(), 0)); + std::vector> worker_abundances(opt.n_threads, std::vector(references.size(), 0)); for (int i = 0; i < opt.n_threads; ++i) { std::thread consumer(perform_task, std::ref(input_buffer), std::ref(output_buffer), std::ref(log_stats_vec[i]), std::ref(worker_done[i]), std::ref(aln_params), std::ref(map_param), std::ref(index_parameters), std::ref(references), - std::ref(index), std::ref(opt.read_group_id), std::ref(align_abundances[i])); + std::ref(index), std::ref(opt.read_group_id), std::ref(worker_abundances[i])); workers.push_back(std::move(consumer)); } if (opt.show_progress && isatty(2)) { @@ -333,17 +341,15 @@ int run_strobealign(int argc, char **argv) { } if (map_param.is_abundance_out){ - std::vector abundances(references.size(), 0); - std::vector abundances_norm(references.size(), 0); - for (size_t i = 0; i < align_abundances.size(); ++i) { - for (size_t j = 0; j < align_abundances[i].size(); ++j) { - abundances[j] += align_abundances[i][j]; - } + std::vector abundances(references.size(), 0); + for (size_t i = 0; i < worker_abundances.size(); ++i) { + for (size_t j = 0; j < worker_abundances[i].size(); ++j) { + abundances[j] += worker_abundances[i][j]; } + } - for (size_t i = 0; i < references.size(); ++i) { - std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / float(references.sequences[i].size()) << std::endl; - } + // output the abundance file + output_abundance(abundances, references); } diff --git a/src/pc.cpp b/src/pc.cpp index 34da9c97..417717ea 100644 --- a/src/pc.cpp +++ b/src/pc.cpp @@ -140,7 +140,7 @@ void perform_task( const References& references, const StrobemerIndex& index, const std::string& read_group_id, - std::vector &abundance + std::vector &abundances ) { bool eof = false; Aligner aligner{aln_params}; @@ -171,12 +171,12 @@ void perform_task( to_uppercase(record1.seq); to_uppercase(record2.seq); align_or_map_paired(record1, record2, sam, sam_out, statistics, isize_est, aligner, - map_param, index_parameters, references, index, random_engine, abundance); + map_param, index_parameters, references, index, random_engine, abundances); statistics.n_reads += 2; } for (size_t i = 0; i < records3.size(); ++i) { auto record = records3[i]; - align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundance); + align_or_map_single(record, sam, sam_out, statistics, aligner, map_param, index_parameters, references, index, random_engine, abundances); statistics.n_reads++; } diff --git a/src/pc.hpp b/src/pc.hpp index b0fa6ad8..87e4e873 100644 --- a/src/pc.hpp +++ b/src/pc.hpp @@ -65,7 +65,7 @@ class OutputBuffer { void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer, AlignmentStatistics& statistics, int& done, const AlignmentParameters &aln_params, const MappingParameters &map_param, const IndexParameters& index_parameters, - const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundance); + const References& references, const StrobemerIndex& index, const std::string& read_group_id, std::vector &abundances); bool same_name(const std::string& n1, const std::string& n2); diff --git a/tests/phix.abun.pe.txt b/tests/phix.abun.pe.txt new file mode 100644 index 00000000..7e86894b --- /dev/null +++ b/tests/phix.abun.pe.txt @@ -0,0 +1 @@ +NC_001422.1 4.572291 diff --git a/tests/phix.abun.se.txt b/tests/phix.abun.se.txt new file mode 100644 index 00000000..62b75067 --- /dev/null +++ b/tests/phix.abun.se.txt @@ -0,0 +1 @@ +NC_001422.1 2.313690 diff --git a/tests/run.sh b/tests/run.sh index f82dbb5b..d7d6f064 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -58,6 +58,16 @@ strobealign -x tests/phix.fasta tests/phix.1.fastq tests/phix.2.fastq | tail -n diff tests/phix.pe.paf phix.pe.paf rm phix.pe.paf +# Single-end abundance estimation +strobealign --aemb tests/phix.fasta tests/phix.1.fastq > phix.abun.se.txt +diff tests/phix.abun.se.txt phix.abun.se.txt +rm phix.abun.se.txt + +# Paired-end abundance estimation +strobealign --aemb tests/phix.fasta tests/phix.1.fastq tests/phix.2.fastq > phix.abun.pe.txt +diff tests/phix.abun.pe.txt phix.abun.pe.txt +rm phix.abun.pe.txt + # Build a separate index strobealign --no-PG -r 150 tests/phix.fasta tests/phix.1.fastq > without-sti.sam strobealign -r 150 -i tests/phix.fasta From cec26e8b4b60a5fbe82de27463105825a0d5bd3e Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 23:29:13 +0800 Subject: [PATCH 18/26] Update single-end abundance testing file --- tests/phix.abun.se.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/phix.abun.se.txt b/tests/phix.abun.se.txt index 62b75067..55e31ba6 100644 --- a/tests/phix.abun.se.txt +++ b/tests/phix.abun.se.txt @@ -1 +1 @@ -NC_001422.1 2.313690 +NC_001422.1 2.347196 From 7be8e2c592081479b82299e6e0eae61c80240095 Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Tue, 20 Feb 2024 23:31:32 +0800 Subject: [PATCH 19/26] Update paired-end abundance testing file --- tests/phix.abun.pe.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/phix.abun.pe.txt b/tests/phix.abun.pe.txt index 7e86894b..900b9de5 100644 --- a/tests/phix.abun.pe.txt +++ b/tests/phix.abun.pe.txt @@ -1 +1 @@ -NC_001422.1 4.572291 +NC_001422.1 4.638507 From 47770eed948f7a2d6fcf987c56cf4dd2f418aab8 Mon Sep 17 00:00:00 2001 From: Luis Pedro Coelho Date: Wed, 21 Feb 2024 14:35:10 +1000 Subject: [PATCH 20/26] RFCT Simplify code by merging duplications --- src/aln.cpp | 97 ++++++++++++++++++++--------------------------------- 1 file changed, 36 insertions(+), 61 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index e66f9039..a9f7ac27 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -869,18 +869,15 @@ inline void get_best_map_location( InsertSizeDistribution &isize_est, Nam &best_nam1, Nam &best_nam2, - std::vector &abundances, - int read1_len, + std::vector &abundances, + int read1_len, int read2_len, - bool output_abundance + bool output_abundance ) { std::vector nam_pairs = get_best_scoring_nam_pairs(nams1, nams2, isize_est.mu, isize_est.sigma); best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped - std::vector best_contig1; - std::vector best_contig2; - if (nam_pairs.empty()) { return; } @@ -912,71 +909,49 @@ inline void get_best_map_location( best_nam2 = n2_joint_max; if (output_abundance){ - // find all NAM pairs that have the same best score + // we loop twice because we need to count the number of best pairs + size_t n_best = 0; for (auto &[score, n1, n2] : nam_pairs){ if ((n1.score + n2.score) == score_joint){ - best_contig1.push_back(n1); - best_contig2.push_back(n2); - }else{ + ++n_best; + } else { break; } } - size_t contig_size1 = best_contig1.size(); - for (auto &t: best_contig1){ - if (t.ref_start < 0) { - continue; - } - abundances[t.ref_id] += float(read1_len) / float(contig_size1); - } - - int contig_size2 = best_contig2.size(); - for (auto &t: best_contig2){ - if (t.ref_start < 0) { - continue; - } - abundances[t.ref_id] += float(read2_len) / float(contig_size2); - } - } - } - else{ - if (output_abundance){ - if (!nams1.empty()){ - // find all NAM1 that have the same best score - for (auto &t : nams1){ - if (t.score == nams1[0].score){ - best_contig1.push_back(t); - }else{ - break; + for (auto &[score, n1, n2] : nam_pairs){ + if ((n1.score + n2.score) == score_joint){ + if (n1.ref_start >= 0) { + abundances[n1.ref_id] += float(read1_len) / float(n_best); } - } - - size_t contig_size1 = best_contig1.size(); - for (auto &t: best_contig1){ - if (t.ref_start < 0) { - continue; + if (n2.ref_start >= 0) { + abundances[n2.ref_id] += float(read2_len) / float(n_best); } - abundances[t.ref_id] += float(read1_len) / float(contig_size1); + } else { + break; } + } } - - if (!nams2.empty()){ - // find all NAM2 that have the same best score - for (auto &t : nams2){ - if (t.score == nams2[0].score){ - best_contig2.push_back(t); - }else{ - break; - } + } else if (output_abundance) { + for (auto &[nams, read_len]: { std::make_pair(std::cref(nams1), read1_len), + std::make_pair(std::cref(nams2), read2_len) }) { + size_t best_score = 0; + // We loop twice because we need to count the number of NAMs with best score + for (auto &t : nams) { + if (t.score == nams[0].score){ + ++best_score; + } else { + break; } - - int contig_size2 = best_contig2.size(); - for (auto &t: best_contig2){ - if (t.ref_start < 0) { - continue; - } - abundances[t.ref_id] += float(read2_len) / float(contig_size2); + } + for (auto &t: nams) { + if (t.ref_start < 0) { + continue; + } + if (t.score != nams[0].score){ + break; } + abundances[t.ref_id] += float(read_len) / float(best_score); } } } @@ -1033,7 +1008,7 @@ void align_or_map_paired( const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, - std::minstd_rand& random_engine, + std::minstd_rand& random_engine, std::vector &abundances ) { std::array details; @@ -1079,7 +1054,7 @@ void align_or_map_paired( Nam nam_read2; get_best_map_location(nams_pair[0], nams_pair[1], isize_est, nam_read1, - nam_read2, abundances, record1.seq.length(), + nam_read2, abundances, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); output_hits_paf_PE(outstring, nam_read1, record1.name, references, From 677edcc66c982932d13eda41fb1a69bea2d4ab0e Mon Sep 17 00:00:00 2001 From: Luis Pedro Coelho Date: Wed, 21 Feb 2024 14:36:05 +1000 Subject: [PATCH 21/26] RFCT Simplify output_format handling This is 3-state variable: SAM/PAF/Abundance --- src/aln.cpp | 71 ++++++++++++++++++++++++++-------------------------- src/aln.hpp | 9 +++++-- src/main.cpp | 10 +++++--- src/pc.cpp | 6 ++--- 4 files changed, 51 insertions(+), 45 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index a9f7ac27..2b092684 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -1043,25 +1043,23 @@ void align_or_map_paired( } Timer extend_timer; - if (map_param.is_abundance_out){ + if (map_param.output_format == OutputFormat::Abundance) { Nam nam_read1; Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundances, record1.seq.length(), record2.seq.length(), map_param.is_abundance_out); - } - else{ - if (!map_param.is_sam_out) { - Nam nam_read1; - Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est, - nam_read1, - nam_read2, abundances, record1.seq.length(), - record2.seq.length(), map_param.is_abundance_out); - output_hits_paf_PE(outstring, nam_read1, record1.name, - references, - record1.seq.length()); - output_hits_paf_PE(outstring, nam_read2, record2.name, - references, - record2.seq.length()); + get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundances, record1.seq.length(), record2.seq.length(), true); + } else if (map_param.output_format == OutputFormat::PAF) { + Nam nam_read1; + Nam nam_read2; + get_best_map_location(nams_pair[0], nams_pair[1], isize_est, + nam_read1, + nam_read2, abundances, record1.seq.length(), + record2.seq.length(), false); + output_hits_paf_PE(outstring, nam_read1, record1.name, + references, + record1.seq.length()); + output_hits_paf_PE(outstring, nam_read2, record2.name, + references, + record2.seq.length()); } else { Read read1(record1.seq); Read read2(record2.seq); @@ -1125,7 +1123,6 @@ void align_or_map_paired( ); } } - } statistics.tot_extend += extend_timer.duration(); statistics += details[0]; statistics += details[1]; @@ -1173,35 +1170,37 @@ void align_or_map_single( Timer extend_timer; std::vector best_contig; - if (map_param.is_abundance_out){ - if (!nams.empty()){ - for (auto &t : nams){ - if (t.score == nams[0].score){ - best_contig.push_back(t); - }else{ - break; + switch (map_param.output_format) { + case OutputFormat::Abundance: { + if (!nams.empty()){ + for (auto &t : nams){ + if (t.score == nams[0].score){ + best_contig.push_back(t); + }else{ + break; + } } - } - int contig_size = best_contig.size(); - for (auto &t: best_contig){ - if (t.ref_start < 0) { - continue; + int contig_size = best_contig.size(); + for (auto &t: best_contig){ + if (t.ref_start < 0) { + continue; + } + abundances[t.ref_id] += float(record.seq.length()) / float(contig_size); } - abundances[t.ref_id] += float(record.seq.length()) / float(contig_size); } } - } - else{ - if (!map_param.is_sam_out) { + break; + case OutputFormat::PAF: output_hits_paf(outstring, nams, record.name, references, record.seq.length()); - } else { + break; + case OutputFormat::SAM: align_single( aligner, sam, nams, record, index_parameters.syncmer.k, references, details, map_param.dropoff_threshold, map_param.max_tries, map_param.max_secondary, random_engine ); - } + break; } statistics.tot_extend += extend_timer.duration(); statistics += details; diff --git a/src/aln.hpp b/src/aln.hpp index a9f24e27..0a60fb41 100644 --- a/src/aln.hpp +++ b/src/aln.hpp @@ -56,6 +56,12 @@ struct AlignmentStatistics { } }; +enum class OutputFormat { + SAM, + PAF, + Abundance +}; + struct MappingParameters { int r { 150 }; int max_secondary { 0 }; @@ -63,8 +69,7 @@ struct MappingParameters { int rescue_level { 2 }; int max_tries { 20 }; int rescue_cutoff; - bool is_sam_out { true }; - bool is_abundance_out {false}; + OutputFormat output_format {OutputFormat::SAM}; CigarOps cigar_ops{CigarOps::M}; bool output_unmapped { true }; bool details{false}; diff --git a/src/main.cpp b/src/main.cpp index 811420d6..b8e5eb00 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -195,12 +195,14 @@ int run_strobealign(int argc, char **argv) { map_param.dropoff_threshold = opt.dropoff_threshold; map_param.rescue_level = opt.rescue_level; map_param.max_tries = opt.max_tries; - map_param.is_sam_out = opt.is_sam_out; + map_param.output_format = ( + opt.is_abundance_out ? OutputFormat::Abundance : + opt.is_sam_out ? OutputFormat::SAM : + OutputFormat::PAF); map_param.cigar_ops = opt.cigar_eqx ? CigarOps::EQX : CigarOps::M; map_param.output_unmapped = opt.output_unmapped; map_param.details = opt.details; map_param.fastq_comments = opt.fastq_comments; - map_param.is_abundance_out = opt.is_abundance_out; map_param.verify(); log_parameters(index_parameters, map_param, aln_params); @@ -301,7 +303,7 @@ int run_strobealign(int argc, char **argv) { std::ostream out(buf); - if (map_param.is_sam_out && !map_param.is_abundance_out){ + if (map_param.output_format == OutputFormat::SAM) { std::stringstream cmd_line; for(int i = 0; i < argc; ++i) { cmd_line << argv[i] << " "; @@ -340,7 +342,7 @@ int run_strobealign(int argc, char **argv) { tot_statistics += it; } - if (map_param.is_abundance_out){ + if (map_param.output_format == OutputFormat::Abundance) { std::vector abundances(references.size(), 0); for (size_t i = 0; i < worker_abundances.size(); ++i) { for (size_t j = 0; j < worker_abundances[i].size(); ++j) { diff --git a/src/pc.cpp b/src/pc.cpp index 417717ea..eec33316 100644 --- a/src/pc.cpp +++ b/src/pc.cpp @@ -180,9 +180,9 @@ void perform_task( statistics.n_reads++; } - if (!map_param.is_abundance_out){ - output_buffer.output_records(std::move(sam_out), chunk_index); - assert(sam_out == ""); + if (map_param.output_format != OutputFormat::Abundance) { + output_buffer.output_records(std::move(sam_out), chunk_index); + assert(sam_out == ""); } } statistics.tot_aligner_calls += aligner.calls_count(); From ff7222a811d91d7af48fc7e9daa1f419fbe9d129 Mon Sep 17 00:00:00 2001 From: Luis Pedro Coelho Date: Wed, 21 Feb 2024 14:42:22 +1000 Subject: [PATCH 22/26] RFCT Merge duplicated code --- src/aln.cpp | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index 2b092684..6df39b43 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -862,16 +862,16 @@ std::vector align_paired( return high_scores; } -// Only used for PAF output +// Used for PAF and abundances output inline void get_best_map_location( std::vector &nams1, std::vector &nams2, InsertSizeDistribution &isize_est, Nam &best_nam1, Nam &best_nam2, - std::vector &abundances, int read1_len, int read2_len, + std::vector &abundances, bool output_abundance ) { std::vector nam_pairs = get_best_scoring_nam_pairs(nams1, nams2, isize_est.mu, isize_est.sigma); @@ -1043,23 +1043,24 @@ void align_or_map_paired( } Timer extend_timer; - if (map_param.output_format == OutputFormat::Abundance) { + if (map_param.output_format != OutputFormat::SAM) { // PAF or abundance Nam nam_read1; Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est,nam_read1, nam_read2, abundances, record1.seq.length(), record2.seq.length(), true); - } else if (map_param.output_format == OutputFormat::PAF) { - Nam nam_read1; - Nam nam_read2; - get_best_map_location(nams_pair[0], nams_pair[1], isize_est, - nam_read1, - nam_read2, abundances, record1.seq.length(), - record2.seq.length(), false); - output_hits_paf_PE(outstring, nam_read1, record1.name, - references, - record1.seq.length()); - output_hits_paf_PE(outstring, nam_read2, record2.name, - references, - record2.seq.length()); + get_best_map_location( + nams_pair[0], nams_pair[1], + isize_est, + nam_read1, nam_read2, + record1.seq.length(), record2.seq.length(), + abundances, + map_param.output_format == OutputFormat::Abundance); + if (map_param.output_format == OutputFormat::PAF) { + output_hits_paf_PE(outstring, nam_read1, record1.name, + references, + record1.seq.length()); + output_hits_paf_PE(outstring, nam_read2, record2.name, + references, + record2.seq.length()); + } } else { Read read1(record1.seq); Read read2(record2.seq); From 372b868fccc0eddfa3e2f41a79b5dab78f910572 Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Wed, 21 Feb 2024 14:58:17 +0800 Subject: [PATCH 23/26] BUG fix of RFCT aemb --- src/aln.cpp | 1 - src/main.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index fec9bbd5..2615a1ef 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -1126,7 +1126,6 @@ void align_or_map_paired( ); } } - } statistics.tot_extend += extend_timer.duration(); statistics += details[0]; statistics += details[1]; diff --git a/src/main.cpp b/src/main.cpp index 1a69f120..09da1071 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -203,7 +203,6 @@ int run_strobealign(int argc, char **argv) { map_param.output_unmapped = opt.output_unmapped; map_param.details = opt.details; map_param.fastq_comments = opt.fastq_comments; - map_param.is_abundance_out = opt.is_abundance_out; map_param.verify(); log_parameters(index_parameters, map_param, aln_params); From 383515d3c3b100ab7d28412ca8c53b076845b53f Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Wed, 21 Feb 2024 15:10:18 +0800 Subject: [PATCH 24/26] Remove unused vector in single-end mode --- src/aln.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index 2615a1ef..c9959daf 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -878,9 +878,6 @@ inline void get_best_map_location( best_nam1.ref_start = -1; //Unmapped until proven mapped best_nam2.ref_start = -1; //Unmapped until proven mapped - std::vector best_contig1; - std::vector best_contig2; - if (nam_pairs.empty()) { return; } @@ -1172,23 +1169,26 @@ void align_or_map_single( Timer extend_timer; - std::vector best_contig; + size_t n_best = 0; switch (map_param.output_format) { case OutputFormat::Abundance: { if (!nams.empty()){ for (auto &t : nams){ if (t.score == nams[0].score){ - best_contig.push_back(t); + ++n_best; }else{ break; } } - int contig_size = best_contig.size(); - for (auto &t: best_contig){ + + for (auto &t: nams) { if (t.ref_start < 0) { continue; } - abundances[t.ref_id] += float(record.seq.length()) / float(contig_size); + if (t.score != nams[0].score){ + break; + } + abundances[t.ref_id] += float(record.seq.length()) / float(n_best); } } } From 8c32350931bedcbe467c08e9d702ec092a813f8a Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Mon, 26 Feb 2024 14:48:19 +0800 Subject: [PATCH 25/26] Minor fix --- src/aln.cpp | 20 ++++++++++---------- src/main.cpp | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/aln.cpp b/src/aln.cpp index c9959daf..eeabd292 100644 --- a/src/aln.cpp +++ b/src/aln.cpp @@ -936,21 +936,21 @@ inline void get_best_map_location( std::make_pair(std::cref(nams2), read2_len) }) { size_t best_score = 0; // We loop twice because we need to count the number of NAMs with best score - for (auto &t : nams) { - if (t.score == nams[0].score){ + for (auto &nam : nams) { + if (nam.score == nams[0].score){ ++best_score; } else { break; } } - for (auto &t: nams) { - if (t.ref_start < 0) { + for (auto &nam: nams) { + if (nam.ref_start < 0) { continue; } - if (t.score != nams[0].score){ + if (nam.score != nams[0].score){ break; } - abundances[t.ref_id] += float(read_len) / float(best_score); + abundances[nam.ref_id] += float(read_len) / float(best_score); } } } @@ -1181,14 +1181,14 @@ void align_or_map_single( } } - for (auto &t: nams) { - if (t.ref_start < 0) { + for (auto &nam: nams) { + if (nam.ref_start < 0) { continue; } - if (t.score != nams[0].score){ + if (nam.score != nams[0].score){ break; } - abundances[t.ref_id] += float(record.seq.length()) / float(n_best); + abundances[nam.ref_id] += float(record.seq.length()) / float(n_best); } } } diff --git a/src/main.cpp b/src/main.cpp index 09da1071..5525cd3f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -105,7 +105,7 @@ InputBuffer get_input_buffer(const CommandLineOptions& opt) { } } -void output_abundance(std::vector abundances, References references){ +void output_abundance(const std::vector abundances, const References references){ for (size_t i = 0; i < references.size(); ++i) { std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / double(references.sequences[i].size()) << std::endl; } From 5a1ab9e61fe3683961e87ee176e52bb655f95aed Mon Sep 17 00:00:00 2001 From: psj1997 <1176894699@qq.com> Date: Mon, 26 Feb 2024 22:41:38 +0800 Subject: [PATCH 26/26] Update src/main.cpp Co-authored-by: Marcel Martin --- src/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index 5525cd3f..b2c06004 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -105,7 +105,7 @@ InputBuffer get_input_buffer(const CommandLineOptions& opt) { } } -void output_abundance(const std::vector abundances, const References references){ +void output_abundance(const std::vector& abundances, const References& references){ for (size_t i = 0; i < references.size(); ++i) { std::cout << references.names[i] << '\t' << std::fixed << std::setprecision(6) << abundances[i] / double(references.sequences[i].size()) << std::endl; }