diff --git a/appletree/component.py b/appletree/component.py index 2396d76d..be273a39 100644 --- a/appletree/component.py +++ b/appletree/component.py @@ -1,3 +1,4 @@ +import os from warnings import warn from functools import partial from typing import Tuple, List, Dict, Optional, Union, Set @@ -5,12 +6,13 @@ import numpy as np import pandas as pd from jax import numpy as jnp +from strax import deterministic_hash from appletree import utils from appletree.config import OMITTED from appletree.plugin import Plugin from appletree.share import _cached_configs, _cached_functions, set_global_config -from appletree.utils import exporter, load_data +from appletree.utils import exporter, get_file_path, load_data, calculate_sha256 from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d export, __all__ = exporter() @@ -53,13 +55,6 @@ def __init__(self, name: Optional[str] = None, llh_name: Optional[str] = None, * if "bins" in kwargs.keys() and "bins_type" in kwargs.keys(): self.set_binning(**kwargs) - if self.bins_type != "meshgrid" and self.add_eps_to_hist: - warn( - "It is empirically dangerous to have add_eps_to_hist==True, " - "when your bins_type is not meshgrid! It may lead to very bad fit with " - "lots of eff==0." - ) - def set_binning(self, **kwargs): """Set binning of component.""" if "bins" not in kwargs.keys() or "bins_type" not in kwargs.keys(): @@ -158,7 +153,8 @@ def implement_binning(self, mc, eff): raise ValueError(f"Unsupported bins_type {self.bins_type}!") if self.add_eps_to_hist: # as an uncertainty to prevent blowing up - hist = jnp.clip(hist, 1.0, jnp.inf) + # uncertainty = 1e-10 + jnp.mean(eff) + hist = jnp.clip(hist, 1e-10 + jnp.mean(eff), jnp.inf) return hist def get_normalization(self, hist, parameters, batch_size=None): @@ -187,6 +183,14 @@ def compile(self): """Hook for compiling simulation code.""" pass + @property + def lineage(self): + raise NotImplementedError + + @property + def lineage_hash(self): + return deterministic_hash(self.lineage) + @export class ComponentSim(Component): @@ -341,6 +345,8 @@ def flush_source_code( if isinstance(data_names, str): data_names = [data_names] + instances = set() + code = "" indent = " " * 4 @@ -356,6 +362,7 @@ def flush_source_code( for work in self.worksheet: plugin = work[0] instance = plugin + "_" + self.name + instances.add(instance) code += f"{instance} = {plugin}('{self.llh_name}')\n" # define functions @@ -375,6 +382,7 @@ def flush_source_code( code += f"{indent}return {output}\n" self.code = code + self.instances = instances if func_name in _cached_functions[self.llh_name].keys(): warning = f"Function name {func_name} is already cached. " @@ -480,11 +488,6 @@ def save_code(self, file_path): with open(file_path, "w") as f: f.write(self.code) - def lineage(self, data_name: str = "cs2"): - """Return lineage of plugins.""" - assert isinstance(data_name, str) - pass - def set_config(self, configs): """Set new global configuration options. @@ -565,6 +568,28 @@ def new_component(self, llh_name: Optional[str] = None, pass_binning: bool = Tru ) return component + @property + def lineage(self): + return { + **{ + "rate_name": self.rate_name, + "norm_type": self.norm_type, + "bins": ( + tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins + ), + "bins_type": self.bins_type, + "code": self.code, + }, + **{ + "instances": dict( + zip( + self.instances, + [_cached_functions[self.llh_name][p].lineage for p in self.instances], + ) + ) + }, + } + @export class ComponentFixed(Component): @@ -608,6 +633,21 @@ def simulate_weighted_data(self, parameters, *args, **kwargs): return result + @property + def lineage(self): + return { + "rate_name": self.rate_name, + "norm_type": self.norm_type, + "bins": tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins, + "bins_type": self.bins_type, + "file_path": ( + os.path.basename(self._file_name) + if not utils.FULL_PATH_LINEAGE + else get_file_path(self._file_name) + ), + "sha256": calculate_sha256(get_file_path(self._file_name)), + } + @export def add_component_extensions(module1, module2, force=False): diff --git a/appletree/config.py b/appletree/config.py index 856ea01f..b058b8f7 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Union, Any from immutabledict import immutabledict @@ -5,14 +6,17 @@ from warnings import warn import numpy as np +from strax import deterministic_hash +from appletree import utils from appletree.share import _cached_configs from appletree.utils import ( exporter, load_json, get_file_path, integrate_midpoint, - cum_integrate_midpoint, + cumulative_integrate_midpoint, + calculate_sha256, ) from appletree import interpolation from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX @@ -108,6 +112,14 @@ def build(self, llh_name: Optional[str] = None): def required_parameter(self, llh_name=None): return None + @property + def lineage(self): + raise NotImplementedError + + @property + def lineage_hash(self): + return deterministic_hash(self.lineage) + @export class Constant(Config): @@ -137,6 +149,13 @@ def build(self, llh_name: Optional[str] = None): else: self.value = value + @property + def lineage(self): + return { + "llh_name": self.llh_name, + "value": self.value, + } + @export class Map(Config): @@ -317,10 +336,23 @@ def log_pos(self, pos): def pdf_to_cdf(self, x, pdf): """Convert pdf map to cdf map.""" norm = integrate_midpoint(x, pdf) - x, cdf = cum_integrate_midpoint(x, pdf) + x, cdf = cumulative_integrate_midpoint(x, pdf) cdf /= norm return x, cdf + @property + def lineage(self): + return { + "llh_name": self.llh_name, + "method": self.method, + "file_path": ( + os.path.basename(self.file_path) + if not utils.FULL_PATH_LINEAGE + else get_file_path(self.file_path) + ), + "sha256": calculate_sha256(get_file_path(self.file_path)), + } + @export class SigmaMap(Config): @@ -472,6 +504,16 @@ def apply(self, pos, parameters): add = jnp.where(sigma > 0, add_pos, add_neg) return median + add + @property + def lineage(self): + return { + "llh_name": self.llh_name, + "method": self.method, + "median": self.median.lineage, + "lower": self.lower.lineage, + "upper": self.upper.lineage, + } + @export class ConstantSet(Config): @@ -507,7 +549,7 @@ def build(self, llh_name: Optional[str] = None): self._sanity_check() self.set_volume = len(self.value[1][0]) - self.value = {k: jnp.array(v) for k, v in zip(*self.value)} + self.value = {k: np.array(v) for k, v in zip(*self.value)} def _sanity_check(self): """Check if parameter set lengths are same.""" @@ -518,3 +560,10 @@ def _sanity_check(self): volumes = [len(v) for v in self.value[1]] mesg = "Parameter set lengths should be the same" assert np.all(np.isclose(volumes, volumes[0])), mesg + + @property + def lineage(self): + return { + "llh_name": self.llh_name, + "value": self.value, + } diff --git a/appletree/context.py b/appletree/context.py index 4c021746..2e346d1d 100644 --- a/appletree/context.py +++ b/appletree/context.py @@ -7,13 +7,14 @@ from typing import Set, Optional import numpy as np -import emcee import h5py +import emcee +from strax import deterministic_hash import appletree as apt from appletree import randgen from appletree import Parameter -from appletree.utils import load_json, get_file_path +from appletree.utils import JSON_OPTIONS, load_json, get_file_path from appletree.share import _cached_configs, set_global_config os.environ["OMP_NUM_THREADS"] = "1" @@ -302,19 +303,23 @@ def _dump_meta(self, batch_size, metadata=None): if self.backend_h5 is not None: name = self.sampler.backend.name with h5py.File(self.backend_h5, "r+") as opt: - opt[name].attrs["metadata"] = json.dumps(metadata) + opt[name].attrs["metadata"] = json.dumps(metadata, **JSON_OPTIONS) # parameters prior configuration - opt[name].attrs["par_config"] = json.dumps(self.par_manager.par_config) + opt[name].attrs["par_config"] = json.dumps( + self.par_manager.par_config, **JSON_OPTIONS + ) # max posterior parameters - opt[name].attrs["post_parameters"] = json.dumps(self.get_post_parameters()) + opt[name].attrs["post_parameters"] = json.dumps( + self.get_post_parameters(), **JSON_OPTIONS + ) # the order of parameters saved in backend opt[name].attrs["parameter_fit"] = self.par_manager.parameter_fit # instructions - opt[name].attrs["instruct"] = json.dumps(self.instruct) + opt[name].attrs["instruct"] = json.dumps(self.instruct, **JSON_OPTIONS) # configs - opt[name].attrs["config"] = json.dumps(self.config) + opt[name].attrs["config"] = json.dumps(self.config, **JSON_OPTIONS) # configurations, maybe users will manually add some maps - opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs) + opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs, **JSON_OPTIONS) # batch size opt[name].attrs["batch_size"] = batch_size @@ -390,7 +395,21 @@ def update_parameter_config(self, likelihoods): self.par_config.pop(p) return needed_parameters - def lineage(self, data_name: str = "cs2"): - """Return lineage of plugins.""" - assert isinstance(data_name, str) - pass + @property + def lineage(self): + return { + **self.instruct, + **{"par_config": self.par_config}, + **{ + "likelihoods": dict( + zip( + self.likelihoods.keys(), + [v.lineage for v in self.likelihoods.values()], + ) + ) + }, + } + + @property + def lineage_hash(self): + return deterministic_hash(self.lineage) diff --git a/appletree/likelihood.py b/appletree/likelihood.py index 8c1b74f6..df4f47a1 100644 --- a/appletree/likelihood.py +++ b/appletree/likelihood.py @@ -1,16 +1,23 @@ +import os from warnings import warn from typing import Type, Dict, Set, Optional, cast import inspect from copy import deepcopy import numpy as np -from jax import numpy as jnp - from scipy.stats import norm +from strax import deterministic_hash +from appletree import utils from appletree import randgen from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d -from appletree.utils import load_data, get_equiprob_bins_1d, get_equiprob_bins_2d +from appletree.utils import ( + get_file_path, + load_data, + get_equiprob_bins_1d, + get_equiprob_bins_2d, + calculate_sha256, +) from appletree.component import Component, ComponentSim, ComponentFixed from appletree.randgen import TwoHalfNorm, BandTwoHalfNorm @@ -108,25 +115,25 @@ def set_binning(self, config): self.component_bins_type = "meshgrid" if self._dim == 1: if isinstance(self._bins[0], int): - bins = jnp.linspace(*config["clip"], self._bins[0] + 1) + bins = np.linspace(*config["clip"], self._bins[0] + 1) else: - bins = jnp.array(self._bins[0]) + bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) self._bins = (bins,) elif self._dim == 2: if isinstance(self._bins[0], int): - x_bins = jnp.linspace(*config["x_clip"], self._bins[0] + 1) + x_bins = np.linspace(*config["x_clip"], self._bins[0] + 1) else: - x_bins = jnp.array(self._bins[0]) + x_bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) if isinstance(self._bins[1], int): - y_bins = jnp.linspace(*config["y_clip"], self._bins[1] + 1) + y_bins = np.linspace(*config["y_clip"], self._bins[1] + 1) else: - y_bins = jnp.array(self._bins[1]) + y_bins = np.array(self._bins[1]) if "y_clip" in config: warning = "y_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) @@ -134,7 +141,7 @@ def set_binning(self, config): self.data_hist = make_hist_mesh_grid( self.data, bins=self._bins, - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._bins_type == "equiprob": if not all([isinstance(b, int) for b in self._bins]): @@ -144,13 +151,13 @@ def set_binning(self, config): self.data[:, 0], self._bins[0], clip=config["clip"], - which_np=jnp, + which_np=np, ) - self._bins = [self._bins] + self._bins = (self._bins,) self.data_hist = make_hist_irreg_bin_1d( self.data[:, 0], bins=self._bins[0], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._dim == 2: self._bins = get_equiprob_bins_2d( @@ -158,38 +165,38 @@ def set_binning(self, config): self._bins, x_clip=config["x_clip"], y_clip=config["y_clip"], - which_np=jnp, + which_np=np, ) self.data_hist = make_hist_irreg_bin_2d( self.data, bins_x=self._bins[0], bins_y=self._bins[1], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) self.component_bins_type = "irreg" elif self._bins_type == "irreg": - self._bins = [jnp.array(b) for b in self._bins] + self._bins = [np.array(b) for b in self._bins] if self._dim == 1: if isinstance(self._bins[0], int): - bins = jnp.linspace(*config["clip"], self._bins[0] + 1) + bins = np.linspace(*config["clip"], self._bins[0] + 1) else: - bins = jnp.array(self._bins[0]) + bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) self._bins = (bins,) elif self._dim == 2: if isinstance(self._bins[0], int): - x_bins = jnp.linspace(*config["x_clip"], self._bins[0] + 1) + x_bins = np.linspace(*config["x_clip"], self._bins[0] + 1) else: - x_bins = jnp.array(self._bins[0]) + x_bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) if isinstance(self._bins[1], int): - y_bins = jnp.linspace(*config["y_clip"], self._bins[1] + 1) + y_bins = np.linspace(*config["y_clip"], self._bins[1] + 1) else: - y_bins = jnp.array(self._bins[1]) + y_bins = np.array(self._bins[1]) if "y_clip" in config: warning = "y_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) @@ -212,17 +219,18 @@ def set_binning(self, config): self.data_hist = make_hist_irreg_bin_1d( self.data[:, 0], bins=self._bins[0], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._dim == 2: self.data_hist = make_hist_irreg_bin_2d( self.data, bins_x=self._bins[0], bins_y=self._bins[1], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) else: raise ValueError("'bins_type' should either be meshgrid, equiprob or irreg") + assert isinstance(self._bins, tuple), "bins should be tuple after setting binning!" def register_component( self, component_cls: Type[Component], component_name: str, file_name: Optional[str] = None @@ -293,7 +301,7 @@ def _simulate_model_hist(self, key, batch_size, parameters): parameters: dict of parameters used in simulation. """ - hist = jnp.zeros_like(self.data_hist) + hist = np.zeros_like(self.data_hist) for component_name, component in self.components.items(): if isinstance(component, ComponentSim): key, _hist = component.simulate_hist(key, batch_size, parameters) @@ -338,7 +346,7 @@ def get_log_likelihood(self, key, batch_size, parameters): """ key, model_hist = self._simulate_model_hist(key, batch_size, parameters) # Poisson likelihood - llh = jnp.sum(self.data_hist * jnp.log(model_hist) - model_hist) + llh = np.sum(self.data_hist * np.log(model_hist) - model_hist) llh = float(llh) if np.isnan(llh): llh = -np.inf @@ -404,6 +412,32 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True): print("-" * 40) + @property + def lineage(self): + return { + **{ + "config": self._config, + "file_path": ( + os.path.basename(self._data_file_name) + if not utils.FULL_PATH_LINEAGE + else get_file_path(self._data_file_name) + ), + "sha256": calculate_sha256(get_file_path(self._data_file_name)), + }, + **{ + "components": dict( + zip( + self.components.keys(), + [v.lineage for v in self.components.values()], + ) + ) + }, + } + + @property + def lineage_hash(self): + return deterministic_hash(self.lineage) + class LikelihoodLit(Likelihood): """Using literature constraint to build LLH. @@ -553,3 +587,19 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True): print() print("-" * 40) + + @property + def lineage(self): + return { + **{ + "config": self._config, + }, + **{ + "components": dict( + zip( + self.components.keys(), + [v.lineage_hash for v in self.components.values()], + ) + ) + }, + } diff --git a/appletree/plugin.py b/appletree/plugin.py index f796d90a..3d55941c 100644 --- a/appletree/plugin.py +++ b/appletree/plugin.py @@ -3,6 +3,7 @@ from typing import List, Tuple, Optional from immutabledict import immutabledict +from strax import deterministic_hash from appletree import utils from appletree.utils import exporter @@ -90,6 +91,28 @@ def sanity_check(self): mesg += f"Plugin {self.__class__.__name__} is insane, check dependency!" raise ValueError(mesg) + @property + def lineage(self): + return { + **{ + "depends_on": self.depends_on, + "provides": self.provides, + "parameters": self.parameters, + }, + **{ + "takes_config": dict( + zip( + self.takes_config.keys(), + [v.lineage for v in self.takes_config.values()], + ) + ) + }, + } + + @property + def lineage_hash(self): + return deterministic_hash(self.lineage) + @export def add_plugin_extensions(module1, module2, force=False): diff --git a/appletree/share.py b/appletree/share.py index 5d4e2901..e2dd33c9 100644 --- a/appletree/share.py +++ b/appletree/share.py @@ -15,7 +15,7 @@ def __setitem__(self, key, value): return super().__setitem__(key, value) def __repr__(self): - return json.dumps(self, indent=4) + return json.dumps(self, sort_keys=True, indent=4) def __str__(self): return self.__repr__() diff --git a/appletree/utils.py b/appletree/utils.py index 34c3b62c..1ddb98b4 100644 --- a/appletree/utils.py +++ b/appletree/utils.py @@ -1,6 +1,7 @@ import os import json from warnings import warn +import hashlib import importlib_resources from time import time @@ -27,6 +28,10 @@ SKIP_MONGO_DB = True +JSON_OPTIONS = dict(sort_keys=True, indent=4) + +FULL_PATH_LINEAGE = False + def exporter(export_self=False): """Export utility modified from https://stackoverflow.com/a/41895194 @@ -203,6 +208,23 @@ def get_file_path(fname): raise RuntimeError(f"Can not find {fname}, please check your file system") +@export +def calculate_sha256(file_path): + """Get sha256 hash of the file.""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + +@export +def dump_lineage(file_path, entity): + """Dump lineage of whatever level into .json file.""" + with open(file_path, "w") as f: + f.write(json.dumps(entity.lineage, **JSON_OPTIONS)) + + @export def timeit(indent=""): """Use timeit as a decorator. @@ -585,11 +607,11 @@ def integrate_midpoint(x, y): y: 1D array-like, with the same length as x. """ - _, res = cum_integrate_midpoint(x, y) + _, res = cumulative_integrate_midpoint(x, y) return res[-1] -def cum_integrate_midpoint(x, y): +def cumulative_integrate_midpoint(x, y): """Calculate the cumulative integral using midpoint method. Args: diff --git a/requirements.txt b/requirements.txt index c3cd7668..e748637c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ numpyro pandas scikit-learn scipy +strax straxen diff --git a/tests/test_component.py b/tests/test_component.py index 76ff5255..70ed06b0 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -1,7 +1,7 @@ import pytest +import numpy as np import pandas as pd -from jax import numpy as jnp import appletree as apt from appletree.utils import get_file_path @@ -26,7 +26,7 @@ order=[0, 1], x_clip=[0, 100], y_clip=[1e2, 1e4], - which_np=jnp, + which_np=np, ) @@ -39,6 +39,7 @@ def test_fixed_component(): ) ac.rate_name = "ac_rate" ac.deduce(data_names=["cs1", "cs2"]) + ac.lineage_hash ac.simulate_hist(parameters) ac.simulate_weighted_data(parameters) @@ -94,6 +95,7 @@ def benchmark(key): force_no_eff=True, ) er.compile() + er.lineage_hash er.simulate_hist(key, batch_size, parameters) with pytest.raises(RuntimeError): key, r = er.multiple_simulations(key, batch_size, parameters, 5, apply_eff=True) diff --git a/tests/test_context.py b/tests/test_context.py index 63ffdad4..dbcf37a2 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -8,6 +8,7 @@ def test_rn220_context(): _cached_functions.clear() _cached_configs.clear() context = apt.ContextRn220() + context.lineage_hash context.print_context_summary(short=False) context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) @@ -38,6 +39,7 @@ def test_rn220_context_1d(): instruction["likelihoods"]["rn220_llh"]["bins_type"] = bins_type instruction["likelihoods"]["rn220_llh"]["bins"] = bins context = apt.Context(instruction) + context.lineage_hash context.print_context_summary(short=False) @@ -46,6 +48,7 @@ def test_rn220_ar37_context(): _cached_functions.clear() _cached_configs.clear() context = apt.ContextRn220Ar37() + context.lineage_hash context.print_context_summary(short=False) @@ -63,6 +66,7 @@ def test_neutron_context(): _cached_configs.clear() instruct = get_file_path("neutron_low.json") context = apt.Context(instruct) + context.lineage_hash context.print_context_summary(short=False) context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) @@ -75,6 +79,7 @@ def test_literature_context(): _cached_configs.clear() instruct = get_file_path("literature_lyqy.json") context = apt.Context(instruct) + context.lineage_hash context.print_context_summary(short=False) @@ -93,6 +98,7 @@ def test_backend(): instruct = apt.utils.load_json("rn220.json") instruct["backend_h5"] = "test_backend.h5" context = apt.Context(instruct) + context.lineage_hash context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) _cached_functions.clear() diff --git a/tests/test_likelihood.py b/tests/test_likelihood.py index 23e0279b..a75af44d 100644 --- a/tests/test_likelihood.py +++ b/tests/test_likelihood.py @@ -19,6 +19,7 @@ def test_er_likelihood(): llh = apt.Likelihood(**instruct) llh.register_component(apt.components.AC, "rn220_ac", "AC_Rn220.pkl") llh.register_component(apt.components.ERBand, "rn220_er") + llh.lineage_hash llh.print_likelihood_summary(short=True) # Get parameters @@ -47,6 +48,7 @@ def test_nr_likelihood(): ) llh = apt.Likelihood(**instruct) llh.register_component(apt.components.NR, "neutron_nr") + llh.lineage_hash llh.print_likelihood_summary(short=True) # Get parameters