Skip to content

Commit

Permalink
Add lineage hash for all levels: Config, Plugin, Component, and…
Browse files Browse the repository at this point in the history
… `Context` (#178)

* Add lineage_hash of config, plugin, and component

* Add lineage_hash for likehood and context, include sha256 of files into lineage_hash

* Update tests

* Debug

* Debug
  • Loading branch information
dachengx authored Jul 30, 2024
1 parent 08fe243 commit 105859b
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 43 deletions.
48 changes: 42 additions & 6 deletions appletree/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import numpy as np
import pandas as pd
from jax import numpy as jnp
from strax import deterministic_hash

from appletree import utils
from appletree.config import OMITTED
from appletree.plugin import Plugin
from appletree.share import _cached_configs, _cached_functions, set_global_config
from appletree.utils import exporter, load_data
from appletree.utils import exporter, get_file_path, load_data, calculate_sha256
from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d

export, __all__ = exporter()
Expand Down Expand Up @@ -181,6 +182,10 @@ def compile(self):
"""Hook for compiling simulation code."""
pass

@property
def lineage_hash(self):
raise NotImplementedError


@export
class ComponentSim(Component):
Expand Down Expand Up @@ -335,6 +340,8 @@ def flush_source_code(
if isinstance(data_names, str):
data_names = [data_names]

instances = set()

code = ""
indent = " " * 4

Expand All @@ -350,6 +357,7 @@ def flush_source_code(
for work in self.worksheet:
plugin = work[0]
instance = plugin + "_" + self.name
instances.add(instance)
code += f"{instance} = {plugin}('{self.llh_name}')\n"

# define functions
Expand All @@ -369,6 +377,7 @@ def flush_source_code(
code += f"{indent}return {output}\n"

self.code = code
self.instances = instances

if func_name in _cached_functions[self.llh_name].keys():
warning = f"Function name {func_name} is already cached. "
Expand Down Expand Up @@ -474,11 +483,6 @@ def save_code(self, file_path):
with open(file_path, "w") as f:
f.write(self.code)

def lineage(self, data_name: str = "cs2"):
"""Return lineage of plugins."""
assert isinstance(data_name, str)
pass

def set_config(self, configs):
"""Set new global configuration options.
Expand Down Expand Up @@ -559,6 +563,26 @@ def new_component(self, llh_name: Optional[str] = None, pass_binning: bool = Tru
)
return component

@property
def lineage_hash(self):
return deterministic_hash(
{
**{
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"bins": self.bins,
"bins_type": self.bins_type,
"code": self.code,
},
**dict(
zip(
self.instances,
[_cached_functions[self.llh_name][p].lineage_hash for p in self.instances],
)
),
}
)


@export
class ComponentFixed(Component):
Expand Down Expand Up @@ -602,6 +626,18 @@ def simulate_weighted_data(self, parameters, *args, **kwargs):

return result

@property
def lineage_hash(self):
return deterministic_hash(
{
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"bins": self.bins,
"bins_type": self.bins_type,
"file_name": calculate_sha256(get_file_path(self._file_name)),
}
)


@export
def add_component_extensions(module1, module2, force=False):
Expand Down
54 changes: 51 additions & 3 deletions appletree/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import os
from typing import Optional, Union, Any

from immutabledict import immutabledict
from jax import numpy as jnp
from warnings import warn

import numpy as np
from strax import deterministic_hash

from appletree.share import _cached_configs
from appletree.utils import (
exporter,
load_json,
get_file_path,
integrate_midpoint,
cum_integrate_midpoint,
cumulative_integrate_midpoint,
calculate_sha256,
)
from appletree import interpolation
from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX
Expand Down Expand Up @@ -108,6 +111,10 @@ def build(self, llh_name: Optional[str] = None):
def required_parameter(self, llh_name=None):
return None

@property
def lineage_hash(self):
raise NotImplementedError


@export
class Constant(Config):
Expand Down Expand Up @@ -137,6 +144,15 @@ def build(self, llh_name: Optional[str] = None):
else:
self.value = value

@property
def lineage_hash(self):
return deterministic_hash(
{
"llh_name": self.llh_name,
"value": self.value,
}
)


@export
class Map(Config):
Expand Down Expand Up @@ -317,10 +333,21 @@ def log_pos(self, pos):
def pdf_to_cdf(self, x, pdf):
"""Convert pdf map to cdf map."""
norm = integrate_midpoint(x, pdf)
x, cdf = cum_integrate_midpoint(x, pdf)
x, cdf = cumulative_integrate_midpoint(x, pdf)
cdf /= norm
return x, cdf

@property
def lineage_hash(self):
return deterministic_hash(
{
"llh_name": self.llh_name,
"method": self.method,
"file_path": os.path.basename(self.file_path),
"sha256": calculate_sha256(get_file_path(self.file_path)),
}
)


@export
class SigmaMap(Config):
Expand Down Expand Up @@ -472,6 +499,18 @@ def apply(self, pos, parameters):
add = jnp.where(sigma > 0, add_pos, add_neg)
return median + add

@property
def lineage_hash(self):
return deterministic_hash(
{
"llh_name": self.llh_name,
"method": self.method,
"median": self.median.lineage_hash,
"lower": self.lower.lineage_hash,
"upper": self.upper.lineage_hash,
}
)


@export
class ConstantSet(Config):
Expand Down Expand Up @@ -507,7 +546,7 @@ def build(self, llh_name: Optional[str] = None):

self._sanity_check()
self.set_volume = len(self.value[1][0])
self.value = {k: jnp.array(v) for k, v in zip(*self.value)}
self.value = {k: np.array(v) for k, v in zip(*self.value)}

def _sanity_check(self):
"""Check if parameter set lengths are same."""
Expand All @@ -518,3 +557,12 @@ def _sanity_check(self):
volumes = [len(v) for v in self.value[1]]
mesg = "Parameter set lengths should be the same"
assert np.all(np.isclose(volumes, volumes[0])), mesg

@property
def lineage_hash(self):
return deterministic_hash(
{
"llh_name": self.llh_name,
"value": self.value,
}
)
21 changes: 16 additions & 5 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import Set, Optional

import numpy as np
import emcee
import h5py
import emcee
from strax import deterministic_hash

import appletree as apt
from appletree import randgen
Expand Down Expand Up @@ -390,7 +391,17 @@ def update_parameter_config(self, likelihoods):
self.par_config.pop(p)
return needed_parameters

def lineage(self, data_name: str = "cs2"):
"""Return lineage of plugins."""
assert isinstance(data_name, str)
pass
@property
def lineage_hash(self):
return deterministic_hash(
{
**self.instruct,
**self.par_config,
**dict(
zip(
self.likelihoods.keys(),
[v.lineage_hash for v in self.likelihoods.values()],
)
),
}
)
Loading

0 comments on commit 105859b

Please sign in to comment.