diff --git a/appletree/component.py b/appletree/component.py index c49d3e43..925302ce 100644 --- a/appletree/component.py +++ b/appletree/component.py @@ -5,12 +5,13 @@ import numpy as np import pandas as pd from jax import numpy as jnp +from strax import deterministic_hash from appletree import utils from appletree.config import OMITTED from appletree.plugin import Plugin from appletree.share import _cached_configs, _cached_functions, set_global_config -from appletree.utils import exporter, load_data +from appletree.utils import exporter, get_file_path, load_data, calculate_sha256 from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d export, __all__ = exporter() @@ -181,6 +182,10 @@ def compile(self): """Hook for compiling simulation code.""" pass + @property + def lineage_hash(self): + raise NotImplementedError + @export class ComponentSim(Component): @@ -335,6 +340,8 @@ def flush_source_code( if isinstance(data_names, str): data_names = [data_names] + instances = set() + code = "" indent = " " * 4 @@ -350,6 +357,7 @@ def flush_source_code( for work in self.worksheet: plugin = work[0] instance = plugin + "_" + self.name + instances.add(instance) code += f"{instance} = {plugin}('{self.llh_name}')\n" # define functions @@ -369,6 +377,7 @@ def flush_source_code( code += f"{indent}return {output}\n" self.code = code + self.instances = instances if func_name in _cached_functions[self.llh_name].keys(): warning = f"Function name {func_name} is already cached. " @@ -474,11 +483,6 @@ def save_code(self, file_path): with open(file_path, "w") as f: f.write(self.code) - def lineage(self, data_name: str = "cs2"): - """Return lineage of plugins.""" - assert isinstance(data_name, str) - pass - def set_config(self, configs): """Set new global configuration options. @@ -559,6 +563,26 @@ def new_component(self, llh_name: Optional[str] = None, pass_binning: bool = Tru ) return component + @property + def lineage_hash(self): + return deterministic_hash( + { + **{ + "rate_name": self.rate_name, + "norm_type": self.norm_type, + "bins": self.bins, + "bins_type": self.bins_type, + "code": self.code, + }, + **dict( + zip( + self.instances, + [_cached_functions[self.llh_name][p].lineage_hash for p in self.instances], + ) + ), + } + ) + @export class ComponentFixed(Component): @@ -602,6 +626,18 @@ def simulate_weighted_data(self, parameters, *args, **kwargs): return result + @property + def lineage_hash(self): + return deterministic_hash( + { + "rate_name": self.rate_name, + "norm_type": self.norm_type, + "bins": self.bins, + "bins_type": self.bins_type, + "file_name": calculate_sha256(get_file_path(self._file_name)), + } + ) + @export def add_component_extensions(module1, module2, force=False): diff --git a/appletree/config.py b/appletree/config.py index 856ea01f..cc5076da 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Union, Any from immutabledict import immutabledict @@ -5,6 +6,7 @@ from warnings import warn import numpy as np +from strax import deterministic_hash from appletree.share import _cached_configs from appletree.utils import ( @@ -12,7 +14,8 @@ load_json, get_file_path, integrate_midpoint, - cum_integrate_midpoint, + cumulative_integrate_midpoint, + calculate_sha256, ) from appletree import interpolation from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX @@ -108,6 +111,10 @@ def build(self, llh_name: Optional[str] = None): def required_parameter(self, llh_name=None): return None + @property + def lineage_hash(self): + raise NotImplementedError + @export class Constant(Config): @@ -137,6 +144,15 @@ def build(self, llh_name: Optional[str] = None): else: self.value = value + @property + def lineage_hash(self): + return deterministic_hash( + { + "llh_name": self.llh_name, + "value": self.value, + } + ) + @export class Map(Config): @@ -317,10 +333,21 @@ def log_pos(self, pos): def pdf_to_cdf(self, x, pdf): """Convert pdf map to cdf map.""" norm = integrate_midpoint(x, pdf) - x, cdf = cum_integrate_midpoint(x, pdf) + x, cdf = cumulative_integrate_midpoint(x, pdf) cdf /= norm return x, cdf + @property + def lineage_hash(self): + return deterministic_hash( + { + "llh_name": self.llh_name, + "method": self.method, + "file_path": os.path.basename(self.file_path), + "sha256": calculate_sha256(get_file_path(self.file_path)), + } + ) + @export class SigmaMap(Config): @@ -472,6 +499,18 @@ def apply(self, pos, parameters): add = jnp.where(sigma > 0, add_pos, add_neg) return median + add + @property + def lineage_hash(self): + return deterministic_hash( + { + "llh_name": self.llh_name, + "method": self.method, + "median": self.median.lineage_hash, + "lower": self.lower.lineage_hash, + "upper": self.upper.lineage_hash, + } + ) + @export class ConstantSet(Config): @@ -507,7 +546,7 @@ def build(self, llh_name: Optional[str] = None): self._sanity_check() self.set_volume = len(self.value[1][0]) - self.value = {k: jnp.array(v) for k, v in zip(*self.value)} + self.value = {k: np.array(v) for k, v in zip(*self.value)} def _sanity_check(self): """Check if parameter set lengths are same.""" @@ -518,3 +557,12 @@ def _sanity_check(self): volumes = [len(v) for v in self.value[1]] mesg = "Parameter set lengths should be the same" assert np.all(np.isclose(volumes, volumes[0])), mesg + + @property + def lineage_hash(self): + return deterministic_hash( + { + "llh_name": self.llh_name, + "value": self.value, + } + ) diff --git a/appletree/context.py b/appletree/context.py index 4c021746..d56e85cc 100644 --- a/appletree/context.py +++ b/appletree/context.py @@ -7,8 +7,9 @@ from typing import Set, Optional import numpy as np -import emcee import h5py +import emcee +from strax import deterministic_hash import appletree as apt from appletree import randgen @@ -390,7 +391,17 @@ def update_parameter_config(self, likelihoods): self.par_config.pop(p) return needed_parameters - def lineage(self, data_name: str = "cs2"): - """Return lineage of plugins.""" - assert isinstance(data_name, str) - pass + @property + def lineage_hash(self): + return deterministic_hash( + { + **self.instruct, + **self.par_config, + **dict( + zip( + self.likelihoods.keys(), + [v.lineage_hash for v in self.likelihoods.values()], + ) + ), + } + ) diff --git a/appletree/likelihood.py b/appletree/likelihood.py index 8c1b74f6..f1b0ff2c 100644 --- a/appletree/likelihood.py +++ b/appletree/likelihood.py @@ -4,13 +4,18 @@ from copy import deepcopy import numpy as np -from jax import numpy as jnp - from scipy.stats import norm +from strax import deterministic_hash from appletree import randgen from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d -from appletree.utils import load_data, get_equiprob_bins_1d, get_equiprob_bins_2d +from appletree.utils import ( + get_file_path, + load_data, + get_equiprob_bins_1d, + get_equiprob_bins_2d, + calculate_sha256, +) from appletree.component import Component, ComponentSim, ComponentFixed from appletree.randgen import TwoHalfNorm, BandTwoHalfNorm @@ -108,25 +113,25 @@ def set_binning(self, config): self.component_bins_type = "meshgrid" if self._dim == 1: if isinstance(self._bins[0], int): - bins = jnp.linspace(*config["clip"], self._bins[0] + 1) + bins = np.linspace(*config["clip"], self._bins[0] + 1) else: - bins = jnp.array(self._bins[0]) + bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) self._bins = (bins,) elif self._dim == 2: if isinstance(self._bins[0], int): - x_bins = jnp.linspace(*config["x_clip"], self._bins[0] + 1) + x_bins = np.linspace(*config["x_clip"], self._bins[0] + 1) else: - x_bins = jnp.array(self._bins[0]) + x_bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) if isinstance(self._bins[1], int): - y_bins = jnp.linspace(*config["y_clip"], self._bins[1] + 1) + y_bins = np.linspace(*config["y_clip"], self._bins[1] + 1) else: - y_bins = jnp.array(self._bins[1]) + y_bins = np.array(self._bins[1]) if "y_clip" in config: warning = "y_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) @@ -134,7 +139,7 @@ def set_binning(self, config): self.data_hist = make_hist_mesh_grid( self.data, bins=self._bins, - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._bins_type == "equiprob": if not all([isinstance(b, int) for b in self._bins]): @@ -144,13 +149,13 @@ def set_binning(self, config): self.data[:, 0], self._bins[0], clip=config["clip"], - which_np=jnp, + which_np=np, ) self._bins = [self._bins] self.data_hist = make_hist_irreg_bin_1d( self.data[:, 0], bins=self._bins[0], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._dim == 2: self._bins = get_equiprob_bins_2d( @@ -158,38 +163,38 @@ def set_binning(self, config): self._bins, x_clip=config["x_clip"], y_clip=config["y_clip"], - which_np=jnp, + which_np=np, ) self.data_hist = make_hist_irreg_bin_2d( self.data, bins_x=self._bins[0], bins_y=self._bins[1], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) self.component_bins_type = "irreg" elif self._bins_type == "irreg": - self._bins = [jnp.array(b) for b in self._bins] + self._bins = [np.array(b) for b in self._bins] if self._dim == 1: if isinstance(self._bins[0], int): - bins = jnp.linspace(*config["clip"], self._bins[0] + 1) + bins = np.linspace(*config["clip"], self._bins[0] + 1) else: - bins = jnp.array(self._bins[0]) + bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) self._bins = (bins,) elif self._dim == 2: if isinstance(self._bins[0], int): - x_bins = jnp.linspace(*config["x_clip"], self._bins[0] + 1) + x_bins = np.linspace(*config["x_clip"], self._bins[0] + 1) else: - x_bins = jnp.array(self._bins[0]) + x_bins = np.array(self._bins[0]) if "x_clip" in config: warning = "x_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) if isinstance(self._bins[1], int): - y_bins = jnp.linspace(*config["y_clip"], self._bins[1] + 1) + y_bins = np.linspace(*config["y_clip"], self._bins[1] + 1) else: - y_bins = jnp.array(self._bins[1]) + y_bins = np.array(self._bins[1]) if "y_clip" in config: warning = "y_clip is ignored when bins_type is meshgrid and bins is not int" warn(warning) @@ -212,14 +217,14 @@ def set_binning(self, config): self.data_hist = make_hist_irreg_bin_1d( self.data[:, 0], bins=self._bins[0], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) elif self._dim == 2: self.data_hist = make_hist_irreg_bin_2d( self.data, bins_x=self._bins[0], bins_y=self._bins[1], - weights=jnp.ones(len(self.data)), + weights=np.ones(len(self.data)), ) else: raise ValueError("'bins_type' should either be meshgrid, equiprob or irreg") @@ -293,7 +298,7 @@ def _simulate_model_hist(self, key, batch_size, parameters): parameters: dict of parameters used in simulation. """ - hist = jnp.zeros_like(self.data_hist) + hist = np.zeros_like(self.data_hist) for component_name, component in self.components.items(): if isinstance(component, ComponentSim): key, _hist = component.simulate_hist(key, batch_size, parameters) @@ -338,7 +343,7 @@ def get_log_likelihood(self, key, batch_size, parameters): """ key, model_hist = self._simulate_model_hist(key, batch_size, parameters) # Poisson likelihood - llh = jnp.sum(self.data_hist * jnp.log(model_hist) - model_hist) + llh = np.sum(self.data_hist * np.log(model_hist) - model_hist) llh = float(llh) if np.isnan(llh): llh = -np.inf @@ -404,6 +409,23 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True): print("-" * 40) + @property + def lineage_hash(self): + return deterministic_hash( + { + **{ + "config": self._config, + "sha256": calculate_sha256(get_file_path(self._data_file_name)), + }, + **dict( + zip( + self.components.keys(), + [v.lineage_hash for v in self.components.values()], + ) + ), + } + ) + class LikelihoodLit(Likelihood): """Using literature constraint to build LLH. @@ -553,3 +575,19 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True): print() print("-" * 40) + + @property + def lineage_hash(self): + return deterministic_hash( + { + **{ + "config": self._config, + }, + **dict( + zip( + self.components.keys(), + [v.lineage_hash for v in self.components.values()], + ) + ), + } + ) diff --git a/appletree/plugin.py b/appletree/plugin.py index f796d90a..19e494e6 100644 --- a/appletree/plugin.py +++ b/appletree/plugin.py @@ -3,6 +3,7 @@ from typing import List, Tuple, Optional from immutabledict import immutabledict +from strax import deterministic_hash from appletree import utils from appletree.utils import exporter @@ -90,6 +91,24 @@ def sanity_check(self): mesg += f"Plugin {self.__class__.__name__} is insane, check dependency!" raise ValueError(mesg) + @property + def lineage_hash(self): + return deterministic_hash( + { + **{ + "depends_on": self.depends_on, + "provides": self.provides, + "parameters": self.parameters, + }, + **dict( + zip( + self.takes_config.keys(), + [v.lineage_hash for v in self.takes_config.values()], + ) + ), + } + ) + @export def add_plugin_extensions(module1, module2, force=False): diff --git a/appletree/utils.py b/appletree/utils.py index 34c3b62c..4e6bd050 100644 --- a/appletree/utils.py +++ b/appletree/utils.py @@ -1,6 +1,7 @@ import os import json from warnings import warn +import hashlib import importlib_resources from time import time @@ -203,6 +204,15 @@ def get_file_path(fname): raise RuntimeError(f"Can not find {fname}, please check your file system") +@export +def calculate_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + @export def timeit(indent=""): """Use timeit as a decorator. @@ -585,11 +595,11 @@ def integrate_midpoint(x, y): y: 1D array-like, with the same length as x. """ - _, res = cum_integrate_midpoint(x, y) + _, res = cumulative_integrate_midpoint(x, y) return res[-1] -def cum_integrate_midpoint(x, y): +def cumulative_integrate_midpoint(x, y): """Calculate the cumulative integral using midpoint method. Args: diff --git a/requirements.txt b/requirements.txt index c3cd7668..e748637c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ numpyro pandas scikit-learn scipy +strax straxen diff --git a/tests/test_component.py b/tests/test_component.py index 76ff5255..70ed06b0 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -1,7 +1,7 @@ import pytest +import numpy as np import pandas as pd -from jax import numpy as jnp import appletree as apt from appletree.utils import get_file_path @@ -26,7 +26,7 @@ order=[0, 1], x_clip=[0, 100], y_clip=[1e2, 1e4], - which_np=jnp, + which_np=np, ) @@ -39,6 +39,7 @@ def test_fixed_component(): ) ac.rate_name = "ac_rate" ac.deduce(data_names=["cs1", "cs2"]) + ac.lineage_hash ac.simulate_hist(parameters) ac.simulate_weighted_data(parameters) @@ -94,6 +95,7 @@ def benchmark(key): force_no_eff=True, ) er.compile() + er.lineage_hash er.simulate_hist(key, batch_size, parameters) with pytest.raises(RuntimeError): key, r = er.multiple_simulations(key, batch_size, parameters, 5, apply_eff=True) diff --git a/tests/test_context.py b/tests/test_context.py index 63ffdad4..dbcf37a2 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -8,6 +8,7 @@ def test_rn220_context(): _cached_functions.clear() _cached_configs.clear() context = apt.ContextRn220() + context.lineage_hash context.print_context_summary(short=False) context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) @@ -38,6 +39,7 @@ def test_rn220_context_1d(): instruction["likelihoods"]["rn220_llh"]["bins_type"] = bins_type instruction["likelihoods"]["rn220_llh"]["bins"] = bins context = apt.Context(instruction) + context.lineage_hash context.print_context_summary(short=False) @@ -46,6 +48,7 @@ def test_rn220_ar37_context(): _cached_functions.clear() _cached_configs.clear() context = apt.ContextRn220Ar37() + context.lineage_hash context.print_context_summary(short=False) @@ -63,6 +66,7 @@ def test_neutron_context(): _cached_configs.clear() instruct = get_file_path("neutron_low.json") context = apt.Context(instruct) + context.lineage_hash context.print_context_summary(short=False) context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) @@ -75,6 +79,7 @@ def test_literature_context(): _cached_configs.clear() instruct = get_file_path("literature_lyqy.json") context = apt.Context(instruct) + context.lineage_hash context.print_context_summary(short=False) @@ -93,6 +98,7 @@ def test_backend(): instruct = apt.utils.load_json("rn220.json") instruct["backend_h5"] = "test_backend.h5" context = apt.Context(instruct) + context.lineage_hash context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) _cached_functions.clear() diff --git a/tests/test_likelihood.py b/tests/test_likelihood.py index 23e0279b..a75af44d 100644 --- a/tests/test_likelihood.py +++ b/tests/test_likelihood.py @@ -19,6 +19,7 @@ def test_er_likelihood(): llh = apt.Likelihood(**instruct) llh.register_component(apt.components.AC, "rn220_ac", "AC_Rn220.pkl") llh.register_component(apt.components.ERBand, "rn220_er") + llh.lineage_hash llh.print_likelihood_summary(short=True) # Get parameters @@ -47,6 +48,7 @@ def test_nr_likelihood(): ) llh = apt.Likelihood(**instruct) llh.register_component(apt.components.NR, "neutron_nr") + llh.lineage_hash llh.print_likelihood_summary(short=True) # Get parameters