Skip to content

Commit

Permalink
Merge pull request #121 from jacanchaplais/patch/py38-120
Browse files Browse the repository at this point in the history
Downgrade syntax to Python 3.8
  • Loading branch information
jacanchaplais authored Mar 31, 2023
2 parents 9cb5c90 + ae29d14 commit 7452499
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
10 changes: 6 additions & 4 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return _array_ufunc(self, ufunc, method, *inputs, **kwargs)

def __iter__(self) -> ty.Iterator[bool]:
yield from map(bool, self.data)
return map(bool, self.data)

def copy(self) -> "MaskArray":
return self.__class__(self._data.copy())
Expand Down Expand Up @@ -705,6 +705,8 @@ def __setitem__(self, key: str, mask: base.MaskLike) -> None:
self._mask_arrays.update({key: mask})

def __bool__(self) -> bool:
if len(self) == 0:
return False
if np.shape(self.data)[0] == 0:
return False
return True
Expand Down Expand Up @@ -1176,7 +1178,7 @@ def __array_wrap__(cls, array: base.AnyVector) -> "MomentumArray":

def __iter__(self) -> ty.Iterator[MomentumElement]:
flat_vals = map(float, self._data.flatten())
elems = zip(*(flat_vals,) * 4, strict=True) # type: ignore
elems = zip(*(flat_vals,) * 4) # type: ignore
yield from it.starmap(MomentumElement, elems)

def __len__(self) -> int:
Expand Down Expand Up @@ -1379,7 +1381,7 @@ def __repr__(self) -> str:

def __iter__(self) -> ty.Iterator[ColorElement]:
flat_vals = map(int, it.chain.from_iterable(self.data))
elems = zip(*(flat_vals,) * 2, strict=True) # type: ignore
elems = zip(*(flat_vals,) * 2)
yield from it.starmap(ColorElement, elems)

@property
Expand Down Expand Up @@ -1857,7 +1859,7 @@ def __repr__(self) -> str:

def __iter__(self) -> ty.Iterator[VertexPair]:
flat_vals = map(int, it.chain.from_iterable(self._data))
elems = zip(*(flat_vals,) * 2, strict=True) # type: ignore
elems = zip(*(flat_vals,) * 2)
yield from it.starmap(VertexPair, elems)

def __len__(self) -> int:
Expand Down
14 changes: 13 additions & 1 deletion graphicle/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
MaskGeneric = ty.TypeVar("MaskGeneric", gcl.MaskGroup, gcl.MaskArray, MaskType)


def _param_check(
param: ty.Any, name: str, expected: ty.Type
) -> ty.Optional[ty.NoReturn]:
if not isinstance(param, expected):
received = type(param)
raise ValueError(
f"Expected {name} to be {expected}. Received {received}."
)


def fastjet_clusters(
pmu: gcl.MomentumArray,
radius: float,
Expand Down Expand Up @@ -95,6 +105,7 @@ def fastjet_clusters(
``p_val`` set to ``-1`` gives **anti-kT**, ``0`` gives
**Cambridge-Aachen**, and ``1`` gives **kT** clusterings.
"""
_param_check(pmu, "pmu", gcl.MomentumArray)
pmu_pyjet = pmu.data[["e", "x", "y", "z"]]
pmu_pyjet.dtype.names = "E", "px", "py", "pz"
pmu_pyjet_idx = rfn.append_fields(
Expand Down Expand Up @@ -925,6 +936,7 @@ def centroid_prune(
Mask which retains only the particles within ``radius`` of the
centroid.
"""
_param_check(pmu, "pmu", gcl.MomentumArray)
if mask is not None:
pmu = pmu[mask]
event_mask = np.zeros_like(mask, "<?")
Expand Down Expand Up @@ -1097,7 +1109,7 @@ def clusters(
parton_pmus = map(op.getitem, it.repeat(graph.pmu), parton_masks)
parton_centroids = map(op.attrgetter("eta", "phi"), parton_pmus)
for leaf, centroid in zip(colored_leaves, parton_centroids):
leaf[...] = centroid_prune(graph.pmu, radius, leaf, centroid)
leaf.data[...] = centroid_prune(graph.pmu, radius, leaf, centroid).data
hier.recursive_drop(inplace=True)
flat_hier = hier.flatten("rise")
flat_hier_final = map(op.itemgetter(graph.final), flat_hier.values())
Expand Down

0 comments on commit 7452499

Please sign in to comment.