From c1a5f08b96620f602e6de260a1a640c9bc08c646 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Tue, 10 Oct 2023 14:43:13 -0500 Subject: [PATCH] split names with with _ --- src/pyhf/experimental/modifiers.py | 36 ++++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/pyhf/experimental/modifiers.py b/src/pyhf/experimental/modifiers.py index ae5a273214..a16a2680aa 100644 --- a/src/pyhf/experimental/modifiers.py +++ b/src/pyhf/experimental/modifiers.py @@ -57,7 +57,7 @@ def func(d: Sequence[float]) -> Any: def make_builder( - funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]] + func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]] ) -> BaseBuilder: class _builder(BaseBuilder): is_shared = False @@ -83,13 +83,13 @@ def append(self, key, channel, sample, thismod, defined_samp): moddata = self.collect(thismod, nom) self.builder_data[key][sample]["data"]["mask"] += moddata["mask"] if thismod: - if thismod["name"] != funcname: + if thismod["name"] != func_name: print(thismod) self.builder_data["funcs"].setdefault( thismod["name"], thismod["data"]["expr"] ) self.required_parsets = { - k: [_allocate_new_param(v)] for k, v in newparams.items() + k: [_allocate_new_param(v)] for k, v in new_params.items() } def finalize(self): @@ -99,10 +99,10 @@ def finalize(self): def make_applier( - funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]] + func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]] ) -> BaseApplier: class _applier(BaseApplier): - name = funcname + name = func_name op_code = "multiplication" def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): @@ -120,7 +120,7 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): self.param_viewer = ParamViewer( parfield_shape, pdfconfig.par_map, pars_for_applier ) - self._custommod_mask = [ + self._custom_mod_mask = [ [[builder_data[modname][s]["data"]["mask"]] for s in pdfconfig.samples] for modname in _modnames ] @@ -131,14 +131,14 @@ def _precompute(self): tensorlib, _ = get_backend() if not self.param_viewer.index_selection: return - self.custommod_mask = tensorlib.tile( - tensorlib.astensor(self._custommod_mask), + self.custom_mod_mask = tensorlib.tile( + tensorlib.astensor(self._custom_mod_mask), (1, 1, self.batch_size or 1, 1), ) - self.custommod_mask_bool = tensorlib.astensor( - self.custommod_mask, dtype="bool" + self.custom_mod_mask_bool = tensorlib.astensor( + self.custom_mod_mask, dtype="bool" ) - self.custommod_default = tensorlib.ones(self.custommod_mask.shape) + self.custom_mod_default = tensorlib.ones(self.custom_mod_mask.shape) def apply(self, pars): """ @@ -152,16 +152,18 @@ def apply(self, pars): deps = self.param_viewer.get(pars) print("deps", deps.shape) results = tensorlib.astensor([f(deps) for f in self.funcs]) - results = tensorlib.einsum("msab,m->msab", self.custommod_mask, results) + results = tensorlib.einsum( + "msab,m->msab", self.custom_mod_mask, results + ) else: deps = self.param_viewer.get(pars) print("deps", deps.shape) results = tensorlib.astensor([f(deps) for f in self.funcs]) results = tensorlib.einsum( - "msab,ma->msab", self.custommod_mask, results + "msab,ma->msab", self.custom_mod_mask, results ) results = tensorlib.where( - self.custommod_mask_bool, results, self.custommod_default + self.custom_mod_mask_bool, results, self.custom_mod_default ) return results @@ -169,10 +171,10 @@ def apply(self, pars): def add_custom_modifier( - funcname: str, deps: list[str], newparams: dict[str, dict[str, Sequence[float]]] + func_name: str, deps: list[str], new_params: dict[str, dict[str, Sequence[float]]] ) -> dict[str, tuple[BaseBuilder, BaseApplier]]: - _builder = make_builder(funcname, deps, newparams) - _applier = make_applier(funcname, deps, newparams) + _builder = make_builder(func_name, deps, new_params) + _applier = make_applier(func_name, deps, new_params) modifier_set = {_applier.name: (_builder, _applier)} modifier_set.update(**pyhf.modifiers.histfactory_set)