Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple nodep_data_names support #187

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
41 changes: 29 additions & 12 deletions appletree/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,22 @@ def dependencies_deduce(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"],
dependencies: Optional[List[Dict]] = None,
nodep_data_name: str = "batch_size",
nodep_data_names: Union[List[str], Tuple[str]] = "batch_size",
mhliu0001 marked this conversation as resolved.
Show resolved Hide resolved
) -> list:
"""Deduce dependencies.

Args:
data_names: data names that simulation will output.
dependencies: dependency tree.
nodep_data_name: data_name without dependency will not be deduced.
nodep_data_names: data_name without dependency will not be deduced.

"""
if dependencies is None:
dependencies = []

for data_name in data_names:
# usually `batch_size` have no dependency
if data_name == nodep_data_name:
if data_name in nodep_data_names:
continue
try:
dependencies.append(
Expand All @@ -298,12 +298,12 @@ def dependencies_deduce(

for data_name in data_names:
# `batch_size` has no dependency
if data_name == nodep_data_name:
if data_name in nodep_data_names:
continue
dependencies = self.dependencies_deduce(
data_names=self._plugin_class_registry[data_name].depends_on,
dependencies=dependencies,
nodep_data_name=nodep_data_name,
nodep_data_names=nodep_data_names,
)

return dependencies
Expand Down Expand Up @@ -335,7 +335,7 @@ def flush_source_code(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"],
func_name: str = "simulate",
nodep_data_name: str = "batch_size",
nodep_data_names: Union[List[str], Tuple[str]] = ["batch_size"],
):
"""Infer the simulation code from the dependency tree."""
self.func_name = func_name
Expand All @@ -345,6 +345,13 @@ def flush_source_code(
if isinstance(data_names, str):
data_names = [data_names]

if not isinstance(nodep_data_names, (list, str)):
raise RuntimeError(
f"nodep_data_names must be list or str, but given {type(nodep_data_names)}"
)
if isinstance(nodep_data_names, str):
nodep_data_names = [nodep_data_names]

instances = set()

code = ""
Expand All @@ -367,11 +374,19 @@ def flush_source_code(

# define functions
code += "\n"
if nodep_data_name == "batch_size":
code += "@partial(jit, static_argnums=(1, ))\n"
batch_size_index = next(
(
index
for index, data_name in enumerate(nodep_data_names)
if data_name == "batch_size"
),
-1,
)
if batch_size_index >= 0:
code += f"@partial(jit, static_argnums=({batch_size_index + 1}, ))\n"
else:
code += "@jit\n"
code += f"def {func_name}(key, {nodep_data_name}, parameters):\n"
code += f"def {func_name}(key, {', '.join(nodep_data_names)}, parameters):\n"

for work in self.worksheet:
provides = "key, " + ", ".join(work[1])
Expand Down Expand Up @@ -404,7 +419,7 @@ def deduce(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2"],
func_name: str = "simulate",
nodep_data_name: str = "batch_size",
nodep_data_names: Union[List[str], Tuple[str]] = ["batch_size"],
force_no_eff: bool = False,
):
"""Deduce workflow and code.
Expand All @@ -418,6 +433,8 @@ def deduce(
"""
if not isinstance(data_names, (list, tuple)):
raise ValueError(f"Unsupported data_names type {type(data_names)}!")
if not isinstance(nodep_data_names, (list, tuple)):
raise ValueError(f"Unsupported nodep_data_names type {type(nodep_data_names)}!")
# make sure that 'eff' is the last data_name
data_names = list(data_names)
if "eff" in data_names:
Expand All @@ -428,9 +445,9 @@ def deduce(
# track status of component
self.force_no_eff = True

dependencies = self.dependencies_deduce(data_names, nodep_data_name=nodep_data_name)
dependencies = self.dependencies_deduce(data_names, nodep_data_names=nodep_data_names)
self.dependencies_simplify(dependencies)
self.flush_source_code(data_names, func_name, nodep_data_name)
self.flush_source_code(data_names, func_name, nodep_data_names)

def compile(self):
"""Build simulation function and cache it to share._cached_functions."""
Expand Down
Loading