Skip to content

Commit

Permalink
feat: Finalizing the initial infrastructure of JaCe (#18)
Browse files Browse the repository at this point in the history
This PR introduces some missing pieces into mainline JaCe.

While the changes are very related to each other, the features this PR
introduces were not split.
A short summary of what it introduces:
- The stages now support pytrees as input and output.
- Inside the translators scalars and arrays are now distinguished.
- Support for arrays that are not in continuous C order.
- Type annotation in the wrapped objects (there are technical
limitations, see note in `src/jace/stages.py` for more).
- General changes in the organization of the code.
- Possibility to globally control the optimization levels (currently not
that useful).

However, this commit only introduces the state of the development branch
regarding the basic infrastructure, i.e. (mostly) `src/jace`, but leaves
out the translators that the development branch has and its tests, to
keep the PR small.
  • Loading branch information
philip-paul-mueller authored Jul 2, 2024
1 parent 1dfa79a commit 19c89b0
Show file tree
Hide file tree
Showing 21 changed files with 1,155 additions and 1,196 deletions.
14 changes: 13 additions & 1 deletion CODING_GUIDELINES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the

- According to subsection [_3.19.12 Imports For Typing_](https://google.github.io/styleguide/pyguide.html#31912-imports-for-typing), symbols from `typing` and `collections.abc` modules used in type annotations _"can be imported directly to keep common annotations concise and match standard typing practices"_. Following the same spirit, we allow symbols to be imported directly from third-party or internal modules when they only contain a collection of frequently used typying definitions.

### Aliasing of Modules

According to subsection [2.2](https://google.github.io/styleguide/pyguide.html#22-imports) in certain cases it is allowed to introduce an alias for an import.
Inside JaCe the following convention is applied:

- If the module has a standard abbreviation use that, e.g. `import numpy as np`.
- For a JaCe module use:
- If the module name is only a single word use it directly, e.g. `from jace import translator`.
- If the module name consists of multiple words use the last word prefixed with the first letters of the others, e.g. `from jace.translator import post_translator as ptranslator` or `from jace import translated_jaxpr_sdfg as tjsdfg`.
- In case of a clash use your best judgment.
- For an external module use the rule above, but prefix the name with the main package's name, e.g. `from dace.codegen import compiled_sdfg as dace_csdfg`.

### Python usage recommendations

- `pass` vs `...` (`Ellipsis`)
Expand Down Expand Up @@ -104,7 +116,7 @@ We generate the API documentation automatically from the docstrings using [Sphin
Sphinx supports the [reStructuredText][sphinx-rest] (reST) markup language for defining additional formatting options in the generated documentation, however section [_3.8 Comments and Docstrings_](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) of the Google Python Style Guide does not specify how to use markups in docstrings. As a result, we decided to forbid reST markup in docstrings, except for the following cases:

- Cross-referencing other objects using Sphinx text roles for the [Python domain](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#the-python-domain) (as explained [here](https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#python-roles)).
- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. ``` ``literal`` ``` strings, bulleted lists).
- Very basic formatting markup to improve _readability_ of the generated documentation without obscuring the source docstring (e.g. `` `literal` `` strings, bulleted lists).

We highly encourage the [doctest] format for code examples in docstrings. In fact, doctest runs code examples and makes sure they are in sync with the codebase.

Expand Down
2 changes: 1 addition & 1 deletion src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

import jace.translator.primitive_translators as _ # noqa: F401 # Populate the internal registry.
import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry.

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit
Expand Down
50 changes: 31 additions & 19 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,68 +10,80 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Literal, overload
from collections.abc import Callable, Mapping
from typing import Literal, ParamSpec, TypedDict, TypeVar, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator


if TYPE_CHECKING:
from collections.abc import Callable, Mapping
__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")
_R = TypeVar("_R")

__all__ = ["grad", "jacfwd", "jacrev", "jit"]

class JITOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.
Note:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
"""


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaCeWrapped]: ...
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ...


@overload
def jit(
fun: Callable,
fun: Callable[_P, _R],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped: ...
**kwargs: Unpack[JITOptions],
) -> stages.JaCeWrapped[_P, _R]: ...


def jit(
fun: Callable | None = None,
fun: Callable[_P, _R] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.
It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
It works the same way as `jax.jit` does, but instead of lowering the
computation to XLA, it is lowered to DaCe.
The function supports a subset of the arguments that are accepted by `jax.jit()`,
currently none, and some JaCe specific ones.
Args:
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
Notes:
After constructions any change to `primitive_translators` has no effect.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaCeWrapped:
# TODO(egparedes): Improve typing.
def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
33 changes: 23 additions & 10 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"""
JaCe specific optimizations.
Currently just a dummy exists for the sake of providing a callable function.
Todo:
Organize this module once it is a package.
"""

from __future__ import annotations
Expand All @@ -19,7 +20,20 @@


if TYPE_CHECKING:
from jace import translator
from jace import translated_jaxpr_sdfg as tjsdfg


DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"simplify": True,
"persistent_transients": True,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent_transients": False,
}


class CompilerOptions(TypedDict, total=False):
Expand All @@ -35,15 +49,10 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent_transients: bool


# TODO(phimuell): Add a context manager to modify the default.
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False}


def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 [undocumented-param]
"""
Performs optimization of the translated SDFG _in place_.
Expand All @@ -55,8 +64,12 @@ def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[Compil
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
persistent_transients: Set the allocation lifetime of (non register) transients
in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated
between different invocations.
"""
# Currently this function exists primarily for the same of existing.
# TODO(phimuell): Implement the functionality.
# Currently this function exists primarily for the sake of existing.

simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)
Expand Down
Loading

0 comments on commit 19c89b0

Please sign in to comment.