Skip to content

Commit

Permalink
channel: add global cache for channels
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Nov 13, 2024
1 parent fccb0b1 commit 075117c
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 50 deletions.
1 change: 1 addition & 0 deletions lumicks/pylake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .fitting.fit import FdFit
from .image_stack import ImageStack, CorrelatedStack
from .file_download import *
from .detail.caching import set_cache_enabled
from .fitting.models import *
from .fitting.parameter_trace import parameter_trace
from .kymotracker.kymotracker import *
Expand Down
84 changes: 38 additions & 46 deletions lumicks/pylake/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import numpy.typing as npt

from .detail import caching
from .detail.plotting import _annotate
from .detail.timeindex import to_seconds, to_timestamp
from .detail.utilities import downsample
Expand Down Expand Up @@ -657,7 +658,7 @@ def range_selector(self, show=True, **kwargs) -> SliceRangeSelectorWidget:
return SliceRangeSelectorWidget(self, show=show, **kwargs)


class Continuous:
class Continuous(caching.LazyCacheMixin):
"""A source of continuous data for a timeline slice
Parameters
Expand All @@ -671,8 +672,8 @@ class Continuous:
"""

def __init__(self, data, start, dt):
super().__init__()
self._src_data = data
self._cached_data = None
self.start = start
self.stop = start + len(data) * dt
self.dt = dt # ns
Expand All @@ -693,7 +694,7 @@ def from_dataset(dset, y_label="y", calibration=None):
start = dset.attrs["Start time (ns)"]
dt = int(1e9 / dset.attrs["Sample rate (Hz)"]) # ns
return Slice(
Continuous(dset, start, dt),
Continuous(caching.from_h5py(dset), start, dt),
labels={"title": dset.name.strip("/"), "y": y_label},
calibration=calibration,
)
Expand All @@ -719,9 +720,7 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data
return self.read_lazy_cache("data", self._src_data)

@property
def timestamps(self) -> npt.ArrayLike:
Expand Down Expand Up @@ -755,7 +754,7 @@ def downsampled_by(self, factor, reduce):
)


class TimeSeries:
class TimeSeries(caching.LazyCacheMixin):
"""A source of time series data for a timeline slice
Parameters
Expand All @@ -778,10 +777,9 @@ def __init__(self, data, timestamps):
f"({len(timestamps)})."
)

super().__init__()
self._src_data = data
self._cached_data = None
self._src_timestamps = timestamps
self._cached_timestamps = None

def __len__(self):
return len(self._src_data)
Expand All @@ -796,32 +794,8 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y", calibration=None) -> Slice:
class LazyLoadedCompoundField:
"""Wrapper to enable lazy loading of HDF5 compound datasets
Notes
-----
We only need to support the methods `__array__()` and `__len__()`, as we only access
`LazyLoadedCompoundField` via the properties `TimeSeries.data`, `timestamps` and the
method `__len__()`.
`LazyLoadCompoundField` might be replaced with `dset.fields(fieldname)` if and when the
returned `FieldsWrapper` object provides an `__array__()` method itself"""

def __init__(self, dset, fieldname):
self._dset = dset
self._fieldname = fieldname

def __array__(self):
"""Get the data of the field as an array"""
return self._dset[self._fieldname]

def __len__(self):
"""Get the length of the underlying dataset"""
return len(self._dset)

data = LazyLoadedCompoundField(dset, "Value")
timestamps = LazyLoadedCompoundField(dset, "Timestamp")
data = caching.from_h5py(dset, field="Value")
timestamps = caching.from_h5py(dset, field="Timestamp")
return Slice(
TimeSeries(data, timestamps),
labels={"title": dset.name.strip("/"), "y": y_label},
Expand Down Expand Up @@ -850,15 +824,11 @@ def to_dataset(self, parent, name, **kwargs):

@property
def data(self) -> npt.ArrayLike:
if self._cached_data is None:
self._cached_data = np.asarray(self._src_data)
return self._cached_data
return self.read_lazy_cache("data", self._src_data)

@property
def timestamps(self) -> npt.ArrayLike:
if self._cached_timestamps is None:
self._cached_timestamps = np.asarray(self._src_timestamps)
return self._cached_timestamps
return self.read_lazy_cache("timestamps", self._src_timestamps)

@property
def start(self):
Expand Down Expand Up @@ -893,7 +863,7 @@ def downsampled_by(self, factor, reduce):
raise NotImplementedError("Downsampling is currently not available for time series data")


class TimeTags:
class TimeTags(caching.LazyCacheMixin):
"""A source of time tag data for a timeline slice
Parameters
Expand All @@ -907,13 +877,32 @@ class TimeTags:
"""

def __init__(self, data, start=None, stop=None):
self.data = np.asarray(data, dtype=np.int64)
self.start = start if start is not None else (self.data[0] if self.data.size > 0 else 0)
self.stop = stop if stop is not None else (self.data[-1] + 1 if self.data.size > 0 else 0)
super().__init__()
self._src_data = data
self._start = start
self._stop = stop

def __len__(self):
return self.data.size

@property
def start(self):
return (
self._start if self._start is not None else (self.data[0] if self.data.size > 0 else 0)
)

@property
def stop(self):
return (
self._stop
if self._stop is not None
else (self.data[-1] + 1 if self.data.size > 0 else 0)
)

@property
def data(self):
return self.read_lazy_cache("data", self._src_data)

def _with_data(self, data):
raise NotImplementedError("Time tags do not currently support this operation")

Expand All @@ -922,7 +911,10 @@ def _apply_mask(self, mask):

@staticmethod
def from_dataset(dset, y_label="y"):
return Slice(TimeTags(dset))
return Slice(
TimeTags(caching.from_h5py(dset)),
labels={"title": dset.name.strip("/"), "y": y_label},
)

def to_dataset(self, parent, name, **kwargs):
"""Save this to an h5 dataset
Expand Down
96 changes: 96 additions & 0 deletions lumicks/pylake/detail/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
from cachetools import LRUCache, cached

