Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates raxml to allow for floating point site weights #113

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[submodule "libs/pll-modules"]
path = libs/pll-modules
url = https://github.com/ddarriba/pll-modules.git
url = https://github.com/computations/pll-modules.git
branch = floating_weights
[submodule "libs/terraphast"]
path = libs/terraphast
url = https://github.com/amkozlov/terraphast-one
2 changes: 1 addition & 1 deletion libs/pll-modules
8 changes: 4 additions & 4 deletions src/MSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,13 @@ void MSA::compress_patterns(const pll_state_t * charmap, bool store_backmap)
_length = _pll_msa->length;

if (_weights.empty())
_weights = WeightVector(w, w + _pll_msa->length);
_weights = FloatWeightVector(w, w + _pll_msa->length);
else
{
/* external weights specified -> use site_pattern_map to generate a compressed weight vector */
assert(_weights.size() == uncompressed_length);
assert(!_site_pattern_map.empty());
WeightVector new_weights(_length, 0);
FloatWeightVector new_weights(_length, 0);
for (size_t i = 0; i < _site_pattern_map.size(); ++i)
new_weights[_site_pattern_map[i]] += _weights[i];
_weights = std::move(new_weights);
Expand Down Expand Up @@ -326,13 +326,13 @@ void MSA::update_num_sites()
_num_sites = std::accumulate(_weights.begin(), _weights.end(), 0);
}

void MSA::weights(const WeightVector& v)
void MSA::weights(const FloatWeightVector& v)
{
_weights = v;
update_num_sites();
}

