From fbfc5b612a0269a45ed929e5b4be9cbe0391c3a4 Mon Sep 17 00:00:00 2001 From: dachengx Date: Tue, 14 Nov 2023 11:17:19 -0600 Subject: [PATCH] Set required_parameter as method of Config --- appletree/component.py | 5 ++-- appletree/config.py | 64 ++++++++++++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/appletree/component.py b/appletree/component.py index 95c84f13..f4af7ee4 100644 --- a/appletree/component.py +++ b/appletree/component.py @@ -286,8 +286,9 @@ def dependencies_simplify(self, dependencies): self.needed_parameters |= set(plugin.parameters) # Add needed_parameters from config for config in plugin.takes_config.values(): - if config.required_parameter is not None: - self.needed_parameters |= {config.required_parameter} + required_parameter = config.required_parameter(self.llh_name) + if required_parameter is not None: + self.needed_parameters |= {required_parameter} def flush_source_code( self, diff --git a/appletree/config.py b/appletree/config.py index 3ef1c1d9..7d2d22be 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -63,6 +63,8 @@ def wrapped(plugin_class): class Config: """Configuration option taken by a appletree plugin.""" + llh_name: Optional[str] = None + def __init__( self, name: str, @@ -98,13 +100,15 @@ def build(self, llh_name: Optional[str] = None): """Build configuration, set attributes to Config instance.""" raise NotImplementedError + def required_parameter(self, llh_name=None): + return None + @export class Constant(Config): """Constant is a special config which takes only certain value.""" value = None - required_parameter = None def build(self, llh_name: Optional[str] = None): """Set value of Constant.""" @@ -140,8 +144,6 @@ class Map(Config): """ - required_parameter = None - def build(self, llh_name: Optional[str] = None): """Cache the map to jnp.array.""" @@ -312,7 +314,7 @@ def build(self, llh_name: Optional[str] = None): for i, sigma in enumerate(sigmas): # propagate _configs_default to Map instances if isinstance(_configs_default, list): - maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default[i]) + default = _configs_default[i] else: if not isinstance(_configs_default, str): raise ValueError( @@ -320,16 +322,20 @@ def build(self, llh_name: Optional[str] = None): "then it should be a string." ) # If only one file is given, then use the same file for all sigmas - maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default) + default = _configs_default + maps[sigma] = Map(name=self.name + f"_{sigma}", default=default) - if maps[sigma].name not in _cached_configs.keys(): - _cached_configs[maps[sigma].name] = dict() + if self.llh_name is None: + # if llh_name is not specified, no need to update _cached_configs + continue # In case some plugins only use the median # and may already update the map name in `_cached_configs` + if maps[sigma].name not in _cached_configs.keys(): + _cached_configs[maps[sigma].name] = dict() if isinstance(_cached_configs[maps[sigma].name], dict): if isinstance(_configs, list): - _cached_configs[maps[sigma].name].update({self.llh_name: _configs[i]}) + value = _configs[i] else: if not isinstance(_configs, str): raise ValueError( @@ -337,7 +343,14 @@ def build(self, llh_name: Optional[str] = None): "then it should be a string." ) # If only one file is given, then use the same file for all sigmas - _cached_configs[maps[sigma].name].update({self.llh_name: _configs}) + value = _configs + _value = _cached_configs[maps[sigma].name].get(self.llh_name, value) + if _value != value: + raise ValueError( + f"You give different values for {self.name} in " + f"configs, find {_value} and {value}." + ) + _cached_configs[maps[sigma].name].update({self.llh_name: value}) setattr(self, sigma, maps[sigma]) @@ -345,15 +358,21 @@ def build(self, llh_name: Optional[str] = None): self.lower.build(llh_name=self.llh_name) # type: ignore self.upper.build(llh_name=self.llh_name) # type: ignore - if self.required_parameter is not None: + required_parameter = self.required_parameter() + if required_parameter is not None: print( f"{self.llh_name}'s map {self.name} is using " - f"the parameter {self.required_parameter}." + f"the parameter {required_parameter}." ) else: print(f"{self.llh_name}'s map {self.name} is static and not using any parameter.") - def get_configs(self): + def get_configs(self, llh_name=None): + """Get configs of SigmaMap.""" + # if llh_name is not specified, use the attribute of SigmaMap + if llh_name is None and self.llh_name is not None: + llh_name = self.llh_name + if self.name in _cached_configs: _configs = _cached_configs[self.name] else: @@ -362,21 +381,26 @@ def get_configs(self): _cached_configs[self.name] = _configs if isinstance(_configs, dict): + if llh_name is None: + raise ValueError( + f"You specified {self.name} as a dictionary in _cached_configs. " + "The key of it should be the name of one of the likelihood, but it is None." + ) try: - return _configs[self.llh_name] + return _configs[llh_name] except KeyError: mesg = ( f"You specified {self.name} as a dictionary. " f"The key of it should be the name of one " - f"of the likelihood, but it is {self.llh_name}." + f"of the likelihood, but it is {llh_name}." ) raise ValueError(mesg) else: return _configs - @property - def required_parameter(self): - _configs = self.get_configs() + def required_parameter(self, llh_name=None): + """Get required parameter of SigmaMap.""" + _configs = self.get_configs(llh_name=llh_name) # Find required parameter if isinstance(_configs, list): if len(_configs) == 4: @@ -388,10 +412,10 @@ def required_parameter(self): def apply(self, pos, parameters): """Apply SigmaMap with sigma and position.""" - if self.required_parameter is None: + if self.required_parameter() is None: sigma = 1.0 else: - sigma = parameters[self.required_parameter] + sigma = parameters[self.required_parameter()] median = self.median.apply(pos) lower = self.lower.apply(pos) upper = self.upper.apply(pos) @@ -411,8 +435,6 @@ class ConstantSet(Config): """ - required_parameter = None - def build(self, llh_name: Optional[str] = None): """Set value of Constant.""" if self.name in _cached_configs: