diff --git a/.readthedocs.yml b/.readthedocs.yml index b961408b..eceb71f7 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,10 +7,11 @@ sphinx: configuration: docs/source/conf.py build: - image: latest + os: ubuntu-22.04 + tools: + python: "3.9" python: - version: "3.8" install: - requirements: requirements.txt - requirements: docs/requirements-docs.txt diff --git a/appletree/component.py b/appletree/component.py index 4dc6a2ce..95c84f13 100644 --- a/appletree/component.py +++ b/appletree/component.py @@ -271,14 +271,23 @@ def dependencies_simplify(self, dependencies): """Simplify the dependencies.""" already_seen = [] self.worksheet = [] - self.needed_parameters.add(self.rate_name) - for plugin in dependencies[::-1]: - plugin = plugin["plugin"] + # Reinitialize needed_parameters + # because sometimes user will deduce(& compile) after changing configs + self.needed_parameters: Set[str] = set() + # Add rate_name to needed_parameters only when it's not empty + if self.rate_name != "": + self.needed_parameters.add(self.rate_name) + for _plugin in dependencies[::-1]: + plugin = _plugin["plugin"] if plugin.__name__ in already_seen: continue self.worksheet.append([plugin.__name__, plugin.provides, plugin.depends_on]) already_seen.append(plugin.__name__) self.needed_parameters |= set(plugin.parameters) + # Add needed_parameters from config + for config in plugin.takes_config.values(): + if config.required_parameter is not None: + self.needed_parameters |= {config.required_parameter} def flush_source_code( self, @@ -436,6 +445,10 @@ def set_config(self, configs): """ set_global_config(configs) + warn( + "New config is set, please run deduce() " + "and compile() again to update the simulation code." + ) def show_config(self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"]): """Return configuration options that affect data_names. diff --git a/appletree/config.py b/appletree/config.py index cdc3e764..0a3c697f 100644 --- a/appletree/config.py +++ b/appletree/config.py @@ -104,6 +104,7 @@ class Constant(Config): """Constant is a special config which takes only certain value.""" value = None + required_parameter = None def build(self, llh_name: Optional[str] = None): """Set value of Constant.""" @@ -139,6 +140,8 @@ class Map(Config): """ + required_parameter = None + def build(self, llh_name: Optional[str] = None): """Cache the map to jnp.array.""" @@ -287,6 +290,41 @@ class SigmaMap(Config): def build(self, llh_name: Optional[str] = None): """Read maps.""" self.llh_name = llh_name + _configs = self.get_configs() + + _configs_default = self.get_default() + + maps = dict() + sigmas = ["median", "lower", "upper"] + for i, sigma in enumerate(sigmas): + if isinstance(_configs_default, list): + maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default[i]) + else: + # If only one file is given, then use the same file for all sigmas + maps[sigma] = Map(name=self.name + f"_{sigma}", default=_configs_default) + + if maps[sigma].name not in _cached_configs.keys(): + _cached_configs[maps[sigma].name] = dict() + + # In case some plugins only use the median + # and may already update the map name in `_cached_configs` + if isinstance(_cached_configs[maps[sigma].name], dict): + if isinstance(_configs, list): + _cached_configs[maps[sigma].name].update({self.llh_name: _configs[i]}) + else: + # If only one file is given, then use the same file for all sigmas + _cached_configs[maps[sigma].name].update({self.llh_name: _configs}) + + setattr(self, sigma, maps[sigma]) + + self.median.build(llh_name=self.llh_name) # type: ignore + self.lower.build(llh_name=self.llh_name) # type: ignore + self.upper.build(llh_name=self.llh_name) # type: ignore + + if isinstance(_configs, list) and len(_configs) > 4: + raise ValueError(f"You give too much information in {self.name} configs.") + + def get_configs(self): if self.name in _cached_configs: _configs = _cached_configs[self.name] else: @@ -296,51 +334,38 @@ def build(self, llh_name: Optional[str] = None): if isinstance(_configs, dict): try: - self._configs = _configs[llh_name] + return _configs[self.llh_name] except KeyError: mesg = ( f"You specified {self.name} as a dictionary. " f"The key of it should be the name of one " - f"of the likelihood, but it is {llh_name}." + f"of the likelihood, but it is {self.llh_name}." ) raise ValueError(mesg) else: - self._configs = _configs - - self._configs_default = self.get_default() - - maps = dict() - sigmas = ["median", "lower", "upper"] - for i, sigma in enumerate(sigmas): - maps[sigma] = Map(name=self.name + f"_{sigma}", default=self._configs_default[i]) - if maps[sigma].name not in _cached_configs.keys(): - _cached_configs[maps[sigma].name] = dict() - if isinstance(_cached_configs[maps[sigma].name], dict): - # In case some plugins only use the median - # and may already update the map name in `_cached_configs` - _cached_configs[maps[sigma].name].update({self.llh_name: self._configs[i]}) - setattr(self, sigma, maps[sigma]) - - self.median.build(llh_name=self.llh_name) # type: ignore - self.lower.build(llh_name=self.llh_name) # type: ignore - self.upper.build(llh_name=self.llh_name) # type: ignore - - if len(self._configs) > 4: - raise ValueError(f"You give too much information in {self.name} configs.") + return _configs + @property + def required_parameter(self): + _configs = self.get_configs() # Find required parameter - if len(self._configs) == 4: - self.required_parameter = self._configs[-1] - print( - f"{self.llh_name} is using the parameter " - f"{self.required_parameter} in {self.name} map." - ) + if isinstance(_configs, list): + if len(_configs) == 4: + print( + f"{self.llh_name} is using the parameter " f"{_configs[-1]} in {self.name} map." + ) + return _configs[-1] + else: + return self.name + "_sigma" else: - self.required_parameter = self.name + "_sigma" + return None def apply(self, pos, parameters): """Apply SigmaMap with sigma and position.""" - sigma = parameters[self.required_parameter] + if self.required_parameter is None: + sigma = 1.0 + else: + sigma = parameters[self.required_parameter] median = self.median.apply(pos) lower = self.lower.apply(pos) upper = self.upper.apply(pos) @@ -360,6 +385,8 @@ class ConstantSet(Config): """ + required_parameter = None + def build(self, llh_name: Optional[str] = None): """Set value of Constant.""" if self.name in _cached_configs: diff --git a/appletree/context.py b/appletree/context.py index 58756803..e75cd584 100644 --- a/appletree/context.py +++ b/appletree/context.py @@ -35,9 +35,9 @@ def __init__(self, instruct, par_config=None): if "url_base" in instruct.keys(): self.update_url_base(instruct["url_base"]) - self.set_instruct(instruct) - if "configs" in instruct.keys(): - self.set_config(instruct["configs"]) + self.instruct = instruct + self.config = instruct.get("configs", {}) + set_global_config(self.config) self.backend_h5 = instruct.get("backend_h5", None) @@ -358,33 +358,6 @@ def update_parameter_config(self, likelihoods): self.par_config.pop(p) return needed_parameters - def set_instruct(self, instructs): - """Set instruction. - - :param instructs: dict, instruction file name or dictionary - - """ - if not hasattr(self, "instruct"): - self.instruct = dict() - - # update instructuration only in this Context - self.instruct.update(instructs) - - def set_config(self, configs): - """Set new configuration options. - - :param configs: dict, configuration file name or dictionary - - """ - if not hasattr(self, "config"): - self.config = dict() - - # update configuration only in this Context - self.config.update(configs) - - # also store required configurations to appletree.share - set_global_config(configs) - def lineage(self, data_name: str = "cs2"): """Return lineage of plugins.""" assert isinstance(data_name, str) diff --git a/appletree/plugins/efficiency.py b/appletree/plugins/efficiency.py index 55a8a248..36d91b12 100644 --- a/appletree/plugins/efficiency.py +++ b/appletree/plugins/efficiency.py @@ -31,7 +31,7 @@ def simulate(self, key, parameters, s2_area): class S1ReconEff(Plugin): depends_on = ["num_s1_phd"] provides = ["acc_s1_recon_eff"] - parameters = ("s1_eff_3f_sigma",) + # parameters = ("s1_eff_3f_sigma",) @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_s1_phd): @@ -51,7 +51,7 @@ def simulate(self, key, parameters, num_s1_phd): class S1CutAccept(Plugin): depends_on = ["s1_area"] provides = ["cut_acc_s1"] - parameters = ("s1_cut_acc_sigma",) + # parameters = ("s1_cut_acc_sigma",) @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s1_area): @@ -71,7 +71,7 @@ def simulate(self, key, parameters, s1_area): class S2CutAccept(Plugin): depends_on = ["s2_area"] provides = ["cut_acc_s2"] - parameters = ("s2_cut_acc_sigma",) + # parameters = ("s2_cut_acc_sigma",) @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s2_area):