Skip to content

Commit

Permalink
Merge pull request #173 from jacanchaplais/feature/maskgroup-from-nes…
Browse files Browse the repository at this point in the history
…ted-172

Instantiate MaskGroups from nested mappings
  • Loading branch information
jacanchaplais authored Feb 28, 2024
2 parents f375988 + 9b83b8d commit 6428304
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
47 changes: 46 additions & 1 deletion graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _mask_neq(mask1: base.MaskLike, mask2: base.MaskLike) -> MaskArray:
return MaskArray(np.not_equal(mask1, mask2))


_IN_MASK_DICT = ty.OrderedDict[str, ty.Union[MaskArray, base.BoolVector]]
_IN_MASK_DICT = ty.Mapping[
str, ty.Union[MaskArray, base.BoolVector, ty.Iterable[bool]]
]
_MASK_DICT = ty.OrderedDict[str, MaskArray]


Expand All @@ -651,12 +653,32 @@ def _mask_dict_convert(masks: _IN_MASK_DICT) -> _MASK_DICT:
for key, val in masks.items():
if isinstance(val, MaskArray) or isinstance(val, MaskGroup):
mask = val
elif isinstance(val, cla.Mapping):
mask = MaskGroup(_mask_dict_convert(val))
else:
mask = MaskArray(val)
out_masks[key] = mask
return out_masks


def _maskgroup_equal(
group_1: "MaskGroup", group_2: "MaskGroup", check_order: bool
) -> bool:
key_struct = tuple if check_order else set
if key_struct(group_1) != key_struct(group_2):
return False
for key in group_1:
mask_1, mask_2 = group_1[key], group_2[key]
if type(mask_1) != type(mask_2):
return False
if isinstance(mask_1, MaskGroup):
if not _maskgroup_equal(mask_1, mask_2, check_order):
return False
elif not np.array_equal(mask_1.data, mask_2.data):
return False
return True


class MaskAggOp(Enum):
AND = "and"
OR = "or"
Expand Down Expand Up @@ -980,6 +1002,29 @@ def serialize(self) -> ty.Dict[str, ty.Any]:
"""
return {key: val.serialize() for key, val in self._mask_arrays.items()}

def equal_to(self, other: "MaskGroup", check_order: bool = False) -> bool:
"""Checks whether this instance is identical to ``other``
``MaskGroup``, comparing keys at all levels of nesting, and
boolean array data at the leaf level.
.. versionadded:: 0.3.9
Parameters
----------
other : MaskGroup
Other instance, against which to compare for equality.
check_order : bool
If ``True``, will check that the ordering of elements is
identical. Default is ``False``.
Returns
-------
bool
``True`` if instance is identical to ``other``, ``False``
otherwise.
"""
return _maskgroup_equal(self, other, check_order)


@define(eq=False)
class PdgArray(base.ArrayBase):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dataclasses as dc
import math
import random
import string

import numpy as np
import pytest
Expand All @@ -18,6 +19,13 @@
ZERO_TOL = 1.0e-10 # absolute tolerance for detecting zero-values


def random_alphanum(length: int) -> str:
return "".join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length)
)


@dc.dataclass
class MomentumExample:
"""Dataclass for representing the four-momentum of a particle, based
Expand Down Expand Up @@ -111,3 +119,57 @@ def test_pmu_zero_pt() -> None:
with pytest.warns(gcl.base.NumericalStabilityWarning):
phi_invalid = math.isnan(pmu_zero_pt.phi.item())
assert phi_invalid, "Azimuth is not NaN when pT is low"


def generate_tree(
max_width: int, max_depth: int, leaf_length: int
) -> gcl.MaskGroup:
"""Generates a nested MaskGroup tree structure, with random branch
widths and depths.
Parameters
----------
max_width, max_depth : int
Maximum limits on the nested tree structure.
leaf_length : int
The number of elements in the leaf MaskArrays.
Returns
-------
MaskGroup
Tree structure, with random structure, and random MaskArrays at
the leaf levels.
"""
rng = np.random.default_rng()

def generate_branch(depth: int) -> gcl.base.MaskBase:
if (depth == 0) or (
random.choice((True, False)) and not (depth == max_depth)
):
return gcl.MaskArray(
rng.integers(0, 2, size=leaf_length, dtype=np.bool_)
)
num_children = random.randint(1, max_width)
return gcl.MaskGroup(
{
random_alphanum(10): generate_branch(depth - 1)
for _ in range(num_children)
}
)

return generate_branch(max_depth)


def test_maskgroup_serialize_inverse() -> None:
"""Tests that instantiating a MaskGroup from its serialization
yields identical results to the original.
"""
invertible = True
for _ in range(10):
maskgroup = generate_tree(
max_width=5,
max_depth=10,
leaf_length=random.randint(0, 1_000),
)
invertible &= maskgroup.equal_to(gcl.MaskGroup(maskgroup.serialize()))
assert invertible, "Serializing MaskGroups is not invertible."

0 comments on commit 6428304

Please sign in to comment.