diff --git a/src/coffea/nanoevents/methods/physlite.py b/src/coffea/nanoevents/methods/physlite.py index c0efcdc39..751c5d03f 100644 --- a/src/coffea/nanoevents/methods/physlite.py +++ b/src/coffea/nanoevents/methods/physlite.py @@ -2,6 +2,7 @@ from numbers import Number import awkward +import dask_awkward import numpy from coffea.nanoevents.methods import base, vector @@ -38,7 +39,30 @@ def _element_link(target_collection, eventindex, index, key): return target_collection._apply_global_index(global_index) +def _element_link_method(self, link_name, target_name, _dask_array_): + if _dask_array_ is not None: + target = _dask_array_.behavior["__original_array__"]()[target_name] + links = _dask_array_[link_name] + return _element_link( + target, + _dask_array_._eventindex, + links.m_persIndex, + links.m_persKey, + ) + links = self[link_name] + return _element_link( + self._events()[target_name], + self._eventindex, + links.m_persIndex, + links.m_persKey, + ) + + def _element_link_multiple(events, obj, link_field, with_name=None): + # currently not working in dask because: + # - we don't know the resulting type beforehand + # - also not the targets, so no way to find out which columns to load? + # - could consider to treat the case of truth collections by just loading all truth columns link = obj[link_field] key = link.m_persKey index = link.m_persIndex @@ -64,22 +88,46 @@ def where(unique_keys): return out -def _get_target_offsets(offsets, event_index): +def _get_target_offsets(load_column, event_index): + if isinstance(load_column, dask_awkward.Array) and isinstance( + event_index, dask_awkward.Array + ): + # wrap in map_partitions if dask arrays + return dask_awkward.map_partitions( + _get_target_offsets, load_column, event_index + ) + + offsets = load_column.layout.offsets.data + if isinstance(event_index, Number): return offsets[event_index] + # let the necessary column optimization know that we need to load this + # column to get the offsets + if awkward.backend(load_column) == "typetracer": + awkward.typetracer.touch_data(load_column) + + # necessary to stick it into the `NumpyArray` constructor + # if typetracer is passed through + offsets = awkward.typetracer.length_zero_if_typetracer( + load_column.layout.offsets.data + ) + def descend(layout, depth, **kwargs): if layout.purelist_depth == 1: return awkward.contents.NumpyArray(offsets)[layout] - return awkward.transform(descend, event_index) + return awkward.transform(descend, event_index.layout) def _get_global_index(target, eventindex, index): - load_column = target[ - target.fields[0] - ] # awkward is eager-mode now (will need to dask this) - target_offsets = _get_target_offsets(load_column.layout.offsets, eventindex) + for field in target.fields: + # fetch first column to get offsets from + # (but try to avoid the double-jagged ones if possible) + load_column = target[field] + if load_column.ndim < 3: + break + target_offsets = _get_target_offsets(load_column, eventindex) return target_offsets + index @@ -140,12 +188,12 @@ class Muon(Particle): """ @property - def trackParticle(self): - return _element_link( - self._events().CombinedMuonTrackParticles, - self._eventindex, - self["combinedTrackParticleLink.m_persIndex"], - self["combinedTrackParticleLink.m_persKey"], + def trackParticle(self, _dask_array_=None): + return _element_link_method( + self, + "combinedTrackParticleLink", + "CombinedMuonTrackParticles", + _dask_array_, ) @@ -159,21 +207,25 @@ class Electron(Particle): """ @property - def trackParticles(self): - links = self.trackParticleLinks - return _element_link( - self._events().GSFTrackParticles, - self._eventindex, - links.m_persIndex, - links.m_persKey, + def trackParticles(self, _dask_array_=None): + return _element_link_method( + self, "trackParticleLinks", "GSFTrackParticles", _dask_array_ ) @property - def trackParticle(self): - trackParticles = self.trackParticles - return self.trackParticles[ - tuple([slice(None) for i in range(trackParticles.ndim - 1)] + [0]) - ] + def trackParticle(self, _dask_array_=None): + trackParticles = _element_link_method( + self, "trackParticleLinks", "GSFTrackParticles", _dask_array_ + ) + # Ellipsis (..., 0) slicing not supported yet by dask_awkward + slicer = tuple([slice(None) for i in range(trackParticles.ndim - 1)] + [0]) + return trackParticles[slicer] + + @property + def caloClusters(self, _dask_array_=None): + return _element_link_method( + self, "caloClusterLinks", "CaloCalTopoClusters", _dask_array_ + ) _set_repr_name("Electron") diff --git a/src/coffea/nanoevents/schemas/physlite.py b/src/coffea/nanoevents/schemas/physlite.py index c45240d6a..368dc7d8a 100644 --- a/src/coffea/nanoevents/schemas/physlite.py +++ b/src/coffea/nanoevents/schemas/physlite.py @@ -53,6 +53,7 @@ class PHYSLITESchema(BaseSchema): "GSFTrackParticles": "TrackParticle", "InDetTrackParticles": "TrackParticle", "MuonSpectrometerTrackParticles": "TrackParticle", + "CaloCalTopoClusters": "NanoCollection", } """Default configuration for mixin types, based on the collection name. diff --git a/tests/test_nanoevents_physlite.py b/tests/test_nanoevents_physlite.py index f82471198..95f58491d 100644 --- a/tests/test_nanoevents_physlite.py +++ b/tests/test_nanoevents_physlite.py @@ -1,19 +1,17 @@ import os -import numpy as np +import dask import pytest from coffea.nanoevents import NanoEventsFactory, PHYSLITESchema -pytestmark = pytest.mark.skip(reason="uproot is upset with this file...") - def _events(): path = os.path.abspath("tests/samples/DAOD_PHYSLITE_21.2.108.0.art.pool.root") factory = NanoEventsFactory.from_root( {path: "CollectionTree"}, schemaclass=PHYSLITESchema, - permit_dask=False, + permit_dask=True, ) return factory.events() @@ -23,54 +21,21 @@ def events(): return _events() +def test_load_single_field_of_linked(events): + with dask.config.set({"awkward.raise-failed-meta": True}): + events.Electrons.caloClusters.calE.compute() + + @pytest.mark.parametrize("do_slice", [False, True]) def test_electron_track_links(events, do_slice): if do_slice: - events = events[np.random.randint(2, size=len(events)).astype(bool)] - for event in events: - for electron in event.Electrons: + events = events[::2] + trackParticles = events.Electrons.trackParticles.compute() + for i, event in enumerate(events[["Electrons", "GSFTrackParticles"]].compute()): + for j, electron in enumerate(event.Electrons): for link_index, link in enumerate(electron.trackParticleLinks): track_index = link.m_persIndex - print(track_index) - print(event.GSFTrackParticles) - print(electron.trackParticleLinks) - print(electron.trackParticles) - assert ( event.GSFTrackParticles[track_index].z0 - == electron.trackParticles[link_index].z0 - ) - - -# from MetaData/EventFormat -_hash_to_target_name = { - 13267281: "TruthPhotons", - 342174277: "TruthMuons", - 368360608: "TruthNeutrinos", - 375408000: "TruthTaus", - 394100163: "TruthElectrons", - 614719239: "TruthBoson", - 660928181: "TruthTop", - 779635413: "TruthBottom", -} - - -def test_truth_links_toplevel(events): - children_px = events.TruthBoson.children.px - for i_event, event in enumerate(events): - for i_particle, particle in enumerate(event.TruthBoson): - for i_link, link in enumerate(particle.childLinks): - assert ( - event[_hash_to_target_name[link.m_persKey]][link.m_persIndex].px - == children_px[i_event][i_particle][i_link] - ) - - -def test_truth_links(events): - for i_event, event in enumerate(events): - for i_particle, particle in enumerate(event.TruthBoson): - for i_link, link in enumerate(particle.childLinks): - assert ( - event[_hash_to_target_name[link.m_persKey]][link.m_persIndex].px - == particle.children[i_link].px + == trackParticles[i][j][link_index].z0 )