diff --git a/appletree/component.py b/appletree/component.py index de7c5d6..7568d14 100644 --- a/appletree/component.py +++ b/appletree/component.py @@ -268,14 +268,14 @@ def dependencies_deduce( self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"], dependencies: Optional[List[Dict]] = None, - nodep_data_name: str = "batch_size", + nodep_data_names: Union[List[str], Tuple[str]] = ["batch_size"], ) -> list: """Deduce dependencies. Args: data_names: data names that simulation will output. dependencies: dependency tree. - nodep_data_name: data_name without dependency will not be deduced. + nodep_data_names: data_name without dependency will not be deduced. """ if dependencies is None: @@ -283,7 +283,7 @@ def dependencies_deduce( for data_name in data_names: # usually `batch_size` have no dependency - if data_name == nodep_data_name: + if data_name in nodep_data_names: continue try: dependencies.append( @@ -298,12 +298,12 @@ def dependencies_deduce( for data_name in data_names: # `batch_size` has no dependency - if data_name == nodep_data_name: + if data_name in nodep_data_names: continue dependencies = self.dependencies_deduce( data_names=self._plugin_class_registry[data_name].depends_on, dependencies=dependencies, - nodep_data_name=nodep_data_name, + nodep_data_names=nodep_data_names, ) return dependencies @@ -335,7 +335,7 @@ def flush_source_code( self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"], func_name: str = "simulate", - nodep_data_name: str = "batch_size", + nodep_data_names: Union[List[str], Tuple[str]] = ["batch_size"], ): """Infer the simulation code from the dependency tree.""" self.func_name = func_name @@ -345,6 +345,13 @@ def flush_source_code( if isinstance(data_names, str): data_names = [data_names] + if not isinstance(nodep_data_names, (list, str)): + raise RuntimeError( + f"nodep_data_names must be list or str, but given {type(nodep_data_names)}" + ) + if isinstance(nodep_data_names, str): + nodep_data_names = [nodep_data_names] + instances = set() code = "" @@ -367,11 +374,19 @@ def flush_source_code( # define functions code += "\n" - if nodep_data_name == "batch_size": - code += "@partial(jit, static_argnums=(1, ))\n" + batch_size_index = next( + ( + index + for index, data_name in enumerate(nodep_data_names) + if data_name == "batch_size" + ), + -1, + ) + if batch_size_index >= 0: + code += f"@partial(jit, static_argnums=({batch_size_index + 1}, ))\n" else: code += "@jit\n" - code += f"def {func_name}(key, {nodep_data_name}, parameters):\n" + code += f"def {func_name}(key, {', '.join(nodep_data_names)}, parameters):\n" for work in self.worksheet: provides = "key, " + ", ".join(work[1]) @@ -404,7 +419,7 @@ def deduce( self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2"], func_name: str = "simulate", - nodep_data_name: str = "batch_size", + nodep_data_names: Union[List[str], Tuple[str]] = ["batch_size"], force_no_eff: bool = False, ): """Deduce workflow and code. @@ -412,12 +427,14 @@ def deduce( Args: data_names: data names that simulation will output. func_name: name of the simulation function, used to cache it. - nodep_data_name: data_name without dependency will not be deduced. + nodep_data_names: data_name without dependency will not be deduced. force_no_eff: force to ignore the efficiency, used in yield prediction. """ if not isinstance(data_names, (list, tuple)): raise ValueError(f"Unsupported data_names type {type(data_names)}!") + if not isinstance(nodep_data_names, (list, tuple)): + raise ValueError(f"Unsupported nodep_data_names type {type(nodep_data_names)}!") # make sure that 'eff' is the last data_name data_names = list(data_names) if "eff" in data_names: @@ -428,9 +445,9 @@ def deduce( # track status of component self.force_no_eff = True - dependencies = self.dependencies_deduce(data_names, nodep_data_name=nodep_data_name) + dependencies = self.dependencies_deduce(data_names, nodep_data_names=nodep_data_names) self.dependencies_simplify(dependencies) - self.flush_source_code(data_names, func_name, nodep_data_name) + self.flush_source_code(data_names, func_name, nodep_data_names) def compile(self): """Build simulation function and cache it to share._cached_functions.""" @@ -510,7 +527,7 @@ def show_config(self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", """ dependencies = self.dependencies_deduce( data_names, - nodep_data_name="batch_size", + nodep_data_names=["batch_size"], ) r = [] seen = [] diff --git a/tests/test_component.py b/tests/test_component.py index 70ed06b..bb95a62 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -99,3 +99,35 @@ def benchmark(key): er.simulate_hist(key, batch_size, parameters) with pytest.raises(RuntimeError): key, r = er.multiple_simulations(key, batch_size, parameters, 5, apply_eff=True) + + # test for multiple nodep_data_names + with pytest.raises(RuntimeError): + er.deduce( + data_names=("cs1", "cs2"), + func_name="er_sim", + nodep_data_names=["x", "y", "z", "num_photon", "num_electron"], + force_no_eff=True, + ) + _cached_functions.clear() + er.deduce( + data_names=("cs1", "cs2"), + func_name="er_sim", + nodep_data_names=["x", "y", "z", "num_photon", "num_electron"], + force_no_eff=True, + ) + er.compile() + er.lineage_hash + x, y, z = np.zeros((3, batch_size)) + num_photon = np.random.randint(1, 100) + num_electron = np.random.randint(1, 100) + er.simulate(key, x, y, z, num_photon, num_electron, parameters) + + _cached_functions.clear() + er.deduce( + data_names=("cs1", "cs2"), + func_name="er_sim", + nodep_data_names=["batch_size", "num_photon", "num_electron"], + force_no_eff=True, + ) + er.compile() + er.simulate(key, batch_size, num_photon, num_electron, parameters)