diff --git a/coffea/jetmet_tools/CorrectedJetsFactory.py b/coffea/jetmet_tools/CorrectedJetsFactory.py index 6e1059bac..fa3e58931 100644 --- a/coffea/jetmet_tools/CorrectedJetsFactory.py +++ b/coffea/jetmet_tools/CorrectedJetsFactory.py @@ -1,9 +1,9 @@ import awkward import numpy import warnings -from functools import partial +from functools import partial, reduce import operator - +from topcoffea.modules.JECStack import JECStack _stack_parts = ["jec", "junc", "jer", "jersf"] _MIN_JET_ENERGY = numpy.array(1e-2, dtype=numpy.float32) @@ -17,14 +17,11 @@ "primitive": "float32", } - -# we're gonna assume that the first record array we encounter is the flattened data def rewrap_recordarray(layout, depth, data): if isinstance(layout, awkward.layout.RecordArray): return lambda: data return None - def awkward_rewrap(arr, like_what, gfunc): behavior = awkward._util.behaviorof(like_what) func = partial(gfunc, data=arr.layout) @@ -32,7 +29,6 @@ def awkward_rewrap(arr, like_what, gfunc): newlayout = awkward._util.recursively_apply(layout, func) return awkward._util.wrap(newlayout, behavior=behavior) - def rand_gauss(item, randomstate): def getfunction(layout, depth): if isinstance(layout, awkward.layout.NumpyArray) or not isinstance( @@ -42,14 +38,12 @@ def getfunction(layout, depth): randomstate.normal(size=len(layout)).astype(numpy.float32) ) return None - out = awkward._util.recursively_apply( awkward.operations.convert.to_layout(item), getfunction ) assert out is not None return awkward._util.wrap(out, awkward._util.behaviorof(item)) - def jer_smear( variation, forceStochastic, @@ -61,10 +55,8 @@ def jer_smear( jet_energy_resolution_scale_factor, ): pt_gen = pt_gen if not forceStochastic else None - if not isinstance(jetPt, awkward.highlevel.Array): raise Exception("'jetPt' must be an awkward array of some kind!") - if forceStochastic: pt_gen = awkward.without_parameters(awkward.zeros_like(jetPt)) @@ -72,13 +64,11 @@ def jer_smear( jersf = jet_energy_resolution_scale_factor[:, variation] deltaPtRel = (jetPt - pt_gen) / jetPt doHybrid = (pt_gen > 0) & (numpy.abs(deltaPtRel) < 3 * jet_energy_resolution) - detSmear = 1 + (jersf - 1) * deltaPtRel stochSmear = 1 + numpy.sqrt(numpy.maximum(jersf**2 - 1, 0)) * jersmear min_jet_pt = _MIN_JET_ENERGY / numpy.cosh(etaJet) min_jet_pt_corr = min_jet_pt / jetPt - smearfact = awkward.where(doHybrid, detSmear, stochSmear) smearfact = awkward.where( (smearfact * jetPt) < min_jet_pt, min_jet_pt_corr, smearfact @@ -97,12 +87,47 @@ def getfunction(layout, depth): smearfact = awkward._util.wrap(smearfact, awkward._util.behaviorof(jetPt)) return smearfact +# Wrapper function to apply jec corrections +def rawvar_jec(jecval, rawvar, lazy_cache): + return awkward.virtual( + operator.mul, + args=(jecval, rawvar), + cache=lazy_cache, + ) + +def get_corr_inputs(jets, corr_obj, name_map, cache=None, corrections=None): + """ + Helper function for getting values of input variables + given a dictionary and a correction object. + """ + + if corrections is None: + input_values = [awkward.flatten(jets[name_map[inp.name]]) for inp in corr_obj.inputs if (inp.name != "systematic")] + else: + ## This is needed to propagate the previous level of corrections, before applying the next one + input_values = [] + for inp in corr_obj.inputs: + if inp.name == "systematic": + continue + elif inp.name == "JetPt": + rawvar = awkward.flatten(jets[name_map[inp.name]]) + init_input_value = partial(rawvar_jec, rawvar=rawvar, lazy_cache=cache) + input_value = init_input_value(jecval=corrections) + else: + input_value = awkward.flatten(jets[name_map[inp.name]]) + input_values.append(input_value) + return input_values + class CorrectedJetsFactory(object): def __init__(self, name_map, jec_stack): - # from PhysicsTools/PatUtils/interface/SmearedJetProducerT.h#L283 + if not isinstance(jec_stack, JECStack): + raise TypeError("jec_stack must be an instance of JECStack") + + self.tool = "clib" if jec_stack.use_clib else "jecstack" self.forceStochastic = False + # Handle name map for raw pt and mass if "ptRaw" not in name_map or name_map["ptRaw"] is None: warnings.warn( "There is no name mapping for ptRaw," @@ -114,23 +139,19 @@ def __init__(self, name_map, jec_stack): if "massRaw" not in name_map or name_map["massRaw"] is None: warnings.warn( "There is no name mapping for massRaw," - " CorrectedJets will assume that .mass is raw pt!" + " CorrectedJets will assume that .mass is raw mass!" ) - name_map["ptRaw"] = name_map["JetMass"] + "_raw" + name_map["massRaw"] = name_map["JetMass"] + "_raw" - total_signature = set() - for part in _stack_parts: - attr = getattr(jec_stack, part) - if attr is not None: - total_signature.update(attr.signature) + self.jec_stack = jec_stack + self.name_map = name_map - missing = total_signature - set(name_map.keys()) - if len(missing) > 0: - raise Exception( - f"Missing mapping of {missing} in name_map!" - + " Cannot evaluate jet corrections!" - + " Please supply mappings for these variables!" - ) + if self.jec_stack.use_clib: + # For clib scenario, load corrections from json_path + self.load_corrections_clib() + else: + # For non-clib scenario, use the provided corrections (e.g., JEC/JER) + self.load_corrections_jecstack() if "ptGenJet" not in name_map: warnings.warn( @@ -138,43 +159,51 @@ def __init__(self, name_map, jec_stack): ) self.forceStochastic = True - self.real_sig = [v for k, v in name_map.items()] - self.name_map = name_map - self.jec_stack = jec_stack + def load_corrections_clib(self): + """Load the corrections from correctionlib using the json_path in JECStack.""" + self.corrections = self.jec_stack.corrections - def uncertainties(self): - out = ["JER"] if self.jec_stack.jer is not None else [] - if self.jec_stack.junc is not None: - out.extend(["JES_{0}".format(unc) for unc in self.jec_stack.junc.levels]) - return out + def load_corrections_jecstack(self): + """Use the corrections provided in the JECStack for non-clib scenario.""" + self.corrections = self.jec_stack.corrections - def build(self, jets, lazy_cache): - if lazy_cache is None: + # Ensure all required inputs have mappings + total_signature = set() + for part in _stack_parts: + attr = getattr(self.jec_stack, part) + if attr is not None: + total_signature.update(attr.signature) + + missing = total_signature - set(self.name_map.keys()) + if len(missing) > 0: raise Exception( - "CorrectedJetsFactory requires a awkward-array cache to function correctly." + f"Missing mapping of {missing} in name_map!" + + " Cannot evaluate jet corrections!" + + " Please supply mappings for these variables!" ) + + def build(self, jets, lazy_cache): + if lazy_cache is None: + raise Exception("CorrectedJetsFactory requires an awkward-array cache to function correctly.") lazy_cache = awkward._util.MappingProxy.maybe_wrap(lazy_cache) if not isinstance(jets, awkward.highlevel.Array): raise Exception("'jets' must be an awkward > 1.0.0 array of some kind!") + + # THESE ARE THE ATTRIBUTES OF THE JET COLLECTION fields = awkward.fields(jets) if len(fields) == 0: - raise Exception( - "Empty record, please pass a jet object with at least {self.real_sig} defined!" - ) + raise Exception("Empty record, please pass a jet object with at least {self.real_sig} defined!") + out = awkward.flatten(jets) wrap = partial(awkward_rewrap, like_what=jets, gfunc=rewrap_recordarray) - scalar_form = awkward.without_parameters( - out[self.name_map["ptRaw"]] - ).layout.form + scalar_form = awkward.without_parameters(out[self.name_map["ptRaw"]]).layout.form in_dict = {field: out[field] for field in fields} out_dict = dict(in_dict) - # take care of nominal JEC (no JER if available) + # Add original values out_dict[self.name_map["JetPt"] + "_orig"] = out_dict[self.name_map["JetPt"]] - out_dict[self.name_map["JetMass"] + "_orig"] = out_dict[ - self.name_map["JetMass"] - ] + out_dict[self.name_map["JetMass"] + "_orig"] = out_dict[self.name_map["JetMass"]] if self.treat_pt_as_raw: out_dict[self.name_map["ptRaw"]] = out_dict[self.name_map["JetPt"]] out_dict[self.name_map["massRaw"]] = out_dict[self.name_map["JetMass"]] @@ -182,19 +211,66 @@ def build(self, jets, lazy_cache): jec_name_map = dict(self.name_map) jec_name_map["JetPt"] = jec_name_map["ptRaw"] jec_name_map["JetMass"] = jec_name_map["massRaw"] - if self.jec_stack.jec is not None: - jec_args = { - k: out_dict[jec_name_map[k]] for k in self.jec_stack.jec.signature - } - out_dict["jet_energy_correction"] = self.jec_stack.jec.getCorrection( - **jec_args, form=scalar_form, lazy_cache=lazy_cache - ) - else: - out_dict["jet_energy_correction"] = awkward.without_parameters( - awkward.ones_like(out_dict[self.name_map["JetPt"]]) - ) - # finally the lazy binding to the JEC + # Apply JEC corrections based on scenario + total_correction = None + if self.tool == "jecstack": + if self.jec_stack.jec is not None: + jec_args = { + k: out_dict[jec_name_map[k]] for k in self.jec_stack.jec.signature + } + total_correction = self.jec_stack.jec.getCorrection( + **jec_args, form=scalar_form, lazy_cache=lazy_cache + ) + else: + total_correction = awkward.ones_like(out_dict[self.name_map["JetPt"]]) + + elif self.tool == "clib": + corrections_list = [] + + for lvl in self.jec_stack.jec_names_clib: + cumCorr = None + if len(corrections_list) > 0: + ones = numpy.ones_like(corrections_list[-1], dtype=numpy.float32) + cumCorr = reduce(lambda x, y: y * x, corrections_list, ones).astype(dtype=numpy.float32) + + sf = self.corrections.get(lvl, None) + if sf is None: + raise ValueError(f"Correction {lvl} not found in self.corrections") + + ## This automatically apply the previous levels of correction, when needed + inputs = get_corr_inputs(jets=jets, corr_obj=sf, name_map=jec_name_map, cache=lazy_cache, corrections=cumCorr) + correction = sf.evaluate(*inputs).astype(dtype=numpy.float32) + corrections_list.append(correction) + if total_correction is None: + total_correction = numpy.ones_like(correction, dtype=numpy.float32) + total_correction *= correction + + if self.jec_stack.savecorr: + jec_lvl_tag = "_jec_" + lvl + + out_dict[f"jet_energy_correction_{lvl}"] = correction + init_pt_lvl = partial( + awkward.virtual, + operator.mul, + args=(out_dict[f"jet_energy_correction_{lvl}"], out_dict[self.name_map["ptRaw"]]), + cache=lazy_cache, + ) + init_mass_lvl = partial( + awkward.virtual, + operator.mul, + args=(out_dict[f"jet_energy_correction_{lvl}"], out_dict[self.name_map["massRaw"]]), + cache=lazy_cache, + ) + out_dict[self.name_map["JetPt"] + f"_{lvl}"] = init_pt_lvl(length=len(out), form=scalar_form) + out_dict[self.name_map["JetMass"] + f"_{lvl}"] = init_mass_lvl(length=len(out), form=scalar_form) + + out_dict[self.name_map["JetPt"] + jec_lvl_tag] = out_dict[self.name_map["JetPt"] + f"_{lvl}"] + out_dict[self.name_map["JetMass"] + jec_lvl_tag] = out_dict[self.name_map["JetMass"] + f"_{lvl}"] + + out_dict["jet_energy_correction"] = total_correction + + # Finally, the lazy binding to the JEC init_pt = partial( awkward.virtual, operator.mul, @@ -212,40 +288,75 @@ def build(self, jets, lazy_cache): ) out_dict[self.name_map["JetPt"]] = init_pt(length=len(out), form=scalar_form) - out_dict[self.name_map["JetMass"]] = init_mass( - length=len(out), form=scalar_form - ) + out_dict[self.name_map["JetMass"]] = init_mass(length=len(out), form=scalar_form) out_dict[self.name_map["JetPt"] + "_jec"] = out_dict[self.name_map["JetPt"]] out_dict[self.name_map["JetMass"] + "_jec"] = out_dict[self.name_map["JetMass"]] - # in jer we need to have a stash for the intermediate JEC products has_jer = False - if self.jec_stack.jer is not None and self.jec_stack.jersf is not None: - has_jer = True + if self.tool == "jecstack": + if self.jec_stack.jer is not None and self.jec_stack.jersf is not None: + has_jer = True + elif self.tool == "clib": + has_jer = len(self.jec_stack.jer_names_clib) > 0 + + if has_jer: jer_name_map = dict(self.name_map) jer_name_map["JetPt"] = jer_name_map["JetPt"] + "_jec" jer_name_map["JetMass"] = jer_name_map["JetMass"] + "_jec" - jerargs = { - k: out_dict[jer_name_map[k]] for k in self.jec_stack.jer.signature - } - out_dict["jet_energy_resolution"] = self.jec_stack.jer.getResolution( - **jerargs, form=scalar_form, lazy_cache=lazy_cache - ) + if self.tool == "jecstack": + jer_args = { + k: out_dict[jer_name_map[k]] for k in self.jec_stack.jer.signature + } + out_dict["jet_energy_resolution"] = self.jec_stack.jer.getResolution( + **jer_args, form=scalar_form, lazy_cache=lazy_cache + ) - jersfargs = { - k: out_dict[jer_name_map[k]] for k in self.jec_stack.jersf.signature - } - out_dict[ - "jet_energy_resolution_scale_factor" - ] = self.jec_stack.jersf.getScaleFactor( - **jersfargs, form=_JERSF_FORM, lazy_cache=lazy_cache - ) + jersf_args = { + k: out_dict[jer_name_map[k]] for k in self.jec_stack.jersf.signature + } + out_dict["jet_energy_resolution_scale_factor"] = self.jec_stack.jersf.getScaleFactor( + **jersf_args, form=_JERSF_FORM, lazy_cache=lazy_cache + ) - seeds = numpy.array(out_dict[self.name_map["JetPt"] + "_orig"])[ - [0, -1] - ].view("i4") + elif self.tool == "clib": + # Prepare for clib-based corrections + jer_out_parms = out.layout.parameters + jer_out_parms["corrected"] = True + jer_out = awkward.zip( + out_dict, depth_limit=1, parameters=jer_out_parms, behavior=out.behavior + ) + jerjets = wrap(jer_out) + + for jer_entry in self.jec_stack.jer_names_clib: + outtag = "jet_energy_resolution" + jer_entry = jer_entry.replace("SF", "ScaleFactor") + sf = self.corrections[jer_entry] + inputs = get_corr_inputs(jets=jerjets, corr_obj=sf, name_map=jer_name_map) + if "ScaleFactor" in jer_entry: + outtag += "_scale_factor" + correction = awkward.Array([ + sf.evaluate(*inputs, "nom").astype(dtype=numpy.float32), + sf.evaluate(*inputs, "up").astype(dtype=numpy.float32), + sf.evaluate(*inputs, "down").astype(dtype=numpy.float32), + ]) + correction = awkward.concatenate([ + correction[0][:, numpy.newaxis], + correction[1][:, numpy.newaxis], + correction[2][:, numpy.newaxis] + ], axis=1) + else: + correction = awkward.Array( + sf.evaluate(*inputs).astype(dtype=numpy.float32), + ) + + out_dict[outtag] = correction + + del jerjets + + # Gaussian smearing + seeds = numpy.array(out_dict[self.name_map["JetPt"] + "_orig"])[[0, -1]].view("i4") out_dict["jet_resolution_rand_gauss"] = awkward.virtual( rand_gauss, args=( @@ -263,48 +374,35 @@ def build(self, jets, lazy_cache): args=( 0, self.forceStochastic, - out_dict[jer_name_map["ptGenJet"]], - out_dict[jer_name_map["JetPt"]], - out_dict[jer_name_map["JetEta"]], - out_dict["jet_energy_resolution"], - out_dict["jet_resolution_rand_gauss"], - out_dict["jet_energy_resolution_scale_factor"], + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) - out_dict["jet_energy_resolution_correction"] = init_jerc( - length=len(out), form=scalar_form - ) + out_dict["jet_energy_resolution_correction"] = init_jerc(length=len(out), form=scalar_form) init_pt_jer = partial( awkward.virtual, operator.mul, - args=( - out_dict["jet_energy_resolution_correction"], - out_dict[jer_name_map["JetPt"]], - ), + args=(out_dict["jet_energy_resolution_correction"], out_dict[jer_name_map["JetPt"]]), cache=lazy_cache, ) init_mass_jer = partial( awkward.virtual, operator.mul, - args=( - out_dict["jet_energy_resolution_correction"], - out_dict[jer_name_map["JetMass"]], - ), + args=(out_dict["jet_energy_resolution_correction"], out_dict[jer_name_map["JetMass"]]), cache=lazy_cache, ) - out_dict[self.name_map["JetPt"]] = init_pt_jer( - length=len(out), form=scalar_form - ) - out_dict[self.name_map["JetMass"]] = init_mass_jer( - length=len(out), form=scalar_form - ) + + out_dict[self.name_map["JetPt"]] = init_pt_jer(length=len(out), form=scalar_form) + out_dict[self.name_map["JetMass"]] = init_mass_jer(length=len(out), form=scalar_form) out_dict[self.name_map["JetPt"] + "_jer"] = out_dict[self.name_map["JetPt"]] - out_dict[self.name_map["JetMass"] + "_jer"] = out_dict[ - self.name_map["JetMass"] - ] + out_dict[self.name_map["JetMass"] + "_jer"] = out_dict[self.name_map["JetMass"]] # JER systematics jerc_up = partial( @@ -313,12 +411,12 @@ def build(self, jets, lazy_cache): args=( 1, self.forceStochastic, - out_dict[jer_name_map["ptGenJet"]], - out_dict[jer_name_map["JetPt"]], - out_dict[jer_name_map["JetEta"]], - out_dict["jet_energy_resolution"], - out_dict["jet_resolution_rand_gauss"], - out_dict["jet_energy_resolution_scale_factor"], + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) @@ -355,12 +453,12 @@ def build(self, jets, lazy_cache): args=( 2, self.forceStochastic, - out_dict[jer_name_map["ptGenJet"]], - out_dict[jer_name_map["JetPt"]], - out_dict[jer_name_map["JetEta"]], - out_dict["jet_energy_resolution"], - out_dict["jet_resolution_rand_gauss"], - out_dict["jet_energy_resolution_scale_factor"], + awkward.values_astype(out_dict[jer_name_map["ptGenJet"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetPt"]], numpy.float32), + awkward.values_astype(out_dict[jer_name_map["JetEta"]], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution"], numpy.float32), + awkward.values_astype(out_dict["jet_resolution_rand_gauss"], numpy.float32), + awkward.values_astype(out_dict["jet_energy_resolution_scale_factor"], numpy.float32), ), cache=lazy_cache, ) @@ -396,19 +494,56 @@ def build(self, jets, lazy_cache): {"up": up, "down": down}, depth_limit=1, with_name="JetSystematic" ) - if self.jec_stack.junc is not None: - juncnames = {} - juncnames.update(self.name_map) + # Apply uncertainties (JES) + has_junc = self.jec_stack.junc is not None + if self.tool == "clib": + has_junc = len(self.jec_stack.jec_uncsources_clib) > 0 + + if has_junc: + junc_name_map = dict(self.name_map) if has_jer: - juncnames["JetPt"] = juncnames["JetPt"] + "_jer" - juncnames["JetMass"] = juncnames["JetMass"] + "_jer" + junc_name_map["JetPt"] = junc_name_map["JetPt"] + "_jer" + junc_name_map["JetMass"] = junc_name_map["JetMass"] + "_jer" else: - juncnames["JetPt"] = juncnames["JetPt"] + "_jec" - juncnames["JetMass"] = juncnames["JetMass"] + "_jec" - juncargs = { - k: out_dict[juncnames[k]] for k in self.jec_stack.junc.signature - } - juncs = self.jec_stack.junc.getUncertainty(**juncargs) + junc_name_map["JetPt"] = junc_name_map["JetPt"] + "_jec" + junc_name_map["JetMass"] = junc_name_map["JetMass"] + "_jec" + + if self.tool == "jecstack": + junc_args = { + k: out_dict[junc_name_map[k]] for k in self.jec_stack.junc.signature + } + juncs = self.jec_stack.junc.getUncertainty(**junc_args) + + elif self.tool == "clib": + junc_out_parms = out.layout.parameters + junc_out_parms["corrected"] = True + junc_out = awkward.zip( + out_dict, depth_limit=1, parameters=junc_out_parms, behavior=out.behavior + ) + juncjets = wrap(junc_out) + + uncnames, uncvalues = [], [] + for junc_name in self.jec_stack.jec_uncsources_clib: + sf = self.corrections[junc_name] + if sf is None: + raise ValueError(f"Correction {junc_name} not found in self.corrections") + + inputs = get_corr_inputs(jets=juncjets, corr_obj=sf, name_map=junc_name_map) + unc = awkward.values_astype(sf.evaluate(*inputs), numpy.float32) + central = awkward.ones_like(out_dict[self.name_map["JetPt"]]) + unc_up = central + unc + unc_down = central - unc + uncnames.append(junc_name.split("_")[-2]) + uncvalues.append([unc_up, unc_down]) + del juncjets + + # Combine the up and down values into pairs + combined_uncvalues = [ + awkward.Array([[up, down] for up, down in zip(unc_up, unc_down)]) + for unc_up, unc_down in uncvalues + ] + + juncs = zip(uncnames, combined_uncvalues) def junc_smeared_val(uncvals, up_down, variable): return awkward.materialized(uncvals[:, up_down] * variable) @@ -418,7 +553,7 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): var_dict[jetpt] = awkward.virtual( junc_smeared_val, args=( - unc, + awkward.to_numpy(awkward.values_astype(unc, numpy.float32)), updown, jetpt_orig, ), @@ -429,7 +564,7 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): var_dict[jetmass] = awkward.virtual( junc_smeared_val, args=( - unc, + awkward.to_numpy(awkward.values_astype(unc, numpy.float32)), updown, jetmass_orig, ), @@ -437,34 +572,25 @@ def build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, updown): form=scalar_form, cache=lazy_cache, ) - return awkward.zip( - var_dict, - depth_limit=1, - parameters=out.layout.parameters, - behavior=out.behavior, - ) + return awkward.zip(var_dict, depth_limit=1, parameters=out.layout.parameters, behavior=out.behavior) def build_variant(unc, jetpt, jetpt_orig, jetmass, jetmass_orig): up = build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, 0) down = build_variation(unc, jetpt, jetpt_orig, jetmass, jetmass_orig, 1) - return awkward.zip( - {"up": up, "down": down}, depth_limit=1, with_name="JetSystematic" - ) + return awkward.zip({"up": up, "down": down}, depth_limit=1, with_name="JetSystematic") for name, func in juncs: out_dict[f"jet_energy_uncertainty_{name}"] = func out_dict[f"JES_{name}"] = build_variant( func, self.name_map["JetPt"], - out_dict[juncnames["JetPt"]], + out_dict[junc_name_map["JetPt"]], self.name_map["JetMass"], - out_dict[juncnames["JetMass"]], + out_dict[junc_name_map["JetMass"]], ) out_parms = out.layout.parameters out_parms["corrected"] = True - out = awkward.zip( - out_dict, depth_limit=1, parameters=out_parms, behavior=out.behavior - ) + out = awkward.zip(out_dict, depth_limit=1, parameters=out_parms, behavior=out.behavior) return wrap(out) diff --git a/coffea/jetmet_tools/JECStack.py b/coffea/jetmet_tools/JECStack.py index c5fefda83..b551f911d 100644 --- a/coffea/jetmet_tools/JECStack.py +++ b/coffea/jetmet_tools/JECStack.py @@ -1,105 +1,119 @@ +from dataclasses import dataclass, field +from typing import List, Dict, Optional from coffea.jetmet_tools.FactorizedJetCorrector import FactorizedJetCorrector, _levelre from coffea.jetmet_tools.JetResolution import JetResolution from coffea.jetmet_tools.JetResolutionScaleFactor import JetResolutionScaleFactor from coffea.jetmet_tools.JetCorrectionUncertainty import JetCorrectionUncertainty +import correctionlib as clib -_singletons = ["jer", "jersf"] -_nicenames = ["Jet Resolution Calculator", "Jet Resolution Scale Factor Calculator"] +@dataclass +class JECStack: + """Handles both JEC and clib cases with conditional attributes.""" + # Common fields for both scenarios + corrections: Dict[str, any] = field(default_factory=dict) + use_clib: bool = False # Set to True if useclib is needed + # Fields for the clib scenario (useclib=True) + jec_tag: Optional[str] = None + jec_levels: Optional[List[str]] = field(default_factory=list) + jer_tag: Optional[str] = None + jet_algo: Optional[str] = None + junc_types: Optional[List[str]] = field(default_factory=list) + json_path: Optional[str] = None + savecorr: bool = False -class JECStack(object): - def __init__(self, corrections, jec=None, junc=None, jer=None, jersf=None): - """ - corrections is a dict-like of function names and functions - we expect JEC names to be formatted as their filenames - jecs, etc. can be overridden by passing in the appropriate corrector class. - """ - self._jec = None - self._junc = None - self._jer = None - self._jersf = None + # Fields for the usejecstack scenario (useclib=False) + jec: Optional[FactorizedJetCorrector] = None + junc: Optional[JetCorrectionUncertainty] = None + jer: Optional[JetResolution] = None + jersf: Optional[JetResolutionScaleFactor] = None + + def __post_init__(self): + """Handle initialization based on use_clib flag.""" + if self.use_clib: + self._initialize_clib() + else: + self._initialize_jecstack() + + def _initialize_clib(self): + """Initialize the clib-based correction tools.""" + if not self.json_path: + raise ValueError("json_path is required for clib initialization.") + + # Load corrections directly from the JSON path + self.cset = clib.CorrectionSet.from_file(self.json_path) + + # Construct lists for jec, jer, and uncertainties + self.jec_names_clib = [f"{self.jec_tag}_{level}_{self.jet_algo}" for level in self.jec_levels] + self.jer_names_clib = [] + self.jec_uncsources_clib = [] + + if self.jer_tag is not None: + self.jer_names_clib = [ + f"{self.jer_tag}_ScaleFactor_{self.jet_algo}", + f"{self.jer_tag}_PtResolution_{self.jet_algo}" + ] + + if self.junc_types: + self.jec_uncsources_clib = [f"{self.jec_tag}_{junc_type}_{self.jet_algo}" for junc_type in self.junc_types] + + # Combine requested corrections + requested_corrections = self.jec_names_clib + self.jer_names_clib + self.jec_uncsources_clib + available_corrections = list(self.cset.keys()) + missing_corrections = [name for name in requested_corrections if name not in available_corrections] + + if missing_corrections: + raise ValueError( + f"\nMissing corrections in the CorrectionSet: {missing_corrections}. " + f"\n\nAvailable corrections are: {available_corrections}. " + f"\n\nRequested corrections are: {requested_corrections}" + ) + + # Store corrections directly in the JECStack for easy access + self.corrections = {name: self.cset[name] for name in requested_corrections} + + def _initialize_jecstack(self): + """Initialize the JECStack tools for the non-clib scenario.""" + assembled = self.assemble_corrections() + + if len(assembled["jec"]) > 0: + self.jec = FactorizedJetCorrector(**assembled["jec"]) + if len(assembled["junc"]) > 0: + self.junc = JetCorrectionUncertainty(**assembled["junc"]) + if len(assembled["jer"]) > 0: + self.jer = JetResolution(**assembled["jer"]) + if len(assembled["jersf"]) > 0: + self.jersf = JetResolutionScaleFactor(**assembled["jersf"]) + + if (self.jer is None) != (self.jersf is None): + raise ValueError("Cannot apply JER-SF without an input JER, and vice-versa!") + + def to_list(self): + """Convert to list for clib case.""" + return self.jec_names_clib + self.jer_names_clib + self.jec_uncsources_clib + [self.json_path, self.savecorr] + + def assemble_corrections(self): + """Assemble corrections for both scenarios.""" assembled = {"jec": {}, "junc": {}, "jer": {}, "jersf": {}} - for key in corrections.keys(): + + for key in self.corrections.keys(): if "Uncertainty" in key: - assembled["junc"][key] = corrections[key] - elif "SF" in key: - assembled["jersf"][key] = corrections[key] - elif "Resolution" in key and "SF" not in key: - assembled["jer"][key] = corrections[key] + assembled["junc"][key] = self.corrections[key] + elif ("ScaleFactor" in key or "SF" in key): + assembled["jersf"][key] = self.corrections[key] + elif "Resolution" in key and not ("ScaleFactor" in key or "SF" in key): + assembled["jer"][key] = self.corrections[key] elif len(_levelre.findall(key)) > 0: - assembled["jec"][key] = corrections[key] - - for corrtype, nname in zip(_singletons, _nicenames): - Noftype = len(assembled[corrtype]) - if Noftype > 1: - raise Exception( - f"JEC Stack has at most one {nname}, {Noftype} are present" - ) - - if jec is None: - if len(assembled["jec"]) == 0: - self._jec = None # allow for no JEC + assembled["jec"][key] = self.corrections[key] else: - self._jec = FactorizedJetCorrector( - **{name: corrections[name] for name in assembled["jec"]} - ) - else: - if isinstance(jec, FactorizedJetCorrector): - self._jec = jec - else: - raise Exception( - 'JECStack needs a FactorizedJetCorrector passed as "jec"' - + " got object of type {}".format(type(jec)) - ) - - if junc is None: - if len(assembled["junc"]) > 0: - self._junc = JetCorrectionUncertainty( - **{name: corrections[name] for name in assembled["junc"]} - ) - else: - if isinstance(junc, JetCorrectionUncertainty): - self._junc = junc - else: - raise Exception( - 'JECStack needs a JetCorrectionUncertainty passed as "junc"' - + " got object of type {}".format(type(junc)) - ) - - if jer is None: - if len(assembled["jer"]) > 0: - self._jer = JetResolution( - **{name: corrections[name] for name in assembled["jer"]} - ) - else: - if isinstance(jer, JetResolution): - self._jer = jer - else: - raise Exception( - '"jer" must be of type "JetResolution"' - + " got {}".format(type(jer)) - ) - - if jersf is None: - if len(assembled["jersf"]) > 0: - self._jersf = JetResolutionScaleFactor( - **{name: corrections[name] for name in assembled["jersf"]} - ) - else: - if isinstance(jer, JetResolutionScaleFactor): - self._jersf = jersf - else: - raise Exception( - '"jer" must be of type "JetResolutionScaleFactor"' - + " got {}".format(type(jer)) - ) + print(f"Unknown correction type for key: {key}") - if (self.jer is None) != (self.jersf is None): - raise Exception("Cannot apply JER-SF without an input JER, and vice-versa!") + return assembled @property def blank_name_map(self): + """Returns a blank name map for corrections.""" out = { "massRaw", "ptRaw", @@ -111,32 +125,16 @@ def blank_name_map(self): "UnClusteredEnergyDeltaX", "UnClusteredEnergyDeltaY", } - if self._jec is not None: - for name in self._jec.signature: + if self.jec is not None: + for name in self.jec.signature: out.add(name) - if self._junc is not None: - for name in self._junc.signature: + if self.junc is not None: + for name in self.junc.signature: out.add(name) - if self._jer is not None: - for name in self._jer.signature: + if self.jer is not None: + for name in self.jer.signature: out.add(name) - if self._jersf is not None: - for name in self._jersf.signature: + if self.jersf is not None: + for name in self.jersf.signature: out.add(name) return {name: None for name in out} - - @property - def jec(self): - return self._jec - - @property - def junc(self): - return self._junc - - @property - def jer(self): - return self._jer - - @property - def jersf(self): - return self._jersf