Skip to content

Commit

Permalink
cosmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
eaton-lab committed May 14, 2024
1 parent daca880 commit 8e90140
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 48 deletions.
114 changes: 68 additions & 46 deletions toytree/utils/src/toytree_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
>>> mts = msprime.sim_mutations()
>>> tts = ToyTreeSequence(tts)
# Draw an individual tree w/ mutations.
>>> # Draw an individual tree w/ mutations.
>>> tts.draw_tree(idx=0, ...)
>>> tts.draw_tree(site=30, ...)
Expand Down Expand Up @@ -57,12 +57,12 @@ class ToyTreeSequence:
renamed from treesequence data as {pop-id}-{node-id}.
"""
def __init__(
self,
self,
tree_sequence: TreeSequence,
sample: Union[int, Iterable[int], None]=None,
seed: Optional[int]=None,
name_dict: Optional[Mapping[int,str]]=None,
):
sample: Union[int, Iterable[int], None] = None,
seed: Optional[int] = None,
name_dict: Optional[Mapping[int, str]] = None,
):
# TODO: perhaps combine name_dict and 'sample' request method?

# store user args
Expand All @@ -74,14 +74,14 @@ def __init__(
""": the number of haplotypes to sample from the TreeSequence."""
self.tree_sequence: TreeSequence = None
""": the TreeSequence object."""
self.name_dict: Mapping[int,str] = {} if name_dict is None else name_dict
self.name_dict: Mapping[int, str] = {} if name_dict is None else name_dict
self._i = 0
""": iterable counter."""

# subsample/simplify treesequence to same or smaller nsamples
if sample is None:
self.sample = [tree_sequence.sample_size]
self.tree_sequence = tree_sequence
self.tree_sequence = tree_sequence
elif isinstance(sample, int):
self.sample = [sample] * self.npopulations
self.tree_sequence = self._get_subsampled_ts(tree_sequence)
Expand Down Expand Up @@ -111,7 +111,7 @@ def _get_subsampled_ts(self, tree_sequence: TreeSequence) -> TreeSequence:
"""Return a subsampled (simplify'd) tree sequence.
Samples a maximum of `self.sample` nodes at time=0 from each
population. This is useful for working with simulations from
population. This is useful for working with simulations from
SLiM where nodes of the entire population are often present.
"""
samps = []
Expand All @@ -137,7 +137,7 @@ def at_site(self, pos: int):
tree = self.tree_sequence.at(pos)
return self._get_toytree(tree, site=pos)

def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
def _get_toytree(self, tree: TskitTree, site: Optional[int] = None) -> ToyTree:
"""Return a ToyTree with mutations for a tree selected by site or index.
It is assumed here that the tree is fully coalesced, returning
Expand All @@ -154,18 +154,22 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
# warn use if multiple roots are present
if tree.has_multiple_roots:
logger.warning(f"tree has multiple ({len(tree.roots)}) roots.")
pseudo_root = toytree.Node(name="pseudo-root")
# a special name to indicate the root is masked ancestors
pseudo_root = toytree.Node(name="MASKED-ANCESTORS")
pseudo_root.tsidx = -1
for tsidx in tree.roots:
node = toytree.Node(name=tsidx, dist=tree.branch_length(tsidx))
node.tsidx = tsidx
node._up = None
# pseudo_root._add_child(node) # children parents as None?
pseudo_root._children += (node, )
# setting disconnected root children causes more problems
# than its worth. It was easier to just hide the root node
# during visualizations if desired.
# node._up = None
# pseudo_root._children += (node, )
pseudo_root._add_child(node)
tsidx_dict[tsidx] = node
logger.debug(f"adding root: {tsidx}")

# if root has no children then
# if root has no children then
else:
tsidx = tree.roots[0]
pseudo_root = toytree.Node(name=tsidx, dist=0)
Expand All @@ -179,7 +183,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
# dict of {child: parent, ...} e.g., {0: 5, 1: 8, 2: 7, ...}
pdict = tree.get_parent_dict()
logger.debug(pdict)

# iterate over child, parent items
for cidx, pidx in pdict.items():
# if parent already in tsdict get existing Node
Expand All @@ -191,7 +195,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
pnode = toytree.Node(name=str(pidx), dist=tree.branch_length(pidx))
pnode.tsidx = pidx
tsidx_dict[pidx] = pnode
logger.debug(f"adding child: {cidx} to parent: {pidx}")
logger.debug(f"adding child: {cidx} to parent: {pidx}")

# if child node already exists use Node.
if cidx in tsidx_dict:
Expand All @@ -203,7 +207,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
name = self.name_dict[cidx]
else:
pop = self.tree_sequence.tables.nodes.population[cidx]
name=f"p{pop}-{cidx}"
name = f"p{pop}-{cidx}"
cnode = toytree.Node(
name=name,
dist=tree.branch_length(cidx),
Expand All @@ -213,7 +217,8 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:
pnode._add_child(cnode)

# wrap TreeNode as a toytree
ttree = toytree.ToyTree(pseudo_root)#root_idx])#.ladderize()
# return pseudo_root
ttree = toytree.ToyTree(pseudo_root) # root_idx])#.ladderize()

# add mutations as metadata
if site is not None:
Expand All @@ -227,14 +232,14 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree:

def draw_tree_sequence(
self,
start: int=0,
max_trees: int=10,
height: int=None,
width: int=None,
start: int = 0,
max_trees: int = 10,
height: int = None,
width: int = None,
# show_mutations: bool=True,
scrollable: bool=True,
scrollable: bool = True,
**kwargs,
):
):
# -> Tuple[toyplot.Canvas, toyplot.coordinates.cartesian, :
"""
Returns a toyplot Canvas, Axes, and list of marks composing
Expand All @@ -251,7 +256,7 @@ def draw_tree_sequence(
trees = [self.at_index(i) for i in tree_range]
breaks = self.tree_sequence.breakpoints(True)[start: tmax + 1]
tsd = ToyTreeSequenceDrawing(
trees, breaks,
trees, breaks,
width=width, height=height,
scrollable=scrollable,
**kwargs,
Expand All @@ -260,18 +265,18 @@ def draw_tree_sequence(

def draw_tree(
self,
site: Optional[int]=None,
idx: Optional[int]=None,
mutation_size: Union[int, Iterable[int]]=8,
mutation_style: Dict[str,str]=None,
mutation_color_palette: Iterable['color']=None,
site: Optional[int] = None,
idx: Optional[int] = None,
mutation_size: Union[int, Iterable[int]] = 8,
mutation_style: Dict[str, str] = None,
mutation_color_palette: Iterable['color'] = None,
show_label=True,
**kwargs,
):
):
"""Return a toytree drawing of a selected tree site or index.
You can select a tree by index or site, but not both. The
selected tree is extracted from the tree_sequence as a
selected tree is extracted from the tree_sequence as a
ToyTree object with relabeled tips showing the population id.
Mutations are drawn on edges and colored by MutationType, if
info is present.
Expand Down Expand Up @@ -303,7 +308,7 @@ def draw_tree(
colors = []

# TODO: iterate over .sites() to show positions and describe
# whether multiple mutations occurred at a position. This
# whether multiple mutations occurred at a position. This
# needs to be tested for both msprime and slim data types...
for mut in tree.mutations:

Expand All @@ -319,13 +324,13 @@ def draw_tree(
# project mutation points to proper layout type
if mark.layout == "l":
ypos.append(node._x)
xpos.append(time)
xpos.append(time)
elif mark.layout == "r":
ypos.append(node._x)
xpos.append(-time)
elif mark.layout == "d":
xpos.append(node._x)
ypos.append(time)
ypos.append(time)
elif mark.layout == "u":
xpos.append(node._x)
ypos.append(-time)
Expand Down Expand Up @@ -375,20 +380,37 @@ def draw_tree(
# import pyslim
toytree.set_log_level("DEBUG")

TRE = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123)
MOD = ipcoal.Model(tree=TRE, Ne=1e5, recomb=2e-9, store_tree_sequences=True)
MOD.sim_loci(1, 10000)

# slim_ts = pyslim.load("/tmp/test2.trees")
ts = MOD.ts_dict[0]
tts = ToyTreeSequence(ts, sample=5, seed=123)#.at_index(0)
tts.draw_tree(width=None, height=300, idxs=None, scrollable=True)
# SLIM/shadie example
import tskit
TSFILE = "/tmp/test.trees"
# full ts
ts = tskit.load(TSFILE)
# subsampled ts
tts = ToyTreeSequence(ts, sample=10)
# get a toytree
node = tts.at_index(0)
print(node)
for child in node.children:
print(child, child.up)
# print(node.draw_ascii())
toytree.ToyTree(node) # ._draw_browser()
# tts._get_toytree(tts.at_index(0))

# ...
# TRE = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123)
# MOD = ipcoal.Model(tree=TRE, Ne=1e5, recomb=2e-9, store_tree_sequences=True)
# MOD.sim_loci(1, 10000)

# # slim_ts = pyslim.load("/tmp/test2.trees")
# ts = MOD.ts_dict[0]
# tts = ToyTreeSequence(ts, sample=5, seed=123) #.at_index(0)
# tts.draw_tree(width=None, height=300) #, idxs=None)# scrollable=True)
# tts = ToyTreeSequence(slim_ts, sample=5, seed=123).at_index(0)

# TREE = toytree.rtree.unittree(10, treeheight=1e5)
# MOD = ipcoal.Model(tree=TREE, Ne=1e4, nsamples=1)
# MOD.sim_loci(5, 2000)

# # optional: access the tree sequences directly
# TS = MOD.ts_dict[0]

Expand Down
9 changes: 7 additions & 2 deletions toytree/utils/src/toytree_sequence_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Box:
top: float
bottom: float


@dataclass
class MultiDrawing:
trees: List['ToyTree']
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
scrollable: bool=True,
axes: Optional['toyplot.coordinates.Cartesian']=None,
**kwargs,
):
):
super().__init__(trees, breakpoints, width, height, padding, margin)
self.canvas = self.get_canvas(scrollable, axes)
self.axes = self.get_axes(axes)
Expand Down Expand Up @@ -113,7 +114,7 @@ def get_axes(self, axes):
Makes a 'share' axes so that ticks appear on top of plot.
"""
if axes is not None:
return axes
return axes
axes = self.canvas.cartesian(margin=self.margin, padding=self.padding)
axes2 = axes.share('y')
axes.x.show = False
Expand Down Expand Up @@ -280,3 +281,7 @@ def set_axes_style(self):
# labels = locator.ticks(min(self.breakpoints), max(self.breakpoints))
# self.axes.x.ticks.locator = toyplot.locator.Explicit(
# ticks[0], labels[1])


if __name__ == "__main__":
pass

0 comments on commit 8e90140

Please sign in to comment.