void MSA::weights(WeightVector&& v)
void MSA::weights(FloatWeightVector&& v)
{
_weights = std::move(v);
update_num_sites();
Expand Down
8 changes: 4 additions & 4 deletions src/MSA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class MSA
size_t length() const { return _length; }
size_t num_sites() const { return _num_sites; }
size_t num_patterns() const { return _weights.size(); }
const WeightVector& weights() const {return _weights; }
const FloatWeightVector& weights() const {return _weights; }
const NameIdMap& label_id_map() const { return _label_id_map; }
const WeightVector& site_pattern_map() const { return _site_pattern_map; }
const pll_msa_t * pll_msa() const;
Expand All @@ -71,8 +71,8 @@ class MSA
doubleVector state_freqs() const;

void num_sites(const unsigned int sites) { _num_sites = sites; }
void weights(const WeightVector& v);
void weights(WeightVector&& v);
void weights(const FloatWeightVector& v);
void weights(FloatWeightVector&& v);

void remove_sites(const std::vector<size_t>& site_indices);

Expand All @@ -97,7 +97,7 @@ class MSA
container _sequences;
container _labels;
NameIdMap _label_id_map;
WeightVector _weights;
FloatWeightVector _weights;
WeightVector _site_pattern_map;
ProbVectorList _probs;
RangeList _local_seq_ranges;
Expand Down
4 changes: 2 additions & 2 deletions src/Model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Model
bool param_estimated(int param) const;

AscBiasCorrection ascbias_type() const { return _ascbias_type; }
const WeightVector& ascbias_weights() const { return _ascbias_weights; }
const FloatWeightVector& ascbias_weights() const { return _ascbias_weights; }

/* per alignment site, given in elements (NOT in bytes) */
size_t clv_entry_size() const { return _num_states * _num_ratecats; }
Expand Down Expand Up @@ -192,7 +192,7 @@ class Model
double _brlen_scaler;

AscBiasCorrection _ascbias_type;
WeightVector _ascbias_weights;
FloatWeightVector _ascbias_weights;

std::vector<SubstitutionModel> _submodels;

Expand Down
2 changes: 1 addition & 1 deletion src/PartitionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void PartitionInfo::compress_patterns(bool store_backmap)

pllmod_msa_stats_t * PartitionInfo::compute_stats(unsigned long stats_mask) const
{
const unsigned int * weights = _msa.weights().empty() ? nullptr : _msa.weights().data();
const double * weights = _msa.weights().empty() ? nullptr : _msa.weights().data();
pllmod_msa_stats_t * stats = pllmod_msa_compute_stats(_msa.pll_msa(), _model.num_states(),
_model.charmap(), weights, stats_mask);

Expand Down
2 changes: 1 addition & 1 deletion src/PartitionedMSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void PartitionedMSA::split_msa()
if (!_full_msa.weights().empty())
{
auto& msa = _part_list[p].msa();
WeightVector w(msa.length());
FloatWeightVector w(msa.length());
const auto full_weights = _full_msa.weights();
assert(full_weights.size() == site_part_map().size());

Expand Down
6 changes: 3 additions & 3 deletions src/PartitionedMSAView.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,23 @@ const IDSet& PartitionedMSAView::exclude_sites(size_t part_id) const
return _excluded_sites.at(part_id);
}

void PartitionedMSAView::site_weights(const WeightVectorList& weights)
void PartitionedMSAView::site_weights(const FloatWeightVectorList& weights)
{
if (weights.size() != part_count())
throw runtime_error("PartitionedMSAView: invalid weight vector size");

_site_weights.clear();
_site_weights.resize(part_count());
size_t part_id = 0;
for (const WeightVector& v: weights)
for (const auto& v: weights)
{
site_weights(part_id, v);
part_id++;
}
assert(_site_weights.size() == part_count());
}

void PartitionedMSAView::site_weights(size_t part_id, const WeightVector& weights)
void PartitionedMSAView::site_weights(size_t part_id, const FloatWeightVector& weights)
{
_site_weights.resize(part_count());

Expand Down
6 changes: 3 additions & 3 deletions src/PartitionedMSAView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ class PartitionedMSAView
void exclude_sites(size_t part_id, IDVector site_ids);
const IDSet& exclude_sites(size_t part_id) const;

void site_weights(const WeightVectorList& site_weights);
void site_weights(size_t part_id, const WeightVector& site_weights);
void site_weights(const FloatWeightVectorList& site_weights);
void site_weights(size_t part_id, const FloatWeightVector& site_weights);

private:
const PartitionedMSA * _parted_msa;
std::shared_ptr<const PartitionedMSA> _parted_msa_sptr;
NameMap _taxon_name_map;
IDSet _excluded_taxa;
std::vector<IDSet> _excluded_sites;
WeightVectorList _site_weights;
FloatWeightVectorList _site_weights;

mutable IDVector _orig_taxon_ids;

Expand Down
17 changes: 9 additions & 8 deletions src/TreeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,29 @@

#include "TreeInfo.hpp"
#include "ParallelContext.hpp"
#include "types.hpp"

using namespace std;

TreeInfo::TreeInfo (const Options &opts, const Tree& tree, const PartitionedMSA& parted_msa,
const IDVector& tip_msa_idmap,
const PartitionAssignment& part_assign)
{
init(opts, tree, parted_msa, tip_msa_idmap, part_assign, std::vector<uintVector>());
init(opts, tree, parted_msa, tip_msa_idmap, part_assign, FloatWeightVectorList());
}

TreeInfo::TreeInfo (const Options &opts, const Tree& tree, const PartitionedMSA& parted_msa,
const IDVector& tip_msa_idmap,
const PartitionAssignment& part_assign,
const std::vector<uintVector>& site_weights)
const FloatWeightVectorList& site_weights)
{
init(opts, tree, parted_msa, tip_msa_idmap, part_assign, site_weights);
}

void TreeInfo::init(const Options &opts, const Tree& tree, const PartitionedMSA& parted_msa,
const IDVector& tip_msa_idmap,
const PartitionAssignment& part_assign,
const std::vector<uintVector>& site_weights)
const FloatWeightVectorList& site_weights)
{
_brlen_min = opts.brlen_min;
_brlen_max = opts.brlen_max;
Expand Down Expand Up @@ -458,7 +459,7 @@ void assign(Model& model, const TreeInfo& treeinfo, size_t partition_id)
model.brlen_scaler(pll_treeinfo.brlen_scalers[partition_id]);
}

void build_clv(ProbVector::const_iterator probs, size_t sites, WeightVector::const_iterator weights,
void build_clv(ProbVector::const_iterator probs, size_t sites, FloatWeightVector::const_iterator weights,
pll_partition_t* partition, bool normalize, std::vector<double>& clv)
{
const auto states = partition->states;
Expand Down Expand Up @@ -535,7 +536,7 @@ void set_partition_tips(const Options& opts, const MSA& msa, const IDVector& tip
void set_partition_tips(const Options& opts, const MSA& msa, const IDVector& tip_msa_idmap,
const PartitionRange& part_region,
pll_partition_t* partition, const pll_state_t * charmap,
const WeightVector& weights)
const FloatWeightVector& weights)
{
assert(!weights.empty());

Expand All @@ -544,7 +545,7 @@ void set_partition_tips(const Options& opts, const MSA& msa, const IDVector& tip
const auto pend = pstart + plen;

/* compress weights array by removing all zero entries */
uintVector comp_weights;
FloatWeightVector comp_weights;
for (size_t j = pstart; j < pend; ++j)
{
if (weights[j] > 0)
Expand Down Expand Up @@ -595,7 +596,7 @@ void set_partition_tips(const Options& opts, const MSA& msa, const IDVector& tip

pll_partition_t* create_pll_partition(const Options& opts, const PartitionInfo& pinfo,
const IDVector& tip_msa_idmap,
const PartitionRange& part_region, const uintVector& weights)
const PartitionRange& part_region, const FloatWeightVector& weights)
{
const MSA& msa = pinfo.msa();
const Model& model = pinfo.model();
Expand All @@ -607,7 +608,7 @@ pll_partition_t* create_pll_partition(const Options& opts, const PartitionInfo&
const size_t part_length = weights.empty() ? part_region.length :
std::count_if(weights.begin() + pstart,
weights.begin() + pstart + part_region.length,
[](uintVector::value_type w) -> bool
[](FloatWeightVector::value_type w) -> bool
{ return w > 0; }
);

Expand Down
7 changes: 4 additions & 3 deletions src/TreeInfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "Options.hpp"
#include "AncestralStates.hpp"
#include "loadbalance/PartitionAssignment.hpp"
#include "types.hpp"

struct spr_round_params
{
Expand All @@ -31,7 +32,7 @@ class TreeInfo
const IDVector& tip_msa_idmap, const PartitionAssignment& part_assign);
TreeInfo (const Options &opts, const Tree& tree, const PartitionedMSA& parted_msa,
const IDVector& tip_msa_idmap, const PartitionAssignment& part_assign,
const std::vector<uintVector>& site_weights);
const FloatWeightVectorList& site_weights);
virtual
~TreeInfo ();

Expand Down Expand Up @@ -74,7 +75,7 @@ class TreeInfo

void init(const Options &opts, const Tree& tree, const PartitionedMSA& parted_msa,
const IDVector& tip_msa_idmap, const PartitionAssignment& part_assign,
const std::vector<uintVector>& site_weights);
const FloatWeightVectorList& site_weights);

void assert_lh_improvement(double old_lh, double new_lh, const std::string& where = "");
};
Expand All @@ -85,6 +86,6 @@ void assign(Model& model, const TreeInfo& treeinfo, size_t partition_id);

pll_partition_t* create_pll_partition(const Options& opts, const PartitionInfo& pinfo,
const IDVector& tip_msa_idmap,
const PartitionRange& part_region, const uintVector& weights);
const PartitionRange& part_region, const FloatWeightVector& weights);

#endif /* RAXML_TREEINFO_HPP_ */
8 changes: 4 additions & 4 deletions src/bootstrap/BootstrapGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ BootstrapReplicate BootstrapGenerator::generate(const PartitionedMSA& parted_msa
return result;
}

WeightVector BootstrapGenerator::generate(const MSA& msa, unsigned long random_seed)
FloatWeightVector BootstrapGenerator::generate(const MSA& msa, unsigned long random_seed)
{
RandomGenerator gen(random_seed);

return generate(msa, gen);
}

WeightVector BootstrapGenerator::generate(const MSA& msa, RandomGenerator& gen)
FloatWeightVector BootstrapGenerator::generate(const MSA& msa, RandomGenerator& gen)
{
unsigned int orig_len = msa.num_sites();
unsigned int comp_len = msa.length();

WeightVector w_buf(orig_len, 0);
FloatWeightVector w_buf(orig_len, 0);

std::uniform_int_distribution<unsigned int> distr(0, orig_len-1);

Expand All @@ -50,7 +50,7 @@ WeightVector BootstrapGenerator::generate(const MSA& msa, RandomGenerator& gen)
return w_buf;
else
{
WeightVector result(comp_len, 0);
FloatWeightVector result(comp_len, 0);
auto orig_weights = msa.weights();

assert(!orig_weights.empty());
Expand Down
6 changes: 3 additions & 3 deletions src/bootstrap/BootstrapGenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

struct BootstrapReplicate
{
WeightVectorList site_weights;
FloatWeightVectorList site_weights;
};

typedef std::vector<BootstrapReplicate> BootstrapReplicateList;
Expand All @@ -20,10 +20,10 @@ class BootstrapGenerator
~BootstrapGenerator ();

BootstrapReplicate generate(const PartitionedMSA& parted_msa, unsigned long random_seed);
WeightVector generate(const MSA& msa, unsigned long random_seed);
FloatWeightVector generate(const MSA& msa, unsigned long random_seed);

private:
WeightVector generate(const MSA& msa, RandomGenerator& gen);
FloatWeightVector generate(const MSA& msa, RandomGenerator& gen);
};

#endif /* RAXML_BOOTSTRAP_BOOTSTRAPGENERATOR_HPP_ */
4 changes: 2 additions & 2 deletions src/io/binary_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ BasicBinaryStream& operator>>(BasicBinaryStream& stream, MSA& m)

m = MSA(pat_count);

m.weights(stream.get<WeightVector>());
m.weights(stream.get<FloatWeightVector>());

std::string seq(pat_count, 0);
for (size_t i = 0; i < taxa_count; ++i)
Expand Down Expand Up @@ -355,7 +355,7 @@ BasicBinaryStream& operator>>(BasicBinaryStream& stream, MSARange mr)
assert(local_len <= pat_count);
assert(local_len <= weight_count);

WeightVector w(local_len);
FloatWeightVector w(local_len);
read_vector_range(stream, &w[0], rl, pat_count);

m.weights(std::move(w));
Expand Down
16 changes: 8 additions & 8 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,12 +985,12 @@ void load_msa_weights(MSA& msa, const Options& opts)
if (!f)
throw runtime_error("Unable to open site weights file: " + opts.weights_file);

WeightVector w;
FloatWeightVector w;
w.reserve(msa.length());
const auto maxw = std::numeric_limits<WeightVector::value_type>::max();
const auto maxw = std::numeric_limits<FloatWeightVector::value_type>::max();
int fres;
intmax_t x;
while ((fres = fscanf(f,"%jd", &x)) != EOF)
double x;
while ((fres = fscanf(f,"%lf", &x)) != EOF)
{
if (!fres)
{
Expand All @@ -1000,7 +1000,7 @@ void load_msa_weights(MSA& msa, const Options& opts)
fclose(f);
throw runtime_error("Invalid site weight entry found near: " + string(buf));
}
else if (x <= 0)
else if (x < 0)
{
fclose(f);
throw runtime_error("Non-positive site weight found: " + to_string(x) +
Expand All @@ -1013,7 +1013,7 @@ void load_msa_weights(MSA& msa, const Options& opts)
" (max: " + to_string(maxw) + ")");
}
else
w.push_back((WeightType) x);
w.push_back((FloatWeightType) x);
}
fclose(f);

Expand Down Expand Up @@ -1589,7 +1589,7 @@ void balance_load(RaxmlInstance& instance)
LOG_VERB << endl << instance.proc_part_assign;
}

PartitionAssignmentList balance_load(RaxmlInstance& instance, WeightVectorList part_site_weights)
PartitionAssignmentList balance_load(RaxmlInstance& instance, FloatWeightVectorList part_site_weights)
{
/* This function is used to re-distribute sites across processes for each bootstrap replicate.
* Since during bootstrapping alignment sites are sampled with replacement, some sites will be
Expand All @@ -1600,7 +1600,7 @@ PartitionAssignmentList balance_load(RaxmlInstance& instance, WeightVectorList p

PartitionAssignmentList assign_list;
PartitionAssignment part_sizes;
WeightVectorList comp_pos_map(part_site_weights.size());
FloatWeightVectorList comp_pos_map(part_site_weights.size());

/* init list of partition sizes */
size_t i = 0;
Expand Down
Loading