Skip to content

Commit

Permalink
Merge branch 'master' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx authored Nov 11, 2023
2 parents c599fbf + 3ac9739 commit efbf91e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 70 deletions.
5 changes: 3 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ sphinx:
configuration: docs/source/conf.py

build:
image: latest
os: ubuntu-22.04
tools:
python: "3.9"

python:
version: "3.8"
install:
- requirements: requirements.txt
- requirements: docs/requirements-docs.txt
Expand Down
19 changes: 16 additions & 3 deletions appletree/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,23 @@ def dependencies_simplify(self, dependencies):
"""Simplify the dependencies."""
already_seen = []
self.worksheet = []
self.needed_parameters.add(self.rate_name)
for plugin in dependencies[::-1]:
plugin = plugin["plugin"]
# Reinitialize needed_parameters
# because sometimes user will deduce(& compile) after changing configs
self.needed_parameters: Set[str] = set()
# Add rate_name to needed_parameters only when it's not empty
if self.rate_name != "":
self.needed_parameters.add(self.rate_name)
for _plugin in dependencies[::-1]:
plugin = _plugin["plugin"]
if plugin.__name__ in already_seen:
continue
self.worksheet.append([plugin.__name__, plugin.provides, plugin.depends_on])
already_seen.append(plugin.__name__)
self.needed_parameters |= set(plugin.parameters)
# Add needed_parameters from config
for config in plugin.takes_config.values():
if config.required_parameter is not None:
self.needed_parameters |= {config.required_parameter}

def flush_source_code(
self,
Expand Down Expand Up @@ -436,6 +445,10 @@ def set_config(self, configs):
"""
set_global_config(configs)
warn(
"New config is set, please run deduce() "
"and compile() again to update the simulation code."
)

def show_config(self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"]):
"""Return configuration options that affect data_names.
Expand Down
91 changes: 59 additions & 32 deletions appletree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class Constant(Config):
"""Constant is a special config which takes only certain value."""

value = None
required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
Expand Down Expand Up @@ -139,6 +140,8 @@ class Map(Config):
"""

required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Cache the map to jnp.array."""

Expand Down Expand Up @@ -287,6 +290,41 @@ class SigmaMap(Config):
def build(self, llh_name: Optional[str] = None):
"""Read maps."""
self.llh_name = llh_name
_configs = self.get_configs()

_configs_default = self.get_default()

maps = dict()
sigmas = ["median", "lower", "upper"]
for i, sigma in enumerate(sigmas):
if isinstance(_configs_default, list):
maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default[i])
else:
# If only one file is given, then use the same file for all sigmas
maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default)

if maps[sigma].name not in _cached_configs.keys():
_cached_configs[maps[sigma].name] = dict()

# In case some plugins only use the median
# and may already update the map name in `_cached_configs`
if isinstance(_cached_configs[maps[sigma].name], dict):
if isinstance(_configs, list):
_cached_configs[maps[sigma].name].update({self.llh_name: _configs[i]})
else:
# If only one file is given, then use the same file for all sigmas
_cached_configs[maps[sigma].name].update({self.llh_name: _configs})

setattr(self, sigma, maps[sigma])

self.median.build(llh_name=self.llh_name) # type: ignore
self.lower.build(llh_name=self.llh_name) # type: ignore
self.upper.build(llh_name=self.llh_name) # type: ignore

if isinstance(_configs, list) and len(_configs) > 4:
raise ValueError(f"You give too much information in {self.name} configs.")

def get_configs(self):
if self.name in _cached_configs:
_configs = _cached_configs[self.name]
else:
Expand All @@ -296,51 +334,38 @@ def build(self, llh_name: Optional[str] = None):

if isinstance(_configs, dict):
try:
self._configs = _configs[llh_name]
return _configs[self.llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
f"of the likelihood, but it is {self.llh_name}."
)
raise ValueError(mesg)
else:
self._configs = _configs

self._configs_default = self.get_default()

maps = dict()
sigmas = ["median", "lower", "upper"]
for i, sigma in enumerate(sigmas):
maps[sigma] = Map(name=self.name + f"_{sigma}", default=self._configs_default[i])
if maps[sigma].name not in _cached_configs.keys():
_cached_configs[maps[sigma].name] = dict()
if isinstance(_cached_configs[maps[sigma].name], dict):
# In case some plugins only use the median
# and may already update the map name in `_cached_configs`
_cached_configs[maps[sigma].name].update({self.llh_name: self._configs[i]})
setattr(self, sigma, maps[sigma])

