Skip to content

Commit

Permalink
Merge branch 'master' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx authored Jul 31, 2024
2 parents 8629ea5 + 099e9c6 commit 67c826d
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 60 deletions.
68 changes: 54 additions & 14 deletions appletree/component.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import os
from warnings import warn
from functools import partial
from typing import Tuple, List, Dict, Optional, Union, Set

import numpy as np
import pandas as pd
from jax import numpy as jnp
from strax import deterministic_hash

from appletree import utils
from appletree.config import OMITTED
from appletree.plugin import Plugin
from appletree.share import _cached_configs, _cached_functions, set_global_config
from appletree.utils import exporter, load_data
from appletree.utils import exporter, get_file_path, load_data, calculate_sha256
from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d

export, __all__ = exporter()
Expand Down Expand Up @@ -53,13 +55,6 @@ def __init__(self, name: Optional[str] = None, llh_name: Optional[str] = None, *
if "bins" in kwargs.keys() and "bins_type" in kwargs.keys():
self.set_binning(**kwargs)

if self.bins_type != "meshgrid" and self.add_eps_to_hist:
warn(
"It is empirically dangerous to have add_eps_to_hist==True, "
"when your bins_type is not meshgrid! It may lead to very bad fit with "
"lots of eff==0."
)

def set_binning(self, **kwargs):
"""Set binning of component."""
if "bins" not in kwargs.keys() or "bins_type" not in kwargs.keys():
Expand Down Expand Up @@ -158,7 +153,8 @@ def implement_binning(self, mc, eff):
raise ValueError(f"Unsupported bins_type {self.bins_type}!")
if self.add_eps_to_hist:
# as an uncertainty to prevent blowing up
hist = jnp.clip(hist, 1.0, jnp.inf)
# uncertainty = 1e-10 + jnp.mean(eff)
hist = jnp.clip(hist, 1e-10 + jnp.mean(eff), jnp.inf)
return hist

def get_normalization(self, hist, parameters, batch_size=None):
Expand Down Expand Up @@ -187,6 +183,14 @@ def compile(self):
"""Hook for compiling simulation code."""
pass

@property
def lineage(self):
raise NotImplementedError

@property
def lineage_hash(self):
return deterministic_hash(self.lineage)


@export
class ComponentSim(Component):
Expand Down Expand Up @@ -341,6 +345,8 @@ def flush_source_code(
if isinstance(data_names, str):
data_names = [data_names]

instances = set()

code = ""
indent = " " * 4

Expand All @@ -356,6 +362,7 @@ def flush_source_code(
for work in self.worksheet:
plugin = work[0]
instance = plugin + "_" + self.name
instances.add(instance)
code += f"{instance} = {plugin}('{self.llh_name}')\n"

# define functions
Expand All @@ -375,6 +382,7 @@ def flush_source_code(
code += f"{indent}return {output}\n"

self.code = code
self.instances = instances

if func_name in _cached_functions[self.llh_name].keys():
warning = f"Function name {func_name} is already cached. "
Expand Down Expand Up @@ -480,11 +488,6 @@ def save_code(self, file_path):
with open(file_path, "w") as f:
f.write(self.code)

def lineage(self, data_name: str = "cs2"):
"""Return lineage of plugins."""
assert isinstance(data_name, str)
pass

def set_config(self, configs):
"""Set new global configuration options.
Expand Down Expand Up @@ -565,6 +568,28 @@ def new_component(self, llh_name: Optional[str] = None, pass_binning: bool = Tru
)
return component

@property
def lineage(self):
return {
**{
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"bins": (
tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins
),
"bins_type": self.bins_type,
"code": self.code,
},
**{
"instances": dict(
zip(
self.instances,
[_cached_functions[self.llh_name][p].lineage for p in self.instances],
)
)
},
}


@export
class ComponentFixed(Component):
Expand Down Expand Up @@ -608,6 +633,21 @@ def simulate_weighted_data(self, parameters, *args, **kwargs):

return result

@property
def lineage(self):
return {
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"bins": tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins,
"bins_type": self.bins_type,
"file_path": (
os.path.basename(self._file_name)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self._file_name)
),
"sha256": calculate_sha256(get_file_path(self._file_name)),
}


@export
def add_component_extensions(module1, module2, force=False):
Expand Down
55 changes: 52 additions & 3 deletions appletree/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
from typing import Optional, Union, Any

from immutabledict import immutabledict
from jax import numpy as jnp
from warnings import warn

import numpy as np
from strax import deterministic_hash

from appletree import utils
from appletree.share import _cached_configs
from appletree.utils import (
exporter,
load_json,
get_file_path,
integrate_midpoint,
cum_integrate_midpoint,
cumulative_integrate_midpoint,
calculate_sha256,
)
from appletree import interpolation
from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX
Expand Down Expand Up @@ -108,6 +112,14 @@ def build(self, llh_name: Optional[str] = None):
def required_parameter(self, llh_name=None):
return None

@property
def lineage(self):
raise NotImplementedError

@property
def lineage_hash(self):
return deterministic_hash(self.lineage)


@export
class Constant(Config):
Expand Down Expand Up @@ -137,6 +149,13 @@ def build(self, llh_name: Optional[str] = None):
else:
self.value = value

@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}


@export
class Map(Config):
Expand Down Expand Up @@ -317,10 +336,23 @@ def log_pos(self, pos):
def pdf_to_cdf(self, x, pdf):
"""Convert pdf map to cdf map."""
norm = integrate_midpoint(x, pdf)
x, cdf = cum_integrate_midpoint(x, pdf)
x, cdf = cumulative_integrate_midpoint(x, pdf)
cdf /= norm
return x, cdf

@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"file_path": (
os.path.basename(self.file_path)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self.file_path)
),
"sha256": calculate_sha256(get_file_path(self.file_path)),
}


@export
class SigmaMap(Config):
Expand Down Expand Up @@ -472,6 +504,16 @@ def apply(self, pos, parameters):
add = jnp.where(sigma > 0, add_pos, add_neg)
return median + add

@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"median": self.median.lineage,
"lower": self.lower.lineage,
"upper": self.upper.lineage,
}


@export
class ConstantSet(Config):
Expand Down Expand Up @@ -507,7 +549,7 @@ def build(self, llh_name: Optional[str] = None):

self._sanity_check()
self.set_volume = len(self.value[1][0])
self.value = {k: jnp.array(v) for k, v in zip(*self.value)}
self.value = {k: np.array(v) for k, v in zip(*self.value)}

def _sanity_check(self):
"""Check if parameter set lengths are same."""
Expand All @@ -518,3 +560,10 @@ def _sanity_check(self):
volumes = [len(v) for v in self.value[1]]
mesg = "Parameter set lengths should be the same"
assert np.all(np.isclose(volumes, volumes[0])), mesg

@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}
43 changes: 31 additions & 12 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from typing import Set, Optional

import numpy as np
import emcee
import h5py
import emcee
from strax import deterministic_hash

import appletree as apt
from appletree import randgen
from appletree import Parameter
from appletree.utils import load_json, get_file_path
from appletree.utils import JSON_OPTIONS, load_json, get_file_path
from appletree.share import _cached_configs, set_global_config

os.environ["OMP_NUM_THREADS"] = "1"
Expand Down Expand Up @@ -302,19 +303,23 @@ def _dump_meta(self, batch_size, metadata=None):
if self.backend_h5 is not None:
name = self.sampler.backend.name
with h5py.File(self.backend_h5, "r+") as opt:
opt[name].attrs["metadata"] = json.dumps(metadata)
opt[name].attrs["metadata"] = json.dumps(metadata, **JSON_OPTIONS)
# parameters prior configuration
opt[name].attrs["par_config"] = json.dumps(self.par_manager.par_config)
opt[name].attrs["par_config"] = json.dumps(
self.par_manager.par_config, **JSON_OPTIONS
)
# max posterior parameters
opt[name].attrs["post_parameters"] = json.dumps(self.get_post_parameters())
opt[name].attrs["post_parameters"] = json.dumps(
self.get_post_parameters(), **JSON_OPTIONS
)
# the order of parameters saved in backend
opt[name].attrs["parameter_fit"] = self.par_manager.parameter_fit
# instructions
opt[name].attrs["instruct"] = json.dumps(self.instruct)
opt[name].attrs["instruct"] = json.dumps(self.instruct, **JSON_OPTIONS)
# configs
opt[name].attrs["config"] = json.dumps(self.config)
opt[name].attrs["config"] = json.dumps(self.config, **JSON_OPTIONS)
# configurations, maybe users will manually add some maps
opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs)
opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs, **JSON_OPTIONS)
# batch size
opt[name].attrs["batch_size"] = batch_size

Expand Down Expand Up @@ -390,7 +395,21 @@ def update_parameter_config(self, likelihoods):
self.par_config.pop(p)
return needed_parameters

def lineage(self, data_name: str = "cs2"):
"""Return lineage of plugins."""
assert isinstance(data_name, str)
pass
@property
def lineage(self):
return {
**self.instruct,
**{"par_config": self.par_config},
**{
"likelihoods": dict(
zip(
self.likelihoods.keys(),
[v.lineage for v in self.likelihoods.values()],
)
)
},
}

@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
Loading

0 comments on commit 67c826d

Please sign in to comment.