diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index 3e7fabb..2fd6961 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -29,6 +29,18 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the - According to subsection [_3.19.12 Imports For Typing_](https://google.github.io/styleguide/pyguide.html#31912-imports-for-typing), symbols from `typing` and `collections.abc` modules used in type annotations _"can be imported directly to keep common annotations concise and match standard typing practices"_. Following the same spirit, we allow symbols to be imported directly from third-party or internal modules when they only contain a collection of frequently used typying definitions. +### Aliasing of Modules + +According to subsection [2.2](https://google.github.io/styleguide/pyguide.html#22-imports) in certain cases it is allowed to introduce an alias for an import. +Inside JaCe the following convention is applied: + +- If the module has a standard abbreviation use that, e.g. `import numpy as np`. +- For a JaCe module use: + - If the module name is only a single word use it directly, e.g. `from jace import translator`. + - If the module name consists of multiple words use the last word prefixed with the first letters of the others, e.g. `from jace.translator import post_translator as ptranslator` or `from jace import translated_jaxpr_sdfg as tjsdfg`. + - In case of a clash use your best judgment. +- For an external module use the rule above, but prefix the name with the main package's name, e.g. `from dace.codegen import compiled_sdfg as dace_csdfg`. + ### Python usage recommendations - `pass` vs `...` (`Ellipsis`) @@ -104,7 +116,7 @@ We generate the API documentation automatically from the docstrings using [Sphin Sphinx supports the [reStructuredText][sphinx-rest] (reST) markup language for defining additional formatting options in the generated documentation, however section [_3.8 Comments and Docstrings_](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) of the Google Python Style Guide does not specify how to use markups in docstrings. As a result, we decided to forbid reST markup in docstrings, except for the following cases: - Cross-referencing other objects using Sphinx text roles for the [Python domain](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#the-python-domain) (as explained [here](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#python-roles)). -- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. ``` ``literal`` ``` strings, bulleted lists). +- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. `` `literal` `` strings, bulleted lists). We highly encourage the [doctest] format for code examples in docstrings. In fact, doctest runs code examples and makes sure they are in sync with the codebase. diff --git a/src/jace/__init__.py b/src/jace/__init__.py index 11c5d2a..ca7e0f5 100644 --- a/src/jace/__init__.py +++ b/src/jace/__init__.py @@ -9,7 +9,7 @@ from __future__ import annotations -import jace.translator.primitive_translators as _ # noqa: F401 # Populate the internal registry. +import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry. from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ from .api import grad, jacfwd, jacrev, jit diff --git a/src/jace/api.py b/src/jace/api.py index 8afc20a..35d722a 100644 --- a/src/jace/api.py +++ b/src/jace/api.py @@ -10,18 +10,30 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Any, Literal, overload +from collections.abc import Callable, Mapping +from typing import Literal, ParamSpec, TypedDict, TypeVar, overload from jax import grad, jacfwd, jacrev +from typing_extensions import Unpack from jace import stages, translator -if TYPE_CHECKING: - from collections.abc import Callable, Mapping +__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"] +_P = ParamSpec("_P") +_R = TypeVar("_R") -__all__ = ["grad", "jacfwd", "jacrev", "jit"] + +class JITOptions(TypedDict, total=False): + """ + All known options to `jace.jit` that influence tracing. + + Note: + Currently there are no known options, but essentially it is a subset of some + of the options that are supported by `jax.jit` together with some additional + JaCe specific ones. + """ @overload @@ -29,31 +41,32 @@ def jit( fun: Literal[None] = None, /, primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, - **kwargs: Any, -) -> Callable[[Callable], stages.JaCeWrapped]: ... + **kwargs: Unpack[JITOptions], +) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ... @overload def jit( - fun: Callable, + fun: Callable[_P, _R], /, primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, - **kwargs: Any, -) -> stages.JaCeWrapped: ... + **kwargs: Unpack[JITOptions], +) -> stages.JaCeWrapped[_P, _R]: ... def jit( - fun: Callable | None = None, + fun: Callable[_P, _R] | None = None, /, primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, - **kwargs: Any, -) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]: + **kwargs: Unpack[JITOptions], +) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]: """ JaCe's replacement for `jax.jit` (just-in-time) wrapper. - It works the same way as `jax.jit` does, but instead of using XLA the - computation is lowered to DaCe. In addition it accepts some JaCe specific - arguments. + It works the same way as `jax.jit` does, but instead of lowering the + computation to XLA, it is lowered to DaCe. + The function supports a subset of the arguments that are accepted by `jax.jit()`, + currently none, and some JaCe specific ones. Args: fun: Function to wrap. @@ -61,8 +74,8 @@ def jit( If not specified the translators in the global registry are used. kwargs: Jit arguments. - Notes: - After constructions any change to `primitive_translators` has no effect. + Note: + This function is the only valid way to obtain a JaCe computation. """ if kwargs: # TODO(phimuell): Add proper name verification and exception type. @@ -70,8 +83,7 @@ def jit( f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}." ) - def wrapper(f: Callable) -> stages.JaCeWrapped: - # TODO(egparedes): Improve typing. + def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]: jace_wrapper = stages.JaCeWrapped( fun=f, primitive_translators=( diff --git a/src/jace/optimization.py b/src/jace/optimization.py index b5af4fa..5dc159b 100644 --- a/src/jace/optimization.py +++ b/src/jace/optimization.py @@ -8,7 +8,8 @@ """ JaCe specific optimizations. -Currently just a dummy exists for the sake of providing a callable function. +Todo: + Organize this module once it is a package. """ from __future__ import annotations @@ -19,7 +20,20 @@ if TYPE_CHECKING: - from jace import translator + from jace import translated_jaxpr_sdfg as tjsdfg + + +DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = { + "auto_optimize": True, + "simplify": True, + "persistent_transients": True, +} + +NO_OPTIMIZATIONS: Final[CompilerOptions] = { + "auto_optimize": False, + "simplify": False, + "persistent_transients": False, +} class CompilerOptions(TypedDict, total=False): @@ -35,15 +49,10 @@ class CompilerOptions(TypedDict, total=False): auto_optimize: bool simplify: bool + persistent_transients: bool -# TODO(phimuell): Add a context manager to modify the default. -DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True} - -NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False} - - -def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs +def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 [undocumented-param] """ Performs optimization of the translated SDFG _in place_. @@ -55,8 +64,12 @@ def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[Compil tsdfg: The translated SDFG that should be optimized. simplify: Run the simplification pipeline. auto_optimize: Run the auto optimization pipeline (currently does nothing) + persistent_transients: Set the allocation lifetime of (non register) transients + in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated + between different invocations. """ - # Currently this function exists primarily for the same of existing. + # TODO(phimuell): Implement the functionality. + # Currently this function exists primarily for the sake of existing. simplify = kwargs.get("simplify", False) auto_optimize = kwargs.get("auto_optimize", False) diff --git a/src/jace/stages.py b/src/jace/stages.py index 4639b11..93e0f85 100644 --- a/src/jace/stages.py +++ b/src/jace/stages.py @@ -7,64 +7,84 @@ """ Reimplementation of the `jax.stages` module. -This module reimplements the public classes of that Jax module. -However, they are a bit different, because JaCe uses DaCe as backend. +This module reimplements the public classes of that JAX module. +However, because JaCe uses DaCe as backend they differ is some small aspects. -As in Jax JaCe has different stages, the terminology is taken from -[Jax' AOT-Tutorial](https://jax.readthedocs.io/en/latest/aot.html). +As in JAX JaCe has different stages, the terminology is taken from +[JAX' AOT-Tutorial](https://jax.readthedocs.io/en/latest/aot.html). - Stage out: - In this phase an executable Python function is translated to Jaxpr. + In this phase an executable Python function is translated to a Jaxpr. - Lower: - This will transform the Jaxpr into an SDFG equivalent. As a implementation - note, currently this and the previous step are handled as a single step. + This will transform the Jaxpr into its SDFG equivalent. - Compile: - This will turn the SDFG into an executable object, see `dace.codegen.CompiledSDFG`. + This will turn the SDFG into an executable object. - Execution: This is the actual running of the computation. -As in Jax the `stages` module give access to the last three stages, but not -the first one. +As in JAX the in JaCe the user only has access to the last tree stages and +staging out and lowering is handled as a single step. """ from __future__ import annotations +import contextlib import copy -from typing import TYPE_CHECKING, Any +from collections.abc import Callable, Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, Union -import jax as _jax +from jax import tree_util as jax_tree -from jace import optimization, translator, util -from jace.optimization import CompilerOptions -from jace.translator import post_translation as ptrans -from jace.util import dace_helper, translation_cache as tcache +from jace import api, optimization, tracing, translated_jaxpr_sdfg as tjsdfg, translator, util +from jace.optimization import CompilerOptions # Reexport for compatibility with JAX. +from jace.translator import post_translation as ptranslation +from jace.util import translation_cache as tcache if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence - import dace + from jax import core as jax_core __all__ = [ - "CompilerOptions", # export for compatibility with Jax. + "CompilerOptions", "JaCeCompiled", "JaCeLowered", "JaCeWrapped", "Stage", + "get_compiler_options", + "set_compiler_options", ] - -class JaCeWrapped(tcache.CachingStage["JaCeLowered"]): +#: Known compilation stages in JaCe. +Stage = Union["JaCeWrapped", "JaCeLowered", "JaCeCompiled"] + +# These are used to annotated the `Stages`, however, there are some limitations. +# First, the only stage that is fully annotated is `JaCeWrapped`. Second, since +# static arguments modify the type signature of `JaCeCompiled.__call__()`, see +# [JAX](https://jax.readthedocs.io/en/latest/aot.html#lowering-with-static-arguments) +# for more, its argument can not be annotated, only its return type can. +# However, in case of scalar return values, the return type is wrong anyway, since +# JaCe and JAX for that matter, transforms scalars to arrays. Since there is no way of +# changing that, but from a semantic point they behave the same so it should not +# matter too much. +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P, _R]): """ A function ready to be specialized, lowered, and compiled. This class represents the output of functions such as `jace.jit()` and is the first stage in the translation/compilation chain of JaCe. A user should never create a `JaCeWrapped` object directly, instead `jace.jit` should be - used for that. While it supports just-in-time lowering and compilation, by - just calling it, these steps can also be performed explicitly. The lowering - performed by this stage is cached, thus if a `JaCeWrapped` object is lowered - later, with the same argument the result is taken from the cache. - Furthermore, a `JaCeWrapped` object is composable with all Jax transformations. + used. While it supports just-in-time lowering and compilation, by just + calling it, these steps can also be performed explicitly. + The lowering, performed by this stage is cached, thus if a `JaCeWrapped` + object is later lowered with the same arguments the result might be taken + from the cache. + + Furthermore, a `JaCeWrapped` object is composable with all JAX transformations, + all other stages are not. Args: fun: The function that is wrapped. @@ -72,181 +92,187 @@ class JaCeWrapped(tcache.CachingStage["JaCeLowered"]): jit_options: Options to influence the jit process. Todo: - - Support pytrees. - - Support keyword arguments and default values of the wrapped function. + - Support default values of the wrapped function. - Support static arguments. Note: The tracing of function will always happen with enabled `x64` mode, - which is implicitly and temporary activated while tracing. + which is implicitly and temporary activated during tracing. """ - _fun: Callable + _fun: Callable[_P, _R] _primitive_translators: dict[str, translator.PrimitiveTranslator] - _jit_options: dict[str, Any] + _jit_options: api.JITOptions def __init__( self, - fun: Callable, + fun: Callable[_P, _R], primitive_translators: Mapping[str, translator.PrimitiveTranslator], - jit_options: Mapping[str, Any], + jit_options: api.JITOptions, ) -> None: super().__init__() - # We have to shallow copy both the translator and the jit options. - # This prevents that any modifications affect `self`. - # Shallow is enough since the translators themselves are immutable. self._primitive_translators = {**primitive_translators} - # TODO(phimuell): Do we need to deepcopy the options? self._jit_options = {**jit_options} self._fun = fun - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """ Executes the wrapped function, lowering and compiling as needed in one step. - The arguments passed to this function are the same as the wrapped function uses. + This function will lower and compile in one go. The function accepts the same + arguments as the original computation and the return value is unflattened. + + Note: + This function is also aware if a JAX tracing is going on. In this + case, it will forward the computation. + Currently, this function ignores the value of `jax.disable_jit()`, + however, tracing will consider this value. """ - # If we are inside a traced context, then we forward the call to the wrapped - # function. This ensures that JaCe is composable with Jax. if util.is_tracing_ongoing(*args, **kwargs): return self._fun(*args, **kwargs) lowered = self.lower(*args, **kwargs) compiled = lowered.compile() + # TODO(phimuell): Filter out static arguments return compiled(*args, **kwargs) @tcache.cached_transition - def lower(self, *args: Any, **kwargs: Any) -> JaCeLowered: + def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> JaCeLowered[_R]: """ - Lower this function explicitly for the given arguments. + Lower the wrapped computation for the given arguments. - Performs the first two steps of the AOT steps described above, i.e. - trace the wrapped function with the given arguments and stage it out - to a Jaxpr. Then translate it to SDFG. The result is encapsulated - inside a `JaCeLowered` object which can later be compiled. + Performs the first two steps of the AOT steps described above, i.e. trace the + wrapped function with the given arguments and stage it out to a Jaxpr. Then + translate it to an SDFG. The result is encapsulated inside a `JaCeLowered` + object that can later be compiled. + + It should be noted that the current lowering process will hard code the strides + and the storage location of the input inside the SDFG. Thus if the SDFG is + lowered with arrays in C order, calling the compiled SDFG with FORTRAN order + will result in an error. Note: - The call to the function is cached. As key an abstract description - of the call, similar to the tracers used by Jax, is used. The tracing is always done with activated `x64` mode. """ - if len(kwargs) != 0: - raise NotImplementedError("Currently only positional arguments are supported.") - - # TODO(phimuell): Currently the SDFG that we build only supports `C_CONTIGUOUS` - # memory order. Since we support the paradigm that "everything passed to - # `lower()` should also be accepted as argument to call the result", we forbid - # other memory orders here. - if not all((not util.is_array(arg)) or arg.flags["C_CONTIGUOUS"] for arg in args): - raise NotImplementedError("Currently can not yet handle strides beside 'C_CONTIGUOUS'.") - - # In Jax `float32` is the main datatype, and they go to great lengths to avoid - # some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html). - # However, in this case we will have problems when we call the SDFG, for some - # reasons `CompiledSDFG` does not work in that case correctly, thus we enable - # it for the tracing. - with _jax.experimental.enable_x64(): - builder = translator.JaxprTranslationBuilder( - primitive_translators=self._primitive_translators - ) - jaxpr = _jax.make_jaxpr(self._fun)(*args) - trans_ctx: translator.TranslationContext = builder.translate_jaxpr(jaxpr) - - # Perform the post processing and turn it into a `TranslatedJaxprSDFG` that can - # be compiled and called later. - # NOTE: `tsdfg` was deepcopied as a side effect of post processing. - tsdfg: translator.TranslatedJaxprSDFG = ptrans.postprocess_jaxpr_sdfg( + jaxpr_maker = tracing.make_jaxpr( + fun=self._fun, + trace_options=self._jit_options, + return_out_tree=True, + ) + jaxpr, out_tree = jaxpr_maker(*args, **kwargs) + builder = translator.JaxprTranslationBuilder( + primitive_translators=self._primitive_translators + ) + trans_ctx: translator.TranslationContext = builder.translate_jaxpr(jaxpr) + + flat_call_args = jax_tree.tree_leaves((args, kwargs)) + tsdfg: tjsdfg.TranslatedJaxprSDFG = ptranslation.postprocess_jaxpr_sdfg( trans_ctx=trans_ctx, fun=self.wrapped_fun, - call_args=args, # Already linearised, since we only accept positional args. - intree=None, # Not yet implemented. + flat_call_args=flat_call_args, ) - return JaCeLowered(tsdfg) + # NOTE: `tsdfg` is deepcopied as a side effect of post processing. + return JaCeLowered(tsdfg, out_tree, trans_ctx.jaxpr) @property def wrapped_fun(self) -> Callable: - """Returns the wrapped function.""" + """Return the underlying Python function.""" return self._fun - def _make_call_description(self, *args: Any) -> tcache.StageTransformationSpec: + def _make_call_description( + self, in_tree: jax_tree.PyTreeDef, flat_call_args: Sequence[Any] + ) -> tcache.StageTransformationSpec: """ Computes the key for the `JaCeWrapped.lower()` call inside the cache. - The function will compute a full abstract description on its argument. + For all non static arguments the function will generate an abstract description + of an argument and for all static arguments the concrete value. + + Note: + The abstract description also includes storage location, i.e. if on CPU or + on GPU, and the strides of the arrays. """ - call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in args) - return tcache.StageTransformationSpec(stage_id=id(self), call_args=call_args) + # TODO(phimuell): Implement static arguments + flat_call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in flat_call_args) + return tcache.StageTransformationSpec( + stage_id=id(self), flat_call_args=tuple(flat_call_args), in_tree=in_tree + ) -class JaCeLowered(tcache.CachingStage["JaCeCompiled"]): +class JaCeLowered(tcache.CachingStage["JaCeCompiled"], Generic[_R]): """ Represents the original computation as an SDFG. - This class is the output type of `JaCeWrapped.lower()` and represents the - originally wrapped computation as an SDFG. This stage is followed by the - `JaCeCompiled` stage. + This class is the output type of `JaCeWrapped.lower()` and represents the original + computation as an SDFG. This stage is followed by the `JaCeCompiled` stage, by + calling `self.compile()`. A user should never directly construct a `JaCeLowered` + object directly, instead `JaCeWrapped.lower()` should be used. + + The SDFG is optimized before the compilation, see `JaCeLowered.compile()` for how to + control the process. Args: - tsdfg: The translated SDFG object representing the computation. + tsdfg: The lowered SDFG with metadata. + out_tree: The pytree describing how to unflatten the output. + jaxpr: The Jaxpr expression that was translated into an SDFG. Intended to be + used during debugging and inspection. Note: - `self` will manage the passed `tsdfg` object. Modifying it results in - undefined behavior. Although `JaCeWrapped` is composable with Jax - transformations `JaCeLowered` is not. A user should never create such - an object, instead `JaCeWrapped.lower()` should be used. + `self` will manage the passed `tsdfg` object. Modifying it results is undefined + behavior. Although `JaCeWrapped` is composable with JAX transformations + `JaCeLowered` is not. """ - _translated_sdfg: translator.TranslatedJaxprSDFG + _translated_sdfg: tjsdfg.TranslatedJaxprSDFG + _out_tree: jax_tree.PyTreeDef + _jaxpr: jax_core.ClosedJaxpr - def __init__(self, tsdfg: translator.TranslatedJaxprSDFG) -> None: + def __init__( + self, + tsdfg: tjsdfg.TranslatedJaxprSDFG, + out_tree: jax_tree.PyTreeDef, + jaxpr: jax_core.ClosedJaxpr, + ) -> None: super().__init__() self._translated_sdfg = tsdfg + self._out_tree = out_tree + self._jaxpr = jaxpr @tcache.cached_transition - def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled: + def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled[_R]: """ Optimize and compile the lowered SDFG using `compiler_options`. - Returns an object that encapsulates a compiled SDFG object. To influence - the various optimizations and compile options of JaCe you can use the - `compiler_options` argument. If nothing is specified - `jace.optimization.DEFAULT_OPTIMIZATIONS` will be used. + To perform the optimizations `jace_optimize()` is used. The actual options that + are forwarded to it are obtained by passing `compiler_options` to + `get_compiler_options()`, these options are also included in the + key used to cache the result. - Note: - Before `compiler_options` is forwarded to `jace_optimize()` it - will be merged with the default arguments. + Args: + compiler_options: The optimization options to use. """ # We **must** deepcopy before we do any optimization, because all optimizations # are in place, to properly cache stages, stages needs to be immutable. - tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._translated_sdfg) - optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options)) + tsdfg: tjsdfg.TranslatedJaxprSDFG = copy.deepcopy(self._translated_sdfg) + optimization.jace_optimize(tsdfg=tsdfg, **get_compiler_options(compiler_options)) return JaCeCompiled( - csdfg=dace_helper.compile_jax_sdfg(tsdfg), - inp_names=tsdfg.inp_names, - out_names=tsdfg.out_names, + compiled_sdfg=tjsdfg.compile_jaxpr_sdfg(tsdfg), + out_tree=self._out_tree, ) - def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprSDFG: + def compiler_ir(self, dialect: str | None = None) -> tjsdfg.TranslatedJaxprSDFG: """ Returns the internal SDFG. - The function returns a `TranslatedJaxprSDFG` object. Direct modification - of the returned object is forbidden and will cause undefined behaviour. + The function returns a `TranslatedJaxprSDFG` object. Direct modification of the + returned object is forbidden and results in undefined behaviour. """ if (dialect is None) or (dialect.upper() == "SDFG"): return self._translated_sdfg raise ValueError(f"Unknown dialect '{dialect}'.") - def view(self, filename: str | None = None) -> None: - """ - Runs the `view()` method of the underlying SDFG. - - This will open a browser and display the SDFG. - """ - self.compiler_ir().sdfg.view(filename=filename, verbose=False) - def as_sdfg(self) -> dace.SDFG: """ Returns the encapsulated SDFG. @@ -256,64 +282,127 @@ def as_sdfg(self) -> dace.SDFG: return self.compiler_ir().sdfg def _make_call_description( - self, compiler_options: CompilerOptions | None = None + self, in_tree: jax_tree.PyTreeDef, flat_call_args: Sequence[Any] ) -> tcache.StageTransformationSpec: """ - This function computes the key for the `self.compile()` call inside the cache. + Creates the key for the `self.compile()` transition function. - The key that is computed by this function is based on the concrete - values of the passed compiler options. + The key will depend on the final values that were used for optimization, i.e. + they it will also include the global set of optimization options. """ - options = self._make_compiler_options(compiler_options) - call_args = tuple(sorted(options.items(), key=lambda x: x[0])) - return tcache.StageTransformationSpec(stage_id=id(self), call_args=call_args) + unflatted_args, unflatted_kwargs = jax_tree.tree_unflatten(in_tree, flat_call_args) + assert (not unflatted_kwargs) and (len(unflatted_args) <= 1) - @staticmethod - def _make_compiler_options(compiler_options: CompilerOptions | None) -> CompilerOptions: - return optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {}) + options = get_compiler_options(unflatted_args[0] if unflatted_args else None) + flat_options, option_tree = jax_tree.tree_flatten(options) + return tcache.StageTransformationSpec( + stage_id=id(self), flat_call_args=tuple(flat_options), in_tree=option_tree + ) -class JaCeCompiled: +class JaCeCompiled(Generic[_R]): """ Compiled version of the SDFG. - This is the last stage of the jit chain. A user should never create a + This is the last stage of the JaCe's jit chain. A user should never create a `JaCeCompiled` instance, instead `JaCeLowered.compile()` should be used. + Since the strides and storage location of the arguments, that where used to lower + the computation are hard coded inside the SDFG, a `JaCeCompiled` object can only be + called with compatible arguments. + Args: - csdfg: The compiled SDFG object. - inp_names: Names of the SDFG variables used as inputs. - out_names: Names of the SDFG variables used as outputs. + compiled_sdfg: The compiled SDFG object. + input_names: SDFG variables used as inputs. + output_names: SDFG variables used as outputs. + out_tree: Pytree describing how to unflatten the output. Note: The class assumes ownership of its input arguments. Todo: - - Handle pytrees. + - Automatic strides adaptation. """ - _csdfg: dace_helper.CompiledSDFG - _inp_names: tuple[str, ...] - _out_names: tuple[str, ...] + _compiled_sdfg: tjsdfg.CompiledJaxprSDFG + _out_tree: jax_tree.PyTreeDef def __init__( - self, csdfg: dace_helper.CompiledSDFG, inp_names: Sequence[str], out_names: Sequence[str] + self, + compiled_sdfg: tjsdfg.CompiledJaxprSDFG, + out_tree: jax_tree.PyTreeDef, ) -> None: - if (not inp_names) or (not out_names): - raise ValueError("Input and output can not be empty.") - self._csdfg = csdfg - self._inp_names = tuple(inp_names) - self._out_names = tuple(out_names) + self._compiled_sdfg = compiled_sdfg + self._out_tree = out_tree - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> _R: """ Calls the embedded computation. - The arguments must be the same as for the wrapped function, but with - all static arguments removed. + Note: + Unlike the `lower()` function which takes the same arguments as the original + computation, to call this function you have to remove all static arguments. + Furthermore, all arguments must have strides and storage locations that is + compatible with the ones that were used for lowering. """ - return dace_helper.run_jax_sdfg(self._csdfg, self._inp_names, self._out_names, args, kwargs) + flat_call_args = jax_tree.tree_leaves((args, kwargs)) + flat_output = self._compiled_sdfg(flat_call_args) + return jax_tree.tree_unflatten(self._out_tree, flat_output) -#: Known compilation stages in JaCe. -Stage = JaCeWrapped | JaCeLowered | JaCeCompiled +# <--------------------------- Compilation/Optimization options management + +_JACELOWERED_ACTIVE_COMPILE_OPTIONS: CompilerOptions = optimization.DEFAULT_OPTIMIZATIONS.copy() +"""Global set of currently active compilation/optimization options. + +The global set is initialized to `jace.optimization.DEFAULT_OPTIMIZATIONS`. +For modifying the set of active options the the `set_compiler_options()` +context manager is provided. +To obtain the currently active compiler options use `get_compiler_options()`. +""" + + +@contextlib.contextmanager +def set_compiler_options(compiler_options: CompilerOptions) -> Generator[None, None, None]: + """ + Temporary modifies the set of active compiler options. + + During the activation of this context the active set of active compiler option + consists of the set of option that were previously active merged with the ones + passed through `compiler_options`. + + Args: + compiler_options: Options that should be temporary merged with the currently + active options. + + See Also: + `get_compiler_options()` to get the set of active options that is + currently active. + """ + global _JACELOWERED_ACTIVE_COMPILE_OPTIONS # noqa: PLW0603 [global-statement] + previous_compiler_options = _JACELOWERED_ACTIVE_COMPILE_OPTIONS.copy() + try: + _JACELOWERED_ACTIVE_COMPILE_OPTIONS.update(compiler_options) + yield None + finally: + _JACELOWERED_ACTIVE_COMPILE_OPTIONS = previous_compiler_options + + +def get_compiler_options(compiler_options: CompilerOptions | None) -> CompilerOptions: + """ + Get the final compiler options. + + There are two different sources of optimization options. The first one is the global + set of currently active compiler options, which is returned if `None` is passed. + The second one is the options that are passed to this function, which takes + precedence. This mode is also used by `JaCeLowered.compile()` to determine the + final compiler options. + + Args: + compiler_options: The local compilation options. + + See Also: + `set_compiler_options()` to modify the currently active set of compiler + options. + """ + return _JACELOWERED_ACTIVE_COMPILE_OPTIONS | (compiler_options or {}) diff --git a/src/jace/tracing.py b/src/jace/tracing.py new file mode 100644 index 0000000..4df5d00 --- /dev/null +++ b/src/jace/tracing.py @@ -0,0 +1,118 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Implements the tracing machinery that is used to build the Jaxpr. + +JAX provides `jax.make_jaxpr()`, which is essentially a debug utility, but it does not +provide any other public way to get a Jaxpr. This module provides the necessary +functionality for this in JaCe. +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, overload + +import jax +from jax import core as jax_core, tree_util as jax_tree + + +if TYPE_CHECKING: + from jace import api + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@overload +def make_jaxpr( + fun: Callable[_P, _R], + trace_options: api.JITOptions, + return_out_tree: Literal[True], +) -> Callable[_P, tuple[jax_core.ClosedJaxpr, jax_tree.PyTreeDef]]: ... + + +@overload +def make_jaxpr( + fun: Callable[_P, _R], + trace_options: api.JITOptions, + return_out_tree: Literal[False] = False, +) -> Callable[_P, jax_core.ClosedJaxpr]: ... + + +def make_jaxpr( + fun: Callable[_P, Any], + trace_options: api.JITOptions, + return_out_tree: bool = False, +) -> Callable[_P, tuple[jax_core.ClosedJaxpr, jax_tree.PyTreeDef] | jax_core.ClosedJaxpr]: + """ + JaCe's replacement for `jax.make_jaxpr()`. + + Returns a callable object that produces a Jaxpr and optionally a pytree defining + the output. By default the callable will only return the Jaxpr, however, by setting + `return_out_tree` the function will also return the output tree, this is different + from the `return_shape` of `jax.make_jaxpr()`. + + Currently the tracing is always performed with an enabled `x64` mode. + + Returns: + The function returns a callable that will perform the tracing on the passed + arguments. If `return_out_tree` is `False` that callable will simply return the + generated Jaxpr. If `return_out_tree` is `True` the function will return a tuple + with the Jaxpr and a pytree object describing the structure of the output. + + Args: + fun: The original Python computation. + trace_options: The options used for tracing, the same arguments that + are supported by `jace.jit`. + return_out_tree: Also return the pytree of the output. + + Todo: + - Handle default arguments of `fun`. + - Handle static arguments. + - Turn `trace_options` into a `TypedDict` and sync with `jace.jit`. + """ + if trace_options: + raise NotImplementedError( + f"Not supported tracing options: {', '.join(f'{k}' for k in trace_options)}" + ) + # TODO(phimuell): Test if this restriction is needed. + assert all(param.default is param.empty for param in inspect.signature(fun).parameters.values()) + + def tracer_impl( + *args: _P.args, + **kwargs: _P.kwargs, + ) -> tuple[jax_core.ClosedJaxpr, jax_tree.PyTreeDef] | jax_core.ClosedJaxpr: + # In JAX `float32` is the main datatype, and they go to great lengths to avoid + # some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html). + # However, in this case we will have problems when we call the SDFG, for some + # reasons `CompiledSDFG` does not work in that case correctly, thus we enable + # it for the tracing. + with jax.experimental.enable_x64(): + # TODO(phimuell): copy the implementation of the real tracing + jaxpr_maker = jax.make_jaxpr( + fun, + **trace_options, + return_shape=True, + ) + jaxpr, out_shapes = jaxpr_maker( + *args, + **kwargs, + ) + + if not return_out_tree: + return jaxpr + + # Regardless what the documentation of `make_jaxpr` claims, it does not output + # a pytree but an abstract description of the shape, that we will + # transform into a pytree. + out_tree = jax_tree.tree_structure(out_shapes) + return jaxpr, out_tree + + return tracer_impl diff --git a/src/jace/translated_jaxpr_sdfg.py b/src/jace/translated_jaxpr_sdfg.py new file mode 100644 index 0000000..9cb9908 --- /dev/null +++ b/src/jace/translated_jaxpr_sdfg.py @@ -0,0 +1,225 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Extended versions of `SDFG` and `CompiledSDFG` with additional metadata.""" + +from __future__ import annotations + +import dataclasses +import pathlib +import uuid +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import dace +from dace import data as dace_data + +from jace import util + + +if TYPE_CHECKING: + import numpy as np + from dace.codegen import compiled_sdfg as dace_csdfg + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class TranslatedJaxprSDFG: + """ + Encapsulates the SDFG generated from a Jaxpr and additional metadata. + + Contrary to the SDFG that is encapsulated inside an `TranslationContext` + object, `self` carries a proper SDFG with the following structure: + - It does not have `__return*` variables, instead all return arguments are + passed by arguments. + - All input arguments are passed through arguments mentioned in `input_names`, + while the outputs are passed through `output_names`. + - Only variables listed as in/outputs are non transient. + - The order of `input_names` and `output_names` is the same as in the Jaxpr. + - Its `arg_names` is set to `input_names + output_names`, but arguments that are + input and outputs are only listed as inputs. + - For every transient there is exactly one access node that writes to it, + except the name of the array starts with `__jace_mutable_`, which can + be written to multiple times. + + The only valid way to obtain a `TranslatedJaxprSDFG` is by passing a + `TranslationContext`, that was in turn constructed by + `JaxprTranslationBuilder.translate_jaxpr()`, to the + `finalize_translation_context()` or preferably to the `postprocess_jaxpr_sdfg()` + function. + + Attributes: + sdfg: The encapsulated SDFG object. + input_names: SDFG variables used as inputs. + output_names: SDFG variables used as outputs. + + Todo: + After the SDFG is compiled a lot of code looks strange, because there is + no container to store the compiled SDFG and the metadata. This class + should be extended to address this need. + """ + + sdfg: dace.SDFG + input_names: tuple[str, ...] + output_names: tuple[str, ...] + + def validate(self) -> bool: + """Validate the underlying SDFG.""" + if any(self.sdfg.arrays[inp].transient for inp in self.input_names): + raise dace.sdfg.InvalidSDFGError( + f"Found transient inputs: {(inp for inp in self.input_names if self.sdfg.arrays[inp].transient)}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + if any(self.sdfg.arrays[out].transient for out in self.output_names): + raise dace.sdfg.InvalidSDFGError( + f"Found transient outputs: {(out for out in self.output_names if self.sdfg.arrays[out].transient)}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + if self.sdfg.free_symbols: # This is a simplification that makes our life simple. + raise dace.sdfg.InvalidSDFGError( + f"Found free symbols: {self.sdfg.free_symbols}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + if (self.output_names is not None and self.input_names is not None) and ( + set(self.output_names).intersection(self.input_names) + ): + raise dace.sdfg.InvalidSDFGError( + f"Inputs can not be outputs: {set(self.output_names).intersection(self.input_names)}.", + self.sdfg, + None, + ) + self.sdfg.validate() + return True + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class CompiledJaxprSDFG: + """ + Compiled version of a `TranslatedJaxprSDFG` instance. + + Essentially this class is a wrapper around DaCe's `CompiledSDFG` object, that + supports the calling convention used inside JaCe, as in `DaCe` it is callable. + The only valid way to obtain a `CompiledJaxprSDFG` instance is through + `compile_jaxpr_sdfg()`. + + Args: + compiled_sdfg: The `CompiledSDFG` object. + input_names: Names of the SDFG variables used as inputs. + output_names: Names of the SDFG variables used as outputs. + + Attributes: + compiled_sdfg: The `CompiledSDFG` object. + sdfg: SDFG object used to generate/compile `self.compiled_sdfg`. + input_names: Names of the SDFG variables used as inputs. + output_names: Names of the SDFG variables used as outputs. + + Note: + Currently the strides of the input arguments must match the ones that were used + for lowering the SDFG. + In DaCe the return values are allocated on a per `CompiledSDFG` basis. Thus + every call to a compiled SDFG will override the value of the last call, in JaCe + the memory is allocated on every call. In addition scalars are returned as + arrays of length one. + """ + + compiled_sdfg: dace_csdfg.CompiledSDFG + input_names: tuple[str, ...] + output_names: tuple[str, ...] + + @property + def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method] + return self.compiled_sdfg.sdfg + + def __call__( + self, + flat_call_args: Sequence[Any], + ) -> list[np.ndarray]: + """ + Run the compiled SDFG using the flattened input. + + The function will not perform flattening of its input nor unflattening of + the output. + + Args: + flat_call_args: Flattened input arguments. + """ + if len(self.input_names) != len(flat_call_args): + raise RuntimeError( + f"Expected {len(self.input_names)} flattened arguments, but got {len(flat_call_args)}." + ) + + sdfg_call_args: dict[str, Any] = {} + for in_name, in_val in zip(self.input_names, flat_call_args): + # TODO(phimuell): Implement a stride matching process. + if util.is_jax_array(in_val): + if not util.is_fully_addressable(in_val): + raise ValueError(f"Passed a not fully addressable JAX array as '{in_name}'") + in_val = in_val.__array__() # noqa: PLW2901 [redefined-loop-name] # JAX arrays do not expose the __array_interface__. + sdfg_call_args[in_name] = in_val + + arrays = self.sdfg.arrays + for output_name in self.output_names: + sdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name]) + + assert len(sdfg_call_args) == len(self.compiled_sdfg.argnames), ( + "Failed to construct the call arguments," + f" expected {len(self.compiled_sdfg.argnames)} but got {len(flat_call_args)}." + f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" + ) + + # Calling the SDFG + with dace.config.temporary_config(): + dace.Config.set("compiler", "allow_view_arguments", value=True) + self.compiled_sdfg(**sdfg_call_args) + + return [sdfg_call_args[output_name] for output_name in self.output_names] + + +def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG: + """Compile `tsdfg` and return a `CompiledJaxprSDFG` object with the result.""" + if any( # We do not support the DaCe return mechanism + array_name.startswith("__return") for array_name in tsdfg.sdfg.arrays + ): + raise ValueError("Only support SDFGs without '__return' members.") + if tsdfg.sdfg.free_symbols: # This is a simplification that makes our life simple. + raise NotImplementedError(f"No free symbols allowed, found: {tsdfg.sdfg.free_symbols}") + if not (tsdfg.output_names or tsdfg.input_names): + raise ValueError("No input nor output.") + + # To ensure that the SDFG is compiled and to get rid of a warning we must modify + # some settings of the SDFG. But we also have to fake an immutable SDFG + sdfg = tsdfg.sdfg + original_sdfg_name = sdfg.name + original_recompile = sdfg._recompile + original_regenerate_code = sdfg._regenerate_code + + try: + # We need to give the SDFG another name, this is needed to prevent a DaCe + # error/warning. This happens if we compile the same lowered SDFG multiple + # times with different options. + sdfg.name = f"{sdfg.name}__{str(uuid.uuid1()).replace('-', '_')}" + assert len(sdfg.name) < 255 # noqa: PLR2004 [magic-value-comparison] # 255 maximal file name size on UNIX. + + with dace.config.temporary_config(): + dace.Config.set("compiler", "use_cache", value=False) + # TODO(egparedes/phimuell): Add a configuration option. + dace.Config.set("cache", value="name") + dace.Config.set("default_build_folder", value=pathlib.Path(".jacecache").resolve()) + sdfg._recompile = True + sdfg._regenerate_code = True + compiled_sdfg: dace_csdfg.CompiledSDFG = sdfg.compile() + + finally: + sdfg.name = original_sdfg_name + sdfg._recompile = original_recompile + sdfg._regenerate_code = original_regenerate_code + + return CompiledJaxprSDFG( + compiled_sdfg=compiled_sdfg, input_names=tsdfg.input_names, output_names=tsdfg.output_names + ) diff --git a/src/jace/translator/__init__.py b/src/jace/translator/__init__.py index 2f184a0..9cd3dfd 100644 --- a/src/jace/translator/__init__.py +++ b/src/jace/translator/__init__.py @@ -22,14 +22,12 @@ make_primitive_translator, register_primitive_translator, ) -from .translated_jaxpr_sdfg import TranslatedJaxprSDFG __all__ = [ "JaxprTranslationBuilder", "PrimitiveTranslator", "PrimitiveTranslatorCallable", - "TranslatedJaxprSDFG", "TranslationContext", "get_registered_primitive_translators", "make_primitive_translator", diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index da2e68f..3d7d04c 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast, overload import dace -from dace import data as ddata, properties as dprop +from dace import data as dace_data, properties as dace_properties from jax import core as jax_core from jace import util @@ -31,10 +31,12 @@ class JaxprTranslationBuilder: canonical. The main features of such an SDFG are: - the SDFG is a list of states, - it has a single source and sink state. - - all variable names are derived from Jax names, + - all variable names are derived from JAX names, - there are only transient variables inside the SDFG, - - It lacks the special `__return` variable, - - the `arg_names` parameter is not set. + - it lacks the special `__return` variable, + - the `arg_names` parameter is not set, + - for all scalar values a ` Scalar` SDFG variable is used, thus they cannot + be used to return anything. For these reasons the SDFG is not directly usable, and further manipulations have to be performed. Especially, DaCe's validation function will fail and @@ -64,10 +66,9 @@ class JaxprTranslationBuilder: Args: primitive_translators: Primitive translators to use in the translation. - Notes: + Note: After a translation has been performed the translator object can be used - again. Currently the builder will generate only Array as SDFG variables, - however, this is a temporary solution, see `add_array()`. + again. """ _primitive_translators: Mapping[str, translator.PrimitiveTranslatorCallable] @@ -80,7 +81,7 @@ def __init__( # Maps name of primitives to the associated translator. self._primitive_translators = {**primitive_translators} - # Maps Jax variables to the name of its SDFG equivalent. + # Maps JAX variables to the name of its SDFG equivalent. # Shared between all translation contexts, to ensure consecutive variable # naming as seen as in a pretty printed Jaxpr. Will be cleared by # `_clear_translation_ctx()` at the end of the root translation. @@ -119,7 +120,7 @@ def translate_jaxpr( # context. Thus the builder will start to translate a second (nested) # SDFG. Also note that there is no mechanism that forces the integration # of the nested SDFG/Jaxpr, this must be done manually. - self._allocate_translation_ctx(name=name) + self._allocate_translation_ctx(name=name, jaxpr=jaxpr) self._create_constants(jaxpr=jaxpr) self._create_initial_input(jaxpr=jaxpr) @@ -128,7 +129,7 @@ def translate_jaxpr( def append_new_state( self, label: str | None = None, - condition: dprop.CodeBlock | None = None, + condition: dace_properties.CodeBlock | None = None, assignments: Mapping[str, Any] | None = None, prev_state: dace.SDFGState | None = None, ) -> dace.SDFGState: @@ -146,7 +147,7 @@ def append_new_state( assignments: Symbol assignments on the `InterstateEdge`. prev_state: Alternative state at which we append. - Notes: + Note: It is potentially dangerous to not append to the current terminal state, as a canonical SDFG only has one sink state. If this is done the user has to ensure, that at the end of the processing the SDFG @@ -175,22 +176,22 @@ def append_new_state( return new_state @property - def arrays(self) -> Mapping[str, ddata.Data]: + def arrays(self) -> Mapping[str, dace_data.Data]: """ Get all data descriptors that are currently known to the SDFG. - Notes: + Note: Essentially a shorthand and preferred way for `self.sdfg.arrays`. For getting a specific data descriptor use `self.get_array()`. """ - return cast(Mapping[str, ddata.Data], self._ctx.sdfg.arrays) + return cast(Mapping[str, dace_data.Data], self._ctx.sdfg.arrays) - def get_array(self, name: str | jax_core.Atom | util.JaCeVar) -> ddata.Data: + def get_array(self, name: str | jax_core.Atom | util.JaCeVar) -> dace_data.Data: """ Returns the SDFG `Data` object `name` referees to. `name` can either be a string, in which case it is interpreted as a - verbatim SDFG name. If it is a Jax or JaCe variable, the function will + verbatim SDFG name. If it is a JAX or JaCe variable, the function will first perform a lookup using `self.map_jax_var_to_sdfg(name)`. """ if isinstance(name, (jax_core.Var, util.JaCeVar)): @@ -220,7 +221,7 @@ def map_jax_var_to_sdfg( Get the name of the SDFG variable to which `jax_var` is referring to. Args: - jax_var: The Jax variable to look up. + jax_var: The JAX variable to look up. allow_fail: Return `None` instead of raising a `KeyError`. """ if isinstance(jax_var, jax_core.Literal): @@ -230,10 +231,10 @@ def map_jax_var_to_sdfg( elif allow_fail: return None else: - raise KeyError(f"The Jax variable '{jax_var}' was never registered.") + raise KeyError(f"The JAX variable '{jax_var}' was never registered.") if sdfg_name not in self._ctx.sdfg.arrays: raise KeyError( - f"Jax variable '{jax_var}' was supposed to map to '{sdfg_name}'," + f"JAX variable '{jax_var}' was supposed to map to '{sdfg_name}'," " but no such SDFG variable is known." ) return sdfg_name @@ -271,11 +272,11 @@ def add_jax_name_mapping( is not able to delete a variable mapping that was established before. Args: - jax_var: The Jax variable. + jax_var: The JAX variable. sdfg_name: The name of the corresponding SDFG variable. """ - assert sdfg_name - + if not sdfg_name: + raise ValueError("Supplied 'sdfg_name' is empty.") if jax_var in self._jax_name_map: raise ValueError( f"Cannot change the mapping of '{jax_var}' from" @@ -297,7 +298,7 @@ def add_array( update_var_mapping: bool = False, ) -> str: """ - Creates an SDFG variable for Jax variable `arg` and returns its SDFG name. + Creates an SDFG variable for JAX variable `arg` and returns its SDFG name. The SDFG object is always created as a transient. Furthermore, the function will not update the internal variable mapping, by default. @@ -309,17 +310,9 @@ def add_array( should be used. Args: - arg: The Jax object for which a SDFG equivalent should be created. + arg: The JAX object for which a SDFG equivalent should be created. name_prefix: If given it will be used as prefix for the name. update_var_mapping: Update the internal variable mapping. - - Notes: - As a temporary fix for handling scalar return values, the function - will always generate arrays, even if `arg` is a scalar. According to - the DaCe developer, the majority of the backend, i.e. optimization - pipeline, should be able to handle it. But there are some special - parts that might explicitly want a scalar, it also might block - certain compiler optimization. """ if isinstance(arg, jax_core.Literal): raise TypeError(f"Can not generate an SDFG variable for literal '{arg}'.") @@ -331,9 +324,6 @@ def add_array( as_transient = True strides = None - # Temporary fix for handling DaCe scalars, see above for more. - shape = shape or (1,) - # Propose a name and if needed extend it. arg_name = util.propose_jax_name(arg, self._jax_name_map) if name_prefix: @@ -347,15 +337,20 @@ def add_array( if arg_name in util.FORBIDDEN_SDFG_VAR_NAMES: raise ValueError(f"add_array({arg}): The proposed name '{arg_name}', is forbidden.") - self._ctx.sdfg.add_array( - name=arg_name, - shape=shape, - strides=strides, - offset=offset, - storage=storage, - dtype=dtype, - transient=as_transient, - ) + if shape == (): + self._ctx.sdfg.add_scalar( + name=arg_name, storage=storage, dtype=dtype, transient=as_transient + ) + else: + self._ctx.sdfg.add_array( + name=arg_name, + shape=shape, + strides=strides, + offset=offset, + storage=storage, + dtype=dtype, + transient=as_transient, + ) if update_var_mapping: try: @@ -396,9 +391,9 @@ def create_jax_var_list( # type: ignore[misc] **kwargs: Any, ) -> list[None | str]: """ - Create SDFG variables from the passed Jax variables. + Create SDFG variables from the passed JAX variables. - If a Jax variable already has a SDFG equivalent then the function will + If a JAX variable already has a SDFG equivalent then the function will use this variable. If no corresponding SDFG variable is known the function will create one using `add_array()`. @@ -412,7 +407,7 @@ def create_jax_var_list( # type: ignore[misc] to `True` literals will will be included in the output with the value `None`. Args: - jax_var_list: The list of Jax variables that should be processed. + jax_var_list: The list of JAX variables that should be processed. prevent_creation: Never create a variable, all must already be known. only_creation: Always create a variable. handle_literals: Allow the processing of literals. @@ -448,11 +443,10 @@ def _create_initial_input(self, jaxpr: jax_core.ClosedJaxpr) -> None: """ Creates the input variables of `jaxpr`. - Notes: - The function will populate the `inp_names` member of the current context. + Note: + The function will populate the `input_names` member of the current context. """ - assert self.is_allocated(), "Builder is not allocated, can not create constants." - assert self._ctx.inp_names is None + assert self._ctx.input_names is None # Handle the initial input arguments init_in_var_names: Sequence[str] = self.create_jax_var_list( @@ -464,7 +458,7 @@ def _create_initial_input(self, jaxpr: jax_core.ClosedJaxpr) -> None: self.sdfg.arg_names = [] # The output list is populated by `self._translate_jaxpr_internal()` - self._ctx.inp_names = tuple(init_in_var_names) + self._ctx.input_names = tuple(init_in_var_names) def _create_constants(self, jaxpr: jax_core.ClosedJaxpr) -> None: """ @@ -473,7 +467,6 @@ def _create_constants(self, jaxpr: jax_core.ClosedJaxpr) -> None: The function will create an SDFG variable and add them as constant to the SDFG. Their value is deepcopied. """ - assert self.is_allocated(), "Builder is not allocated, can not create constants." if len(jaxpr.consts) == 0: return @@ -489,14 +482,17 @@ def _create_constants(self, jaxpr: jax_core.ClosedJaxpr) -> None: sdfg_name, copy.deepcopy(const_value), self._ctx.sdfg.arrays[sdfg_name] ) - def _allocate_translation_ctx(self, name: str | None = None) -> JaxprTranslationBuilder: + def _allocate_translation_ctx( + self, name: str | None, jaxpr: jax_core.ClosedJaxpr + ) -> JaxprTranslationBuilder: """ Allocate a new context and activate it. Args: name: The name of the SDFG. + jaxpr: The Jaxpr that should be translated. """ - self._ctx_stack.append(TranslationContext(name=name)) + self._ctx_stack.append(TranslationContext(name=name, jaxpr=jaxpr)) return self @property @@ -590,7 +586,7 @@ def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationC Args: jaxpr: The Jaxpr to translate. - Notes: + Note: Equations that store into drop variables, i.e. with name `_`, will be ignored. """ @@ -612,7 +608,7 @@ def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationC jaxpr.jaxpr.outvars, prevent_creation=True, handle_literals=False ) - self._ctx.out_names = tuple(out_var_names) + self._ctx.output_names = tuple(out_var_names) return cast(TranslationContext, self._clear_translation_ctx()) @@ -635,11 +631,12 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: - Handle the case if if the output is a literal. Note: - The function will _not_ update the `out_names` field of the current context. + The function will _not_ update the `output_names` field of the current + context. """ assert self._ctx.terminal_state is self._ctx.start_state - assert self._ctx.inp_names - assert self._ctx.out_names is None + assert isinstance(self._ctx.input_names, tuple) + assert self._ctx.output_names is None # There is not output so we do not have to copy anything around. if not jaxpr.out_avals: @@ -657,7 +654,7 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: sdfg_in_name: str = self.map_jax_var_to_sdfg(jax_out_var) # Now we create a variable that serves as true output, however, since the - # Jax variable is already known we can not update the variable mapping and + # JAX variable is already known we can not update the variable mapping and # must use another name. sdfg_out_name = self.add_array( jax_out_var, name_prefix="_zero_equation_output_for_", update_var_mapping=False @@ -666,10 +663,10 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: # Now we perform the copy from the input variable in the newly created # output variable. - inp_acc = self._start_state.add_read(sdfg_in_name) + input_acc = self._start_state.add_read(sdfg_in_name) out_acc = self._start_state.add_write(sdfg_out_name) self._start_state.add_nedge( - src=inp_acc, + src=input_acc, dst=out_acc, data=dace.Memlet.from_array(sdfg_in_name, self.get_array(sdfg_in_name)), ) @@ -707,10 +704,12 @@ class TranslationContext: Attributes: sdfg: The encapsulated SDFG object. - inp_names: A list of the SDFG variables that are used as input - out_names: A list of the SDFG variables that are used as output. + input_names: A list of the SDFG variables that are used as input + output_names: A list of the SDFG variables that are used as output. start_state: The first state in the SDFG state machine. terminal_state: The (currently) last state in the state machine. + jaxpr: The Jaxpr expression that was translated into an SDFG. Intended to be + used during debugging and inspection. Args: name: The name of the SDFG. @@ -718,23 +717,27 @@ class TranslationContext: Note: Access of any attribute of this class by an outside user is considered undefined behaviour. + Furthermore, the encapsulated SDFG should be seen as a verbatim translation + of the initial Jaxpr. """ sdfg: dace.SDFG - inp_names: tuple[str, ...] | None - out_names: tuple[str, ...] | None + input_names: tuple[str, ...] | None + output_names: tuple[str, ...] | None start_state: dace.SDFGState terminal_state: dace.SDFGState + jaxpr: jax_core.ClosedJaxpr - def __init__(self, name: str | None = None) -> None: + def __init__(self, name: str | None, jaxpr: jax_core.ClosedJaxpr) -> None: if isinstance(name, str) and not util.VALID_SDFG_OBJ_NAME.fullmatch(name): raise ValueError(f"'{name}' is not a valid SDFG name.") self.sdfg = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}")) - self.inp_names = None - self.out_names = None + self.input_names = None + self.output_names = None self.start_state = self.sdfg.add_state(label="initial_state", is_start_block=True) self.terminal_state = self.start_state + self.jaxpr = jaxpr def validate(self) -> bool: """ @@ -748,13 +751,31 @@ def validate(self) -> bool: f"Expected to find '{self.start_state}' as start state," f" but instead found '{self.sdfg.start_block}'.", self.sdfg, - self.sdfg.node_id(self.start_state), + None, ) if {self.terminal_state} != set(self.sdfg.sink_nodes()): raise dace.sdfg.InvalidSDFGError( f"Expected to find as terminal state '{self.terminal_state}'," f" but instead found '{self.sdfg.sink_nodes()}'.", self.sdfg, - self.sdfg.node_id(self.terminal_state), + None, + ) + if not ( + self.input_names is None + or all(input_name in self.sdfg.arrays for input_name in self.input_names) + ): + raise dace.sdfg.InvalidSDFGError( + f"Missing input arguments: {(input_name for input_name in self.input_names if input_name not in self.sdfg.arrays)}", + self.sdfg, + None, + ) + if not ( + self.output_names is None + or all(output_name in self.sdfg.arrays for output_name in self.output_names) + ): + raise dace.sdfg.InvalidSDFGError( + f"Missing output arguments: {(output_name for output_name in self.output_names if output_name not in self.sdfg.arrays)}", + self.sdfg, + None, ) return True diff --git a/src/jace/translator/post_translation.py b/src/jace/translator/post_translation.py index ec445e9..a00b651 100644 --- a/src/jace/translator/post_translation.py +++ b/src/jace/translator/post_translation.py @@ -5,104 +5,232 @@ # # SPDX-License-Identifier: BSD-3-Clause -""" -This module contains all functions that are related to post processing the SDFG. - -Most of them operate on `TranslatedJaxprSDFG` objects. -Currently they mostly exist for the sake of existing. -""" +"""Functions for the pre and post processing during the translation.""" from __future__ import annotations import copy +from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any -from jace import translator +import dace + +from jace import translated_jaxpr_sdfg as tjsdfg, util if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from jace import translator def postprocess_jaxpr_sdfg( trans_ctx: translator.TranslationContext, - fun: Callable, # noqa: ARG001 # Currently unused - call_args: Sequence[Any], # noqa: ARG001 # Currently unused - intree: None, # noqa: ARG001 # Currently unused -) -> translator.TranslatedJaxprSDFG: + fun: Callable, # noqa: ARG001 [unused-function-argument] # Currently unused. + flat_call_args: Sequence[Any], + validate: bool = True, +) -> tjsdfg.TranslatedJaxprSDFG: """ - Perform the final post processing steps on the `TranslationContext` _in place_. + Final post processing steps on the `TranslationContext`. - The function will perform post processing stages on the context in place. - However, the function will return a decoupled `TranslatedJaxprSDFG` object. + While the function performs the post processing on the context in place, the + returned `TranslatedJaxprSDFG` will be decoupled from the input. Args: trans_ctx: The `TranslationContext` obtained from a `translate_jaxpr()` call. fun: The original function that was translated. - call_args: The linearized input arguments. - intree: The pytree describing the inputs. + flat_call_args: The flattened input arguments. + validate: Perform validation. Todo: - - Setting correct input names (layer that does not depend on JAX). - - Setting the correct strides & storage properties. - Fixing the scalar input problem on GPU. + - Fixing stride problem of the input. + - Make it such that the context is not modified as a side effect. """ - # Currently we do nothing except finalizing. - trans_ctx.validate() + trans_ctx.validate() # Always validate, it is cheap. + create_input_output_stages(trans_ctx=trans_ctx, flat_call_args=flat_call_args) + return finalize_translation_context(trans_ctx, validate=validate) - # - # Assume some post processing here. - # - return finalize_translation_context(trans_ctx, validate=True) +def create_input_output_stages( + trans_ctx: translator.TranslationContext, flat_call_args: Sequence[Any] +) -> None: + """ + Creates an input and output state inside the SDFG in place. + + See `_create_input_state()` and `_create_output_state()` for more information. + + Args: + trans_ctx: The translation context that should be modified. + flat_call_args: The flattened call arguments that should be used. + + Note: + The processed SDFG will remain canonical. + """ + _create_input_state(trans_ctx, flat_call_args) + _create_output_state(trans_ctx) + + +def _create_output_state(trans_ctx: translator.TranslationContext) -> None: + """ + Creates the output processing stage for the SDFG in place. + + The function will create a new terminal state, in which all outputs, denoted + in `trans_ctx.output_names`, will be written into new SDFG variables. In case the + output variable is a scalar, the output will be replaced by an array of length one. + This behaviour is consistent with JAX. + + Args: + trans_ctx: The translation context to process. + """ + assert trans_ctx.input_names is not None and trans_ctx.output_names is not None + + output_pattern = "__jace_output_{}" + sdfg = trans_ctx.sdfg + new_output_state: dace.SDFGState = sdfg.add_state("output_processing_stage") + new_output_names: list[str] = [] + + for i, org_output_name in enumerate(trans_ctx.output_names): + new_output_name = output_pattern.format(i) + org_output_desc: dace.data.Data = sdfg.arrays[org_output_name] + assert org_output_desc.transient + assert ( + new_output_name not in sdfg.arrays + ), f"Final output variable '{new_output_name}' is already present." + + if isinstance(org_output_desc, dace.data.Scalar): + _, new_output_desc = sdfg.add_array( + new_output_name, + dtype=org_output_desc.dtype, + shape=(1,), + transient=True, # Needed for an canonical SDFG + ) + memlet = dace.Memlet.simple(new_output_name, subset_str="0", other_subset_str="0") + + else: + new_output_desc = org_output_desc.clone() + sdfg.add_datadesc(new_output_name, new_output_desc) + memlet = dace.Memlet.from_array(org_output_name, org_output_desc) + + new_output_state.add_nedge( + new_output_state.add_read(org_output_name), + new_output_state.add_write(new_output_name), + memlet, + ) + new_output_names.append(new_output_name) + + sdfg.add_edge(trans_ctx.terminal_state, new_output_state, dace.InterstateEdge()) + trans_ctx.terminal_state = new_output_state + trans_ctx.output_names = tuple(new_output_names) + + +def _create_input_state( + trans_ctx: translator.TranslationContext, flat_call_args: Sequence[Any] +) -> None: + """ + Creates the input processing state for the SDFG in place. + + The function will create a new set of variables that are exposed as inputs. This + variables are based on the example input arguments passed through `flat_call_args`. + This process will hard code the memory location and strides into the SDFG. + The assignment is performed inside a new state, which is put at the beginning. + + Args: + trans_ctx: The translation context that should be modified. + flat_call_args: The flattened call arguments for which the input + state should be specialized. + + Todo: + Handle transfer of scalar input in GPU mode. + """ + assert trans_ctx.input_names is not None and trans_ctx.output_names is not None + + if len(flat_call_args) != len(trans_ctx.input_names): + raise ValueError(f"Expected {len(trans_ctx.input_names)}, but got {len(flat_call_args)}.") + + sdfg = trans_ctx.sdfg + new_input_state: dace.SDFGState = sdfg.add_state(f"{sdfg.name}__start_state") + new_input_names: list[str] = [] + input_pattern = "__jace_input_{}" + + for i, (org_input_name, call_arg) in enumerate(zip(trans_ctx.input_names, flat_call_args)): + org_input_desc: dace.data.Data = sdfg.arrays[org_input_name] + new_input_name = input_pattern.format(i) + + if isinstance(org_input_desc, dace.data.Scalar): + # TODO(phimuell): In GPU mode: scalar -> GPU_ARRAY -> Old input name + new_input_desc: dace.data.Scalar = org_input_desc.clone() + sdfg.add_datadesc(new_input_name, new_input_desc) + memlet = dace.Memlet.simple(new_input_name, subset_str="0", other_subset_str="0") + + else: + _, new_input_desc = sdfg.add_array( + name=new_input_name, + shape=org_input_desc.shape, + dtype=org_input_desc.dtype, + strides=util.get_strides_for_dace(call_arg), + transient=True, # For canonical SDFG. + storage=( + dace.StorageType.GPU_Global + if util.is_on_device(call_arg) + else dace.StorageType.CPU_Heap + ), + ) + memlet = dace.Memlet.from_array(new_input_name, new_input_desc) + + new_input_state.add_nedge( + new_input_state.add_read(new_input_name), + new_input_state.add_write(org_input_name), + memlet, + ) + new_input_names.append(new_input_name) + + sdfg.add_edge(new_input_state, trans_ctx.start_state, dace.InterstateEdge()) + sdfg.start_block = sdfg.node_id(new_input_state) + trans_ctx.start_state = new_input_state + trans_ctx.input_names = tuple(new_input_names) def finalize_translation_context( - trans_ctx: translator.TranslationContext, validate: bool = True -) -> translator.TranslatedJaxprSDFG: + trans_ctx: translator.TranslationContext, + validate: bool = True, +) -> tjsdfg.TranslatedJaxprSDFG: """ - Finalizes the supplied translation context `trans_ctx`. + Finalizes the translation context and returns a `TranslatedJaxprSDFG` object. - The function will process the SDFG that is encapsulated inside the context, - i.e. a canonical one, into a proper SDFG, as it is described in - `TranslatedJaxprSDFG`. It is important to realize that this function does - not perform any optimization of the underlying SDFG itself, instead it - prepares an SDFG such that it can be passed to the optimization pipeline. + The function will process the SDFG that is encapsulated inside the context, i.e. a + canonical one, into a proper SDFG, as it is described in `TranslatedJaxprSDFG`. It + is important to realize that this function does not perform any optimization of the + underlying SDFG itself, instead it prepares an SDFG such that it can be passed to + the optimization pipeline. - The function will not mutate the passed translation context and the output - is always decoupled from its output. + The returned object is fully decoupled from its input and `trans_ctx` is not + modified. Args: trans_ctx: The context that should be finalized. validate: Call the validate function after the finalizing. """ trans_ctx.validate() - if trans_ctx.inp_names is None: + if trans_ctx.input_names is None: raise ValueError("Input names are not specified.") - if trans_ctx.out_names is None: + if trans_ctx.output_names is None: raise ValueError("Output names are not specified.") + if not (trans_ctx.output_names or trans_ctx.input_names): + raise ValueError("No input nor output.") # We guarantee decoupling - tsdfg = translator.TranslatedJaxprSDFG( + tsdfg = tjsdfg.TranslatedJaxprSDFG( sdfg=copy.deepcopy(trans_ctx.sdfg), - inp_names=trans_ctx.inp_names, - out_names=trans_ctx.out_names, + input_names=trans_ctx.input_names, + output_names=trans_ctx.output_names, ) # Make inputs and outputs to globals. sdfg_arg_names: list[str] = [] - for glob_name in tsdfg.inp_names + tsdfg.out_names: - if glob_name in sdfg_arg_names: - continue - tsdfg.sdfg.arrays[glob_name].transient = False - sdfg_arg_names.append(glob_name) - - # This forces the signature of the SDFG to include all arguments in order they - # appear. If an argument is used as input and output then it is only listed as - # input. + for arg_name in tsdfg.input_names + tsdfg.output_names: + tsdfg.sdfg.arrays[arg_name].transient = False + sdfg_arg_names.append(arg_name) tsdfg.sdfg.arg_names = sdfg_arg_names if validate: tsdfg.validate() - return tsdfg diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py index dc3bd74..ab84c5d 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/primitive_translator.py @@ -14,21 +14,16 @@ from __future__ import annotations import abc +from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Literal, Protocol, cast, overload, runtime_checkable if TYPE_CHECKING: - from collections.abc import Callable, Sequence - import dace from jax import core as jax_core from jace import translator -#: Global registry of the active primitive translators. -#: The `dict` maps the name of a primitive to its associated translators. -_PRIMITIVE_TRANSLATORS_REGISTRY: dict[str, translator.PrimitiveTranslator] = {} - class PrimitiveTranslatorCallable(Protocol): """Callable version of the primitive translators.""" @@ -43,7 +38,7 @@ def __call__( eqn_state: dace.SDFGState, ) -> dace.SDFGState | None: """ - Translates the Jax primitive into its SDFG equivalent. + Translates the JAX primitive into its SDFG equivalent. Before the builder calls this function it will perform the following preparatory tasks: @@ -82,7 +77,7 @@ def __call__( SDFG for the inpts or `None` in case of a literal. out_var_names: List of the names of the arrays created inside the SDFG for the outputs. - eqn: The Jax primitive that should be translated. + eqn: The JAX primitive that should be translated. eqn_state: State into which the primitive`s SDFG representation should be constructed. """ @@ -92,7 +87,7 @@ def __call__( @runtime_checkable class PrimitiveTranslator(PrimitiveTranslatorCallable, Protocol): """ - Interface for all Jax primitive translators. + Interface for all JAX primitive translators. A translator for a primitive translates a single equation of a Jaxpr into its SDFG equivalent. For satisfying this interface a concrete implementation @@ -111,7 +106,7 @@ class PrimitiveTranslator(PrimitiveTranslatorCallable, Protocol): @property @abc.abstractmethod def primitive(self) -> str: - """Returns the name of the Jax primitive that `self` is able to handle.""" + """Returns the name of the JAX primitive that `self` is able to handle.""" ... @@ -140,7 +135,7 @@ def make_primitive_translator( that it satisfy the `PrimitiveTranslator` protocol. However, it does not add it to the registry, for that `register_primitive_translator()` has to be used. - Notes: + Note: This function can also be used as decorator. """ @@ -158,6 +153,17 @@ def wrapper( return wrapper if primitive_translator is None else wrapper(primitive_translator) +# <--------------------------- Managing translators + + +_PRIMITIVE_TRANSLATORS_REGISTRY: dict[str, translator.PrimitiveTranslator] = {} +"""Global registry of the active primitive translators. + +Use `register_primitive_translator()` to add a translator to the registry and +`get_registered_primitive_translators()` get the current active set. +""" + + @overload def register_primitive_translator( primitive_translator: Literal[None] = None, overwrite: bool = False @@ -192,9 +198,9 @@ def register_primitive_translator( overwrite: Replace the current primitive translator with `primitive_translator`. Note: - To add a `primitive` property use the `@make_primitive_translator` decorator. - This function returns `primitive_translator` unmodified, which allows it to be - used as decorator. + To add a `primitive` property use the `@make_primitive_translator` + decorator. This function returns `primitive_translator` unmodified, + which allows it to be used as decorator. """ def wrapper( diff --git a/src/jace/translator/primitive_translators/alu_translator.py b/src/jace/translator/primitive_translators/alu_translator.py index d865ee8..f217924 100644 --- a/src/jace/translator/primitive_translators/alu_translator.py +++ b/src/jace/translator/primitive_translators/alu_translator.py @@ -10,7 +10,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Final, cast +from collections.abc import Sequence +from typing import Any, Final, cast import dace import numpy as np @@ -20,10 +21,6 @@ from jace import translator, util -if TYPE_CHECKING: - from collections.abc import Sequence - - class ALUTranslator(translator.PrimitiveTranslator): """ This translator handles all arithmetic and logical operations. @@ -61,15 +58,15 @@ def __call__( builder: The builder object of the translation. in_var_names: List of the names of the arrays created inside the SDFG for the inpts or 'None' in case of a literal. out_var_names: List of the names of the arrays created inside the SDFG for the outputs. - eqn: The Jax equation that is translated. + eqn: The JAX equation that is translated. eqn_state: State into which the primitive's SDFG representation is constructed. """ assert self._prim_name == eqn.primitive.name # Determine what kind of input we got and how we should proceed. is_scalar = len(util.get_jax_var_shape(eqn.outvars[0])) == 0 - inp_scalars = [len(util.get_jax_var_shape(Inp)) == 0 for i, Inp in enumerate(eqn.invars)] - has_scalars_as_inputs = any(inp_scalars) + input_scalars = [len(util.get_jax_var_shape(Inp)) == 0 for i, Inp in enumerate(eqn.invars)] + has_scalars_as_inputs = any(input_scalars) has_some_literals = any(x is None for x in in_var_names) inps_same_shape = all( util.get_jax_var_shape(eqn.invars[0]) == util.get_jax_var_shape(eqn.invars[i]) @@ -101,19 +98,23 @@ def __call__( else: # This is the general broadcasting case # We assume that both inputs and the output have the same rank but different sizes in each dimension. - # It seems that Jax ensures this. + # It seems that JAX ensures this. # We further assume that if the size in a dimension differs then one must have size 1. # This is the size we broadcast over, i.e. conceptually replicated. out_shps = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output - inp_shpl = tuple(util.get_jax_var_shape(eqn.invars[0])) # Shape of the left/first input - inp_shpr = tuple( + input_shpl = tuple( + util.get_jax_var_shape(eqn.invars[0]) + ) # Shape of the left/first input + input_shpr = tuple( util.get_jax_var_shape(eqn.invars[1]) ) # Shape of the right/second input - if not ((len(inp_shpl) == len(inp_shpr)) and (len(out_shps) == len(inp_shpr))): + if not ((len(input_shpl) == len(input_shpr)) and (len(out_shps) == len(input_shpr))): raise NotImplementedError("Can not broadcast over different ranks.") - for dim, (shp_lft, shp_rgt, out_shp) in enumerate(zip(inp_shpl, inp_shpr, out_shps)): + for dim, (shp_lft, shp_rgt, out_shp) in enumerate( + zip(input_shpl, input_shpr, out_shps) + ): if shp_lft == shp_rgt: assert out_shp == shp_lft elif shp_lft == 1: @@ -139,7 +140,7 @@ def __call__( if in_var_names[i] is None: # Literal: No input needed. tskl_inputs.append((None, None)) continue - if inp_scalars[i]: # Scalar + if input_scalars[i]: # Scalar assert len(dims_to_bcast) == 0 i_memlet = dace.Memlet.simple(in_var_names[i], "0") else: # Array: We may have to broadcast diff --git a/src/jace/translator/translated_jaxpr_sdfg.py b/src/jace/translator/translated_jaxpr_sdfg.py deleted file mode 100644 index afa91ff..0000000 --- a/src/jace/translator/translated_jaxpr_sdfg.py +++ /dev/null @@ -1,71 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Container for storing a translated SDFG.""" - -from __future__ import annotations - -import dataclasses - -import dace - - -@dataclasses.dataclass(kw_only=True, frozen=True) -class TranslatedJaxprSDFG: - """ - Encapsulates a translated SDFG with additional the metadata. - - Contrary to the SDFG that is encapsulated inside the `TranslationContext` - object, `self` carries a proper SDFG, however: - - It does not have `__return*` variables, instead all return arguments are - passed by arguments. - - All input arguments are passed through arguments mentioned in `inp_names`, - while the outputs are passed through `out_names`. - - Only variables listed as in/outputs are non transient. - - The order inside `inp_names` and `out_names` is the same as in the original Jaxpr. - - If an input is used as outputs it appears in both `inp_names` and `out_names`. - - Its `arg_names` is set to `inp_names + out_names`, but arguments that are - input and outputs are only listed as inputs. - - The only valid way to obtain a `TranslatedJaxprSDFG` is by passing a - `TranslationContext`, that was in turn constructed by - `JaxprTranslationBuilder.translate_jaxpr()`, to the - `finalize_translation_context()` or preferably to the `postprocess_jaxpr_sdfg()` - function. - - Attributes: - sdfg: The encapsulated SDFG object. - inp_names: A list of the SDFG variables that are used as input - out_names: A list of the SDFG variables that are used as output. - """ - - sdfg: dace.SDFG - inp_names: tuple[str, ...] - out_names: tuple[str, ...] - - def validate(self) -> bool: - """Validate the underlying SDFG.""" - if any(self.sdfg.arrays[inp].transient for inp in self.inp_names): - raise dace.sdfg.InvalidSDFGError( - f"Found transient inputs: {(inp for inp in self.inp_names if self.sdfg.arrays[inp].transient)}", - self.sdfg, - self.sdfg.node_id(self.sdfg.start_state), - ) - if any(self.sdfg.arrays[out].transient for out in self.out_names): - raise dace.sdfg.InvalidSDFGError( - f"Found transient outputs: {(out for out in self.out_names if self.sdfg.arrays[out].transient)}", - self.sdfg, - self.sdfg.node_id(self.sdfg.start_state), - ) - if self.sdfg.free_symbols: # This is a simplification that makes our life simple. - raise dace.sdfg.InvalidSDFGError( - f"Found free symbols: {self.sdfg.free_symbols}", - self.sdfg, - self.sdfg.node_id(self.sdfg.start_state), - ) - self.sdfg.validate() - return True diff --git a/src/jace/util/__init__.py b/src/jace/util/__init__.py index ab73e4e..9532454 100644 --- a/src/jace/util/__init__.py +++ b/src/jace/util/__init__.py @@ -12,6 +12,7 @@ from .definitions import FORBIDDEN_SDFG_VAR_NAMES, VALID_SDFG_OBJ_NAME, VALID_SDFG_VAR_NAME from .jax_helper import ( JaCeVar, + get_jax_literal_value, get_jax_var_dtype, get_jax_var_name, get_jax_var_shape, @@ -20,7 +21,9 @@ translate_dtype, ) from .traits import ( + get_strides_for_dace, is_array, + is_c_contiguous, is_drop_var, is_fully_addressable, is_jax_array, @@ -34,10 +37,13 @@ "VALID_SDFG_OBJ_NAME", "VALID_SDFG_VAR_NAME", "JaCeVar", + "get_jax_literal_value", "get_jax_var_dtype", "get_jax_var_name", "get_jax_var_shape", + "get_strides_for_dace", "is_array", + "is_c_contiguous", "is_drop_var", "is_fully_addressable", "is_jax_array", diff --git a/src/jace/util/dace_helper.py b/src/jace/util/dace_helper.py deleted file mode 100644 index 1828fac..0000000 --- a/src/jace/util/dace_helper.py +++ /dev/null @@ -1,144 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements all utility functions that are related to DaCe.""" - -from __future__ import annotations - -import time -from typing import TYPE_CHECKING, Any - -import dace -import numpy as np -from dace import data as dace_data - -# The compiled SDFG is not available in the dace namespace or anywhere else -# Thus we import it here directly -from dace.codegen.compiled_sdfg import CompiledSDFG - -from jace import util - - -if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - - from jace import translator - -__all__ = ["CompiledSDFG", "compile_jax_sdfg", "run_jax_sdfg"] - - -def compile_jax_sdfg(tsdfg: translator.TranslatedJaxprSDFG) -> CompiledSDFG: - """Compiles the embedded SDFG and return the resulting `CompiledSDFG` object.""" - if any( # We do not support the DaCe return mechanism - array_name.startswith("__return") - for array_name in tsdfg.sdfg.arrays.keys() # noqa: SIM118 # we can not use `in` because we are also interested in `__return_`! - ): - raise ValueError("Only support SDFGs without '__return' members.") - - # To ensure that the SDFG is compiled and to get rid of a warning we must modify - # some settings of the SDFG. To fake an immutable SDFG, we will restore them later. - sdfg = tsdfg.sdfg - org_sdfg_name = sdfg.name - org_recompile = sdfg._recompile - org_regenerate_code = sdfg._regenerate_code - - try: - # We need to give the SDFG another name, this is needed to prevent a DaCe - # error/warning. This happens if we compile the same lowered SDFG multiple - # times with different options. - sdfg.name = f"{sdfg.name}__comp_{int(time.time() * 1000)}" - - with dace.config.temporary_config(): - sdfg._recompile = True - sdfg._regenerate_code = True - dace.Config.set("compiler", "use_cache", value=False) - csdfg: CompiledSDFG = sdfg.compile() - - finally: - sdfg.name = org_sdfg_name - sdfg._recompile = org_recompile - sdfg._regenerate_code = org_regenerate_code - - return csdfg - - -def run_jax_sdfg( - csdfg: CompiledSDFG, - inp_names: Sequence[str], - out_names: Sequence[str], - call_args: Sequence[Any], - call_kwargs: Mapping[str, Any], -) -> tuple[Any, ...] | Any: - """ - Run the compiled SDFG. - - The function assumes that the SDFG was finalized and then compiled by - `compile_jax_sdfg()`. For running the SDFG you also have to pass the input - names (`inp_names`) and output names (`out_names`) that were inside the - `TranslatedJaxprSDFG` from which `csdfg` was compiled from. - - Args: - csdfg: The `CompiledSDFG` object. - inp_names: List of names of the input arguments. - out_names: List of names of the output arguments. - call_args: All positional arguments of the call. - call_kwargs: All keyword arguments of the call. - - Note: - There is no pytree mechanism jet, thus the return values are returned - inside a `tuple` or in case of one value, directly, in the order - determined by Jax. Furthermore, DaCe does not support scalar return - values, thus they are silently converted into arrays of length 1, the - same holds for inputs. - - Todo: - - Implement non C strides. - """ - sdfg: dace.SDFG = csdfg.sdfg - - if len(call_kwargs) != 0: - raise NotImplementedError("No kwargs are supported yet.") - if len(inp_names) != len(call_args): - raise RuntimeError("Wrong number of arguments.") - if sdfg.free_symbols: # This is a simplification that makes our life simple. - raise NotImplementedError( - f"No externally defined symbols are allowed, found: {sdfg.free_symbols}" - ) - - # Build the argument list that we will pass to the compiled object. - sdfg_call_args: dict[str, Any] = {} - for in_name, in_val in zip(inp_names, call_args, strict=True): - if util.is_scalar(in_val): - # Currently the translator makes scalar into arrays, this has to be - # reflected here - in_val = np.array([in_val]) # noqa: PLW2901 # Loop variable is intentionally modified. - sdfg_call_args[in_name] = in_val - - for out_name, sdfg_array in ((out_name, sdfg.arrays[out_name]) for out_name in out_names): - if out_name in sdfg_call_args: - if util.is_jax_array(sdfg_call_args[out_name]): - # Jax arrays are immutable, so they can not be return values too. - raise ValueError("Passed a Jax array as output.") - else: - sdfg_call_args[out_name] = dace_data.make_array_from_descriptor(sdfg_array) - - assert len(sdfg_call_args) == len(csdfg.argnames), ( - "Failed to construct the call arguments," - f" expected {len(csdfg.argnames)} but got {len(call_args)}." - f"\nExpected: {csdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" - ) - - # Calling the SDFG - with dace.config.temporary_config(): - dace.Config.set("compiler", "allow_view_arguments", value=True) - csdfg(**sdfg_call_args) - - # Handling the output (pytrees are missing) - if not out_names: - return None - ret_val: tuple[Any] = tuple(sdfg_call_args[out_name] for out_name in out_names) - return ret_val[0] if len(out_names) == 1 else ret_val diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index 175671f..bc2de21 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -6,9 +6,9 @@ # SPDX-License-Identifier: BSD-3-Clause """ -Implements all utility functions that are related to Jax. +Implements all utility functions that are related to JAX. -Most of the functions defined here allow an unified access to Jax' internal in +Most of the functions defined here allow an unified access to JAX' internal in a consistent and stable way. """ @@ -16,17 +16,18 @@ import dataclasses import itertools -from typing import TYPE_CHECKING, Any +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, overload import dace +import jax import jax.core as jax_core -import numpy as np from jace import util if TYPE_CHECKING: - from collections.abc import Mapping + import numpy as np @dataclasses.dataclass(repr=True, frozen=True, eq=False) @@ -36,12 +37,12 @@ class JaCeVar: This class can be seen as some kind of substitute `jax.core.Var`. The main intention of this class is as an internal representation of values, as they - are used in Jax, but without the Jax machinery. As abstract values in Jax + are used in JAX, but without the JAX machinery. As abstract values in JAX this class has a datatype, which is a `dace.typeclass` instance and a shape. In addition it has an optional name, which allows to create variables with a certain name using `JaxprTranslationBuilder.add_array()`. - If it is expected that code must handle both Jax variables and `JaCeVar` + If it is expected that code must handle both JAX variables and `JaCeVar` then the `get_jax_var_*()` functions should be used. Args: @@ -51,7 +52,7 @@ class JaCeVar: Note: If the name of a `JaCeVar` is '_' it is considered a drop variable. The - definitions of `__hash__` and `__eq__` are in accordance with how Jax + definitions of `__hash__` and `__eq__` are in accordance with how JAX variable works. Todo: @@ -93,7 +94,7 @@ def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: # but leads to stable and valid names. return f"jax{jax_var.count}{jax_var.suffix}" case jax_core.Literal(): - raise TypeError("Can not derive a name from a Jax Literal.") + raise TypeError("Can not derive a name from a JAX Literal.") case _: raise TypeError( f"Does not know how to transform '{jax_var}' (type: '{type(jax_var).__name__}') " @@ -101,6 +102,14 @@ def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: ) +@overload +def get_jax_var_shape(jax_var: jax_core.Atom) -> tuple[int, ...]: ... + + +@overload +def get_jax_var_shape(jax_var: JaCeVar) -> tuple[int | dace.symbol | str, ...]: ... + + def get_jax_var_shape(jax_var: jax_core.Atom | JaCeVar) -> tuple[int | dace.symbol | str, ...]: """Returns the shape of `jax_var`.""" match jax_var: @@ -132,14 +141,24 @@ def is_tracing_ongoing(*args: Any, **kwargs: Any) -> bool: While a return value `True` guarantees that a translation is ongoing, a value of `False` does not guarantees that no tracing is ongoing. """ - # The current implementation only checks the arguments if it contains tracers. - if (len(args) == 0) and (len(kwargs) == 0): - raise RuntimeError("Failed to determine if tracing is ongoing.") - return any(isinstance(x, jax_core.Tracer) for x in itertools.chain(args, kwargs.values())) + # To detect if there is tracing ongoing, we check the internal tracing stack of JAX. + # Note that this is highly internal and depends on the precise implementation of + # JAX. For that reason we first look at all arguments and check if they are + # tracers. Furthermore, it seems that JAX always have a bottom interpreter on the + # stack, thus it is empty if `len(...) == 1`! + # See also: https://github.com/google/jax/pull/3370 + if any(isinstance(x, jax_core.Tracer) for x in itertools.chain(args, kwargs.values())): + return True + trace_stack_height = len(jax._src.core.thread_local_state.trace_state.trace_stack.stack) + if trace_stack_height == 1: + return False + if trace_stack_height > 1: + return True + raise RuntimeError("Failed to determine if tracing is ongoing.") def translate_dtype(dtype: Any) -> dace.typeclass: - """Turns a Jax datatype into a DaCe datatype.""" + """Turns a JAX datatype into a DaCe datatype.""" if dtype is None: raise NotImplementedError # Handling a special case in DaCe. if isinstance(dtype, dace.typeclass): @@ -169,7 +188,7 @@ def propose_jax_name( Args: jax_var: The variable for which a name to propose. - jax_name_map: A mapping of all Jax variables that were already named. + jax_name_map: A mapping of all JAX variables that were already named. Note: The function guarantees that the returned name passes `VALID_SDFG_VAR_NAME` @@ -185,7 +204,7 @@ def propose_jax_name( if isinstance(jax_var, JaCeVar) and (jax_var.name is not None): return jax_var.name - # This code is taken from Jax so it will generate similar ways, the difference is + # This code is taken from JAX so it will generate similar ways, the difference is # that we do the counting differently. # Note that `z` is followed by `ba` and not `aa` as it is in Excel. c = len(jax_name_map) @@ -209,9 +228,12 @@ def get_jax_literal_value(lit: jax_core.Atom) -> bool | float | int | np.generic if not isinstance(lit, jax_core.Literal): raise TypeError(f"Can only extract literals not '{type(lit)}'.") val = lit.val - if isinstance(val, np.ndarray): + # In previous versions of JAX literals were always 0-dim arrays, but it seems + # that in newer versions the values are either arrays or scalars. + # I saw both thus we have to keep both branches. + if util.is_array(val): assert val.shape == () - return val.max() - if isinstance(val, (bool, float, int)): + return val.dtype.type(val.max()) + if util.is_scalar(val): return val - raise TypeError(f"Failed to extract value from '{lit}'.") + raise TypeError(f"Failed to extract value from '{lit}' ('{val}' type: {type(val).__name__}).") diff --git a/src/jace/util/traits.py b/src/jace/util/traits.py index a8e6bc8..af2b290 100644 --- a/src/jace/util/traits.py +++ b/src/jace/util/traits.py @@ -30,18 +30,19 @@ def is_drop_var(jax_var: jax_core.Atom | util.JaCeVar) -> TypeGuard[jax_core.Dro def is_jax_array(obj: Any) -> TypeGuard[jax.Array]: """ - Tests if `obj` is a Jax array. + Tests if `obj` is a JAX array. Note: - Jax arrays are special as they can not be mutated. Furthermore, they always - allocate on the CPU _and_ on the GPU, if present. + JAX arrays are special as they can not be mutated. Furthermore, they always + allocate on the CPU _and_ on the GPU, if present. """ return isinstance(obj, jax.Array) -def is_array(obj: Any) -> bool: - """Identifies arrays, this also includes Jax arrays.""" - return dace.is_array(obj) or is_jax_array(obj) +def is_array(obj: Any) -> TypeGuard[jax.Array]: + """Identifies arrays, this also includes JAX arrays.""" + # `dace.is_array()` does not seem to recognise shape zero arrays. + return isinstance(obj, np.ndarray) or dace.is_array(obj) or is_jax_array(obj) def is_scalar(obj: Any) -> bool: @@ -74,11 +75,40 @@ def is_scalar(obj: Any) -> bool: return type(obj) in known_types +def get_strides_for_dace(obj: Any) -> tuple[int, ...] | None: + """ + Get the strides of `obj` in a DaCe compatible format. + + The function returns the strides in number of elements, as it is used inside + DaCe and not in bytes as it is inside NumPy. As in DaCe `None` is returned to + indicate standard C order. + + Note: + If `obj` is not array like an error is generated. + """ + if not is_array(obj): + raise TypeError(f"Passed '{obj}' ({type(obj).__name__}) is not array like.") + + if is_jax_array(obj): + if not is_fully_addressable(obj): + raise NotImplementedError("Sharded jax arrays are not supported.") + obj = obj.__array__() + assert hasattr(obj, "strides") + + if obj.strides is None: + return None + if not hasattr(obj, "itemsize"): + # No `itemsize` member so we assume that it is already in elements. + return obj.strides + + return tuple(stride // obj.itemsize for stride in obj.strides) + + def is_on_device(obj: Any) -> bool: """ Tests if `obj` is on a device. - Jax arrays are always on the CPU and GPU (if there is one). Thus for Jax + JAX arrays are always on the CPU and GPU (if there is one). Thus for JAX arrays this function is more of a test, if there is a GPU at all. """ if is_jax_array(obj): @@ -91,3 +121,12 @@ def is_fully_addressable(obj: Any) -> bool: if is_jax_array(obj): return obj.is_fully_addressable return True + + +def is_c_contiguous(obj: Any) -> bool: + """Tests if `obj` is in C order.""" + if not is_array(obj): + return False + if is_jax_array(obj): + obj = obj.__array__() + return obj.flags["C_CONTIGUOUS"] diff --git a/src/jace/util/translation_cache.py b/src/jace/util/translation_cache.py index f6366bd..bbb214c 100644 --- a/src/jace/util/translation_cache.py +++ b/src/jace/util/translation_cache.py @@ -21,11 +21,11 @@ import collections import dataclasses import functools -from collections.abc import Callable, Hashable +from collections.abc import Callable, Hashable, Sequence from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, cast import dace -from jax import core as jax_core +from jax import core as jax_core, tree_util as jax_tree from jace import util @@ -33,31 +33,33 @@ if TYPE_CHECKING: from jace import stages -#: Caches used to store the state transition. -#: The caches are on a per stage and not per instant basis. _TRANSLATION_CACHES: dict[type[CachingStage], StageCache] = {} +"""Caches used to store the state transition. + +The caches are on a per stage and not per instant basis. +""" -# Denotes the stage that follows the current one. -# Used by the `NextStage` Mixin. +# Type annotation for the caching. +P = ParamSpec("P") NextStage = TypeVar("NextStage", bound="stages.Stage") +TransitionFunction: TypeAlias = "Callable[Concatenate[CachingStage[NextStage], P], NextStage]" +CachingStageT = TypeVar("CachingStageT", bound="CachingStage") + +# Type to describe a single argument either in an abstract or concrete way. +CallArgsSpec: TypeAlias = tuple["_AbstractCallArgument | Hashable"] class CachingStage(Generic[NextStage]): """ Annotates a stage whose transition to the next stage is cacheable. - To make the transition of a stage cacheable, the stage must be derived from - this class, and its initialization must call `CachingStage.__init__()`. - Furthermore, its transition function must be annotated by the - `@cached_transition` decorator. - - A class must implement the `_make_call_description()` to compute an abstract - description of the call. This is needed to operate the cache to store the - stage transitions. - - Notes: - The `__init__()` function must explicitly be called to fully setup `self`. + To make a transition cacheable, a stage must: + - be derived from this class. + - its `__init__()` function must explicitly call `CachingStage.__init__()`. + - the transition function must be annotated by `@cached_transition`. + - it must implement the `_make_call_description()` to create the key. + - the stage object must be immutable. Todo: - Handle eviction from the cache due to collecting of unused predecessor stages. @@ -70,40 +72,47 @@ def __init__(self) -> None: @abc.abstractmethod def _make_call_description( - self: CachingStage, *args: Any, **kwargs: Any + self: CachingStage, in_tree: jax_tree.PyTreeDef, flat_call_args: Sequence[Any] ) -> StageTransformationSpec: - """Generates the key that is used to store/locate the call in the cache.""" - ... + """ + Computes the key used to represent the call. + This function is used by the `@cached_transition` decorator to perform + the lookup inside the cache. It should return a description of the call + that is encapsulated inside a `StageTransformationSpec` object, see + there for more information. -# Type annotation for the caching. -P = ParamSpec("P") -TransitionFunction = Callable[Concatenate[CachingStage[NextStage], P], NextStage] -CachingStageType = TypeVar("CachingStageType", bound=CachingStage) + Args: + in_tree: Pytree object describing how the input arguments were flattened. + flat_call_args: The flattened arguments that were passed to the + annotated function. + """ + ... def cached_transition( - transition: Callable[Concatenate[CachingStageType, P], NextStage], + transition: Callable[Concatenate[CachingStageT, P], NextStage], ) -> Callable[Concatenate[CachingStage[NextStage], P], NextStage]: """ Decorator for making the transition function of the stage cacheable. - In order to work, the stage must be derived from `CachingStage`. For computing - the key of a call the function will use the `_make_call_description()` - function of the cache. + See the description of `CachingStage` for the requirements. + The function will use `_make_call_description()` to decide if the call is + already known and if so it will return the cached object. If the call is + not known it will call the wrapped transition function and record its + return value inside the cache, before returning it. Todo: - Implement a way to temporary disable the cache. """ @functools.wraps(transition) - def transition_wrapper(self: CachingStageType, *args: P.args, **kwargs: P.kwargs) -> NextStage: - key: StageTransformationSpec = self._make_call_description(*args, **kwargs) - if key in self._cache: - return self._cache[key] - next_stage = transition(self, *args, **kwargs) - self._cache[key] = next_stage - return next_stage + def transition_wrapper(self: CachingStageT, *args: P.args, **kwargs: P.kwargs) -> NextStage: + flat_call_args, in_tree = jax_tree.tree_flatten((args, kwargs)) + key = self._make_call_description(flat_call_args=flat_call_args, in_tree=in_tree) + if key not in self._cache: + self._cache[key] = transition(self, *args, **kwargs) + return self._cache[key] return cast(TransitionFunction, transition_wrapper) @@ -129,17 +138,18 @@ class _AbstractCallArgument: As noted in `StageTransformationSpec` there are two ways to describe an argument, either by using its concrete value or an abstract description, - which is similar to tracers in Jax. This class represents the second way. + which is similar to tracers in JAX. This class represents the second way. To create an instance you should use `_AbstractCallArgument.from_value()`. - Its description is limited to scalars and arrays. To describe more complex - types, they should be processed by pytrees first. - Attributes: shape: In case of an array its shape, in case of a scalar the empty tuple. dtype: The DaCe type of the argument. strides: The strides of the argument, or `None` if they are unknown or a scalar. storage: The storage type where the argument is stored. + + Note: + This class is only able to describe scalars and arrays, thus it should + only be used after the arguments were flattened. """ shape: tuple[int, ...] @@ -153,15 +163,15 @@ def from_value(cls, value: Any) -> _AbstractCallArgument: if not util.is_fully_addressable(value): raise NotImplementedError("Distributed arrays are not addressed yet.") if isinstance(value, jax_core.Literal): - raise TypeError("Jax Literals are not supported as cache keys.") + raise TypeError("JAX Literals are not supported as cache keys.") if util.is_array(value): if util.is_jax_array(value): value = value.__array__() # Passing `copy=False` leads to error in NumPy. shape = value.shape dtype = util.translate_dtype(value.dtype) - strides = getattr(value, "strides", None) - # Is `CPU_Heap` always okay? There would also be `CPU_Pinned`. + strides = util.get_strides_for_dace(value) + # TODO(phimuell): `CPU_Heap` vs. `CPU_Pinned`. storage = ( dace.StorageType.GPU_Global if util.is_on_device(value) @@ -182,74 +192,74 @@ def from_value(cls, value: Any) -> _AbstractCallArgument: raise TypeError(f"Can not make 'an abstract description from '{type(value).__name__}'.") -#: This type is the abstract description of a function call. -#: It is part of the key used in the cache. -CallArgsSpec: TypeAlias = tuple[ - _AbstractCallArgument | Hashable | tuple[str, _AbstractCallArgument | Hashable], ... -] - - @dataclasses.dataclass(frozen=True) class StageTransformationSpec: """ Represents the entire call to a state transformation function of a stage. State transition functions are annotated with `@cached_transition` and their - result may be cached. They key to locate them inside the cache is represented + result is cached. They key to locate them inside the cache is represented by this class and computed by the `CachingStage._make_call_description()` - function. The actual key is consists of two parts, `stage_id` and `call_args`. + function. The actual key is consists of three parts, `stage_id`, `call_args` + and `in_tree`, see below for more. Args: stage_id: Origin of the call, for which the id of the stage object should be used. - call_args: Description of the arguments of the call. There are two ways - to describe the arguments: + flat_call_args: Flat representation of the arguments of the call. Each element + describes a single argument. To describe an argument there are two ways: - Abstract description: In this way, the actual value of the argument - is irrelevant, only the structure of them are important, similar - to the tracers used in Jax. - - Concrete description: Here one caches on the actual value of the - argument. The only requirement is that they can be hashed. + is irrelevant, its structure is important, similar to the tracers + used in JAX. To represent it, use `_AbstractCallArgument`. + - Concrete description: Here the actual value of the argument is + considered, this is similar to how static arguments in JAX works. + The only requirement is that they can be hashed. + in_tree: A pytree structure that describes how the input was flatten. """ stage_id: int - call_args: CallArgsSpec + flat_call_args: CallArgsSpec + in_tree: jax_tree.PyTreeDef -# Denotes the stage that is stored inside the cache. -StageType = TypeVar("StageType", bound="stages.Stage") +#: Denotes the stage that is stored inside the cache. +StageT = TypeVar("StageT", bound="stages.Stage") -class StageCache(Generic[StageType]): +class StageCache(Generic[StageT]): """ Simple LRU cache to cache the results of the stage transition function. Args: - size: The size of the cache, defaults to 256. + capacity: The size of the cache, defaults to 256. """ # The most recently used entry is at the end of the `OrderedDict`. - _memory: collections.OrderedDict[StageTransformationSpec, StageType] - _size: int - - def __init__(self, size: int = 256) -> None: + _memory: collections.OrderedDict[StageTransformationSpec, StageT] + _capacity: int + + def __init__( + self, + capacity: int = 256, + ) -> None: + self._capacity = capacity self._memory = collections.OrderedDict() - self._size = size def __contains__(self, key: StageTransformationSpec) -> bool: return key in self._memory - def __getitem__(self, key: StageTransformationSpec) -> StageType: + def __getitem__(self, key: StageTransformationSpec) -> StageT: if key not in self: raise KeyError(f"Key '{key}' is unknown.") self._memory.move_to_end(key, last=True) return self._memory[key] - def __setitem__(self, key: StageTransformationSpec, res: StageType) -> None: + def __setitem__(self, key: StageTransformationSpec, res: StageT) -> None: if key in self: self._memory.move_to_end(key, last=True) self._memory[key] = res else: - if len(self._memory) == self._size: + if len(self._memory) == self._capacity: self.popitem(None) self._memory[key] = res @@ -267,8 +277,19 @@ def popitem(self, key: StageTransformationSpec | None) -> None: self._memory.move_to_end(key, last=False) self._memory.popitem(last=False) - def clear(self) -> None: # noqa: D102 # Missing description. + def clear(self) -> None: # noqa: D102 [undocumented-public-method] self._memory.clear() + def __len__(self) -> int: + return len(self._memory) + + @property + def capacity(self) -> int: # noqa: D102 [undocumented-public-method] + return self._capacity + + def front(self) -> tuple[StageTransformationSpec, StageT]: + """Returns the front of the cache, i.e. its newest entry.""" + return next(reversed(self._memory.items())) + def __repr__(self) -> str: - return f"StageCache({len(self._memory)} / {self._size} || {', '.join('[' + repr(k) + ']' for k in self._memory)})" + return f"StageCache({len(self._memory)} / {self._capacity} || {', '.join('[' + repr(k) + ']' for k in self._memory)})" diff --git a/tests/test_caching.py b/tests/test_caching.py index bc0e44c..01fabc9 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -11,7 +11,6 @@ from __future__ import annotations import itertools as it -import re import numpy as np import pytest @@ -192,8 +191,11 @@ def jaceWrapped(A: np.ndarray, B: np.ndarray) -> np.ndarray: # Because of the way how things work the optimized must have more than the # unoptimized. If there is sharing, then this would not be the case. assert unoptiCompiled is not optiCompiled - assert optiCompiled._csdfg.sdfg.number_of_nodes() == 1 - assert optiCompiled._csdfg.sdfg.number_of_nodes() < unoptiCompiled._csdfg.sdfg.number_of_nodes() + assert optiCompiled._compiled_sdfg.sdfg.number_of_nodes() == 1 + assert ( + optiCompiled._compiled_sdfg.sdfg.number_of_nodes() + < unoptiCompiled._compiled_sdfg.sdfg.number_of_nodes() + ) def test_caching_dtype(): @@ -244,13 +246,9 @@ def wrapped(A: np.ndarray) -> np.ndarray: # But the cache is aware of this, which helps catch some nasty bugs. F_lower = None # Remove later F_res = C_res.copy() # Remove later - with pytest.raises( # noqa: PT012 # Multiple calls - expected_exception=NotImplementedError, - match=re.escape("Currently can not yet handle strides beside 'C_CONTIGUOUS'."), - ): - F_lower = wrapped.lower(F) - F_res = wrapped(F) - assert F_lower is None # Remove later. + F_lower = wrapped.lower(F) + F_res = wrapped(F) + assert F_lower is not C_lower assert C_res is not F_res assert np.allclose(F_res, C_res) assert F_lower is not C_lower diff --git a/tests/test_jaxpr_translator_builder.py b/tests/test_jaxpr_translator_builder.py deleted file mode 100644 index efc6657..0000000 --- a/tests/test_jaxpr_translator_builder.py +++ /dev/null @@ -1,539 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements some tests of the subtranslator builder.""" - -from __future__ import annotations - -import re - -import dace -import jax -import numpy as np -import pytest -from dace.data import Array -from jax import core as jax_core - -import jace -from jace import translator, util -from jace.util import JaCeVar - - -# These are some JaCe variables that we use inside the tests -# Unnamed arrays -array1 = JaCeVar((10, 12), dace.float64) -array2 = JaCeVar((10, 13), dace.float32) -array3 = JaCeVar((11, 16), dace.int64) - -# Unnamed scalars -scal1 = JaCeVar((), dace.float16) -scal2 = JaCeVar((), dace.float32) -scal3 = JaCeVar((), dace.int64) - -# Named variables -narray = JaCeVar((10,), dace.float16, "narr") -nscal = JaCeVar((), dace.int32, "nscal") - - -@pytest.fixture() -def translation_builder(): - """Returns an allocated builder instance.""" - name = "fixture_builder" - builder = translator.JaxprTranslationBuilder( - primitive_translators=translator.get_registered_primitive_translators() - ) - builder._allocate_translation_ctx(name=name) - return builder - - -def test_builder_alloc() -> None: - """Tests the state right after allocation. - - Does not use the fixture because it does it on its own. - """ - builder = translator.JaxprTranslationBuilder( - primitive_translators=translator.get_registered_primitive_translators() - ) - assert not builder.is_allocated(), "Builder was created allocated." - assert len(builder._ctx_stack) == 0 - - # The reserved names will be tested in `test_builder_fork()`. - sdfg_name = "qwertzuiopasdfghjkl" - builder._allocate_translation_ctx(name=sdfg_name) - assert len(builder._ctx_stack) == 1 - assert builder.is_root_translator() - - sdfg: dace.SDFG = builder.sdfg - - assert builder._ctx.sdfg is sdfg - assert builder.sdfg.name == sdfg_name - assert sdfg.number_of_nodes() == 1 - assert sdfg.number_of_edges() == 0 - assert sdfg.start_block is builder._ctx.start_state - assert builder._terminal_sdfg_state is builder._ctx.start_state - - -def test_builder_variable_alloc_auto_naming( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests simple variable allocation.""" - for i, var in enumerate([array1, array2, scal1, array3, scal2, scal3]): - sdfg_name = translation_builder.add_array(var, update_var_mapping=True) - sdfg_var = translation_builder.get_array(sdfg_name) - assert sdfg_name == chr(97 + i) - assert isinstance(sdfg_var, Array) # Everything is now an array - assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) - assert sdfg_var.dtype == var.dtype - - -def test_builder_variable_alloc_mixed_naming( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests the naming in a mixed setting. - - If `update_var_mapping=True` is given, then the naming will skip variables, - see also `test_builder_variable_alloc_mixed_naming2()`. - """ - # * b c d * f g - for i, var in enumerate([narray, array1, array2, scal1, nscal, scal2, scal3]): - sdfg_name = translation_builder.add_array(var, update_var_mapping=True) - sdfg_var = translation_builder.get_array(sdfg_name) - if var.name is None: - assert sdfg_name == chr(97 + i) - else: - assert sdfg_name == var.name - assert isinstance(sdfg_var, Array) # Everything is now an array - assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) - assert sdfg_var.dtype == var.dtype - - -def test_builder_variable_alloc_mixed_naming2( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests the naming in a mixed setting. - - This time we do not use `update_var_mapping=True`, instead it now depends on the - name. This means that automatic naming will now again include all, letters, but not - in a linear order. - """ - letoff = 0 - # * a b c * d e - for var in [narray, array1, array2, scal1, nscal, scal2, scal3]: - sdfg_name = translation_builder.add_array(var, update_var_mapping=var.name is None) - sdfg_var = translation_builder.get_array(sdfg_name) - if var.name is None: - assert sdfg_name == chr(97 + letoff) - letoff += 1 - else: - assert sdfg_name == var.name - assert isinstance(sdfg_var, Array) # Everything is now an array - assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) - assert sdfg_var.dtype == var.dtype - - -def test_builder_variable_alloc_prefix_naming( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Using the prefix to name variables.""" - prefix_1 = "__my_special_prefix" - exp_name_1 = prefix_1 + "a" - sdfg_name_1 = translation_builder.add_array( - array1, name_prefix=prefix_1, update_var_mapping=False - ) - assert exp_name_1 == sdfg_name_1 - - # Because `update_var_mapping` is `False` above, 'a' will be reused. - prefix_2 = "__my_special_prefix_second_" - exp_name_2 = prefix_2 + "a" - sdfg_name_2 = translation_builder.add_array( - array1, name_prefix=prefix_2, update_var_mapping=False - ) - assert exp_name_2 == sdfg_name_2 - - # Now we use a named variables, which are also affected. - prefix_3 = "__my_special_prefix_third_named_" - exp_name_3 = prefix_3 + nscal.name # type: ignore[operator] # `.name` is not `None`. - sdfg_name_3 = translation_builder.add_array( - nscal, name_prefix=prefix_3, update_var_mapping=False - ) - assert exp_name_3 == sdfg_name_3 - - -def test_builder_variable_alloc_auto_naming_wrapped( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests the variable naming if we have more than 26 variables.""" - single_letters = [chr(x) for x in range(97, 123)] - i = 0 - for let1 in ["", *single_letters[1:]]: # Note `z` is followed by `ba` and not by `aa`. - for let2 in single_letters: - i += 1 - # Create a variable and enter it into the variable naming. - var = JaCeVar(shape=(19, 19), dtype=dace.float64) - sdfg_name = translation_builder.add_array(arg=var, update_var_mapping=True) - mapped_name = translation_builder.map_jax_var_to_sdfg(var) - assert ( - sdfg_name == mapped_name - ), f"Mapping for '{var}' failed, expected '{sdfg_name}' got '{mapped_name}'." - - # Get the name that we really expect, we must also handle some situations. - exp_name = let1 + let2 - if exp_name in util.FORBIDDEN_SDFG_VAR_NAMES: - exp_name = "__jace_forbidden_" + exp_name - assert ( - exp_name == sdfg_name - ), f"Automated naming failed, expected '{exp_name}' but got '{sdfg_name}'." - - -def test_builder_nested(translation_builder: translator.JaxprTranslationBuilder) -> None: - """Tests the ability of the nesting of the builder.""" - - # Now add a variable to the current subtext. - name_1 = translation_builder.add_array(array1, update_var_mapping=True) - assert name_1 == "a" - assert translation_builder.map_jax_var_to_sdfg(array1) == name_1 - - # For the sake of doing it add a new state to the SDFG. - translation_builder.append_new_state("sake_state") - assert translation_builder.sdfg.number_of_nodes() == 2 - assert translation_builder.sdfg.number_of_edges() == 1 - - # Now we go one subcontext deeper - translation_builder._allocate_translation_ctx("builder") - assert len(translation_builder._ctx_stack) == 2 - assert translation_builder.sdfg.name == "builder" - assert translation_builder.sdfg.number_of_nodes() == 1 - assert translation_builder.sdfg.number_of_edges() == 0 - assert not translation_builder.is_root_translator() - - # Because we have a new SDFG the mapping to previous SDFG does not work, - # regardless the fact that it still exists. - with pytest.raises( - expected_exception=KeyError, - match=re.escape( - f"Jax variable '{array1}' was supposed to map to '{name_1}', but no such SDFG variable is known." - ), - ): - _ = translation_builder.map_jax_var_to_sdfg(array1) - - # Because the SDFGs are distinct it is possible to add `array1` to the nested one. - # However, it is not able to update the mapping. - with pytest.raises( - expected_exception=ValueError, - match=re.escape(f"Cannot change the mapping of '{array1}' from '{name_1}' to '{name_1}'."), - ): - _ = translation_builder.add_array(array1, update_var_mapping=True) - assert name_1 not in translation_builder.sdfg.arrays - - # Without updating the mapping it is possible create the variable. - assert name_1 == translation_builder.add_array(array1, update_var_mapping=False) - - # Now add a new variable, the map is shared, so a new name will be generated. - name_2 = translation_builder.add_array(array2, update_var_mapping=True) - assert name_2 == "b" - assert name_2 == translation_builder.map_jax_var_to_sdfg(array2) - - # Now we go one stack level back. - translation_builder._clear_translation_ctx() - assert len(translation_builder._ctx_stack) == 1 - assert translation_builder.sdfg.number_of_nodes() == 2 - assert translation_builder.sdfg.number_of_edges() == 1 - - # Again the variable that was declared in the last stack is now no longer present. - # Note if the nested SDFG was integrated into the parent SDFG it would be - # accessible - with pytest.raises( - expected_exception=KeyError, - match=re.escape( - f"Jax variable '{array2}' was supposed to map to '{name_2}', but no such SDFG variable is known." - ), - ): - _ = translation_builder.map_jax_var_to_sdfg(array2) - assert name_2 == translation_builder._jax_name_map[array2] - - # Now add a new variable, since the map is shared, we will now get the next name. - name_3 = translation_builder.add_array(array3, update_var_mapping=True) - assert name_3 == "c" - assert name_3 == translation_builder.map_jax_var_to_sdfg(array3) - - -def test_builder_append_state(translation_builder: translator.JaxprTranslationBuilder) -> None: - """Tests the functionality of appending states.""" - sdfg: dace.SDFG = translation_builder.sdfg - - terminal_state_1: dace.SDFGState = translation_builder.append_new_state("terminal_state_1") - assert sdfg.number_of_nodes() == 2 - assert sdfg.number_of_edges() == 1 - assert terminal_state_1 is translation_builder._terminal_sdfg_state - assert translation_builder._terminal_sdfg_state is translation_builder._ctx.terminal_state - assert translation_builder._ctx.start_state is sdfg.start_block - assert translation_builder._ctx.start_state is not terminal_state_1 - assert next(iter(sdfg.edges())).src is sdfg.start_block - assert next(iter(sdfg.edges())).dst is terminal_state_1 - - # Specifying an explicit append state that is the terminal should also update the - # terminal state of the builder. - terminal_state_2: dace.SDFGState = translation_builder.append_new_state( - "terminal_state_2", prev_state=terminal_state_1 - ) - assert sdfg.number_of_nodes() == 3 - assert sdfg.number_of_edges() == 2 - assert terminal_state_2 is translation_builder._terminal_sdfg_state - assert sdfg.out_degree(terminal_state_1) == 1 - assert sdfg.out_degree(terminal_state_2) == 0 - assert sdfg.in_degree(terminal_state_2) == 1 - assert next(iter(sdfg.in_edges(terminal_state_2))).src is terminal_state_1 - - # Specifying a previous node that is not the terminal state should not do anything. - non_terminal_state: dace.SDFGState = translation_builder.append_new_state( - "non_terminal_state", prev_state=terminal_state_1 - ) - assert translation_builder._terminal_sdfg_state is not non_terminal_state - assert sdfg.in_degree(non_terminal_state) == 1 - assert sdfg.out_degree(non_terminal_state) == 0 - assert next(iter(sdfg.in_edges(non_terminal_state))).src is terminal_state_1 - - -def test_builder_variable_multiple_variables( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Add an already known variable, but with a different name.""" - # Now we will add `array1` and then different ways of updating it. - narray1: str = translation_builder.add_array(array1, update_var_mapping=True) - - # It will fail if we use the prefix, because we also want to update. - prefix = "__jace_prefix" - prefix_expected_name = prefix + narray1 - with pytest.raises( - expected_exception=ValueError, - match=re.escape( - f"Cannot change the mapping of '{array1}' from '{translation_builder.map_jax_var_to_sdfg(array1)}' to '{prefix_expected_name}'." - ), - ): - _ = translation_builder.add_array(array1, update_var_mapping=True, name_prefix=prefix) - assert prefix_expected_name not in translation_builder.sdfg.arrays - - # But if we do not want to update it then it works. - prefix_sdfg_name = translation_builder.add_array( - array1, update_var_mapping=False, name_prefix=prefix - ) - assert prefix_expected_name == prefix_sdfg_name - assert prefix_expected_name in translation_builder.sdfg.arrays - assert narray1 == translation_builder.map_jax_var_to_sdfg(array1) - - -def test_builder_variable_invalid_prefix( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Use invalid prefix.""" - # It will fail if we use the prefix, because we also want to update. - for iprefix in ["0_", "_ja ", "_!"]: - with pytest.raises( - expected_exception=ValueError, - match=re.escape(f"add_array({array1}): The proposed name '{iprefix}a', is invalid."), - ): - _ = translation_builder.add_array(array1, update_var_mapping=False, name_prefix=iprefix) - assert len(translation_builder.sdfg.arrays) == 0 - - -def test_builder_variable_alloc_list( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api.""" - var_list_1 = [array1, nscal, scal2] - exp_names_1 = ["a", nscal.name, "c"] - - res_names_1 = translation_builder.create_jax_var_list(var_list_1, update_var_mapping=True) - assert len(translation_builder.arrays) == 3 - assert res_names_1 == exp_names_1 - - # Now a mixture of the collection and creation. - var_list_2 = [array2, nscal, scal1] - exp_names_2 = ["d", nscal.name, "e"] - - res_names_2 = translation_builder.create_jax_var_list(var_list_2, update_var_mapping=True) - assert res_names_2 == exp_names_2 - assert len(translation_builder.arrays) == 5 - - -@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") -def test_builder_variable_alloc_list_cleaning( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. - - It will fail because `update_var_mapping=False` thus the third variable will - cause an error because it is proposed to `a`, which is already used. - """ - var_list = [array1, nscal, scal2] - - with pytest.raises( - expected_exception=ValueError, - match=re.escape(f"add_array({scal2}): The proposed name 'a', is used."), - ): - _ = translation_builder.create_jax_var_list(var_list) - - assert len(translation_builder.arrays) == 0 - - -def test_builder_variable_alloc_list_prevent_creation( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. - - It will test the `prevent_creation` flag. - """ - # First create a variable. - translation_builder.add_array(array1, update_var_mapping=True) - assert len(translation_builder.arrays) == 1 - - # Now create the variables - var_list = [array1, array2] - - with pytest.raises( - expected_exception=ValueError, - match=re.escape(f"'prevent_creation' given but have to create '{array2}'."), - ): - translation_builder.create_jax_var_list(var_list, prevent_creation=True) - assert len(translation_builder.arrays) == 1 - assert translation_builder.map_jax_var_to_sdfg(array1) == "a" - - -@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") -def test_builder_variable_alloc_list_only_creation( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. - - It will test the `only_creation` flag. - """ - # First create a variable. - translation_builder.add_array(array1, update_var_mapping=True) - assert len(translation_builder.arrays) == 1 - - # Now create the variables - var_list = [array2, array1] - - with pytest.raises( - expected_exception=ValueError, - match=re.escape(f"'only_creation' given '{array1}' already exists."), - ): - translation_builder.create_jax_var_list(var_list, only_creation=True) - assert len(translation_builder.arrays) == 1 - assert translation_builder.map_jax_var_to_sdfg(array1) == "a" - - -def test_builder_variable_alloc_list_handle_literal( - translation_builder: translator.JaxprTranslationBuilder, -) -> None: - """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. - - It will test the `handle_literals` flag. - """ - - val = np.array(1) - aval = jax_core.get_aval(val) - lit = jax_core.Literal(val, aval) - var_list = [lit] - - with pytest.raises( - expected_exception=ValueError, - match=re.escape("Encountered a literal but `handle_literals` was `False`."), - ): - translation_builder.create_jax_var_list(var_list, handle_literals=False) - assert len(translation_builder.arrays) == 0 - - name_list = translation_builder.create_jax_var_list(var_list, handle_literals=True) - assert len(translation_builder.arrays) == 0 - assert name_list == [None] - - -def test_builder_constants(translation_builder: translator.JaxprTranslationBuilder) -> None: - """Tests part of the `JaxprTranslationBuilder._create_constants()` api. - - See also the `test_subtranslators_alu.py::test_add3` test. - """ - # Create the Jaxpr that we need. - constant = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - jaxpr = jax.make_jaxpr(lambda A: A + jax.numpy.array(constant))(1.0) - - # We have to manually allocate the builder context. - # You should not do that. - translation_builder._allocate_translation_ctx(name="Manual_test") - - # No create the constants. - translation_builder._create_constants(jaxpr) - - # Test if it was created with the correct value. - assert len(translation_builder.arrays) == 1 - assert len(translation_builder._jax_name_map) == 1 - assert next(iter(translation_builder._jax_name_map.values())) == "__const_a" - assert len(translation_builder.sdfg.constants) == 1 - assert np.all(translation_builder.sdfg.constants["__const_a"] == constant) - - -def test_builder_scalar_return_value() -> None: - """Tests if scalars can be returned directly.""" - - def scalar_ops(A: float) -> float: - return A + A - A * A - - lower_cnt = [0] - - @jace.jit - def wrapped(A: float) -> float: - lower_cnt[0] += 1 - return scalar_ops(A) - - vals = np.random.random(100) # noqa: NPY002 - for i in range(vals.size): - res = wrapped(vals[i]) - ref = scalar_ops(vals[i]) - assert np.allclose(res, ref) - assert lower_cnt[0] == 1 - - -@pytest.mark.skip(reason="Currently 'scalar' return values, are actually shape '(1,)' arrays.") -def test_builder_scalar_return_type() -> None: - """Tests if the type is the same, in case of scalar return.""" - - @jace.jit - def wrapped(A: np.float64) -> np.float64: - return A + A - A * A - - A = np.float64(1.0) - assert type(A) is np.float64, f"Expected type 'np.float64', but got '{type(A).__name__}'." - - -def test_builder_jace_var() -> None: - """Simple tests about the `JaCeVar` objects.""" - for iname in ["do", "", "_ _", "9al", "_!"]: - with pytest.raises( - expected_exception=ValueError, match=re.escape(f"Supplied the invalid name '{iname}'.") - ): - _ = JaCeVar((), dace.int8, name=iname) - - -def test_builder_F_strides() -> None: - """Tests if we can lower without a standard stride. - - Notes: - This tests if the restriction is currently in place. - See also `tests/test_caching.py::test_caching_strides`. - """ - - @jace.jit - def testee(A: np.ndarray) -> np.ndarray: - return A + 10.0 - - F = np.full((4, 3), 10, dtype=np.float64, order="F") - - with pytest.raises( - expected_exception=NotImplementedError, - match=re.escape("Currently can not yet handle strides beside 'C_CONTIGUOUS'."), - ): - _ = testee(F) diff --git a/tests/test_subtranslator_helper.py b/tests/test_subtranslator_helper.py index 56b30fb..a4c4ad9 100644 --- a/tests/test_subtranslator_helper.py +++ b/tests/test_subtranslator_helper.py @@ -177,7 +177,11 @@ def foo(A): D = C + 1 return D + 1 - _ = foo.lower(1) + with pytest.warns( + UserWarning, + match=re.escape('Use of uninitialized transient "e" in state output_processing_stage'), + ): + _ = foo.lower(1) assert trans_cnt[0] == 4