diff --git a/src/checkpoint.cpp b/src/checkpoint.cpp index 6d1c873..4af6ced 100644 --- a/src/checkpoint.cpp +++ b/src/checkpoint.cpp @@ -2,11 +2,11 @@ #include "debug.h" #include "util.hpp" #include +#include #include #include #include #include -#include template <> size_t write(int fd, const std::string &str) { size_t string_size = str.size(); @@ -70,6 +70,7 @@ template <> size_t write(int fd, const cli_options_t &options) { total_written += write(fd, options.echo); total_written += write(fd, options.invariant_sites); total_written += write(fd, options.early_stop); + total_written += write(fd, options.initial_root_strategy); debug_print( EMIT_LEVEL_DEBUG, "Wrote %lu bytes for cli_options", total_written); return total_written; @@ -101,6 +102,7 @@ template <> size_t read(int fd, cli_options_t &options) { total_read += read(fd, options.echo); total_read += read(fd, options.invariant_sites); total_read += read(fd, options.early_stop); + total_read += read(fd, options.initial_root_strategy); debug_print(EMIT_LEVEL_DEBUG, "Read %lu bytes for cli_options", total_read); return total_read; } diff --git a/src/main.cpp b/src/main.cpp index 74efb24..16b775b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,12 +7,12 @@ extern "C" { #include #include #include +#include #include #include #include #include #include -#include #ifdef MPI_BUILD #include #endif @@ -118,6 +118,14 @@ static void print_usage() { << " Tolerance for the BFGS steps. Default is 1e-7\n" << " --factor [NUMBER]\n" << " Factor for the BFGS steps. Default is 1e2\n" + << " --initial-root-strategy {random, midpoint, modified-mad}\n" + << " The strategy to pick the initial branches for rooting.\n" + << " Random is the default, and simply picks the branches at\n" + << " random. Midpoint uses the midpoint and similar branches to\n" + << " start. Modified MAD will use a modified version of mad to\n" + << " pick the starting branches. This can actually drive the\n" + << " so care should be taken when selecting this option.\n" + << " Default is random\n" << " --threads [NUMBER]\n" << " Number of threads to use\n" << " --silent\n" @@ -137,33 +145,34 @@ static void print_usage() { cli_options_t parse_options(int argv, char **argc) { static struct option long_opts[] = { - {"msa", required_argument, 0, 0}, /* 0 */ - {"tree", required_argument, 0, 0}, /* 1 */ - {"model", required_argument, 0, 0}, /* 2 */ - {"seed", required_argument, 0, 0}, /* 3 */ - {"verbose", no_argument, 0, 0}, /* 4 */ - {"silent", no_argument, 0, 0}, /* 5 */ - {"min-roots", required_argument, 0, 0}, /* 6 */ - {"root-ratio", required_argument, 0, 0}, /* 7 */ - {"atol", required_argument, 0, 0}, /* 8 */ - {"brtol", required_argument, 0, 0}, /* 9 */ - {"bfgstol", required_argument, 0, 0}, /* 10 */ - {"factor", required_argument, 0, 0}, /* 11 */ - {"partition", required_argument, 0, 0}, /* 12 */ - {"prefix", required_argument, 0, 0}, /* 13 */ - {"exhaustive", no_argument, 0, 0}, /* 14 */ - {"early-stop", no_argument, 0, 0}, /* 15 */ - {"no-early-stop", no_argument, 0, 0}, /* 16 */ - {"rate-cats", required_argument, 0, 0}, /* 17 */ - {"rate-cats-type", required_argument, 0, 0}, /* 18 */ - {"invariant-sites", no_argument, 0, 0}, /* 19 */ - {"threads", required_argument, 0, 0}, /* 20 */ - {"version", no_argument, 0, 0}, /* 21 */ - {"debug", no_argument, 0, 0}, /* 22 */ - {"mpi-debug", no_argument, 0, 0}, /* 23 */ - {"clean", no_argument, 0, 0}, /* 24 */ - {"echo", no_argument, 0, 0}, /* 25 */ - {"help", no_argument, 0, 0}, /* 26 */ + {"msa", required_argument, 0, 0}, /* 0 */ + {"tree", required_argument, 0, 0}, /* 1 */ + {"model", required_argument, 0, 0}, /* 2 */ + {"seed", required_argument, 0, 0}, /* 3 */ + {"verbose", no_argument, 0, 0}, /* 4 */ + {"silent", no_argument, 0, 0}, /* 5 */ + {"min-roots", required_argument, 0, 0}, /* 6 */ + {"root-ratio", required_argument, 0, 0}, /* 7 */ + {"atol", required_argument, 0, 0}, /* 8 */ + {"brtol", required_argument, 0, 0}, /* 9 */ + {"bfgstol", required_argument, 0, 0}, /* 10 */ + {"factor", required_argument, 0, 0}, /* 11 */ + {"partition", required_argument, 0, 0}, /* 12 */ + {"prefix", required_argument, 0, 0}, /* 13 */ + {"exhaustive", no_argument, 0, 0}, /* 14 */ + {"early-stop", no_argument, 0, 0}, /* 15 */ + {"no-early-stop", no_argument, 0, 0}, /* 16 */ + {"rate-cats", required_argument, 0, 0}, /* 17 */ + {"rate-cats-type", required_argument, 0, 0}, /* 18 */ + {"invariant-sites", no_argument, 0, 0}, /* 19 */ + {"initial-root-strategy", required_argument, 0, 0}, /* 20 */ + {"threads", required_argument, 0, 0}, /* 21 */ + {"version", no_argument, 0, 0}, /* 22 */ + {"debug", no_argument, 0, 0}, /* 23 */ + {"mpi-debug", no_argument, 0, 0}, /* 24 */ + {"clean", no_argument, 0, 0}, /* 25 */ + {"echo", no_argument, 0, 0}, /* 26 */ + {"help", no_argument, 0, 0}, /* 27 */ {0, 0, 0, 0}, }; @@ -241,25 +250,36 @@ cli_options_t parse_options(int argv, char **argc) { case 19: // invariant-sites cli_options.invariant_sites = true; break; - case 20: // threads + case 20: // initial-root-strategy + if (strcmp(optarg, "random") == 0) { + cli_options.initial_root_strategy = {initial_root_strategy_t::random}; + } else if (strcmp(optarg, "midpoint") == 0) { + cli_options.initial_root_strategy = {initial_root_strategy_t::midpoint}; + } else if (strcmp(optarg, "modified-mad") == 0) { + cli_options.initial_root_strategy = { + initial_root_strategy_t::modified_mad}; + } + + break; + case 21: // threads cli_options.threads = {(size_t)atol(optarg)}; break; - case 21: // version + case 22: // version print_version(); std::exit(0); - case 22: // debug + case 23: // debug __VERBOSE__ = EMIT_LEVEL_DEBUG; break; - case 23: // mpi-debug + case 24: // mpi-debug __VERBOSE__ = EMIT_LEVEL_MPI_DEBUG; break; - case 24: // clean + case 25: // clean cli_options.clean = true; break; - case 25: // echo + case 26: // echo cli_options.echo = true; break; - case 26: // help + case 27: // help print_usage(); std::exit(0); break; @@ -526,6 +546,7 @@ int wrapped_main(int argv, char **argc) { cli_options.root_ratio, static_cast(__MPI_RANK__), static_cast(__MPI_NUM_TASKS__), + cli_options.initial_root_strategy, checkpoint); #ifdef MPI_VERSION MPI_Barrier(MPI_COMM_WORLD); diff --git a/src/model.cpp b/src/model.cpp index d57effa..0a1eceb 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1816,12 +1816,34 @@ void model_t::assign_indicies_by_rank_search(size_t min_roots, size_t rank, size_t num_tasks, checkpoint_t &checkpoint) { - auto completed_work = checkpoint.completed_indicies(); - // auto shuffled_idx = shuffle_root_indicies(); - // auto shuffled_idx = suggest_root_indicies_midpoint(); - // auto shuffled_idx = suggest_root_indicies_length(); - auto shuffled_idx = suggest_root_indicies_modified_mad(); - size_t root_count = std::min( + assign_indicies_by_rank_search(min_roots, + root_ratio, + rank, + num_tasks, + initial_root_strategy_t::random, + checkpoint); +} + +void model_t::assign_indicies_by_rank_search(size_t min_roots, + double root_ratio, + size_t rank, + size_t num_tasks, + initial_root_strategy_t init_root, + checkpoint_t &checkpoint) { + auto completed_work = checkpoint.completed_indicies(); + std::vector shuffled_idx; + + if (init_root == initial_root_strategy_t::random) { + shuffled_idx = shuffle_root_indicies(); + } else if (init_root == initial_root_strategy_t::midpoint) { + shuffled_idx = suggest_root_indicies_midpoint(); + } else if (init_root == initial_root_strategy_t::modified_mad) { + shuffled_idx = suggest_root_indicies_modified_mad(); + } else { + throw std::runtime_error{"The initial root strategy was not recognized"}; + } + + size_t root_count = std::min( std::max(static_cast(_tree.root_count() * root_ratio), min_roots), _tree.root_count()); diff --git a/src/model.hpp b/src/model.hpp index 099c8ad..4252177 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -130,10 +130,16 @@ class model_t { void assign_indicies(size_t beg, size_t end, std::vector idx); void assign_indicies(); - void assign_indicies_by_rank_search(size_t min_roots, - double root_ratio, - size_t rank, - size_t num_tasks, + void assign_indicies_by_rank_search(size_t min_roots, + double root_ratio, + size_t rank, + size_t num_tasks, + checkpoint_t &checkpoint); + void assign_indicies_by_rank_search(size_t min_roots, + double root_ratio, + size_t rank, + size_t num_tasks, + initial_root_strategy_t init_root, checkpoint_t &); void assign_indicies_by_rank_exhaustive(size_t rank, size_t num_tasks, diff --git a/src/util.hpp b/src/util.hpp index f796777..dfe7914 100644 --- a/src/util.hpp +++ b/src/util.hpp @@ -62,6 +62,27 @@ namespace asc_bias_type { enum asc_bias_type_e { lewis, fels, stam }; } +struct initial_root_strategy_t { + enum initial_root_strategy_e { random, midpoint, modified_mad }; + initial_root_strategy_e strategy; + + initial_root_strategy_t(initial_root_strategy_e s) : strategy{s} {} + + bool operator==(const initial_root_strategy_t &other) const { + return strategy == other.strategy; + } + + bool operator==(const initial_root_strategy_e &other) const { + return strategy == other; + } +}; + +/* +namespace initial_root_strategy_t { +enum initial_root_strategy_e { random, midpoint, modified_mad }; +} +*/ + struct asc_bias_opts_t { asc_bias_type::asc_bias_type_e type; double fels_weight; @@ -158,6 +179,9 @@ struct cli_options_t { bool clean = false; initialized_flag_t early_stop; + initial_root_strategy_t initial_root_strategy = { + initial_root_strategy_t::random}; + bool operator==(const cli_options_t &other) const { return msa_filename == other.msa_filename && tree_filename == other.tree_filename && prefix == other.prefix @@ -173,7 +197,8 @@ struct cli_options_t { && br_tolerance == other.br_tolerance && bfgs_tol == other.bfgs_tol && states == other.states && exhaustive == other.exhaustive && echo == other.echo && invariant_sites == other.invariant_sites - && early_stop == other.early_stop; + && early_stop == other.early_stop + && initial_root_strategy == other.initial_root_strategy; } bool operator!=(const cli_options_t &other) const { return !(*this == other);