Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set required_parameter as method of Config #119

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions appletree/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ def dependencies_simplify(self, dependencies):
self.needed_parameters |= set(plugin.parameters)
# Add needed_parameters from config
for config in plugin.takes_config.values():
if config.required_parameter is not None:
self.needed_parameters |= {config.required_parameter}
required_parameter = config.required_parameter(self.llh_name)
if required_parameter is not None:
self.needed_parameters |= {required_parameter}

def flush_source_code(
self,
Expand Down
64 changes: 43 additions & 21 deletions appletree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def wrapped(plugin_class):
class Config:
"""Configuration option taken by a appletree plugin."""

llh_name: Optional[str] = None

def __init__(
self,
name: str,
Expand Down Expand Up @@ -98,13 +100,15 @@ def build(self, llh_name: Optional[str] = None):
"""Build configuration, set attributes to Config instance."""
raise NotImplementedError

def required_parameter(self, llh_name=None):
return None


@export
class Constant(Config):
"""Constant is a special config which takes only certain value."""

value = None
required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
Expand Down Expand Up @@ -140,8 +144,6 @@ class Map(Config):

"""

required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Cache the map to jnp.array."""

Expand Down Expand Up @@ -312,48 +314,65 @@ def build(self, llh_name: Optional[str] = None):
for i, sigma in enumerate(sigmas):
# propagate _configs_default to Map instances
if isinstance(_configs_default, list):
maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default[i])
default = _configs_default[i]
else:
if not isinstance(_configs_default, str):
raise ValueError(
f"If {self.name}'s default configuration is not a list, "
"then it should be a string."
)
# If only one file is given, then use the same file for all sigmas
maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default)
default = _configs_default
maps[sigma] = Map(name=self.name + f"_{sigma}", default=default)

if maps[sigma].name not in _cached_configs.keys():
_cached_configs[maps[sigma].name] = dict()
if self.llh_name is None:
# if llh_name is not specified, no need to update _cached_configs
continue

# In case some plugins only use the median
# and may already update the map name in `_cached_configs`
if maps[sigma].name not in _cached_configs.keys():
_cached_configs[maps[sigma].name] = dict()
if isinstance(_cached_configs[maps[sigma].name], dict):
if isinstance(_configs, list):
_cached_configs[maps[sigma].name].update({self.llh_name: _configs[i]})
value = _configs[i]
else:
if not isinstance(_configs, str):
raise ValueError(
f"If {self.name}'s configuration is not a list, "
"then it should be a string."
)
# If only one file is given, then use the same file for all sigmas
_cached_configs[maps[sigma].name].update({self.llh_name: _configs})
value = _configs
_value = _cached_configs[maps[sigma].name].get(self.llh_name, value)
if _value != value:
raise ValueError(
f"You give different values for {self.name} in "
f"configs, find {_value} and {value}."
)
_cached_configs[maps[sigma].name].update({self.llh_name: value})

setattr(self, sigma, maps[sigma])

self.median.build(llh_name=self.llh_name) # type: ignore
self.lower.build(llh_name=self.llh_name) # type: ignore
self.upper.build(llh_name=self.llh_name) # type: ignore

if self.required_parameter is not None:
required_parameter = self.required_parameter()
if required_parameter is not None:
print(
f"{self.llh_name}'s map {self.name} is using "
f"the parameter {self.required_parameter}."
f"the parameter {required_parameter}."
)
else:
print(f"{self.llh_name}'s map {self.name} is static and not using any parameter.")

def get_configs(self):
def get_configs(self, llh_name=None):
"""Get configs of SigmaMap."""
# if llh_name is not specified, use the attribute of SigmaMap
if llh_name is None and self.llh_name is not None:
llh_name = self.llh_name

if self.name in _cached_configs:
_configs = _cached_configs[self.name]
else:
Expand All @@ -362,21 +381,26 @@ def get_configs(self):
_cached_configs[self.name] = _configs

if isinstance(_configs, dict):
if llh_name is None:
raise ValueError(
f"You specified {self.name} as a dictionary in _cached_configs. "
"The key of it should be the name of one of the likelihood, but it is None."
)
try:
return _configs[self.llh_name]
return _configs[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {self.llh_name}."
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
return _configs

@property
def required_parameter(self):
_configs = self.get_configs()
def required_parameter(self, llh_name=None):
"""Get required parameter of SigmaMap."""
_configs = self.get_configs(llh_name=llh_name)
# Find required parameter
if isinstance(_configs, list):
if len(_configs) == 4:
Expand All @@ -388,10 +412,10 @@ def required_parameter(self):

def apply(self, pos, parameters):
"""Apply SigmaMap with sigma and position."""
if self.required_parameter is None:
if self.required_parameter() is None:
sigma = 1.0
else:
sigma = parameters[self.required_parameter]
sigma = parameters[self.required_parameter()]
median = self.median.apply(pos)
lower = self.lower.apply(pos)
upper = self.upper.apply(pos)
Expand All @@ -411,8 +435,6 @@ class ConstantSet(Config):

"""

required_parameter = None

def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
if self.name in _cached_configs:
Expand Down
Loading