From 8e90140e0b02f7a216c9280c383aaca92579a155 Mon Sep 17 00:00:00 2001 From: Deren Date: Tue, 14 May 2024 10:20:56 -0400 Subject: [PATCH] cosmetic --- toytree/utils/src/toytree_sequence.py | 114 +++++++++++------- toytree/utils/src/toytree_sequence_drawing.py | 9 +- 2 files changed, 75 insertions(+), 48 deletions(-) diff --git a/toytree/utils/src/toytree_sequence.py b/toytree/utils/src/toytree_sequence.py index b0c66a6f..5e7ee025 100644 --- a/toytree/utils/src/toytree_sequence.py +++ b/toytree/utils/src/toytree_sequence.py @@ -14,7 +14,7 @@ >>> mts = msprime.sim_mutations() >>> tts = ToyTreeSequence(tts) -# Draw an individual tree w/ mutations. +>>> # Draw an individual tree w/ mutations. >>> tts.draw_tree(idx=0, ...) >>> tts.draw_tree(site=30, ...) @@ -57,12 +57,12 @@ class ToyTreeSequence: renamed from treesequence data as {pop-id}-{node-id}. """ def __init__( - self, + self, tree_sequence: TreeSequence, - sample: Union[int, Iterable[int], None]=None, - seed: Optional[int]=None, - name_dict: Optional[Mapping[int,str]]=None, - ): + sample: Union[int, Iterable[int], None] = None, + seed: Optional[int] = None, + name_dict: Optional[Mapping[int, str]] = None, + ): # TODO: perhaps combine name_dict and 'sample' request method? # store user args @@ -74,14 +74,14 @@ def __init__( """: the number of haplotypes to sample from the TreeSequence.""" self.tree_sequence: TreeSequence = None """: the TreeSequence object.""" - self.name_dict: Mapping[int,str] = {} if name_dict is None else name_dict + self.name_dict: Mapping[int, str] = {} if name_dict is None else name_dict self._i = 0 """: iterable counter.""" # subsample/simplify treesequence to same or smaller nsamples if sample is None: self.sample = [tree_sequence.sample_size] - self.tree_sequence = tree_sequence + self.tree_sequence = tree_sequence elif isinstance(sample, int): self.sample = [sample] * self.npopulations self.tree_sequence = self._get_subsampled_ts(tree_sequence) @@ -111,7 +111,7 @@ def _get_subsampled_ts(self, tree_sequence: TreeSequence) -> TreeSequence: """Return a subsampled (simplify'd) tree sequence. Samples a maximum of `self.sample` nodes at time=0 from each - population. This is useful for working with simulations from + population. This is useful for working with simulations from SLiM where nodes of the entire population are often present. """ samps = [] @@ -137,7 +137,7 @@ def at_site(self, pos: int): tree = self.tree_sequence.at(pos) return self._get_toytree(tree, site=pos) - def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: + def _get_toytree(self, tree: TskitTree, site: Optional[int] = None) -> ToyTree: """Return a ToyTree with mutations for a tree selected by site or index. It is assumed here that the tree is fully coalesced, returning @@ -154,18 +154,22 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: # warn use if multiple roots are present if tree.has_multiple_roots: logger.warning(f"tree has multiple ({len(tree.roots)}) roots.") - pseudo_root = toytree.Node(name="pseudo-root") + # a special name to indicate the root is masked ancestors + pseudo_root = toytree.Node(name="MASKED-ANCESTORS") pseudo_root.tsidx = -1 for tsidx in tree.roots: node = toytree.Node(name=tsidx, dist=tree.branch_length(tsidx)) node.tsidx = tsidx - node._up = None - # pseudo_root._add_child(node) # children parents as None? - pseudo_root._children += (node, ) + # setting disconnected root children causes more problems + # than its worth. It was easier to just hide the root node + # during visualizations if desired. + # node._up = None + # pseudo_root._children += (node, ) + pseudo_root._add_child(node) tsidx_dict[tsidx] = node logger.debug(f"adding root: {tsidx}") - # if root has no children then + # if root has no children then else: tsidx = tree.roots[0] pseudo_root = toytree.Node(name=tsidx, dist=0) @@ -179,7 +183,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: # dict of {child: parent, ...} e.g., {0: 5, 1: 8, 2: 7, ...} pdict = tree.get_parent_dict() logger.debug(pdict) - + # iterate over child, parent items for cidx, pidx in pdict.items(): # if parent already in tsdict get existing Node @@ -191,7 +195,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: pnode = toytree.Node(name=str(pidx), dist=tree.branch_length(pidx)) pnode.tsidx = pidx tsidx_dict[pidx] = pnode - logger.debug(f"adding child: {cidx} to parent: {pidx}") + logger.debug(f"adding child: {cidx} to parent: {pidx}") # if child node already exists use Node. if cidx in tsidx_dict: @@ -203,7 +207,7 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: name = self.name_dict[cidx] else: pop = self.tree_sequence.tables.nodes.population[cidx] - name=f"p{pop}-{cidx}" + name = f"p{pop}-{cidx}" cnode = toytree.Node( name=name, dist=tree.branch_length(cidx), @@ -213,7 +217,8 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: pnode._add_child(cnode) # wrap TreeNode as a toytree - ttree = toytree.ToyTree(pseudo_root)#root_idx])#.ladderize() + # return pseudo_root + ttree = toytree.ToyTree(pseudo_root) # root_idx])#.ladderize() # add mutations as metadata if site is not None: @@ -227,14 +232,14 @@ def _get_toytree(self, tree: TskitTree, site: Optional[int]=None) -> ToyTree: def draw_tree_sequence( self, - start: int=0, - max_trees: int=10, - height: int=None, - width: int=None, + start: int = 0, + max_trees: int = 10, + height: int = None, + width: int = None, # show_mutations: bool=True, - scrollable: bool=True, + scrollable: bool = True, **kwargs, - ): + ): # -> Tuple[toyplot.Canvas, toyplot.coordinates.cartesian, : """ Returns a toyplot Canvas, Axes, and list of marks composing @@ -251,7 +256,7 @@ def draw_tree_sequence( trees = [self.at_index(i) for i in tree_range] breaks = self.tree_sequence.breakpoints(True)[start: tmax + 1] tsd = ToyTreeSequenceDrawing( - trees, breaks, + trees, breaks, width=width, height=height, scrollable=scrollable, **kwargs, @@ -260,18 +265,18 @@ def draw_tree_sequence( def draw_tree( self, - site: Optional[int]=None, - idx: Optional[int]=None, - mutation_size: Union[int, Iterable[int]]=8, - mutation_style: Dict[str,str]=None, - mutation_color_palette: Iterable['color']=None, + site: Optional[int] = None, + idx: Optional[int] = None, + mutation_size: Union[int, Iterable[int]] = 8, + mutation_style: Dict[str, str] = None, + mutation_color_palette: Iterable['color'] = None, show_label=True, **kwargs, - ): + ): """Return a toytree drawing of a selected tree site or index. - + You can select a tree by index or site, but not both. The - selected tree is extracted from the tree_sequence as a + selected tree is extracted from the tree_sequence as a ToyTree object with relabeled tips showing the population id. Mutations are drawn on edges and colored by MutationType, if info is present. @@ -303,7 +308,7 @@ def draw_tree( colors = [] # TODO: iterate over .sites() to show positions and describe - # whether multiple mutations occurred at a position. This + # whether multiple mutations occurred at a position. This # needs to be tested for both msprime and slim data types... for mut in tree.mutations: @@ -319,13 +324,13 @@ def draw_tree( # project mutation points to proper layout type if mark.layout == "l": ypos.append(node._x) - xpos.append(time) + xpos.append(time) elif mark.layout == "r": ypos.append(node._x) xpos.append(-time) elif mark.layout == "d": xpos.append(node._x) - ypos.append(time) + ypos.append(time) elif mark.layout == "u": xpos.append(node._x) ypos.append(-time) @@ -375,20 +380,37 @@ def draw_tree( # import pyslim toytree.set_log_level("DEBUG") - TRE = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123) - MOD = ipcoal.Model(tree=TRE, Ne=1e5, recomb=2e-9, store_tree_sequences=True) - MOD.sim_loci(1, 10000) - - # slim_ts = pyslim.load("/tmp/test2.trees") - ts = MOD.ts_dict[0] - tts = ToyTreeSequence(ts, sample=5, seed=123)#.at_index(0) - tts.draw_tree(width=None, height=300, idxs=None, scrollable=True) + # SLIM/shadie example + import tskit + TSFILE = "/tmp/test.trees" + # full ts + ts = tskit.load(TSFILE) + # subsampled ts + tts = ToyTreeSequence(ts, sample=10) + # get a toytree + node = tts.at_index(0) + print(node) + for child in node.children: + print(child, child.up) + # print(node.draw_ascii()) + toytree.ToyTree(node) # ._draw_browser() + # tts._get_toytree(tts.at_index(0)) + + # ... + # TRE = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123) + # MOD = ipcoal.Model(tree=TRE, Ne=1e5, recomb=2e-9, store_tree_sequences=True) + # MOD.sim_loci(1, 10000) + + # # slim_ts = pyslim.load("/tmp/test2.trees") + # ts = MOD.ts_dict[0] + # tts = ToyTreeSequence(ts, sample=5, seed=123) #.at_index(0) + # tts.draw_tree(width=None, height=300) #, idxs=None)# scrollable=True) # tts = ToyTreeSequence(slim_ts, sample=5, seed=123).at_index(0) # TREE = toytree.rtree.unittree(10, treeheight=1e5) # MOD = ipcoal.Model(tree=TREE, Ne=1e4, nsamples=1) # MOD.sim_loci(5, 2000) - + # # optional: access the tree sequences directly # TS = MOD.ts_dict[0] diff --git a/toytree/utils/src/toytree_sequence_drawing.py b/toytree/utils/src/toytree_sequence_drawing.py index ea16f264..53ca711d 100644 --- a/toytree/utils/src/toytree_sequence_drawing.py +++ b/toytree/utils/src/toytree_sequence_drawing.py @@ -30,6 +30,7 @@ class Box: top: float bottom: float + @dataclass class MultiDrawing: trees: List['ToyTree'] @@ -67,7 +68,7 @@ def __init__( scrollable: bool=True, axes: Optional['toyplot.coordinates.Cartesian']=None, **kwargs, - ): + ): super().__init__(trees, breakpoints, width, height, padding, margin) self.canvas = self.get_canvas(scrollable, axes) self.axes = self.get_axes(axes) @@ -113,7 +114,7 @@ def get_axes(self, axes): Makes a 'share' axes so that ticks appear on top of plot. """ if axes is not None: - return axes + return axes axes = self.canvas.cartesian(margin=self.margin, padding=self.padding) axes2 = axes.share('y') axes.x.show = False @@ -280,3 +281,7 @@ def set_axes_style(self): # labels = locator.ticks(min(self.breakpoints), max(self.breakpoints)) # self.axes.x.ticks.locator = toyplot.locator.Explicit( # ticks[0], labels[1]) + + +if __name__ == "__main__": + pass