Skip to content

Commit

Permalink
Removed the type annotation of the return value in the api and stage.
Browse files Browse the repository at this point in the history
As explained elsewhere the annotation was wrong.
Simple example:
```python
@jace.jit
def foo(a: np.ndarray) -> np.float64:
	return np.float64(a.sum())
```
The actual return value would be `np.ndarray`, although one dimension.
This limitation is also present in JAX.
  • Loading branch information
philip-paul-mueller committed Oct 2, 2024
1 parent fab9131 commit 951f857
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
15 changes: 7 additions & 8 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import functools
from collections.abc import Callable, Mapping
from typing import Literal, ParamSpec, TypedDict, TypeVar, overload
from typing import Any, Literal, ParamSpec, TypedDict, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack
Expand All @@ -22,7 +22,6 @@
__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")
_R = TypeVar("_R")


class JITOptions(TypedDict, total=False):
Expand All @@ -46,24 +45,24 @@ def jit(
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ...
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]]: ...


@overload
def jit(
fun: Callable[_P, _R],
fun: Callable[_P, Any],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> stages.JaCeWrapped[_P, _R]: ...
) -> stages.JaCeWrapped[_P]: ...


def jit(
fun: Callable[_P, _R] | None = None,
fun: Callable[_P, Any] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]:
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]] | stages.JaCeWrapped[_P]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.
Expand All @@ -88,7 +87,7 @@ def jit(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(not_supported_jit_keys)}."
)

def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
35 changes: 17 additions & 18 deletions src/jace/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import contextlib
import copy
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Union

from jax import tree_util as jax_tree

Expand All @@ -57,20 +57,19 @@
#: Known compilation stages in JaCe.
Stage = Union["JaCeWrapped", "JaCeLowered", "JaCeCompiled"]

# These are used to annotated the `Stages`, however, there are some limitations.
# First, the only stage that is fully annotated is `JaCeWrapped`. Second, since
# static arguments modify the type signature of `JaCeCompiled.__call__()`, see
# Used to annotated the `Stages`, however, there are some limitations. First, the
# only stage that is fully annotated is `JaCeWrapped`. The reason is because static
# arguments modify the type signature of `JaCeCompiled.__call__()`, see
# [JAX](https://jax.readthedocs.io/en/latest/aot.html#lowering-with-static-arguments)
# for more, its argument can not be annotated, only its return type can.
# However, in case of scalar return values, the return type is wrong anyway, since
# JaCe and JAX for that matter, transforms scalars to arrays. Since there is no way of
# changing that, but from a semantic point they behave the same so it should not
# matter too much.
# Furthermore, the return value is not annotated. The reason for this is that JAX
# turns all arrays (NumPy, CuPy, ...) into JAX arrays, even scalar, the original
# annotation is wrong for the annotated function. It is even wrong if it would be a
# pytree, since `dict[np.ndarray]` would be translated into `dict[jax.Array]`.
_P = ParamSpec("_P")
_R = TypeVar("_R")


class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P, _R]):
class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P]):
"""
A function ready to be specialized, lowered, and compiled.
Expand Down Expand Up @@ -100,13 +99,13 @@ class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P, _R]):
which is implicitly and temporary activated during tracing.
"""

_fun: Callable[_P, _R]
_fun: Callable[_P, Any]
_primitive_translators: dict[str, translator.PrimitiveTranslator]
_jit_options: api.JITOptions

def __init__(
self,
fun: Callable[_P, _R],
fun: Callable[_P, Any],
primitive_translators: Mapping[str, translator.PrimitiveTranslator],
jit_options: api.JITOptions,
) -> None:
Expand All @@ -115,7 +114,7 @@ def __init__(
self._jit_options = {**jit_options}
self._fun = fun

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
"""
Executes the wrapped function, lowering and compiling as needed in one step.
Expand All @@ -137,7 +136,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return compiled(*args, **kwargs)

@tcache.cached_transition
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> JaCeLowered[_R]:
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> JaCeLowered:
"""
Lower the wrapped computation for the given arguments.
Expand Down Expand Up @@ -207,7 +206,7 @@ def _make_call_description(
)


class JaCeLowered(tcache.CachingStage["JaCeCompiled"], Generic[_R]):
class JaCeLowered(tcache.CachingStage["JaCeCompiled"]):
"""
Represents the original computation as an SDFG.
Expand Down Expand Up @@ -251,7 +250,7 @@ def __init__(
self._device = util.parse_backend_jit_option(device)

@tcache.cached_transition
def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled[_R]:
def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled:
"""
Optimize and compile the lowered SDFG using `compiler_options`.
Expand Down Expand Up @@ -313,7 +312,7 @@ def _make_call_description(
)


class JaCeCompiled(Generic[_R]):
class JaCeCompiled:
"""
Compiled version of the SDFG.
Expand Down Expand Up @@ -348,7 +347,7 @@ def __init__(
self._compiled_sdfg = compiled_sdfg
self._out_tree = out_tree

def __call__(self, *args: Any, **kwargs: Any) -> _R:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Calls the embedded computation.
Expand Down

0 comments on commit 951f857

Please sign in to comment.