Skip to content

Commit

Permalink
Merge branch 'main' into LigandNetwork_delete_edges
Browse files Browse the repository at this point in the history
  • Loading branch information
dotsdl authored Dec 17, 2024
2 parents 18809a7 + 128da4a commit 67ec107
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 28 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- openff-models >=0.0.5
- pip
- pydantic >1
- zstandard
- pytest
- pytest-cov
- pytest-xdist
Expand Down
16 changes: 16 additions & 0 deletions gufe/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import zstandard as zstd
from zstandard import ZstdError


def zst_compress(data: bytes):
compressor = zstd.ZstdCompressor()
return compressor.compress(data)


def zst_decompress(data: bytes):
try:
decompressor = zstd.ZstdDecompressor()
return decompressor.decompress(data)
# need to ensure backwards compatibility for noncompressed artifacts until gufe 2.0
except ZstdError:
return data
5 changes: 3 additions & 2 deletions gufe/custom_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openff.units import DEFAULT_UNIT_REGISTRY

import gufe
from gufe.compression import zst_compress, zst_decompress
from gufe.custom_json import JSONCodec
from gufe.settings.models import SettingsBaseModel

Expand Down Expand Up @@ -62,8 +63,8 @@ def is_openff_quantity_dict(dct):

BYTES_CODEC = JSONCodec(
cls=bytes,
to_dict=lambda obj: {"latin-1": obj.decode("latin-1")},
from_dict=lambda dct: dct["latin-1"].encode("latin-1"),
to_dict=lambda obj: {"latin-1": zst_compress(obj).decode("latin-1")},
from_dict=lambda dct: zst_decompress(dct["latin-1"].encode("latin-1")),
)


Expand Down
80 changes: 54 additions & 26 deletions gufe/tests/test_custom_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Portions Copyright (c) 2014-2022 the contributors to OpenPathSampling
# Permissions are the same as those listed in the gufe LICENSE

import abc
import json
import pathlib
from uuid import uuid4
Expand All @@ -24,7 +25,7 @@
SETTINGS_CODEC,
UUID_CODEC,
)
from gufe.custom_json import JSONSerializerDeserializer, custom_json_factory
from gufe.custom_json import JSONCodec, JSONSerializerDeserializer, custom_json_factory
from gufe.settings import models


Expand Down Expand Up @@ -65,17 +66,21 @@ def test_numpy_codec_order_roundtrip(obj, codecs):
assert obj.dtype == reconstructed.dtype


class CustomJSONCodingTest:
class CustomJSONCodingTest(abc.ABC):
"""Base class for testing codecs.
In ``setup_method()``, user must define the following:
* ``self.codec``: The codec to run
* ``self.objs``: A list of objects to serialize
* ``self.dcts``: A list of expected serilized forms of each object in
``self.objs``
* ``self.dcts``: A list of expected serialized forms of each object in ``self.objs``
* ``self.required_codecs``: A list of all codecs required to serialize objects of this type
"""

@abc.abstractmethod
def setup_method(self):
return NotImplementedError

def test_default(self):
for obj, dct in zip(self.objs, self.dcts):
assert self.codec.default(obj) == dct
Expand All @@ -93,35 +98,43 @@ def _test_round_trip(self, encoder, decoder):
assert json_str == json_str_2

def test_round_trip(self):
encoder, decoder = custom_json_factory([self.codec])
encoder, decoder = custom_json_factory(self.required_codecs)
self._test_round_trip(encoder, decoder)

def test_legacy_bytes_uncompressed(self):
# NOTE: this can be removed in `gufe` 2.0, but is also somewhat harmless
legacy_bytes_codec = JSONCodec(
cls=bytes,
to_dict=lambda obj: {"latin-1": obj.decode("latin-1")},
from_dict=lambda dct: dct["latin-1"].encode("latin-1"),
)

required_codecs = [codec for codec in self.required_codecs if not codec is BYTES_CODEC]

legacy_encoder, _ = custom_json_factory([legacy_bytes_codec, *required_codecs])
_, decoder = custom_json_factory([BYTES_CODEC, *required_codecs])

self._test_round_trip(legacy_encoder, decoder)

