Skip to content

Commit

Permalink
woring merge
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Feb 22, 2024
1 parent d0ff01d commit 64a0f70
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 16 deletions.
96 changes: 94 additions & 2 deletions src/treedata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
Callable,
Collection,
)
from functools import reduce
from typing import Any, Literal

import anndata as ad
import pandas as pd
from anndata._core.merge import _resolve_dim, resolve_merge_strategy

from treedata._utils import combine_trees

from .treedata import TreeData

StrategiesLiteral = Literal["same", "unique", "first", "only"]
Expand All @@ -23,8 +30,93 @@ def concat(
uns_merge: StrategiesLiteral | Callable | None = None,
label: str | None = None,
keys: Collection | None = None,
index_unique: str | None = None,
fill_value: Any | None = None,
pairwise: bool = False,
) -> TreeData:
raise NotImplementedError("Concatenation not yet implemented")
"""Concatenates TreeData objects along an axis.
Params
------
tdatas
The objects to be concatenated. If a Mapping is passed, keys are used for the `keys`
argument and values are concatenated.
axis
Which axis to concatenate along.
join
How to align values when concatenating. If "outer", the union of the other axis
is taken. If "inner", the intersection.
for more.
merge
How elements not aligned to the axis being concatenated along are selected.
Currently implemented strategies include:
* `None`: No elements are kept.
* `"same"`: Elements that are the same in each of the objects.
* `"unique"`: Elements for which there is only one possible value.
* `"first"`: The first element seen at each from each position.
* `"only"`: Elements that show up in only one of the objects.
uns_merge
How the elements of `.uns` are selected. Uses the same set of strategies as
the `merge` argument, except applied recursively.
label
Column in axis annotation (i.e. `.obs` or `.var`) to place batch information in.
If it's None, no column is added.
keys
Names for each object being added. These values are used for column values for
`label` or appended to the index if `index_unique` is not `None`. Defaults to
incrementing integer labels.
fill_value
When `join="outer"`, this is the value that will be used to fill the introduced
indices. By default, sparse arrays are padded with zeros, while dense arrays and
DataFrames are padded with missing values.
pairwise
Whether pairwise elements along the concatenated dimension should be included.
This is False by default, since the resulting arrays are often not meaningful.
"""
axis, dim = _resolve_dim(axis=axis)
alt_axis, alt_dim = _resolve_dim(axis=1 - axis)
merge = resolve_merge_strategy(merge)

# Check indices
concat_indices = pd.concat([pd.Series(getattr(t, f"{dim}_names")) for t in tdatas], ignore_index=True)
if not concat_indices.is_unique:
raise ValueError(f"{dim}_names must be unique to concatenate along axis {axis}")
alt_indices = [getattr(t, f"{alt_dim}_names") for t in tdatas]
if join == "inner":
alt_indices = reduce(lambda x, y: x.intersection(y), alt_indices)
else:
alt_indices = reduce(lambda x, y: x.union(y), alt_indices)

# Concatenate anndata
adata = ad.concat(
tdatas,
axis=axis,
join=join,
merge=merge,
uns_merge=uns_merge,
label=label,
keys=keys,
index_unique=None,
fill_value=fill_value,
pairwise=pairwise,
)
tdata = TreeData(adata, allow_overlap=True)

# Trees for concatenation axis
concat_trees = [getattr(t, f"{dim}t") for t in tdatas]
unique_keys = {key for mapping in concat_trees for key in mapping.keys()}
for key in unique_keys:
trees = [mapping[key] for mapping in concat_trees if key in mapping]
tree = combine_trees(trees)
getattr(tdata, f"{dim}t")[key] = tree

# Trees for other axis
if join == "inner" and alt_axis == 0:
tdatas = [t[alt_indices, :] for t in tdatas]
elif join == "inner" and alt_axis == 1:
tdatas = [t[:, alt_indices] for t in tdatas]
alt_trees = merge([getattr(t, f"{alt_dim}t") for t in tdatas])
for key, tree in alt_trees.items():
getattr(tdata, f"{alt_dim}t")[key] = tree

