Skip to content

Commit

Permalink
Merge branch 'master' into rochester_dask
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray authored Aug 22, 2023
2 parents dcbc45b + 5aade66 commit 742186a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ jobs:
python -m pip install pip hatch --upgrade
python -m hatch build -t sdist -t wheel
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.8
uses: pypa/gh-action-pypi-publish@v1.8.10
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
"Topic :: Utilities",
]
dependencies = [
"awkward>=2.3.1",
"awkward>=2.3.3",
"uproot>=5.0.10",
"dask[array]>=2023.4.0",
"dask-awkward>=2023.7.1,!=2023.8.0",
Expand Down
41 changes: 27 additions & 14 deletions src/coffea/lumi_tools/lumi_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import awkward as ak
import dask_awkward as dak
from numba import types
from numba.typed import Dict

Expand Down Expand Up @@ -107,9 +108,9 @@ def __call__(self, runs, lumis):
Parameters
----------
runs : numpy.ndarray
runs : numpy.ndarray or awkward.highlevel.Array or dask_awkward.Array
Vectorized list of run numbers
lumis : numpy.ndarray
lumis : numpy.ndarray or awkward.highlevel.Array or dask_awkward.Array
Vectorized list of lumiSection numbers
Returns
Expand All @@ -118,18 +119,30 @@ def __call__(self, runs, lumis):
An array of dtype `bool` where valid (run, lumi) tuples
will have their corresponding entry set ``True``.
"""
# fill numba typed dict
_masks = Dict.empty(key_type=types.uint32, value_type=types.uint32[:])
for k, v in self._masks.items():
_masks[k] = v

if isinstance(runs, ak.highlevel.Array):
runs = ak.to_numpy(runs)
if isinstance(lumis, ak.highlevel.Array):
lumis = ak.to_numpy(lumis)
mask_out = np.zeros(dtype="bool", shape=runs.shape)
LumiMask._apply_run_lumi_mask_kernel(_masks, runs, lumis, mask_out)
return mask_out

def apply(runs, lumis):
# fill numba typed dict
_masks = Dict.empty(key_type=types.uint32, value_type=types.uint32[:])
for k, v in self._masks.items():
_masks[k] = v

runs_orig = runs
if isinstance(runs, ak.highlevel.Array):
runs = ak.to_numpy(ak.typetracer.length_zero_if_typetracer(runs))
if isinstance(lumis, ak.highlevel.Array):
lumis = ak.to_numpy(ak.typetracer.length_zero_if_typetracer(lumis))
mask_out = np.zeros(dtype="bool", shape=runs.shape)
LumiMask._apply_run_lumi_mask_kernel(_masks, runs, lumis, mask_out)
if isinstance(runs_orig, ak.Array):
mask_out = ak.Array(mask_out)
if ak.backend(runs_orig) == "typetracer":
mask_out = ak.Array(mask_out.layout.to_typetracer(forget_length=True))
return mask_out

if isinstance(runs, dak.Array):
return dak.map_partitions(apply, runs, lumis)
else:
return apply(runs, lumis)

# This could be run in parallel, but windows does not support it
@staticmethod
Expand Down
14 changes: 14 additions & 0 deletions tests/test_lumi_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import awkward as ak
import cloudpickle
import dask_awkward as dak
from dask.distributed import Client

from coffea.lumi_tools import LumiData, LumiList, LumiMask
from coffea.util import numpy as np
Expand Down Expand Up @@ -56,6 +59,8 @@ def test_lumidata():


def test_lumimask():
client = Client()

lumimask = LumiMask(
"tests/samples/Cert_294927-306462_13TeV_EOY2017ReReco_Collisions17_JSON.txt"
)
Expand Down Expand Up @@ -86,6 +91,15 @@ def test_lumimask():

assert np.all(lumimask(runs, lumis) == lumimask_pickle(runs, lumis))

runs_dak = dak.from_awkward(ak.Array(runs), 1)
lumis_dak = dak.from_awkward(ak.Array(lumis), 1)
assert np.all(
client.compute(lumimask(runs_dak, lumis_dak)).result()
== lumimask_pickle(runs, lumis)
)

client.close()


def test_lumilist():
lumidata = LumiData("tests/samples/lumi_small.csv")
Expand Down

0 comments on commit 742186a

Please sign in to comment.