Skip to content

Commit

Permalink
caching: add global cache for confocal objects
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Nov 13, 2024
1 parent 075117c commit 7a8b91f
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 118 deletions.
60 changes: 59 additions & 1 deletion lumicks/pylake/detail/caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys

import numpy as np
from cachetools import LRUCache, cached
from cachetools import LRUCache, keys, cached, cachedmethod

global_cache = False

Expand Down Expand Up @@ -94,3 +96,59 @@ def read_lazy_cache(self, key, src_field):
self._cache[key] = np.asarray(src_field)

return self._cache[key]


def _getsize(x):
return x.nbytes if isinstance(x, np.ndarray) else sys.getsizeof(x)


_method_cache = LRUCache(maxsize=1 << 30, getsizeof=_getsize) # 1 GB of cache


def method_cache(name):
"""A small convenience decorator to incorporate some really basic instance method memoization
Note: When used on properties, this one should be included _after_ the @property decorator.
Data will be stored in the `_cache` variable of the instance.
Parameters
----------
name : str
Name of the instance method to memo-ize. Suggestion: the instance method name.
Examples
--------
::
class Test:
def __init__(self):
self._cache = {}
...
@property
@method_cache("example_property")
def example_property(self):
return 10
@method_cache("example_method")
def example_method(self, arguments):
return 5
test = Test()
test.example_property
test.example_method("hi")
test._cache
# test._cache will now show {('example_property',): 10, ('example_method', 'hi'): 5}
"""

# cachetools>=5.0.0 passes self as first argument. We don't want to bump the reference count
# by including a reference to the object we're about to store the cache into, so we explicitly
# drop the first argument. Note that for the default key, they do the same in the package, but
# we can't use the default key, since it doesn't hash in the method name.
def key(self, *args, **kwargs):
return keys.hashkey(self._location, name, *args, **kwargs)

return cachedmethod(
lambda self: _method_cache if global_cache and self._location else self._cache, key=key
)
15 changes: 11 additions & 4 deletions lumicks/pylake/detail/confocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from .image import reconstruct_image, reconstruct_image_sum
from .mixin import PhotonCounts, ExcitationLaserPower
from .caching import method_cache
from .plotting import parse_color_channel
from .utilities import method_cache, could_sum_overflow
from .utilities import could_sum_overflow
from ..adjustments import no_adjustment
from .imaging_mixins import TiffExport