return tdata
14 changes: 0 additions & 14 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -171,11 +170,6 @@ def _init_as_actual(

# init from scratch
else:
if label is not None:
for attr in ["obs", "var"]:
if label in getattr(self, attr).columns:
warnings.warn(f"label {label} already present in .{attr} overwriting it", stacklevel=2)
getattr(self, attr)[label] = pd.NA
self._tree_label = label
self._allow_overlap = allow_overlap
self._obst = AxisTrees(self, 0, vals=obst)
Expand Down Expand Up @@ -274,14 +268,6 @@ def to_adata(self) -> ad.AnnData:
def copy(self) -> TreeData:
"""Full copy of the object."""
adata = super().copy()

# remove label from obs and var
if self.label is not None:
if self.label in adata.obs.columns:
adata.obs.drop(columns=self.label, inplace=True)
if self.label in adata.var.columns:
adata.var.drop(columns=self.label, inplace=True)
# create a new TreeData object
treedata_copy = TreeData(
adata,
obst=self.obst.copy(),
Expand Down
14 changes: 14 additions & 0 deletions src/treedata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,17 @@ def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph
return tree.subgraph(keep_nodes)
else:
return tree.subgraph(keep_nodes).copy()


def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph:
"""Combine two or more subsets of a tree into a single tree."""
# Initialize a new directed graph for the combined tree
combined_tree = nx.DiGraph()

# Iterate through each subset and add its nodes and edges to the combined tree
for subset in subsets:
combined_tree.add_nodes_from(subset.nodes(data=True))
combined_tree.add_edges_from(subset.edges(data=True))

# The combined_tree now contains all nodes and edges from the subsets
return combined_tree
101 changes: 101 additions & 0 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import networkx as nx
import numpy as np
import pandas as pd
import pytest

import treedata as td


@pytest.fixture
def tree():
tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph)
tree = nx.relabel_nodes(tree, {i: str(i) for i in tree.nodes})
depths = nx.single_source_shortest_path_length(tree, "0")
nx.set_node_attributes(tree, values=depths, name="depth")
yield tree


@pytest.fixture
def tdata(tree):
df = pd.DataFrame({"anno": range(8)}, index=[str(i) for i in range(7, 15)])
yield td.TreeData(
X=np.zeros((8, 8)), obst={"0": tree}, vart={"0": tree}, obs=df, var=df, allow_overlap=True, label="tree"
)


@pytest.fixture
def tdata_list(tdata):
other_tree = nx.DiGraph()
other_tree.add_edges_from([("0", "7"), ("0", "8")])
tdata_1 = tdata[:2, :].copy()
tdata_1.obst["1"] = other_tree
tdata_1.vart["1"] = other_tree
yield [tdata_1, tdata[2:4, :].copy(), tdata[4:, :4].copy()]


def test_concat(tdata_list):
# outer join
tdata = td.concat(tdata_list, axis=0, label="subset", join="outer")
print(tdata)
assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4
assert tdata.obst["0"].number_of_nodes() == 15
assert tdata.obst["1"].number_of_nodes() == 3
assert tdata.shape == (8, 8)
# inner join
tdata = td.concat(tdata_list, axis=0, label="subset", join="inner")
assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4
assert tdata.shape == (8, 4)


def test_merge_outer(tdata_list):
# None
tdata = td.concat(tdata_list, axis=0, join="outer", merge=None)
assert list(tdata.vart.keys()) == []
# same
tdata = td.concat(tdata_list, axis=0, join="outer", merge="same")
assert list(tdata.vart.keys()) == []
# unique
tdata = td.concat(tdata_list, axis=0, join="outer", merge="first")
assert list(tdata.vart.keys()) == ["0", "1"]
# only
tdata = td.concat(tdata_list, axis=0, join="outer", merge="only")
assert list(tdata.vart.keys()) == ["1"]
# first
tdata = td.concat(tdata_list, axis=0, join="outer", merge="first")
assert list(tdata.vart.keys()) == ["0", "1"]
assert tdata.vart["0"].number_of_nodes() == 15
assert tdata.vart["1"].number_of_nodes() == 3


def test_merge_inner(tdata_list):
# None
tdata = td.concat(tdata_list, axis=0, join="inner", merge=None)
assert list(tdata.vart.keys()) == []
# same
tdata = td.concat(tdata_list, axis=0, join="inner", merge="same")
assert list(tdata.vart.keys()) == ["0"]
# unique
tdata = td.concat(tdata_list, axis=0, join="inner", merge="first")
assert list(tdata.vart.keys()) == ["0", "1"]
# only
tdata = td.concat(tdata_list, axis=0, join="inner", merge="only")
assert list(tdata.vart.keys()) == ["1"]
# first
tdata = td.concat(tdata_list, axis=0, join="inner", merge="first")
assert list(tdata.vart.keys()) == ["0", "1"]
assert tdata.vart["0"].number_of_nodes() == 8
assert tdata.vart["1"].number_of_nodes() == 3


def test_concat_bad_index(tdata_list):
tdata_list[0].obs.index = tdata_list[1].obs.index
with pytest.raises(ValueError):
td.concat(tdata_list, axis=0, join="outer")


def test_concat_bad_tree(tdata_list):
bad_tree = nx.DiGraph()
bad_tree.add_edges_from([("bad", "7")])
tdata_list[0].obst["0"] = bad_tree
with pytest.raises(ValueError):
td.concat(tdata_list, axis=0, join="outer")

0 comments on commit 64a0f70

Please sign in to comment.