From 64a0f70e1c1ed28a03a4180a305bc42923b0553d Mon Sep 17 00:00:00 2001 From: colganwi Date: Wed, 21 Feb 2024 20:50:04 -0500 Subject: [PATCH] woring merge --- src/treedata/_core/merge.py | 96 ++++++++++++++++++++++++++++++- src/treedata/_core/treedata.py | 14 ----- src/treedata/_utils.py | 14 +++++ tests/test_merge.py | 101 +++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+), 16 deletions(-) create mode 100755 tests/test_merge.py diff --git a/src/treedata/_core/merge.py b/src/treedata/_core/merge.py index fa3a46b..dd18cf3 100755 --- a/src/treedata/_core/merge.py +++ b/src/treedata/_core/merge.py @@ -7,8 +7,15 @@ Callable, Collection, ) +from functools import reduce from typing import Any, Literal +import anndata as ad +import pandas as pd +from anndata._core.merge import _resolve_dim, resolve_merge_strategy + +from treedata._utils import combine_trees + from .treedata import TreeData StrategiesLiteral = Literal["same", "unique", "first", "only"] @@ -23,8 +30,93 @@ def concat( uns_merge: StrategiesLiteral | Callable | None = None, label: str | None = None, keys: Collection | None = None, - index_unique: str | None = None, fill_value: Any | None = None, pairwise: bool = False, ) -> TreeData: - raise NotImplementedError("Concatenation not yet implemented") + """Concatenates TreeData objects along an axis. + + Params + ------ + tdatas + The objects to be concatenated. If a Mapping is passed, keys are used for the `keys` + argument and values are concatenated. + axis + Which axis to concatenate along. + join + How to align values when concatenating. If "outer", the union of the other axis + is taken. If "inner", the intersection. + for more. + merge + How elements not aligned to the axis being concatenated along are selected. + Currently implemented strategies include: + + * `None`: No elements are kept. + * `"same"`: Elements that are the same in each of the objects. + * `"unique"`: Elements for which there is only one possible value. + * `"first"`: The first element seen at each from each position. + * `"only"`: Elements that show up in only one of the objects. + uns_merge + How the elements of `.uns` are selected. Uses the same set of strategies as + the `merge` argument, except applied recursively. + label + Column in axis annotation (i.e. `.obs` or `.var`) to place batch information in. + If it's None, no column is added. + keys + Names for each object being added. These values are used for column values for + `label` or appended to the index if `index_unique` is not `None`. Defaults to + incrementing integer labels. + fill_value + When `join="outer"`, this is the value that will be used to fill the introduced + indices. By default, sparse arrays are padded with zeros, while dense arrays and + DataFrames are padded with missing values. + pairwise + Whether pairwise elements along the concatenated dimension should be included. + This is False by default, since the resulting arrays are often not meaningful. + """ + axis, dim = _resolve_dim(axis=axis) + alt_axis, alt_dim = _resolve_dim(axis=1 - axis) + merge = resolve_merge_strategy(merge) + + # Check indices + concat_indices = pd.concat([pd.Series(getattr(t, f"{dim}_names")) for t in tdatas], ignore_index=True) + if not concat_indices.is_unique: + raise ValueError(f"{dim}_names must be unique to concatenate along axis {axis}") + alt_indices = [getattr(t, f"{alt_dim}_names") for t in tdatas] + if join == "inner": + alt_indices = reduce(lambda x, y: x.intersection(y), alt_indices) + else: + alt_indices = reduce(lambda x, y: x.union(y), alt_indices) + + # Concatenate anndata + adata = ad.concat( + tdatas, + axis=axis, + join=join, + merge=merge, + uns_merge=uns_merge, + label=label, + keys=keys, + index_unique=None, + fill_value=fill_value, + pairwise=pairwise, + ) + tdata = TreeData(adata, allow_overlap=True) + + # Trees for concatenation axis + concat_trees = [getattr(t, f"{dim}t") for t in tdatas] + unique_keys = {key for mapping in concat_trees for key in mapping.keys()} + for key in unique_keys: + trees = [mapping[key] for mapping in concat_trees if key in mapping] + tree = combine_trees(trees) + getattr(tdata, f"{dim}t")[key] = tree + + # Trees for other axis + if join == "inner" and alt_axis == 0: + tdatas = [t[alt_indices, :] for t in tdatas] + elif join == "inner" and alt_axis == 1: + tdatas = [t[:, alt_indices] for t in tdatas] + alt_trees = merge([getattr(t, f"{alt_dim}t") for t in tdatas]) + for key, tree in alt_trees.items(): + getattr(tdata, f"{alt_dim}t")[key] = tree + + return tdata diff --git a/src/treedata/_core/treedata.py b/src/treedata/_core/treedata.py index 2fffc0f..758c66d 100755 --- a/src/treedata/_core/treedata.py +++ b/src/treedata/_core/treedata.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections.abc import Iterable, Mapping, Sequence from typing import ( TYPE_CHECKING, @@ -171,11 +170,6 @@ def _init_as_actual( # init from scratch else: - if label is not None: - for attr in ["obs", "var"]: - if label in getattr(self, attr).columns: - warnings.warn(f"label {label} already present in .{attr} overwriting it", stacklevel=2) - getattr(self, attr)[label] = pd.NA self._tree_label = label self._allow_overlap = allow_overlap self._obst = AxisTrees(self, 0, vals=obst) @@ -274,14 +268,6 @@ def to_adata(self) -> ad.AnnData: def copy(self) -> TreeData: """Full copy of the object.""" adata = super().copy() - - # remove label from obs and var - if self.label is not None: - if self.label in adata.obs.columns: - adata.obs.drop(columns=self.label, inplace=True) - if self.label in adata.var.columns: - adata.var.drop(columns=self.label, inplace=True) - # create a new TreeData object treedata_copy = TreeData( adata, obst=self.obst.copy(), diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index b21a6c5..9f360fd 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -20,3 +20,17 @@ def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph return tree.subgraph(keep_nodes) else: return tree.subgraph(keep_nodes).copy() + + +def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph: + """Combine two or more subsets of a tree into a single tree.""" + # Initialize a new directed graph for the combined tree + combined_tree = nx.DiGraph() + + # Iterate through each subset and add its nodes and edges to the combined tree + for subset in subsets: + combined_tree.add_nodes_from(subset.nodes(data=True)) + combined_tree.add_edges_from(subset.edges(data=True)) + + # The combined_tree now contains all nodes and edges from the subsets + return combined_tree diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100755 index 0000000..11d7a1b --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,101 @@ +import networkx as nx +import numpy as np +import pandas as pd +import pytest + +import treedata as td + + +@pytest.fixture +def tree(): + tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) + tree = nx.relabel_nodes(tree, {i: str(i) for i in tree.nodes}) + depths = nx.single_source_shortest_path_length(tree, "0") + nx.set_node_attributes(tree, values=depths, name="depth") + yield tree + + +@pytest.fixture +def tdata(tree): + df = pd.DataFrame({"anno": range(8)}, index=[str(i) for i in range(7, 15)]) + yield td.TreeData( + X=np.zeros((8, 8)), obst={"0": tree}, vart={"0": tree}, obs=df, var=df, allow_overlap=True, label="tree" + ) + + +@pytest.fixture +def tdata_list(tdata): + other_tree = nx.DiGraph() + other_tree.add_edges_from([("0", "7"), ("0", "8")]) + tdata_1 = tdata[:2, :].copy() + tdata_1.obst["1"] = other_tree + tdata_1.vart["1"] = other_tree + yield [tdata_1, tdata[2:4, :].copy(), tdata[4:, :4].copy()] + + +def test_concat(tdata_list): + # outer join + tdata = td.concat(tdata_list, axis=0, label="subset", join="outer") + print(tdata) + assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4 + assert tdata.obst["0"].number_of_nodes() == 15 + assert tdata.obst["1"].number_of_nodes() == 3 + assert tdata.shape == (8, 8) + # inner join + tdata = td.concat(tdata_list, axis=0, label="subset", join="inner") + assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4 + assert tdata.shape == (8, 4) + + +def test_merge_outer(tdata_list): + # None + tdata = td.concat(tdata_list, axis=0, join="outer", merge=None) + assert list(tdata.vart.keys()) == [] + # same + tdata = td.concat(tdata_list, axis=0, join="outer", merge="same") + assert list(tdata.vart.keys()) == [] + # unique + tdata = td.concat(tdata_list, axis=0, join="outer", merge="first") + assert list(tdata.vart.keys()) == ["0", "1"] + # only + tdata = td.concat(tdata_list, axis=0, join="outer", merge="only") + assert list(tdata.vart.keys()) == ["1"] + # first + tdata = td.concat(tdata_list, axis=0, join="outer", merge="first") + assert list(tdata.vart.keys()) == ["0", "1"] + assert tdata.vart["0"].number_of_nodes() == 15 + assert tdata.vart["1"].number_of_nodes() == 3 + + +def test_merge_inner(tdata_list): + # None + tdata = td.concat(tdata_list, axis=0, join="inner", merge=None) + assert list(tdata.vart.keys()) == [] + # same + tdata = td.concat(tdata_list, axis=0, join="inner", merge="same") + assert list(tdata.vart.keys()) == ["0"] + # unique + tdata = td.concat(tdata_list, axis=0, join="inner", merge="first") + assert list(tdata.vart.keys()) == ["0", "1"] + # only + tdata = td.concat(tdata_list, axis=0, join="inner", merge="only") + assert list(tdata.vart.keys()) == ["1"] + # first + tdata = td.concat(tdata_list, axis=0, join="inner", merge="first") + assert list(tdata.vart.keys()) == ["0", "1"] + assert tdata.vart["0"].number_of_nodes() == 8 + assert tdata.vart["1"].number_of_nodes() == 3 + + +def test_concat_bad_index(tdata_list): + tdata_list[0].obs.index = tdata_list[1].obs.index + with pytest.raises(ValueError): + td.concat(tdata_list, axis=0, join="outer") + + +def test_concat_bad_tree(tdata_list): + bad_tree = nx.DiGraph() + bad_tree.add_edges_from([("bad", "7")]) + tdata_list[0].obst["0"] = bad_tree + with pytest.raises(ValueError): + td.concat(tdata_list, axis=0, join="outer")