Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpotFindingResults Serialization to Disk (and loading) #1961

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
97 changes: 95 additions & 2 deletions starfish/core/types/_spot_finding_results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple
from typing import Any, Dict, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple

import xarray as xr

from starfish.core.types import Axes, Coordinates, SpotAttributes
from starfish.core.util.logging import Log


AXES_ORDER = (Axes.ROUND, Axes.CH)


Expand Down Expand Up @@ -109,6 +110,98 @@ def values(self):
"""
return self._results.values()

def save(self, output_dir_name: str) -> None:
"""Save spot finding results to series of files.

Parameters
----------
output_dir_name: str
Location to save all files.

"""
json_data: Dict[str, Any] = {}

pwd = os.getcwd()
os.chdir(os.path.dirname(output_dir_name))
base_name = os.path.basename(output_dir_name)

coords = {}
for key in self.physical_coord_ranges.keys():
path = "{}coords_{}.nc".format(base_name, key)
coords[key] = path
self.physical_coord_ranges[key].to_netcdf(path)
json_data["physical_coord_ranges"] = coords

path = "{}log.arr"
json_data["log"] = {}
json_data["log"]["path"] = path.format(base_name)
with open(path.format(base_name), "w") as f:
f.write(self.log.encode())

spot_attrs = {}
for key in self._results.keys():
path = "{}spots_{}_{}.nc".format(base_name, key[0], key[1])
spot_attrs["{}_{}".format(key[0], key[1])] = path
self._results[key].spot_attrs.save(path)
json_data["spot_attrs"] = spot_attrs

save = json.dumps(json_data)
with open("{}SpotFindingResults.json".format(base_name), "w") as f:
f.write(save)

os.chdir(pwd)

@classmethod
def load(cls, json_file: str):
"""Load serialized spot finding results.

Parameters:
-----------
json_file: str
json file to read

Returns:
--------
SpotFindingResults:
Object containing loaded results

"""
fl = open(json_file)
data = json.load(fl)
pwd = os.getcwd()

os.chdir(os.path.dirname(json_file))

with open(data["log"]["path"]) as f:
txt = json.load(f)['log']
txt = json.dumps(txt)
log = Log.decode(txt)

rename_axes = {
'x': Coordinates.X.value,
'y': Coordinates.Y.value,
'z': Coordinates.Z.value
}
coords = {}
for coord, path in data["physical_coord_ranges"].items():
coords[rename_axes[coord]] = xr.load_dataarray(path)

spot_attributes_list = []
for key, path in data["spot_attrs"].items():
zero = int(key.split("_")[0])
one = int(key.split("_")[1])
index = {AXES_ORDER[0]: zero, AXES_ORDER[1]: one}
spots = SpotAttributes.load(path)
spot_attributes_list.append((PerImageSliceSpotResults(spots, extras=None), index))

os.chdir(pwd)

return SpotFindingResults(
imagestack_coords=coords,
log=log,
spot_attributes_list=spot_attributes_list
)

@property
def round_labels(self):
"""
Expand Down
69 changes: 69 additions & 0 deletions starfish/core/types/test/test_saving_spots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import tempfile

import numpy as np
import pandas as pd
import xarray as xr

from starfish.types import Axes, Coordinates, Features
from starfish.core.types import PerImageSliceSpotResults, SpotAttributes, SpotFindingResults
from starfish.core.util.logging import Log

def dummy_spots() -> SpotFindingResults:
rounds = 4
channels = 3
spot_count = 100
img_dim = {'x': 2048, 'y': 2048, 'z': 29}

coords = {}
renameAxes = {
'x': Coordinates.X.value,
'y': Coordinates.Y.value,
'z': Coordinates.Z.value
}
for dim in img_dim.keys():
coords[renameAxes[dim]] = xr.DataArray(np.arange(0, 1, img_dim[dim]))

log = Log()

spot_attributes_list = []
for r in range(rounds):
for c in range(channels):
index = {Axes.ROUND: r, Axes.CH: c}
spots = SpotAttributes(pd.DataFrame(
np.random.randint(0, 100, size=(spot_count, 4)),
columns=[Axes.X.value,
Axes.Y.value,
Axes.ZPLANE.value,
Features.SPOT_RADIUS]
))
spot_attributes_list.append(
(PerImageSliceSpotResults(spots, extras=None), index)
)

return SpotFindingResults(
imagestack_coords=coords,
log=log,
spot_attributes_list=spot_attributes_list
)

def test_saving_spots() -> None:
data = dummy_spots()

# test serialization
tempdir = tempfile.mkdtemp()
print(tempdir)
data.save(tempdir + "/")

# load back into memory
data2 = SpotFindingResults.load(os.path.join(tempdir, 'SpotFindingResults.json'))

# ensure all items are equal
assert data.keys() == data2.keys()
assert data._log.encode() == data2._log.encode()
for ax in data.physical_coord_ranges.keys():
np.testing.assert_equal(data.physical_coord_ranges[ax].to_numpy(),
data2.physical_coord_ranges[ax].to_numpy())
for k in data._results.keys():
np.testing.assert_array_equal(data._results[k].spot_attrs.data,
data2._results[k].spot_attrs.data)