global_cache = False


def set_cache_enabled(enabled):
"""Enable or disable the global cache
Pylake offers a global cache. When the global cache is enabled, all `Slice` objects come from
the same cache.
Parameters
----------
enabled : bool
Whether the caching should be enabled (by default it is off)
"""
global global_cache
global_cache = enabled


@cached(LRUCache(maxsize=1 << 30, getsizeof=lambda x: x.nbytes), info=True) # 1 GB of cache
def _get_array(cache_object):
return cache_object.read_array()


class LazyCache:
def __init__(self, location, dset):
"""A lazy globally cached wrapper around an object that is convertible to a numpy array"""
self._location = location
self._dset = dset

def __len__(self):
return len(self._dset)

def __hash__(self):
return hash(self._location)

@staticmethod
def from_h5py_dset(dset, field=None):
location = f"{dset.file.filename}{dset.name}"
if field:
location = f"{location}.{field}"
dset = dset.fields(field)

return LazyCache(location, dset)

def read_array(self):
# Note, we deliberately do _not_ allow additional arguments to asarray since we would
# have to hash those with and unless necessary, they would unnecessarily increase the
# cache (because of sometimes defensively adding an explicit type). It's better to raise
# in this case and end up at this comment.
arr = np.asarray(self._dset)
arr.flags.writeable = False
return arr

def __eq__(self, other):
return self._location == other._location

def __array__(self):
return _get_array(self)


def from_h5py(dset, field=None):
global global_cache
return (
LazyCache.from_h5py_dset(dset, field=field)
if global_cache
else dset.fields(field) if field else dset
)


class LazyCacheMixin:
def __init__(self):
self._cache = {}

def read_lazy_cache(self, key, src_field):
"""A small convenience decorator to incorporate a lazy cache for properties.
Data will be stored in the `_cache` variable of the instance.
Parameters
----------
key : str
Key to use when caching this data
src_field : LazyCache or dset
Source field to read from
"""
global global_cache

if global_cache:
return np.asarray(src_field)

if key not in self._cache:
self._cache[key] = np.asarray(src_field)

return self._cache[key]
3 changes: 2 additions & 1 deletion lumicks/pylake/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
import warnings
from typing import Dict

Expand Down Expand Up @@ -50,7 +51,7 @@ class File(Group, Force, DownsampledFD, BaselineCorrectedForce, PhotonCounts, Ph
def __init__(self, filename, *, rgb_to_detectors=None):
import h5py

super().__init__(h5py.File(filename, "r"), lk_file=self)
super().__init__(h5py.File(pathlib.Path(filename).absolute(), "r"), lk_file=self)
self._check_file_format()
self._rgb_to_detectors = self._get_detector_mapping(rgb_to_detectors)

Expand Down
14 changes: 14 additions & 0 deletions lumicks/pylake/tests/test_channels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,17 @@ def channel_h5_file(tmpdir_factory, request):
mock_file.make_continuous_channel("Photon count", "Red", np.int64(20e9), freq, counts)

return mock_file.file


@pytest.fixture(autouse=True, scope="module", params=[False, True])
def cache_setting(request):
from copy import deepcopy

from lumicks.pylake.detail.caching import global_cache, set_cache_enabled

old_value = deepcopy(global_cache)
try:
set_cache_enabled(request.param)
yield
finally:
set_cache_enabled(old_value)
6 changes: 5 additions & 1 deletion lumicks/pylake/tests/test_channels/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib as mpl

from lumicks.pylake import channel
from lumicks.pylake.detail import caching
from lumicks.pylake.low_level import make_continuous_slice
from lumicks.pylake.calibration import ForceCalibrationList

Expand Down Expand Up @@ -893,7 +894,10 @@ def test_annotation_bad_item():

def test_regression_lazy_loading(channel_h5_file):
ch = channel.Continuous.from_dataset(channel_h5_file["Force HF"]["Force 1x"])
assert isinstance(ch._src._src_data, h5py.Dataset)
if caching.global_cache:
assert isinstance(ch._src._src_data._dset, h5py.Dataset)
else:
assert isinstance(ch._src._src_data, h5py.Dataset)


@pytest.mark.parametrize(
Expand Down
12 changes: 12 additions & 0 deletions lumicks/pylake/tests/test_file/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,15 @@ def h5_two_colors(tmpdir_factory, request):
mock_file.make_continuous_channel("Photon count", "Blue", np.int64(20e9), freq, counts)
mock_file.make_continuous_channel("Info wave", "Info wave", np.int64(20e9), freq, infowave)
return mock_file.file


@pytest.fixture(autouse=True, scope="module", params=[False, True])
def cache_setting(request):
from lumicks.pylake.detail import caching

old_value = caching.global_cache
try:
caching.set_cache_enabled(request.param)
yield
finally:
caching.set_cache_enabled(old_value)
Loading

0 comments on commit 075117c

Please sign in to comment.