Skip to content

Commit

Permalink
new consensus arg updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eaton-lab committed Dec 2, 2024
1 parent e34a79f commit 13bae88
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
18 changes: 12 additions & 6 deletions toytree/core/multitree.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def write(
# ---------------------------------------------------------------#
# see toytree.infer.get_consensus_tree()
# ---------------------------------------------------------------#
def get_consensus_tree(self, min_freq: float, **kwargs) -> ToyTree:
def get_consensus_tree(self, min_freq: float = 0.0, **kwargs) -> ToyTree:
"""Return an exteded majority-rule consensus tree from a list of trees.
The trees must contain the same set of tips. The returned tree will
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_consensus_tree(self, min_freq: float, **kwargs) -> ToyTree:
return toytree.infer.get_consensus_tree(self.treelist, min_freq=min_freq)


def get_consensus_features(self, tree: ToyTree, features: list[str], ultrametric: bool = False, conditional: bool = True) -> ToyTree:
def get_consensus_features(self, tree: ToyTree, features: list[str] | None = None, ultrametric: bool = False, conditional: bool = True) -> ToyTree:
"""Return tree with feature data mapped to each bipartition from
a set of trees that may or may not share the same bipartitions.
Expand Down Expand Up @@ -745,10 +745,16 @@ def get_tip_labels(self) -> Sequence[str]:
# for tree in mtree:
# print(repr(tree))
# print(mtree.get_unique_topologies())
ctree = mtree.get_consensus_tree()
print(ctree.get_node_data())

ctree = mtree.get_consensus_features(ctree)
print(ctree.get_node_data())
# ctree._draw_browser(node_sizes=16, node_markers="r1.5x1", node_labels="support", tmpdir="~", scale_bar=True, ts='u')

# c, a, m = mtree.draw_cloud_tree()
# toytree.utils.show(c)
mtree[1][-1].children[0]._dist = 100
mtree[1]._update()
c, a, m = mtree.draw(scale_bar=True, node_sizes=0, shared_axes=True, layout='d')
toytree.utils.show(c, tmpdir="~")
# mtree[1][-1].children[0]._dist = 100
# mtree[1]._update()
# c, a, m = mtree.draw(scale_bar=True, node_sizes=0, shared_axes=True, layout='d')
# toytree.utils.show(c, tmpdir="~")
13 changes: 7 additions & 6 deletions toytree/infer/src/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def get_consensus_clades(clades: dict[frozenset, float], min_freq: float) -> dic

# remove any clades that conflicted w/ equal frequency
for clade in mark:
keep.pop(clade)
if clade in keep:
keep.pop(clade)
return keep


Expand Down Expand Up @@ -305,7 +306,7 @@ def get_consensus_features(
tree: ToyTree,
trees: MultiTree,
features: list[str] = None,
rooted: bool = False,
ultrametric: bool = False,
conditional: bool = False,
) -> ToyTree:
"""Return tree with feature data mapped to each bipartition from
Expand All @@ -331,7 +332,7 @@ def get_consensus_features(
wish to have summarized across the nodes of the consensus tree.
For quantitative features this will record min, max, mean, std
conditional on the existence of the node in the tree.
rooted: bool
ultrametric: bool
If trees are rooted and ultrametric then set this option to
True to summarize stats of node heights instead of node dists.
Note: this forces conditional=True.
Expand All @@ -342,7 +343,7 @@ def get_consensus_features(
calculated from tip nodes across all input trees. This only
affects tip node dists.
"""
if rooted:
if ultrametric:
# mtree.all_tree_tips_aligned()
return map_rooted_tree_supports_and_heights_to_rooted_tree(tree, trees, features)
return map_unrooted_tree_supports_and_dists_to_unrooted_tree(tree, trees, features, conditional)
Expand Down Expand Up @@ -424,7 +425,7 @@ def map_unrooted_tree_supports_and_dists_to_unrooted_tree(
setattr(node, "dist_max", np.max(data[bipart]["dist"]))
setattr(node, "dist_std", np.std(data[bipart]["dist"]))
node._dist = node.dist_mean
node.name = set(node.get_leaf_names())
# node.name = set(node.get_leaf_names())

for name in tips:
node = tree.get_nodes(name)[0]
Expand All @@ -439,7 +440,7 @@ def map_unrooted_tree_supports_and_dists_to_unrooted_tree(
tree.edge_features.add("dist_min")
tree.edge_features.add("dist_max")
tree.edge_features.add("dist_std")
tree = tree.mod.root_on_minimal_ancestor_deviation(*outg)
# tree = tree.mod.root_on_minimal_ancestor_deviation(*outg)
return tree


Expand Down

0 comments on commit 13bae88

Please sign in to comment.