self.median.build(llh_name=self.llh_name) # type: ignore
self.lower.build(llh_name=self.llh_name) # type: ignore
self.upper.build(llh_name=self.llh_name) # type: ignore

if len(self._configs) > 4:
raise ValueError(f"You give too much information in {self.name} configs.")
return _configs

@property
def required_parameter(self):
_configs = self.get_configs()
# Find required parameter
if len(self._configs) == 4:
self.required_parameter = self._configs[-1]
print(
f"{self.llh_name} is using the parameter "
f"{self.required_parameter} in {self.name} map."
)
if isinstance(_configs, list):
if len(_configs) == 4:
print(
f"{self.llh_name} is using the parameter " f"{_configs[-1]} in {self.name} map."
)
return _configs[-1]
else:
return self.name + "_sigma"
else:
self.required_parameter = self.name + "_sigma"
return None

def apply(self, pos, parameters):
"""Apply SigmaMap with sigma and position."""
sigma = parameters[self.required_parameter]
if self.required_parameter is None:
sigma = 1.0
else:
sigma = parameters[self.required_parameter]
median = self.median.apply(pos)
lower = self.lower.apply(pos)
upper = self.upper.apply(pos)
Expand All @@ -360,6 +385,8 @@ class ConstantSet(Config):
"""

required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
if self.name in _cached_configs:
Expand Down
33 changes: 3 additions & 30 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __init__(self, instruct, par_config=None):
if "url_base" in instruct.keys():
self.update_url_base(instruct["url_base"])

self.set_instruct(instruct)
if "configs" in instruct.keys():
self.set_config(instruct["configs"])
self.instruct = instruct
self.config = instruct.get("configs", {})
set_global_config(self.config)

self.backend_h5 = instruct.get("backend_h5", None)

Expand Down Expand Up @@ -358,33 +358,6 @@ def update_parameter_config(self, likelihoods):
self.par_config.pop(p)
return needed_parameters

def set_instruct(self, instructs):
"""Set instruction.
:param instructs: dict, instruction file name or dictionary
"""
if not hasattr(self, "instruct"):
self.instruct = dict()

# update instructuration only in this Context
self.instruct.update(instructs)

def set_config(self, configs):
"""Set new configuration options.
:param configs: dict, configuration file name or dictionary
"""
if not hasattr(self, "config"):
self.config = dict()

# update configuration only in this Context
self.config.update(configs)

# also store required configurations to appletree.share
set_global_config(configs)

def lineage(self, data_name: str = "cs2"):
"""Return lineage of plugins."""
assert isinstance(data_name, str)
Expand Down
6 changes: 3 additions & 3 deletions appletree/plugins/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def simulate(self, key, parameters, s2_area):
class S1ReconEff(Plugin):
depends_on = ["num_s1_phd"]
provides = ["acc_s1_recon_eff"]
parameters = ("s1_eff_3f_sigma",)
# parameters = ("s1_eff_3f_sigma",)

@partial(jit, static_argnums=(0,))
def simulate(self, key, parameters, num_s1_phd):
Expand All @@ -51,7 +51,7 @@ def simulate(self, key, parameters, num_s1_phd):
class S1CutAccept(Plugin):
depends_on = ["s1_area"]
provides = ["cut_acc_s1"]
parameters = ("s1_cut_acc_sigma",)
# parameters = ("s1_cut_acc_sigma",)

@partial(jit, static_argnums=(0,))
def simulate(self, key, parameters, s1_area):
Expand All @@ -71,7 +71,7 @@ def simulate(self, key, parameters, s1_area):
class S2CutAccept(Plugin):
depends_on = ["s2_area"]
provides = ["cut_acc_s2"]
parameters = ("s2_cut_acc_sigma",)
# parameters = ("s2_cut_acc_sigma",)

@partial(jit, static_argnums=(0,))
def simulate(self, key, parameters, s2_area):
Expand Down

0 comments on commit efbf91e

Please sign in to comment.