def test_not_mine(self):
# test that the default behavior is obeyed
obj = {"test": 5}
json_str = '{"test": 5}'
encoder, decoder = custom_json_factory([self.codec])
encoder, decoder = custom_json_factory(self.required_codecs)
assert json.dumps(obj, cls=encoder) == json_str
assert json.loads(json_str, cls=decoder) == obj


class TestNumpyCoding(CustomJSONCodingTest):
def setup_method(self):
self.codec = NUMPY_CODEC
self.required_codecs = [self.codec, BYTES_CODEC]
self.objs = [
np.array([[1.0, 0.0], [2.0, 3.2]]),
np.array([1, 0]),
np.array([1.0, 2.0, 3.0], dtype=np.float32),
]
shapes = [
[2, 2],
[
2,
],
[
3,
],
]
shapes = [[2, 2], [2], [3]] # fmt: skip
dtypes = [str(arr.dtype) for arr in self.objs] # may change by system?
byte_reps = [arr.tobytes() for arr in self.objs]
self.dcts = [
Expand All @@ -142,8 +155,7 @@ def test_object_hook(self):
reconstructed = self.codec.object_hook(dct)
npt.assert_array_equal(reconstructed, obj)

def test_round_trip(self):
encoder, decoder = custom_json_factory([self.codec, BYTES_CODEC])
def _test_round_trip(self, encoder, decoder):
for obj, dct in zip(self.objs, self.dcts):
json_str = json.dumps(obj, cls=encoder)
reconstructed = json.loads(json_str, cls=decoder)
Expand All @@ -156,6 +168,7 @@ def test_round_trip(self):
class TestNumpyGenericCodec(TestNumpyCoding):
def setup_method(self):
self.codec = NPY_DTYPE_CODEC
self.required_codecs = [self.codec, BYTES_CODEC]
# Note that np.float64 is treated as a float by the
# default json encode (and so returns a float not a numpy
# object).
Expand Down Expand Up @@ -184,9 +197,25 @@ def setup_method(self):
]


class TestBytesCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = BYTES_CODEC
self.required_codecs = [self.codec]
self.objs = [b"a test string"]
self.dcts = [
{
":is_custom:": True,
"__class__": "bytes",
"__module__": "builtins",
"latin-1": "(µ/ý \ri\x00\x00a test string",
}
]


class TestPathCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = PATH_CODEC
self.required_codecs = [self.codec]
self.objs = [
pathlib.PosixPath("foo/bar"),
]
Expand All @@ -203,6 +232,11 @@ def setup_method(self):
class TestSettingsCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = SETTINGS_CODEC
self.required_codecs = [
self.codec,
OPENFF_QUANTITY_CODEC,
OPENFF_UNIT_CODEC,
]
self.objs = [
models.Settings.get_defaults(),
]
Expand Down Expand Up @@ -260,15 +294,6 @@ def setup_method(self):
},
}
]
self.required_codecs = [
self.codec,
OPENFF_QUANTITY_CODEC,
OPENFF_UNIT_CODEC,
]

def test_round_trip(self):
encoder, decoder = custom_json_factory(self.required_codecs)
self._test_round_trip(encoder, decoder)

def test_full_dump(self):
encoder, _ = custom_json_factory(self.required_codecs)
Expand All @@ -281,6 +306,7 @@ def test_full_dump(self):
class TestOpenFFQuantityCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = OPENFF_QUANTITY_CODEC
self.required_codecs = [self.codec]
self.objs = [
openff.units.DEFAULT_UNIT_REGISTRY("1.0 * kg meter per second squared"),
]
Expand Down Expand Up @@ -308,6 +334,7 @@ def test_openff_quantity_array_roundtrip():
class TestOpenFFUnitCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = OPENFF_UNIT_CODEC
self.required_codecs = [self.codec]
self.objs = [
openff.units.unit.amu,
]
Expand All @@ -323,6 +350,7 @@ def setup_method(self):
class TestUUIDCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = UUID_CODEC
self.required_codecs = [self.codec]
self.objs = [uuid4()]
self.dcts = [
{
Expand Down
23 changes: 23 additions & 0 deletions news/zstd_compression.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* JSON encoder now uses `zstandard compression <https://github.com/OpenFreeEnergy/gufe/pull/438>`_ .

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>

0 comments on commit 67ec107

Please sign in to comment.