diff --git a/coffea/jetmet_tools/CorrectedJetsFactory.py b/coffea/jetmet_tools/CorrectedJetsFactory.py index f6d07a171..fa3e58931 100644 --- a/coffea/jetmet_tools/CorrectedJetsFactory.py +++ b/coffea/jetmet_tools/CorrectedJetsFactory.py @@ -3,7 +3,7 @@ import warnings from functools import partial, reduce import operator -import correctionlib as clib +from topcoffea.modules.JECStack import JECStack _stack_parts = ["jec", "junc", "jer", "jersf"] _MIN_JET_ENERGY = numpy.array(1e-2, dtype=numpy.float32) @@ -17,14 +17,11 @@ "primitive": "float32", } - -# we're gonna assume that the first record array we encounter is the flattened data def rewrap_recordarray(layout, depth, data): if isinstance(layout, awkward.layout.RecordArray): return lambda: data return None - def awkward_rewrap(arr, like_what, gfunc): behavior = awkward._util.behaviorof(like_what) func = partial(gfunc, data=arr.layout) @@ -32,7 +29,6 @@ def awkward_rewrap(arr, like_what, gfunc): newlayout = awkward._util.recursively_apply(layout, func) return awkward._util.wrap(newlayout, behavior=behavior) - def rand_gauss(item, randomstate): def getfunction(layout, depth): if isinstance(layout, awkward.layout.NumpyArray) or not isinstance( @@ -42,14 +38,12 @@ def getfunction(layout, depth): randomstate.normal(size=len(layout)).astype(numpy.float32) ) return None - out = awkward._util.recursively_apply( awkward.operations.convert.to_layout(item), getfunction ) assert out is not None return awkward._util.wrap(out, awkward._util.behaviorof(item)) - def jer_smear( variation, forceStochastic, @@ -61,10 +55,8 @@ def jer_smear( jet_energy_resolution_scale_factor, ): pt_gen = pt_gen if not forceStochastic else None - if not isinstance(jetPt, awkward.highlevel.Array): raise Exception("'jetPt' must be an awkward array of some kind!") - if forceStochastic: pt_gen = awkward.without_parameters(awkward.zeros_like(jetPt)) @@ -72,13 +64,11 @@ def jer_smear( jersf = jet_energy_resolution_scale_factor[:, variation] deltaPtRel = (jetPt - pt_gen) / jetPt doHybrid = (pt_gen > 0) & (numpy.abs(deltaPtRel) < 3 * jet_energy_resolution) - detSmear = 1 + (jersf - 1) * deltaPtRel stochSmear = 1 + numpy.sqrt(numpy.maximum(jersf**2 - 1, 0)) * jersmear min_jet_pt = _MIN_JET_ENERGY / numpy.cosh(etaJet) min_jet_pt_corr = min_jet_pt / jetPt - smearfact = awkward.where(doHybrid, detSmear, stochSmear) smearfact = awkward.where( (smearfact * jetPt) < min_jet_pt, min_jet_pt_corr, smearfact @@ -97,7 +87,6 @@ def getfunction(layout, depth): smearfact = awkward._util.wrap(smearfact, awkward._util.behaviorof(jetPt)) return smearfact - # Wrapper function to apply jec corrections def rawvar_jec(jecval, rawvar, lazy_cache): return awkward.virtual( @@ -106,7 +95,6 @@ def rawvar_jec(jecval, rawvar, lazy_cache): cache=lazy_cache, ) - def get_corr_inputs(jets, corr_obj, name_map, cache=None, corrections=None): """ Helper function for getting values of input variables @@ -114,12 +102,9 @@ def get_corr_inputs(jets, corr_obj, name_map, cache=None, corrections=None): """ if corrections is None: - input_values = [ - awkward.flatten(jets[name_map[inp.name]]) - for inp in corr_obj.inputs - if inp.name != "systematic" - ] + input_values = [awkward.flatten(jets[name_map[inp.name]]) for inp in corr_obj.inputs if (inp.name != "systematic")] else: + ## This is needed to propagate the previous level of corrections, before applying the next one input_values = [] for inp in corr_obj.inputs: if inp.name == "systematic": @@ -128,30 +113,21 @@ def get_corr_inputs(jets, corr_obj, name_map, cache=None, corrections=None): rawvar = awkward.flatten(jets[name_map[inp.name]]) init_input_value = partial(rawvar_jec, rawvar=rawvar, lazy_cache=cache) input_value = init_input_value(jecval=corrections) - else: input_value = awkward.flatten(jets[name_map[inp.name]]) input_values.append(input_value) - return input_values class CorrectedJetsFactory(object): def __init__(self, name_map, jec_stack): - # from PhysicsTools/PatUtils/interface/SmearedJetProducerT.h#L283 - if isinstance(jec_stack, list) and isinstance(jec_stack[-1], bool): - self.tool = "clib" - elif ( - type(jec_stack).__name__ == "JECStack" - and type(jec_stack).__module__ == "coffea.jetmet_tools.JECStack" - ): - self.tool = "jecstack" - else: - raise TypeError( - "jec_stack need to be either an instance of JECStack or a list containing correction-lib setup!" - ) + if not isinstance(jec_stack, JECStack): + raise TypeError("jec_stack must be an instance of JECStack") + self.tool = "clib" if jec_stack.use_clib else "jecstack" self.forceStochastic = False + + # Handle name map for raw pt and mass if "ptRaw" not in name_map or name_map["ptRaw"] is None: warnings.warn( "There is no name mapping for ptRaw," @@ -163,95 +139,71 @@ def __init__(self, name_map, jec_stack): if "massRaw" not in name_map or name_map["massRaw"] is None: warnings.warn( "There is no name mapping for massRaw," - " CorrectedJets will assume that .mass is raw pt!" + " CorrectedJets will assume that .mass is raw mass!" ) - name_map["ptRaw"] = name_map["JetMass"] + "_raw" - - if self.tool == "jecstack": - total_signature = set() - for part in _stack_parts: - attr = getattr(jec_stack, part) - if attr is not None: - total_signature.update(attr.signature) - - missing = total_signature - set(name_map.keys()) - if len(missing) > 0: - raise Exception( - f"Missing mapping of {missing} in name_map!" - + " Cannot evaluate jet corrections!" - + " Please supply mappings for these variables!" - ) + name_map["massRaw"] = name_map["JetMass"] + "_raw" self.jec_stack = jec_stack + self.name_map = name_map + + if self.jec_stack.use_clib: + # For clib scenario, load corrections from json_path + self.load_corrections_clib() + else: + # For non-clib scenario, use the provided corrections (e.g., JEC/JER) + self.load_corrections_jecstack() + if "ptGenJet" not in name_map: warnings.warn( 'Input JaggedCandidateArray must have "ptGenJet" in order to apply hybrid JER smearing method. Stochastic smearing will be applied.' ) self.forceStochastic = True - if self.tool == "clib": - self.separated = self.jec_stack.pop() - self.json_path = self.jec_stack.pop() + def load_corrections_clib(self): + """Load the corrections from correctionlib using the json_path in JECStack.""" + self.corrections = self.jec_stack.corrections - self.real_sig = [v for k, v in name_map.items()] - self.name_map = name_map + def load_corrections_jecstack(self): + """Use the corrections provided in the JECStack for non-clib scenario.""" + self.corrections = self.jec_stack.corrections - if self.tool == "clib": - self.jer_names = [ - name - for name in self.jec_stack - if isinstance(name, str) and ("Resolution" in name or "SF" in name) - ] - self.junc_names = [ - name - for name in self.jec_stack - if isinstance(name, str) and ("Uncertainty" in name) - ] - self.jec_names = [ - name - for name in self.jec_stack - if (name not in self.jer_names and name not in self.junc_names) - ] - ## General setup to use correction-lib - self.cset = clib.CorrectionSet.from_file(self.json_path) - - def uncertainties(self): - out = ["JER"] if self.jec_stack.jer is not None else [] - if self.jec_stack.junc is not None: - out.extend(["JES_{0}".format(unc) for unc in self.jec_stack.junc.levels]) - return out + # Ensure all required inputs have mappings + total_signature = set() + for part in _stack_parts: + attr = getattr(self.jec_stack, part) + if attr is not None: + total_signature.update(attr.signature) - def build(self, jets, lazy_cache): - if lazy_cache is None: + missing = total_signature - set(self.name_map.keys()) + if len(missing) > 0: raise Exception( - "CorrectedJetsFactory requires a awkward-array cache to function correctly." + f"Missing mapping of {missing} in name_map!" + + " Cannot evaluate jet corrections!" + + " Please supply mappings for these variables!" ) + + def build(self, jets, lazy_cache): + if lazy_cache is None: + raise Exception("CorrectedJetsFactory requires an awkward-array cache to function correctly.") lazy_cache = awkward._util.MappingProxy.maybe_wrap(lazy_cache) if not isinstance(jets, awkward.highlevel.Array): raise Exception("'jets' must be an awkward > 1.0.0 array of some kind!") + # THESE ARE THE ATTRIBUTES OF THE JET COLLECTION fields = awkward.fields(jets) - if len(fields) == 0: - raise Exception( - "Empty record, please pass a jet object with at least {self.real_sig} defined!" - ) + raise Exception("Empty record, please pass a jet object with at least {self.real_sig} defined!") out = awkward.flatten(jets) wrap = partial(awkward_rewrap, like_what=jets, gfunc=rewrap_recordarray) - scalar_form = awkward.without_parameters( - out[self.name_map["ptRaw"]] - ).layout.form + scalar_form = awkward.without_parameters(out[self.name_map["ptRaw"]]).layout.form in_dict = {field: out[field] for field in fields} out_dict = dict(in_dict) - # take care of nominal JEC (no JER if available) - + # Add original values out_dict[self.name_map["JetPt"] + "_orig"] = out_dict[self.name_map["JetPt"]] - out_dict[self.name_map["JetMass"] + "_orig"] = out_dict[ - self.name_map["JetMass"] - ] + out_dict[self.name_map["JetMass"] + "_orig"] = out_dict[self.name_map["JetMass"]] if self.treat_pt_as_raw: out_dict[self.name_map["ptRaw"]] = out_dict[self.name_map["JetPt"]] out_dict[self.name_map["massRaw"]] = out_dict[self.name_map["JetMass"]] @@ -260,6 +212,7 @@ def build(self, jets, lazy_cache): jec_name_map["JetPt"] = jec_name_map["ptRaw"] jec_name_map["JetMass"] = jec_name_map["massRaw"] + # Apply JEC corrections based on scenario total_correction = None if self.tool == "jecstack": if self.jec_stack.jec is not None: @@ -270,76 +223,54 @@ def build(self, jets, lazy_cache): **jec_args, form=scalar_form, lazy_cache=lazy_cache ) else: - total_correction = awkward.without_parameters( - awkward.ones_like(out_dict[self.name_map["JetPt"]]) - ) + total_correction = awkward.ones_like(out_dict[self.name_map["JetPt"]]) elif self.tool == "clib": corrections_list = [] - for lvl in self.jec_names: - if "Uncertainty" in lvl: - continue - + for lvl in self.jec_stack.jec_names_clib: cumCorr = None if len(corrections_list) > 0: ones = numpy.ones_like(corrections_list[-1], dtype=numpy.float32) - cumCorr = reduce(lambda x, y: y * x, corrections_list, ones).astype( - dtype=numpy.float32 - ) - sf = self.cset[lvl] - inputs = get_corr_inputs( - jets=jets, - corr_obj=sf, - name_map=jec_name_map, - cache=lazy_cache, - corrections=cumCorr, - ) + cumCorr = reduce(lambda x, y: y * x, corrections_list, ones).astype(dtype=numpy.float32) + + sf = self.corrections.get(lvl, None) + if sf is None: + raise ValueError(f"Correction {lvl} not found in self.corrections") + + ## This automatically apply the previous levels of correction, when needed + inputs = get_corr_inputs(jets=jets, corr_obj=sf, name_map=jec_name_map, cache=lazy_cache, corrections=cumCorr) correction = sf.evaluate(*inputs).astype(dtype=numpy.float32) corrections_list.append(correction) if total_correction is None: total_correction = numpy.ones_like(correction, dtype=numpy.float32) total_correction *= correction - jec_lvl_tag = "_jec_" + lvl + if self.jec_stack.savecorr: + jec_lvl_tag = "_jec_" + lvl - out_dict[f"jet_energy_correction_{lvl}"] = correction - init_pt_lvl = partial( - awkward.virtual, - operator.mul, - args=( - out_dict[f"jet_energy_correction_{lvl}"], - out_dict[self.name_map["ptRaw"]], - ), - cache=lazy_cache, - ) - init_mass_lvl = partial( - awkward.virtual, - operator.mul, - args=( - out_dict[f"jet_energy_correction_{lvl}"], - out_dict[self.name_map["massRaw"]], - ), - cache=lazy_cache, - ) - - out_dict[self.name_map["JetPt"] + f"_{lvl}"] = init_pt_lvl( - length=len(out), form=scalar_form - ) - out_dict[self.name_map["JetMass"] + f"_{lvl}"] = init_mass_lvl( - length=len(out), form=scalar_form - ) + out_dict[f"jet_energy_correction_{lvl}"] = correction + init_pt_lvl = partial( + awkward.virtual, + operator.mul, + args=(out_dict[f"jet_energy_correction_{lvl}"], out_dict[self.name_map["ptRaw"]]), + cache=lazy_cache, + ) + init_mass_lvl = partial( + awkward.virtual, + operator.mul, + args=(out_dict[f"jet_energy_correction_{lvl}"], out_dict[self.name_map["massRaw"]]), + cache=lazy_cache, + ) + out_dict[self.name_map["JetPt"] + f"_{lvl}"] = init_pt_lvl(length=len(out), form=scalar_form) + out_dict[self.name_map["JetMass"] + f"_{lvl}"] = init_mass_lvl(length=len(out), form=scalar_form) - out_dict[self.name_map["JetPt"] + jec_lvl_tag] = out_dict[ - self.name_map["JetPt"] + f"_{lvl}" - ] - out_dict[self.name_map["JetMass"] + jec_lvl_tag] = out_dict[ - self.name_map["JetMass"] + f"_{lvl}" - ] + out_dict[self.name_map["JetPt"] + jec_lvl_tag] = out_dict[self.name_map["JetPt"] + f"_{lvl}"] + out_dict[self.name_map["JetMass"] + jec_lvl_tag] = out_dict[self.name_map["JetMass"] + f"_{lvl}"] out_dict["jet_energy_correction"] = total_correction - # finally the lazy binding to the JEC + # Finally, the lazy binding to the JEC init_pt = partial( awkward.virtual, operator.mul, @@ -357,20 +288,17 @@ def build(self, jets, lazy_cache): ) out_dict[self.name_map["JetPt"]] = init_pt(length=len(out), form=scalar_form) - out_dict[self.name_map["JetMass"]] = init_mass( - length=len(out), form=scalar_form - ) + out_dict[self.name_map["JetMass"]] = init_mass(length=len(out), form=scalar_form) out_dict[self.name_map["JetPt"] + "_jec"] = out_dict[self.name_map["JetPt"]] out_dict[self.name_map["JetMass"] + "_jec"] = out_dict[self.name_map["JetMass"]] - # in jer we need to have a stash for the intermediate JEC products + has_jer = False if self.tool == "jecstack": - has_jer = False if self.jec_stack.jer is not None and self.jec_stack.jersf is not None: has_jer = True elif self.tool == "clib": - has_jer = True + has_jer = len(self.jec_stack.jer_names_clib) > 0 if has_jer: jer_name_map = dict(self.name_map) @@ -378,59 +306,46 @@ def build(self, jets, lazy_cache): jer_name_map["JetMass"] = jer_name_map["JetMass"] + "_jec" if self.tool == "jecstack": - jerargs = { + jer_args = { k: out_dict[jer_name_map[k]] for k in self.jec_stack.jer.signature } out_dict["jet_energy_resolution"] = self.jec_stack.jer.getResolution( - **jerargs, form=scalar_form, lazy_cache=lazy_cache + **jer_args, form=scalar_form, lazy_cache=lazy_cache ) - jersfargs = { + jersf_args = { k: out_dict[jer_name_map[k]] for k in self.jec_stack.jersf.signature } - out_dict[ - "jet_energy_resolution_scale_factor" - ] = self.jec_stack.jersf.getScaleFactor( - **jersfargs, form=_JERSF_FORM, lazy_cache=lazy_cache + out_dict["jet_energy_resolution_scale_factor"] = self.jec_stack.jersf.getScaleFactor( + **jersf_args, form=_JERSF_FORM, lazy_cache=lazy_cache ) + elif self.tool == "clib": - ## needed to attach to jets the JECs + # Prepare for clib-based corrections jer_out_parms = out.layout.parameters jer_out_parms["corrected"] = True jer_out = awkward.zip( - out_dict, - depth_limit=1, - parameters=jer_out_parms, - behavior=out.behavior, + out_dict, depth_limit=1, parameters=jer_out_parms, behavior=out.behavior ) jerjets = wrap(jer_out) - for jer_entry in self.jer_names: + for jer_entry in self.jec_stack.jer_names_clib: outtag = "jet_energy_resolution" jer_entry = jer_entry.replace("SF", "ScaleFactor") - sf = self.cset[jer_entry] - inputs = get_corr_inputs( - jets=jerjets, corr_obj=sf, name_map=jer_name_map - ) + sf = self.corrections[jer_entry] + inputs = get_corr_inputs(jets=jerjets, corr_obj=sf, name_map=jer_name_map) if "ScaleFactor" in jer_entry: outtag += "_scale_factor" - correction = awkward.Array( - [ - sf.evaluate(*inputs, "nom").astype(dtype=numpy.float32), - sf.evaluate(*inputs, "up").astype(dtype=numpy.float32), - sf.evaluate(*inputs, "down").astype( - dtype=numpy.float32 - ), - ] - ) - correction = awkward.concatenate( - [ - correction[0][:, numpy.newaxis], - correction[1][:, numpy.newaxis], - correction[2][:, numpy.newaxis], - ], - axis=1, - ) + correction = awkward.Array([ + sf.evaluate(*inputs, "nom").astype(dtype=numpy.float32), + sf.evaluate(*inputs, "up").astype(dtype=numpy.float32), + sf.evaluate(*inputs, "down").astype(dtype=numpy.float32), + ]) + correction = awkward.concatenate([ + correction[0][:, numpy.newaxis], + correction[1][:, numpy.newaxis], + correction[2][:, numpy.newaxis] + ], axis=1) else: correction = awkward.Array( sf.evaluate(*inputs).astype(dtype=numpy.float32), @@ -440,9 +355,8 @@ def build(self, jets, lazy_cache): del jerjets - seeds = numpy.array(out_dict[self.name_map["JetPt"] + "_orig"])[ - [0, -1] - ].view("i4") + # Gaussian smearing + seeds = numpy.array(out_dict[self.name_map["JetPt"] + "_orig"])[[0, -1]].view("i4") out_dict["jet_resolution_rand_gauss"] = awkward.virtual( rand_gauss, args=( @@ -460,60 +374,35 @@ def build(self, jets, lazy_cache): args=( 0, self.forceStochastic, - awkward.values_astype( - out_dict[jer_name_map["ptGenJet"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetPt"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetEta"]], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_resolution_rand_gauss"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution_scale_factor"], numpy.float32 - ), + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) - out_dict["jet_energy_resolution_correction"] = init_jerc( - length=len(out), form=scalar_form - ) + out_dict["jet_energy_resolution_correction"] = init_jerc(length=len(out), form=scalar_form) init_pt_jer = partial( awkward.virtual, operator.mul, - args=( - out_dict["jet_energy_resolution_correction"], - out_dict[jer_name_map["JetPt"]], - ), + args=(out_dict["jet_energy_resolution_correction"], out_dict[jer_name_map["JetPt"]]), cache=lazy_cache, ) init_mass_jer = partial( awkward.virtual, operator.mul, - args=( - out_dict["jet_energy_resolution_correction"], - out_dict[jer_name_map["JetMass"]], - ), + args=(out_dict["jet_energy_resolution_correction"], out_dict[jer_name_map["JetMass"]]), cache=lazy_cache, ) - out_dict[self.name_map["JetPt"]] = init_pt_jer( - length=len(out), form=scalar_form - ) - out_dict[self.name_map["JetMass"]] = init_mass_jer( - length=len(out), form=scalar_form - ) + + out_dict[self.name_map["JetPt"]] = init_pt_jer(length=len(out), form=scalar_form) + out_dict[self.name_map["JetMass"]] = init_mass_jer(length=len(out), form=scalar_form) out_dict[self.name_map["JetPt"] + "_jer"] = out_dict[self.name_map["JetPt"]] - out_dict[self.name_map["JetMass"] + "_jer"] = out_dict[ - self.name_map["JetMass"] - ] + out_dict[self.name_map["JetMass"] + "_jer"] = out_dict[self.name_map["JetMass"]] # JER systematics jerc_up = partial( @@ -522,24 +411,12 @@ def build(self, jets, lazy_cache): args=( 1, self.forceStochastic, - awkward.values_astype( - out_dict[jer_name_map["ptGenJet"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetPt"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetEta"]], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_resolution_rand_gauss"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution_scale_factor"], numpy.float32 - ), + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) @@ -576,24 +453,12 @@ def build(self, jets, lazy_cache): args=( 2, self.forceStochastic, - awkward.values_astype( - out_dict[jer_name_map["ptGenJet"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetPt"]], numpy.float32 - ), - awkward.values_astype( - out_dict[jer_name_map["JetEta"]], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_resolution_rand_gauss"], numpy.float32 - ), - awkward.values_astype( - out_dict["jet_energy_resolution_scale_factor"], numpy.float32 - ), + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) @@ -629,71 +494,54 @@ def build(self, jets, lazy_cache): {"up": up, "down": down}, depth_limit=1, with_name="JetSystematic" ) - has_junc = False - if self.tool == "jecstack": - if self.jec_stack.junc is not None: - has_junc = True - elif self.tool == "clib": - has_junc = True + # Apply uncertainties (JES) + has_junc = self.jec_stack.junc is not None + if self.tool == "clib": + has_junc = len(self.jec_stack.jec_uncsources_clib) > 0 if has_junc: - juncnames = {} - juncnames.update(self.name_map) + junc_name_map = dict(self.name_map) if has_jer: - juncnames["JetPt"] = juncnames["JetPt"] + "_jer" - juncnames["JetMass"] = juncnames["JetMass"] + "_jer" + junc_name_map["JetPt"] = junc_name_map["JetPt"] + "_jer" + junc_name_map["JetMass"] = junc_name_map["JetMass"] + "_jer" else: - juncnames["JetPt"] = juncnames["JetPt"] + "_jec" - juncnames["JetMass"] = juncnames["JetMass"] + "_jec" + junc_name_map["JetPt"] = junc_name_map["JetPt"] + "_jec" + junc_name_map["JetMass"] = junc_name_map["JetMass"] + "_jec" if self.tool == "jecstack": - juncargs = { - k: out_dict[juncnames[k]] for k in self.jec_stack.junc.signature + junc_args = { + k: out_dict[junc_name_map[k]] for k in self.jec_stack.junc.signature } - juncs_list = list(self.jec_stack.junc.getUncertainty(**juncargs)) - juncs = self.jec_stack.junc.getUncertainty(**juncargs) + juncs = self.jec_stack.junc.getUncertainty(**junc_args) elif self.tool == "clib": junc_out_parms = out.layout.parameters junc_out_parms["corrected"] = True junc_out = awkward.zip( - out_dict, - depth_limit=1, - parameters=junc_out_parms, - behavior=out.behavior, + out_dict, depth_limit=1, parameters=junc_out_parms, behavior=out.behavior ) juncjets = wrap(junc_out) - self.junc_names = [ - junc_name.replace("Quad_", "").replace( - "UncertaintySources_AK4PFchs_", "" - ) - + "_AK4PFchs" - for junc_name in self.junc_names - ] - uncnames, uncvalues = [], [] - for junc_name in self.junc_names: - sf = self.cset[junc_name] - inputs = get_corr_inputs( - jets=juncjets, corr_obj=sf, name_map=juncnames - ) + for junc_name in self.jec_stack.jec_uncsources_clib: + sf = self.corrections[junc_name] + if sf is None: + raise ValueError(f"Correction {junc_name} not found in self.corrections") + + inputs = get_corr_inputs(jets=juncjets, corr_obj=sf, name_map=junc_name_map) unc = awkward.values_astype(sf.evaluate(*inputs), numpy.float32) central = awkward.ones_like(out_dict[self.name_map["JetPt"]]) unc_up = central + unc unc_down = central - unc uncnames.append(junc_name.split("_")[-2]) uncvalues.append([unc_up, unc_down]) - del juncjets # Combine the up and down values into pairs - combined_uncvalues = [] - for unc_up, unc_down in uncvalues: - combined = awkward.Array( - [[up, down] for up, down in zip(unc_up, unc_down)] - ) - combined_uncvalues.append(combined) + combined_uncvalues = [ + awkward.Array([[up, down] for up, down in zip(unc_up, unc_down)]) + for unc_up, unc_down in uncvalues + ] juncs = zip(uncnames, combined_uncvalues) @@ -705,9 +553,7 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): var_dict[jetpt] = awkward.virtual( junc_smeared_val, args=( - awkward.to_numpy( - awkward.values_astype(unc, numpy.float32) - ), # this is needed for the clib variation + awkward.to_numpy(awkward.values_astype(unc, numpy.float32)), updown, jetpt_orig, ), @@ -718,9 +564,7 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): var_dict[jetmass] = awkward.virtual( junc_smeared_val, args=( - awkward.to_numpy( - awkward.values_astype(unc, numpy.float32) - ), # this is needed for the clib variation + awkward.to_numpy(awkward.values_astype(unc, numpy.float32)), updown, jetmass_orig, ), @@ -728,34 +572,25 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): form=scalar_form, cache=lazy_cache, ) - return awkward.zip( - var_dict, - depth_limit=1, - parameters=out.layout.parameters, - behavior=out.behavior, - ) + return awkward.zip(var_dict, depth_limit=1, parameters=out.layout.parameters, behavior=out.behavior) def build_variant(unc, jetpt, jetpt_orig, jetmass, jetmass_orig): up = build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, 0) down = build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, 1) - return awkward.zip( - {"up": up, "down": down}, depth_limit=1, with_name="JetSystematic" - ) + return awkward.zip({"up": up, "down": down}, depth_limit=1, with_name="JetSystematic") for name, func in juncs: out_dict[f"jet_energy_uncertainty_{name}"] = func out_dict[f"JES_{name}"] = build_variant( func, self.name_map["JetPt"], - out_dict[juncnames["JetPt"]], + out_dict[junc_name_map["JetPt"]], self.name_map["JetMass"], - out_dict[juncnames["JetMass"]], + out_dict[junc_name_map["JetMass"]], ) out_parms = out.layout.parameters out_parms["corrected"] = True - out = awkward.zip( - out_dict, depth_limit=1, parameters=out_parms, behavior=out.behavior - ) + out = awkward.zip(out_dict, depth_limit=1, parameters=out_parms, behavior=out.behavior) return wrap(out) diff --git a/coffea/jetmet_tools/JECStack.py b/coffea/jetmet_tools/JECStack.py index c5fefda83..b551f911d 100644 --- a/coffea/jetmet_tools/JECStack.py +++ b/coffea/jetmet_tools/JECStack.py @@ -1,105 +1,119 @@ +from dataclasses import dataclass, field +from typing import List, Dict, Optional from coffea.jetmet_tools.FactorizedJetCorrector import FactorizedJetCorrector, _levelre from coffea.jetmet_tools.JetResolution import JetResolution from coffea.jetmet_tools.JetResolutionScaleFactor import JetResolutionScaleFactor from coffea.jetmet_tools.JetCorrectionUncertainty import JetCorrectionUncertainty +import correctionlib as clib -_singletons = ["jer", "jersf"] -_nicenames = ["Jet Resolution Calculator", "Jet Resolution Scale Factor Calculator"] +@dataclass +class JECStack: + """Handles both JEC and clib cases with conditional attributes.""" + # Common fields for both scenarios + corrections: Dict[str, any] = field(default_factory=dict) + use_clib: bool = False # Set to True if useclib is needed + # Fields for the clib scenario (useclib=True) + jec_tag: Optional[str] = None + jec_levels: Optional[List[str]] = field(default_factory=list) + jer_tag: Optional[str] = None + jet_algo: Optional[str] = None + junc_types: Optional[List[str]] = field(default_factory=list) + json_path: Optional[str] = None + savecorr: bool = False -class JECStack(object): - def __init__(self, corrections, jec=None, junc=None, jer=None, jersf=None): - """ - corrections is a dict-like of function names and functions - we expect JEC names to be formatted as their filenames - jecs, etc. can be overridden by passing in the appropriate corrector class. - """ - self._jec = None - self._junc = None - self._jer = None - self._jersf = None + # Fields for the usejecstack scenario (useclib=False) + jec: Optional[FactorizedJetCorrector] = None + junc: Optional[JetCorrectionUncertainty] = None + jer: Optional[JetResolution] = None + jersf: Optional[JetResolutionScaleFactor] = None + + def __post_init__(self): + """Handle initialization based on use_clib flag.""" + if self.use_clib: + self._initialize_clib() + else: + self._initialize_jecstack() + + def _initialize_clib(self): + """Initialize the clib-based correction tools.""" + if not self.json_path: + raise ValueError("json_path is required for clib initialization.") + + # Load corrections directly from the JSON path + self.cset = clib.CorrectionSet.from_file(self.json_path) + + # Construct lists for jec, jer, and uncertainties + self.jec_names_clib = [f"{self.jec_tag}_{level}_{self.jet_algo}" for level in self.jec_levels] + self.jer_names_clib = [] + self.jec_uncsources_clib = [] + + if self.jer_tag is not None: + self.jer_names_clib = [ + f"{self.jer_tag}_ScaleFactor_{self.jet_algo}", + f"{self.jer_tag}_PtResolution_{self.jet_algo}" + ] + + if self.junc_types: + self.jec_uncsources_clib = [f"{self.jec_tag}_{junc_type}_{self.jet_algo}" for junc_type in self.junc_types] + + # Combine requested corrections + requested_corrections = self.jec_names_clib + self.jer_names_clib + self.jec_uncsources_clib + available_corrections = list(self.cset.keys()) + missing_corrections = [name for name in requested_corrections if name not in available_corrections] + + if missing_corrections: + raise ValueError( + f"\nMissing corrections in the CorrectionSet: {missing_corrections}. " + f"\n\nAvailable corrections are: {available_corrections}. " + f"\n\nRequested corrections are: {requested_corrections}" + ) + + # Store corrections directly in the JECStack for easy access + self.corrections = {name: self.cset[name] for name in requested_corrections} + + def _initialize_jecstack(self): + """Initialize the JECStack tools for the non-clib scenario.""" + assembled = self.assemble_corrections() + + if len(assembled["jec"]) > 0: + self.jec = FactorizedJetCorrector(**assembled["jec"]) + if len(assembled["junc"]) > 0: + self.junc = JetCorrectionUncertainty(**assembled["junc"]) + if len(assembled["jer"]) > 0: + self.jer = JetResolution(**assembled["jer"]) + if len(assembled["jersf"]) > 0: + self.jersf = JetResolutionScaleFactor(**assembled["jersf"]) + + if (self.jer is None) != (self.jersf is None): + raise ValueError("Cannot apply JER-SF without an input JER, and vice-versa!") + + def to_list(self): + """Convert to list for clib case.""" + return self.jec_names_clib + self.jer_names_clib + self.jec_uncsources_clib + [self.json_path, self.savecorr] + + def assemble_corrections(self): + """Assemble corrections for both scenarios.""" assembled = {"jec": {}, "junc": {}, "jer": {}, "jersf": {}} - for key in corrections.keys(): + + for key in self.corrections.keys(): if "Uncertainty" in key: - assembled["junc"][key] = corrections[key] - elif "SF" in key: - assembled["jersf"][key] = corrections[key] - elif "Resolution" in key and "SF" not in key: - assembled["jer"][key] = corrections[key] + assembled["junc"][key] = self.corrections[key] + elif ("ScaleFactor" in key or "SF" in key): + assembled["jersf"][key] = self.corrections[key] + elif "Resolution" in key and not ("ScaleFactor" in key or "SF" in key): + assembled["jer"][key] = self.corrections[key] elif len(_levelre.findall(key)) > 0: - assembled["jec"][key] = corrections[key] - - for corrtype, nname in zip(_singletons, _nicenames): - Noftype = len(assembled[corrtype]) - if Noftype > 1: - raise Exception( - f"JEC Stack has at most one {nname}, {Noftype} are present" - ) - - if jec is None: - if len(assembled["jec"]) == 0: - self._jec = None # allow for no JEC + assembled["jec"][key] = self.corrections[key] else: - self._jec = FactorizedJetCorrector( - **{name: corrections[name] for name in assembled["jec"]} - ) - else: - if isinstance(jec, FactorizedJetCorrector): - self._jec = jec - else: - raise Exception( - 'JECStack needs a FactorizedJetCorrector passed as "jec"' - + " got object of type {}".format(type(jec)) - ) - - if junc is None: - if len(assembled["junc"]) > 0: - self._junc = JetCorrectionUncertainty( - **{name: corrections[name] for name in assembled["junc"]} - ) - else: - if isinstance(junc, JetCorrectionUncertainty): - self._junc = junc - else: - raise Exception( - 'JECStack needs a JetCorrectionUncertainty passed as "junc"' - + " got object of type {}".format(type(junc)) - ) - - if jer is None: - if len(assembled["jer"]) > 0: - self._jer = JetResolution( - **{name: corrections[name] for name in assembled["jer"]} - ) - else: - if isinstance(jer, JetResolution): - self._jer = jer - else: - raise Exception( - '"jer" must be of type "JetResolution"' - + " got {}".format(type(jer)) - ) - - if jersf is None: - if len(assembled["jersf"]) > 0: - self._jersf = JetResolutionScaleFactor( - **{name: corrections[name] for name in assembled["jersf"]} - ) - else: - if isinstance(jer, JetResolutionScaleFactor): - self._jersf = jersf - else: - raise Exception( - '"jer" must be of type "JetResolutionScaleFactor"' - + " got {}".format(type(jer)) - ) + print(f"Unknown correction type for key: {key}") - if (self.jer is None) != (self.jersf is None): - raise Exception("Cannot apply JER-SF without an input JER, and vice-versa!") + return assembled @property def blank_name_map(self): + """Returns a blank name map for corrections.""" out = { "massRaw", "ptRaw", @@ -111,32 +125,16 @@ def blank_name_map(self): "UnClusteredEnergyDeltaX", "UnClusteredEnergyDeltaY", } - if self._jec is not None: - for name in self._jec.signature: + if self.jec is not None: + for name in self.jec.signature: out.add(name) - if self._junc is not None: - for name in self._junc.signature: + if self.junc is not None: + for name in self.junc.signature: out.add(name) - if self._jer is not None: - for name in self._jer.signature: + if self.jer is not None: + for name in self.jer.signature: out.add(name) - if self._jersf is not None: - for name in self._jersf.signature: + if self.jersf is not None: + for name in self.jersf.signature: out.add(name) return {name: None for name in out} - - @property - def jec(self): - return self._jec - - @property - def junc(self): - return self._junc - - @property - def jer(self): - return self._jer - - @property - def jersf(self): - return self._jersf