Skip to content

Commit

Permalink
Merge pull request #18 from YosefLab/serialize
Browse files Browse the repository at this point in the history
serialize np in h5ad
  • Loading branch information
colganwi authored Jun 22, 2024
2 parents b90272a + f16cc9c commit 7fac360
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anndata._io import write_h5ad, write_zarr
from scipy import sparse

from treedata._utils import digraph_to_dict
from treedata._utils import digraph_to_dict, make_serializable

from .aligned_mapping import (
AxisTrees,
Expand Down Expand Up @@ -280,12 +280,13 @@ def to_adata(self) -> ad.AnnData:

def _treedata_attrs(self) -> dict:
"""Dictionary of TreeData attributes"""
return {
attrs = {
"obst": {k: digraph_to_dict(v) for k, v in self.obst.items()},
"vart": {k: digraph_to_dict(v) for k, v in self.vart.items()},
"label": self.label,
"allow_overlap": self.allow_overlap,
}
return make_serializable(attrs)

def _mutated_copy(self, **kwargs):
"""Creating TreeData with attributes optionally specified via kwargs."""
Expand Down
18 changes: 18 additions & 0 deletions src/treedata/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import deque

import networkx as nx
import numpy as np
import pandas as pd


def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph:
Expand Down Expand Up @@ -59,3 +61,19 @@ def dict_to_digraph(graph_dict: dict) -> nx.DiGraph:
for target, attrs in targets.items():
G.add_edge(source, target, **attrs)
return G


def make_serializable(data: dict) -> dict:
"""Make a graph dictionary serializable."""
if isinstance(data, dict):
return {k: make_serializable(v) for k, v in data.items()}
elif isinstance(data, list | tuple | set):
return [make_serializable(v) for v in data]
elif isinstance(data, np.ndarray):
return data.tolist()
elif isinstance(data, np.generic | np.number):
return data.item()
elif isinstance(data, pd.Series):
return data.tolist()
else:
return data
25 changes: 25 additions & 0 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import joblib
import networkx as nx
import numpy as np
import pandas as pd
import pytest

import treedata as td
Expand Down Expand Up @@ -50,6 +51,30 @@ def test_h5ad_readwrite(tdata, tmp_path, backed):
assert tdata2.filename == file_path


def test_h5ad_dtypes(tdata, tmp_path):
file_path = tmp_path / "test.h5ad"
tdata.obst["1"].nodes["root"]["list"] = [1, 2, 3]
tdata.obst["1"].nodes["root"]["tuple"] = (1, 2, 3)
tdata.obst["1"].nodes["root"]["set"] = {1, 2, 3}
tdata.obst["1"].nodes["root"]["np_float"] = np.float64(1.0)
tdata.obst["1"].nodes["root"]["np_array"] = np.array([[1, 2], [3, 4]])
tdata.obst["1"].nodes["root"]["pd_series"] = pd.Series(["1", "2", "3"])
tdata.write_h5ad(file_path)
tdata2 = td.read_h5ad(file_path)
assert tdata2.obst["1"].nodes["root"]["list"] == [1, 2, 3]
assert isinstance(tdata2.obst["1"].nodes["root"]["list"], list)
assert tdata2.obst["1"].nodes["root"]["tuple"] == [1, 2, 3]
assert isinstance(tdata2.obst["1"].nodes["root"]["tuple"], list)
assert tdata2.obst["1"].nodes["root"]["set"] == [1, 2, 3]
assert isinstance(tdata2.obst["1"].nodes["root"]["set"], list)
assert tdata2.obst["1"].nodes["root"]["np_float"] == 1.0
assert isinstance(tdata2.obst["1"].nodes["root"]["np_float"], float)
assert tdata2.obst["1"].nodes["root"]["np_array"] == [[1, 2], [3, 4]]
assert isinstance(tdata2.obst["1"].nodes["root"]["np_array"], list)
assert tdata2.obst["1"].nodes["root"]["pd_series"] == ["1", "2", "3"]
assert isinstance(tdata2.obst["1"].nodes["root"]["pd_series"], list)


def test_zarr_readwrite(tdata, tmp_path):
tdata.write_zarr(tmp_path / "test.zarr")
tdata2 = td.read_zarr(tmp_path / "test.zarr")
Expand Down

0 comments on commit 7fac360

Please sign in to comment.