Expand Down Expand Up @@ -208,9 +209,11 @@ class BaseScan(PhotonCounts, ExcitationLaserPower):
End point in the relevant info wave.
metadata : ScanMetaData
Metadata.
location : str | None
Path of the confocal object.
"""

def __init__(self, name, file, start, stop, metadata):
def __init__(self, name, file, start, stop, metadata, location):
self.start = start
self.stop = stop
self.name = name
Expand All @@ -220,6 +223,7 @@ def __init__(self, name, file, start, stop, metadata):
self._timestamp_factory = _default_timestamp_factory
self._pixelsize_factory = _default_pixelsize_factory
self._pixelcount_factory = _default_pixelcount_factory
self._location = location
self._cache = {}

def _has_default_factories(self):
Expand All @@ -243,12 +247,13 @@ def from_dataset(cls, h5py_dset, file):
start = h5py_dset.attrs["Start time (ns)"]
stop = h5py_dset.attrs["Stop time (ns)"]
name = h5py_dset.name.split("/")[-1]
location = file.h5.filename + h5py_dset.name
try:
metadata = ScanMetaData.from_json(h5py_dset[()])
except KeyError:
raise KeyError(f"{cls.__name__} '{name}' is missing metadata and cannot be loaded")

return cls(name, file, start, stop, metadata)
return cls(name, file, start, stop, metadata, location)

@property
def file(self):
Expand All @@ -269,6 +274,9 @@ def __copy__(self):
start=self.start,
stop=self.stop,
metadata=self._metadata,
# If it has no location, it will be cached only locally. This is safer than implicitly
# caching it under the same location as the parent.
location=None,
)

# Preserve custom factories
Expand Down Expand Up @@ -512,5 +520,4 @@ def get_image(self, channel="rgb") -> np.ndarray:
if channel not in ("red", "green", "blue"):
return np.stack([self.get_image(color) for color in ("red", "green", "blue")], axis=-1)
else:
# Make sure we don't return a reference to our cache
return self._image(channel)
73 changes: 73 additions & 0 deletions lumicks/pylake/detail/tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from lumicks.pylake.detail import caching


@pytest.mark.parametrize(
"location, use_global_cache",
[
(None, False),
(None, True),
("test", False),
("test", True),
],
)
def test_cache_method(location, use_global_cache):
calls = 0

def call():
nonlocal calls
calls += 1

class Test:
def __init__(self, location):
self._cache = {}
self._location = location

@property
@caching.method_cache("example_property")
def example_property(self):
call()
return 10

@caching.method_cache("example_method")
def example_method(self, argument=5):
call()
return argument

old_cache = caching.global_cache
caching.set_cache_enabled(use_global_cache)
caching._method_cache.clear()
test = Test(location=location)

cache_location = caching._method_cache if use_global_cache and location else test._cache

assert len(cache_location) == 0
assert test.example_property == 10
assert len(cache_location) == 1
assert calls == 1
assert test.example_property == 10
assert calls == 1
assert len(cache_location) == 1

assert test.example_method() == 5
assert calls == 2
assert len(cache_location) == 2

assert test.example_method() == 5
assert calls == 2
assert len(cache_location) == 2

assert test.example_method(6) == 6
assert calls == 3
assert len(cache_location) == 3

assert test.example_method(6) == 6
assert calls == 3
assert len(cache_location) == 3

assert test.example_method() == 5
assert calls == 3
assert len(cache_location) == 3

caching.set_cache_enabled(old_cache)
54 changes: 0 additions & 54 deletions lumicks/pylake/detail/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,6 @@
import contextlib

import numpy as np
import cachetools


def method_cache(name):
"""A small convenience decorator to incorporate some really basic instance method memoization
Note: When used on properties, this one should be included _after_ the @property decorator.
Data will be stored in the `_cache` variable of the instance.
Parameters
----------
name : str
Name of the instance method to memo-ize. Suggestion: the instance method name.
Examples
--------
::
class Test:
def __init__(self):
self._cache = {}
...
@property
@method_cache("example_property")
def example_property(self):
return 10
@method_cache("example_method")
def example_method(self, arguments):
return 5
test = Test()
test.example_property
test.example_method("hi")
test._cache
# test._cache will now show {('example_property',): 10, ('example_method', 'hi'): 5}
"""
if int(cachetools.__version__.split(".")[0]) < 5:

def key(*args, **kwargs):
return cachetools.keys.hashkey(name, *args, **kwargs)

else:
# cachetools>=5.0.0 started passing self as first argument. We don't want to bump the
# reference count by including a reference to the object we're about to store the cache
# into, so we explicitly drop the first argument. Note that for the default key, they
# do the same in the package, but we can't use the default key, since it doesn't hash
# in the method name.
def key(_, *args, **kwargs):
return cachetools.keys.hashkey(name, *args, **kwargs)

return cachetools.cachedmethod(lambda self: self._cache, key=key)


def use_docstring_from(copy_func):
Expand Down
18 changes: 15 additions & 3 deletions lumicks/pylake/kymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
seek_timestamp_next_line,
first_pixel_sample_indices,
)
from .detail.caching import method_cache
from .detail.confocal import ScanAxis, ScanMetaData, ConfocalImage
from .detail.plotting import get_axes, show_image
from .detail.timeindex import to_timestamp
from .detail.utilities import method_cache
from .detail.bead_cropping import find_beads_template, find_beads_brightness


Expand Down Expand Up @@ -83,10 +83,22 @@ class Kymo(ConfocalImage):
Coordinate position offset with respect to the original raw data.
calibration : PositionCalibration
Class defining calibration from microns to desired position units.
location : str | None
Path of the Kymo.
"""

def __init__(self, name, file, start, stop, metadata, position_offset=0, calibration=None):
super().__init__(name, file, start, stop, metadata)
def __init__(
self,
name,
file,
start,
stop,
metadata,
location=None,
position_offset=0,
calibration=None,
):
super().__init__(name, file, start, stop, metadata, location)
self._line_time_factory = _default_line_time_factory
self._line_timestamp_ranges_factory = _default_line_timestamp_ranges_factory
self._position_offset = position_offset
Expand Down
4 changes: 3 additions & 1 deletion lumicks/pylake/low_level/low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def create_confocal_object(
metadata = ScanMetaData.from_json(json_metadata)
file = ConfocalFileProxy(infowave, red_channel, green_channel, blue_channel)
confocal_cls = {0: PointScan, 1: Kymo, 2: Scan}
return confocal_cls[metadata.num_axes](name, file, infowave.start, infowave.stop, metadata)
return confocal_cls[metadata.num_axes](
name, file, infowave.start, infowave.stop, metadata, location=None
)


def make_continuous_slice(data, start, dt, y_label="y", name="") -> Slice:
Expand Down
8 changes: 5 additions & 3 deletions lumicks/pylake/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from .adjustments import colormaps, no_adjustment
from .detail.image import make_image_title, reconstruct_num_frames, first_pixel_sample_indices
from .detail.caching import method_cache
from .detail.confocal import ConfocalImage
from .detail.plotting import get_axes, show_image
from .detail.utilities import method_cache
from .detail.imaging_mixins import FrameIndex, VideoExport


Expand All @@ -26,10 +26,12 @@ class Scan(ConfocalImage, VideoExport, FrameIndex):
End point in the relevant info wave.
metadata : ScanMetaData
Metadata.
location : str | None
Path of the Scan.
"""

def __init__(self, name, file, start, stop, metadata):
super().__init__(name, file, start, stop, metadata)
def __init__(self, name, file, start, stop, metadata, location=None):
super().__init__(name, file, start, stop, metadata, location)
if self._metadata.num_axes == 1:
raise RuntimeError("1D scans are not supported")
if self._metadata.num_axes > 2:
Expand Down
Loading

0 comments on commit 7a8b91f

Please sign in to comment.