From c185444b4038b29cbbb0747d11c5f26b879a042a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 3 May 2024 14:40:42 +0200 Subject: [PATCH] Fixing some import names. --- src/jace/translator/__init__.py | 2 - src/jace/translator/_translation_context.py | 14 +-- .../translator/jaxpr_translator_driver.py | 119 +++++++++--------- .../translator/sub_translators/__init__.py | 12 +- .../a_primitive_translator.py} | 14 ++- .../sub_translators/alu_translator.py | 15 ++- src/jace/translator/translated_jaxpr_sdfg.py | 6 +- src/jace/util/debug.py | 12 +- src/jace/util/jax_helper.py | 45 ++++--- src/jace/util/traits.py | 10 +- src/jace/util/util.py | 6 +- 11 files changed, 132 insertions(+), 123 deletions(-) rename src/jace/translator/{primitive_translator.py => sub_translators/a_primitive_translator.py} (90%) diff --git a/src/jace/translator/__init__.py b/src/jace/translator/__init__.py index 8ca5476..d6bb0c7 100644 --- a/src/jace/translator/__init__.py +++ b/src/jace/translator/__init__.py @@ -10,12 +10,10 @@ from __future__ import annotations from .jaxpr_translator_driver import JaxprTranslationDriver -from .primitive_translator import PrimitiveTranslator from .translated_jaxpr_sdfg import TranslatedJaxprSDFG __all__ = [ - "PrimitiveTranslator", "JaxprTranslationDriver", "TranslatedJaxprSDFG", ] diff --git a/src/jace/translator/_translation_context.py b/src/jace/translator/_translation_context.py index a0c1854..6d9c43f 100644 --- a/src/jace/translator/_translation_context.py +++ b/src/jace/translator/_translation_context.py @@ -12,9 +12,9 @@ from collections.abc import MutableMapping, Sequence import dace -from jax import core as jcore +from jax import core as jax_core -from jace import translator as jtrans, util as jutil +from jace import translator, util class _TranslationContext: @@ -70,7 +70,7 @@ def __init__( rev_idx: The revision index of the context. name: Name of the SDFG object. """ - if isinstance(name, str) and not jutil._VALID_SDFG_OBJ_NAME.fullmatch(name): + if isinstance(name, str) and not util._VALID_SDFG_OBJ_NAME.fullmatch(name): raise ValueError(f"'{name}' is not a valid SDFG name.") self._sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}")) @@ -78,14 +78,14 @@ def __init__( label="initial_state", is_start_block=True ) self._terminal_state: dace.SDFGState = self._start_state - self._jax_name_map: MutableMapping[jcore.Var | jutil.JaCeVar, str] = {} + self._jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {} self._inp_names: tuple[str, ...] = () self._out_names: tuple[str, ...] = () self._rev_idx: int = rev_idx - def to_translated_jaxpr_sdfg(self) -> jtrans.TranslatedJaxprSDFG: + def to_translated_jaxpr_sdfg(self) -> translator.TranslatedJaxprSDFG: """Transforms `self` into a `TranslatedJaxprSDFG`.""" - return jtrans.TranslatedJaxprSDFG( + return translator.TranslatedJaxprSDFG( sdfg=self._sdfg, start_state=self._start_state, terminal_state=self._terminal_state, @@ -114,7 +114,7 @@ def terminal_state( self._terminal_state = new_term_state @property - def jax_name_map(self) -> MutableMapping[jcore.Var | jutil.JaCeVar, str]: + def jax_name_map(self) -> MutableMapping[jax_core.Var | util.JaCeVar, str]: return self._jax_name_map @property diff --git a/src/jace/translator/jaxpr_translator_driver.py b/src/jace/translator/jaxpr_translator_driver.py index c6fdb46..5b01fd2 100644 --- a/src/jace/translator/jaxpr_translator_driver.py +++ b/src/jace/translator/jaxpr_translator_driver.py @@ -14,9 +14,10 @@ import dace import jax from dace import data as ddata, properties as dprop -from jax import core as jcore +from jax import core as jax_core -from jace import translator as jtrans, util as jutil +from jace import translator, util +from jace.translator import sub_translators class JaxprTranslationDriver: @@ -81,7 +82,7 @@ def __init__( # They are partitioned by the names of the primitive they have registered for. # This member is allocated by '_init_sub_translators()' and remains allocated # during the lifetime of the object. - self._sub_translators: dict[str, jtrans.PrimitiveTranslator] = None # type: ignore[assignment] + self._sub_translators: dict[str, translator.PrimitiveTranslator] = None # type: ignore[assignment] # These names can not be used for the automatic naming of Jax variables. # They differ from the forbidden names, that they denote valid SDFG names. @@ -104,14 +105,14 @@ def __init__( def translate_jaxpr( self, - jaxpr: jcore.ClosedJaxpr, + jaxpr: jax_core.ClosedJaxpr, *, inp_scalar_as_array: bool = False, name: str | None = None, reserved_names: str | Collection[str] | None = None, allow_empty_jaxpr: bool = False, **kwargs: Any, - ) -> jtrans.TranslatedJaxprSDFG: + ) -> translator.TranslatedJaxprSDFG: """Perform the translation of a Jaxpr into a SDFG. In case this function is called and `self` has an ongoing translation process, a new translation context will be created. @@ -134,7 +135,7 @@ def translate_jaxpr( """ if (len(jaxpr.eqns) == 0) and (not allow_empty_jaxpr): raise ValueError("Passed an empty Jaxpr, but did not allow for empty Jaxpr.") - if not isinstance(jaxpr, jcore.ClosedJaxpr): + if not isinstance(jaxpr, jax_core.ClosedJaxpr): raise TypeError(f"Expected a 'jax.core.ClosedJaxp' instance but got '{type(jaxpr)}'") if len(jaxpr.effects) != 0: raise NotImplementedError("'Jaxpr' with side effects are not supported.") @@ -162,7 +163,7 @@ def translate_jaxpr( jaxpr=jaxpr, inp_scalar_as_array=inp_scalar_as_array, ) - jsdfg: jtrans.TranslatedJaxprSDFG = self._translate_jaxpr_internal(jaxpr) + jsdfg: translator.TranslatedJaxprSDFG = self._translate_jaxpr_internal(jaxpr) # If the translation context is not cleared `self` and `jsdfg` will share the same data. # There is some legitimate use for that. @@ -195,7 +196,7 @@ def append_new_state( prev_state: Alternative `SDFGState` at which we should append the new state. """ - if isinstance(label, str) and (not jutil._VALID_SDFG_OBJ_NAME.fullmatch(label)): + if isinstance(label, str) and (not util._VALID_SDFG_OBJ_NAME.fullmatch(label)): raise ValueError(f"Can not create state with label '{label}' since it is invalid.") # Decide if appending to that state will modify the terminal state. @@ -228,7 +229,7 @@ def get_arrays(self) -> Mapping[str, ddata.Data]: def get_array( self, - name: str | jcore.Atom | jutil.JaCeVar, + name: str | jax_core.Atom | util.JaCeVar, ) -> ddata.Data: """Returns the SDFG `Data` object `name` referees to. @@ -237,7 +238,7 @@ def get_array( """ if isinstance(name, str): sdfg_name: str = name - elif isinstance(name, (jcore.Var, jutil.JaCeVar)): + elif isinstance(name, (jax_core.Var, util.JaCeVar)): sdfg_name = self.map_jax_var_to_sdfg(name) else: raise TypeError(f"Does not know how to handle '{type(name).__name__}'.") @@ -248,19 +249,19 @@ def get_array( @overload def map_jax_var_to_sdfg( self, - jax_var: str | jcore.Atom | jutil.JaCeVar, + jax_var: str | jax_core.Atom | util.JaCeVar, ) -> str: ... @overload def map_jax_var_to_sdfg( self, - jax_var: str | jcore.Atom | jutil.JaCeVar, + jax_var: str | jax_core.Atom | util.JaCeVar, allow_fail: bool, ) -> str | None: ... def map_jax_var_to_sdfg( self, - jax_var: str | jcore.Atom | jutil.JaCeVar, + jax_var: str | jax_core.Atom | util.JaCeVar, allow_fail: bool = False, ) -> str | None: """Get the _name_ of the SDFG variable to which `jax_var` is referring to. @@ -273,7 +274,7 @@ def map_jax_var_to_sdfg( """ if isinstance(jax_var, str): sdfg_name: str = jax_var - elif isinstance(jax_var, jcore.Literal): + elif isinstance(jax_var, jax_core.Literal): raise RuntimeError("There is no SDFG variable for literal '{jax_var}'.") elif jax_var in self._ctx.jax_name_map: sdfg_name = self._ctx.jax_name_map[jax_var] @@ -336,7 +337,7 @@ def get_rev_idx(self) -> int: def add_jax_name_mapping( self, - jax_var: jcore.Var | jutil.JaCeVar, + jax_var: jax_core.Var | util.JaCeVar, sdfg_name: str, ) -> JaxprTranslationDriver: """Creates a mapping between `jax_var` to `sdfg_name`. @@ -384,7 +385,7 @@ def add_reserved_names( raise TypeError(f"Does not know how to handle the type '{type(reserved_names)}'.") for rev_name in reserved_names: assert isinstance(rev_name, str) - if not jutil._VALID_SDFG_VAR_NAME.fullmatch(rev_name): + if not util._VALID_SDFG_VAR_NAME.fullmatch(rev_name): raise ValueError( f"Can not use '{rev_name}' as reserved name as it is not a valid SDFG name." ) @@ -393,7 +394,7 @@ def add_reserved_names( def add_array( self, - arg: jcore.Atom | jutil.JaCeVar, + arg: jax_core.Atom | util.JaCeVar, *, as_transient: bool = True, alt_name: str | None = None, @@ -471,8 +472,8 @@ def add_array( """ assert self.is_allocated() - shape: Sequence[int] = jutil.get_jax_var_shape(arg) - dtype = jutil.get_jax_var_dtype(arg) + shape: Sequence[int] = util.get_jax_var_shape(arg) + dtype = util.get_jax_var_dtype(arg) offset = None # i.e. no offset storage: dace.StorageType = dace.StorageType.Default # Set at later stages (optimization) is_scalar: bool = shape == () @@ -493,7 +494,7 @@ def add_array( raise ValueError( f"Specified 'force_jax_name', but passed '{name_prefix}' as 'name_prefix'." ) - alt_name = jutil._propose_jax_name(arg, self._ctx.jax_name_map) + alt_name = util._propose_jax_name(arg, self._ctx.jax_name_map) if alt_name is not None: assert isinstance( alt_name, str @@ -503,7 +504,7 @@ def add_array( raise ValueError("Passed an empty 'alt_name'.") if alt_name in self._forbidden_names: raise ValueError("'alt_name' is a forbidden name.") - if not jutil._VALID_SDFG_VAR_NAME.fullmatch(alt_name): + if not util._VALID_SDFG_VAR_NAME.fullmatch(alt_name): raise ValueError(f"The passed name 'alt_name' '{alt_name}' is invalid.") if name_prefix is not None: raise ValueError( @@ -536,8 +537,8 @@ def add_array( # Depending on the situation, we will further manipulate it. if alt_name is not None: prop_name = alt_name # Just for completion: will be ignored later - elif isinstance(arg, (jcore.Var, jutil.JaCeVar)): - prop_name = jutil._propose_jax_name(arg, self._ctx.jax_name_map) + elif isinstance(arg, (jax_core.Var, util.JaCeVar)): + prop_name = util._propose_jax_name(arg, self._ctx.jax_name_map) if prop_name.startswith("__"): raise ValueError( f"You tried to create the variable '{prop_name}' which" @@ -545,7 +546,7 @@ def add_array( ) if name_prefix is not None: prop_name = name_prefix + prop_name - elif isinstance(arg, jcore.Literal): # type: ignore[unreachable] + elif isinstance(arg, jax_core.Literal): # type: ignore[unreachable] if not allow_literals: raise NotImplementedError("Jax Literals are not supported.") if alt_name is None: @@ -590,7 +591,7 @@ def add_array( raise ValueError(f"Can't create variable '{arg_name}', name is forbidden.") if arg_name in self._ctx.sdfg.arrays: raise ValueError(f"Can't create variable '{arg_name}', variable is already created.") - if not jutil._VALID_SDFG_VAR_NAME.fullmatch(arg_name): + if not util._VALID_SDFG_VAR_NAME.fullmatch(arg_name): raise ValueError(f"The requested variable name '{arg_name}' is invalid.") # Promotion of scalar to array. @@ -642,7 +643,7 @@ def add_array( def create_jax_var_list( self, - jax_var_list: Sequence[jcore.Atom | jutil.JaCeVar], + jax_var_list: Sequence[jax_core.Atom | util.JaCeVar], prevent_creation: bool = False, only_creation: bool = False, handle_literals: bool = False, @@ -677,11 +678,11 @@ def create_jax_var_list( ret_list: list[None | str] = [] for jax_var in jax_var_list: - if isinstance(jax_var, jcore.Literal): + if isinstance(jax_var, jax_core.Literal): if not handle_literals: raise ValueError("Encountered a literal but `handle_literals` was `False`.") sdfg_name = None - elif isinstance(jax_var, (jcore.Var, jutil.JaCeVar)): + elif isinstance(jax_var, (jax_core.Var, util.JaCeVar)): mapped_sdfg_name: str | None = self.map_jax_var_to_sdfg(jax_var, allow_fail=True) if (mapped_sdfg_name is None) and prevent_creation: raise ValueError(f"'prevent_creation' given but have to create '{jax_var}'.") @@ -702,7 +703,7 @@ def create_jax_var_list( def _create_initial_input( self, - jaxpr: jcore.ClosedJaxpr, + jaxpr: jax_core.ClosedJaxpr, inp_scalar_as_array: bool, ) -> Sequence[str]: """This function will create the internal input variables that are used for the SDFG. @@ -744,7 +745,7 @@ def _create_initial_input( def _create_constants( self, - jaxpr: jcore.ClosedJaxpr, + jaxpr: jax_core.ClosedJaxpr, ) -> Sequence[str]: """Creates all constants requested by the `jaxpr`. @@ -825,21 +826,21 @@ def _init_sub_translators( The function forwards `kwargs` to the constructor of the subtranslators. However, it will remove all arguments starting with an underscore. """ - from jace.translator.sub_translators import _get_subtranslators_cls # Avoid import cycle - assert self._sub_translators is None subtrans_args = {k: v for k, v in subtrans_args.items() if not k.startswith("_")} # type: ignore[unreachable] - sub_translators: dict[str, jtrans.PrimitiveTranslator] = {} - for sub_translator_cls in _get_subtranslators_cls(): - sub_translator: jtrans.PrimitiveTranslator = sub_translator_cls.CREATE(**subtrans_args) - handled_primitives: Iterable[str] = jutil.as_sequence(sub_translator.primitive) + prim_translators: dict[str, translator.PrimitiveTranslator] = {} + for prim_translator_cls in sub_translators._get_subtranslators_cls(): + prim_translator: translator.PrimitiveTranslator = prim_translator_cls.CREATE( + **subtrans_args + ) + handled_primitives: Iterable[str] = util.as_sequence(prim_translator.primitive) for handled_primitive in handled_primitives: - if handled_primitive in sub_translators: - raise RuntimeError(f"Multiple sub_translators for '{handled_primitive}' found.") - sub_translators[handled_primitive] = sub_translator - self._sub_translators = sub_translators + if handled_primitive in prim_translators: + raise RuntimeError(f"Multiple sub translators for '{handled_primitive}' found.") + prim_translators[handled_primitive] = prim_translator + self._sub_translators = prim_translators return self @@ -872,8 +873,8 @@ def _clear_translation_ctx(self) -> JaxprTranslationDriver: def _find_sub_translator_for( self, - eqn: jcore.JaxprEqn, - ) -> jtrans.PrimitiveTranslator: + eqn: jax_core.JaxprEqn, + ) -> translator.PrimitiveTranslator: """Returns the appropriate subtranslator for equation `eqn`.""" assert self._sub_translators is not None @@ -885,8 +886,8 @@ def _find_sub_translator_for( def _translate_single_eqn( self, - jaxpr: jcore.ClosedJaxpr, - eqn: jcore.JaxprEqn, + jaxpr: jax_core.ClosedJaxpr, + eqn: jax_core.JaxprEqn, ) -> tuple[Sequence[str | None], Sequence[str]]: """Translate `eqn` into its SDFG equivalent. @@ -904,8 +905,8 @@ def _translate_single_eqn( While `jaxpr` must be a `ClosedJaxpr`, `eqn` must come from the unclosed instance. The function will perform some consistency checking after the subtranslator was called. """ - assert isinstance(eqn, jcore.JaxprEqn) - assert isinstance(jaxpr, jcore.ClosedJaxpr) + assert isinstance(eqn, jax_core.JaxprEqn) + assert isinstance(jaxpr, jax_core.ClosedJaxpr) if len(eqn.effects) != 0: raise NotImplementedError(f"Equation '{eqn}' has side effects.") @@ -925,7 +926,7 @@ def _translate_single_eqn( ) # Find the subtranslator - subtranslator: jtrans.PrimitiveTranslator = self._find_sub_translator_for(eqn) + subtranslator: translator.PrimitiveTranslator = self._find_sub_translator_for(eqn) # Create the state into which the equation should be translated last_term_state: dace.SDFGState = self.get_terminal_sdfg_state() # noqa: F841 # Will be used later @@ -963,7 +964,7 @@ def _translate_single_eqn( ) for expectedSDFGName, jax_var in zip(out_var_names, eqn.outvars, strict=True): mapped_sdfg_name = self.map_jax_var_to_sdfg(jax_var) - jax_name = jutil.get_jax_var_name(jax_var) + jax_name = util.get_jax_var_name(jax_var) if mapped_sdfg_name != expectedSDFGName: raise ValueError( f"Mapping inconsistency detected, expected that Jax variable" @@ -980,13 +981,13 @@ def _translate_single_eqn( pass elif isinstance(sdfg_var, dace.data.View): raise TypeError( - f"For Jax variable '{jutil.get_jax_var_name(jax_var)}' (SDFG: '{outVarName}')," + f"For Jax variable '{util.get_jax_var_name(jax_var)}' (SDFG: '{outVarName}')," f" which is an output, you used a View, which is not possible." " It must either be an array or a scalar." ) else: raise NotImplementedError( - f"Output variable '{jutil.get_jax_var_name(jax_var)}' (SDFG: '{outVarName}')" + f"Output variable '{util.get_jax_var_name(jax_var)}' (SDFG: '{outVarName}')" f" is of type '{type(sdfg_var).__name__}' which I does not know how to handle." ) @@ -997,8 +998,8 @@ def _translate_single_eqn( def _translate_jaxpr_internal( self, - jaxpr: jcore.ClosedJaxpr, - ) -> jtrans.TranslatedJaxprSDFG: + jaxpr: jax_core.ClosedJaxpr, + ) -> translator.TranslatedJaxprSDFG: """Performs the actual translation of the Jaxpr into an SDFG. The function assumes that the context is allocated as well as initial variables. @@ -1014,7 +1015,7 @@ def _translate_jaxpr_internal( this is used by Jax to indicate that they are never read. Such variables are included by some transformations such as `grad()`. """ - assert isinstance(jaxpr, jcore.ClosedJaxpr) + assert isinstance(jaxpr, jax_core.ClosedJaxpr) assert self.is_allocated() nb_translated_eqn: int = 0 @@ -1023,9 +1024,9 @@ def _translate_jaxpr_internal( assert len(eqn.effects) == 0 if len(eqn.outvars) == 0: # Do we need this special case. continue # Looks more like internal Jax error. - if any(jutil.is_drop_var(outVar) for outVar in eqn.outvars): + if any(util.is_drop_var(outVar) for outVar in eqn.outvars): assert (len(eqn.outvars) == 1) or all( - jutil.is_drop_var(outVar) for outVar in eqn.outvars + util.is_drop_var(outVar) for outVar in eqn.outvars ) continue _, out_var_names = self._translate_single_eqn(jaxpr=jaxpr, eqn=eqn) @@ -1038,7 +1039,7 @@ def _translate_jaxpr_internal( return self._export_context() - def _export_context(self) -> jtrans.TranslatedJaxprSDFG: + def _export_context(self) -> translator.TranslatedJaxprSDFG: """Encapsulate the translation context of `self` into a `TranslatedJaxprSDFG` object.. This function will not deallocate the internal context of `self`. @@ -1049,7 +1050,7 @@ def _export_context(self) -> jtrans.TranslatedJaxprSDFG: assert all((isinstance(x, str) and (len(x) > 0)) for x in self._ctx.inp_names) assert all((isinstance(x, str) and (len(x) > 0)) for x in self._ctx.out_names) - return jtrans.TranslatedJaxprSDFG( + return translator.TranslatedJaxprSDFG( sdfg=self._ctx.sdfg, start_state=self._ctx.start_state, terminal_state=self._ctx.terminal_state, @@ -1060,7 +1061,7 @@ def _export_context(self) -> jtrans.TranslatedJaxprSDFG: def _handle_null_jaxpr( self, - jaxpr: jcore.ClosedJaxpr, + jaxpr: jax_core.ClosedJaxpr, ) -> Sequence[str]: """This function is called in case a `Jaxpr` with zero equations is encountered. @@ -1092,7 +1093,7 @@ def _handle_null_jaxpr( # Thus we have to introduce a some fake output name and explicitly copy the data around. # Once DaCe will inline the nested SDFG it will remove this intermediate copy. for jax_out_var in jaxpr.jaxpr.outvars: - jax_inp_name = jutil.get_jax_var_name( + jax_inp_name = util.get_jax_var_name( jax_out_var ) # Since output == input their names must be the same. assert self.map_jax_var_to_sdfg(jax_inp_name, allow_fail=True) diff --git a/src/jace/translator/sub_translators/__init__.py b/src/jace/translator/sub_translators/__init__.py index 7019076..88c239c 100644 --- a/src/jace/translator/sub_translators/__init__.py +++ b/src/jace/translator/sub_translators/__init__.py @@ -4,25 +4,24 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause - """Module collecting all built-in subtranslators.""" from __future__ import annotations from collections.abc import Sequence -from jace import translator as jtrans -from jace.translator.sub_translators.alu_translator import ALUTranslator +from .a_primitive_translator import PrimitiveTranslator # has to be the first import. +from .alu_translator import ALUTranslator # List of all subtranslators that ships with JaCe. -_KNOWN_SUBTRANSLATORS: list[type[jtrans.PrimitiveTranslator]] = [ +_KNOWN_SUBTRANSLATORS: list[type[PrimitiveTranslator]] = [ ALUTranslator, ] def add_subtranslator( - subtrans: type[jtrans.PrimitiveTranslator], + subtrans: type[PrimitiveTranslator], ) -> bool: """Add `subtrans` to the externally defined subtranslators. @@ -37,7 +36,7 @@ def add_subtranslator( return True -def _get_subtranslators_cls() -> Sequence[type[jtrans.PrimitiveTranslator]]: +def _get_subtranslators_cls() -> Sequence[type[PrimitiveTranslator]]: """Returns the list of all subtranslator known to JaCe. The translators are returned in FIFO order. @@ -48,4 +47,5 @@ def _get_subtranslators_cls() -> Sequence[type[jtrans.PrimitiveTranslator]]: __all__ = [ "ALUTranslator", "add_subtranslator", + "PrimitiveTranslator", ] diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/sub_translators/a_primitive_translator.py similarity index 90% rename from src/jace/translator/primitive_translator.py rename to src/jace/translator/sub_translators/a_primitive_translator.py index 816b709..16d2d4b 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/sub_translators/a_primitive_translator.py @@ -4,6 +4,14 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause +"""Contains the interface for all primitive subtranslators. + +Note the name of this file is because it has to be the first that is imported in the `__init__.py` file. +If not, we would get a cyclic import error. +However, all attempts to prevent ruff from mindlessly (rule abiding) destroying this orders failed. +Thus the name was changed to enforce this. +If you have the solution, feel free to implement it. +""" from __future__ import annotations @@ -12,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable import dace -from jax import core as jcore +from jax import core as jax_core if TYPE_CHECKING: @@ -21,7 +29,7 @@ @runtime_checkable class PrimitiveTranslator(Protocol): - """Interface for all Jax primitive subtranslators. + """Interface for all Jax primitive translators, also known as subtranslator. A translator for a primitive translates a single equation of a Jaxpr into its SDFG equivalent. A type that implements this interface must fulfil the following properties: @@ -66,7 +74,7 @@ def translate_jaxeqn( driver: JaxprTranslationDriver, in_var_names: Sequence[str | None], out_var_names: Sequence[str], - eqn: jcore.JaxprEqn, + eqn: jax_core.JaxprEqn, eqn_state: dace.SDFGState, ) -> dace.SDFGState | None: """Translates the Jax primitive into its SDFG equivalent. diff --git a/src/jace/translator/sub_translators/alu_translator.py b/src/jace/translator/sub_translators/alu_translator.py index b08e474..f397bb3 100644 --- a/src/jace/translator/sub_translators/alu_translator.py +++ b/src/jace/translator/sub_translators/alu_translator.py @@ -14,14 +14,13 @@ import dace import numpy as np -from jax import core as jcore +from jax import core as jax_core from typing_extensions import override -from jace import translator as jtranslator +from jace.translator import sub_translators -class ALUTranslator(jtranslator.PrimitiveTranslator): - # class ALUTranslator(PrimitiveTranslator): +class ALUTranslator(sub_translators.PrimitiveTranslator): """This translator handles all arithmetic and logical operations.""" __slots__ = () @@ -92,10 +91,10 @@ def primitive(self) -> Sequence[str]: @override def translate_jaxeqn( self, - driver: jtranslator.JaxprTranslationDriver, + driver: sub_translators.JaxprTranslationDriver, in_var_names: Sequence[str | None], out_var_names: Sequence[str], - eqn: jcore.JaxprEqn, + eqn: jax_core.JaxprEqn, eqn_state: dace.SDFGState, ) -> None: """Perform the translation. @@ -244,7 +243,7 @@ def translate_jaxeqn( def _writeTaskletCode( self, in_var_names: Sequence[str | None], - eqn: jcore.JaxprEqn, + eqn: jax_core.JaxprEqn, ) -> str: """This function generates the Tasklet code based on a primitive. @@ -284,7 +283,7 @@ def _writeTaskletCode( if in_var_name is not None: continue - jax_in_var: jcore.Literal = cast(jcore.Literal, eqn.invars[i]) + jax_in_var: jax_core.Literal = cast(jax_core.Literal, eqn.invars[i]) if jax_in_var.aval.shape == (): t_val = jax_in_var.val if isinstance(t_val, np.ndarray): diff --git a/src/jace/translator/translated_jaxpr_sdfg.py b/src/jace/translator/translated_jaxpr_sdfg.py index bc1ffab..3a3bb6b 100644 --- a/src/jace/translator/translated_jaxpr_sdfg.py +++ b/src/jace/translator/translated_jaxpr_sdfg.py @@ -12,9 +12,9 @@ from typing import Any import dace -from jax import core as jcore +from jax import core as jax_core -from jace import util as jutil +from jace import util @dataclass(init=True, repr=True, eq=False, frozen=False, kw_only=True, slots=True) @@ -37,7 +37,7 @@ class TranslatedJaxprSDFG: """ sdfg: dace.SDFG - jax_name_map: Mapping[jcore.Var | jutil.JaCeVar, str] + jax_name_map: Mapping[jax_core.Var | util.JaCeVar, str] start_state: dace.SDFGState | None = None terminal_state: dace.SDFGState | None = None inp_names: Sequence[str] | None = None diff --git a/src/jace/util/debug.py b/src/jace/util/debug.py index b3a84f3..27c57aa 100644 --- a/src/jace/util/debug.py +++ b/src/jace/util/debug.py @@ -13,18 +13,16 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import Any import dace import jax - -if TYPE_CHECKING: - from jace import translator as jtrans +from jace import translator def run_jax_sdfg( - jsdfg: jtrans.TranslatedJaxprSDFG, + jsdfg: translator.TranslatedJaxprSDFG, *args: Any, ) -> tuple[Any, ...] | Any: """Calls the SDFG that is encapsulated with the supplied arguments. @@ -97,9 +95,7 @@ def _jace_run( *args: Forwarded to the tracing and final execution of the SDFG. **kwargs: Used to construct the driver. """ - from jace.translator import JaxprTranslationDriver - jaxpr = jax.make_jaxpr(fun)(*args) - driver = JaxprTranslationDriver(**kwargs) + driver = translator.JaxprTranslationDriver(**kwargs) jsdfg = driver.translate_jaxpr(jaxpr) return run_jax_sdfg(jsdfg, *args) diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index 5bd7596..80018fc 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -17,10 +17,10 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Any +from typing import Any, overload import dace -import jax.core as jcore +import jax.core as jax_core import numpy as np from jace import util @@ -36,14 +36,13 @@ class JaCeVar: Notes: Main intention is to test functionality. - While for a Jax `Var` object the name is rather irrelevant, `JaCeVar` use their name. If the name of a `JaCeVar` is '_' it is considered a drop variable. If the name of a `JaCeVar` is empty, the automatic naming will consider it as a Jax variable. The definition of `__hash__` and `__eq__` is in accordance how Jax variable works. """ name: str - shape: tuple[int | dace.symbol | str, ...] | int | dace.symbol | str | tuple[()] + shape: tuple[int | dace.symbol | str, ...] | tuple[()] dtype: dace.typeclass def __hash__(self) -> int: @@ -59,7 +58,7 @@ def __post_init__(self) -> None: raise ValueError("The 'shape' member of a 'JaCeVar' must be a tuple.") -def get_jax_var_name(jax_var: jcore.Atom | JaCeVar | str) -> str: +def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar | str) -> str: """Returns the name of the Jax variable as a string. Args: @@ -70,19 +69,19 @@ def get_jax_var_name(jax_var: jcore.Atom | JaCeVar | str) -> str: This function is subject for removal. """ match jax_var: - case jcore.DropVar(): + case jax_core.DropVar(): return "_" case JaCeVar(): # In case of an empty name consider the jace variable as a Jax variable. # This is mostly for testing. jax_name = f"jax{id(jax_var)}" if jax_var.name == "" else jax_var.name - case jcore.Var(): + case jax_core.Var(): # This stopped working after version 0.20.4, because of some changes in Jax # See `https://github.com/google/jax/pull/10573` for more information. # The following implementation will generate stable names, however, they will be decoupled # from output of the pretty printed Jaxpr jax_name = f"jax{jax_var.count}{jax_var.suffix}" - case jcore.Literal(): + case jax_core.Literal(): raise TypeError("Can not derive a name from a Jax Literal.") case str(): jax_name = jax_var @@ -97,14 +96,24 @@ def get_jax_var_name(jax_var: jcore.Atom | JaCeVar | str) -> str: return jax_name -def get_jax_var_shape(jax_var: jcore.Atom | JaCeVar) -> tuple[int, ...]: +@overload +def get_jax_var_shape(jax_var: JaCeVar) -> tuple[int | dace.symbol | str, ...] | tuple[()]: ... + + +@overload +def get_jax_var_shape(jax_var: jax_core.Atom) -> tuple[int, ...] | tuple[()]: ... + + +def get_jax_var_shape( + jax_var: jax_core.Atom | JaCeVar, +) -> tuple[int | dace.symbol | str, ...] | tuple[()]: """Returns the shape of a Jax variable. Args: jax_var: The variable to process """ match jax_var: - case jcore.Var() | jcore.Literal(): + case jax_core.Var() | jax_core.Literal(): return jax_var.aval.shape case JaCeVar(): return jax_var.shape @@ -112,10 +121,10 @@ def get_jax_var_shape(jax_var: jcore.Atom | JaCeVar) -> tuple[int, ...]: raise TypeError(f"'get_jax_var_shape()` is not implemented for '{type(jax_var)}'.") -def get_jax_var_dtype(jax_var: jcore.Atom | JaCeVar) -> dace.typeclass: +def get_jax_var_dtype(jax_var: jax_core.Atom | JaCeVar) -> dace.typeclass: """Returns the DaCe equivalent of `jax_var`s datatype.""" match jax_var: - case jcore.Var() | jcore.Literal(): + case jax_core.Var() | jax_core.Literal(): return translate_dtype(jax_var.aval.dtype) case JaCeVar(): return translate_dtype(jax_var.dtype) @@ -148,8 +157,8 @@ def translate_dtype(dtype: Any) -> dace.typeclass: def _propose_jax_name( - jax_var: jcore.Atom | JaCeVar, - jax_name_map: Mapping[jcore.Var | JaCeVar, Any] | None = None, + jax_var: jax_core.Atom | JaCeVar, + jax_name_map: Mapping[jax_core.Var | JaCeVar, Any] | None = None, ) -> str: """Proposes a variable name for `jax_var`. @@ -167,11 +176,9 @@ def _propose_jax_name( The naming of variables are only consistent with the inner most Jaxpr a variable is defined in. Dropped variables will always be named `'_'`. """ - from jace.util.traits import is_drop_var - - if is_drop_var(jax_var): + if util.traits.is_drop_var(jax_var): return "_" - if isinstance(jax_var, jcore.Literal): + if isinstance(jax_var, jax_core.Literal): raise TypeError(f"Can not propose a name for literal '{jax_var}'.") if jax_name_map is None: return get_jax_var_name(jax_var) @@ -180,7 +187,7 @@ def _propose_jax_name( raise RuntimeError( f"Can not propose a second name for '{jax_var}', it already known as '{jax_name_map[jax_var]}'." ) - if isinstance(jax_var, jcore.Var): + if isinstance(jax_var, jax_core.Var): pass elif isinstance(jax_var, JaCeVar): # If the name of the JaCe variable is empty, then use the name proposing diff --git a/src/jace/util/traits.py b/src/jace/util/traits.py index 247c999..1e063c8 100644 --- a/src/jace/util/traits.py +++ b/src/jace/util/traits.py @@ -12,9 +12,9 @@ from collections.abc import Iterable from typing import Any, TypeGuard -from jax import core as jcore +from jax import core as jax_core -from jace import util as jutil +from jace import util class NonStringIterable(Iterable): ... @@ -24,11 +24,11 @@ def is_non_string_iterable(val: Any) -> TypeGuard[NonStringIterable]: return isinstance(val, Iterable) and not isinstance(val, str) -def is_drop_var(jax_var: jcore.Atom | jutil.JaCeVar) -> bool: +def is_drop_var(jax_var: jax_core.Atom | util.JaCeVar) -> bool: """Tests if `jax_var` is a drop variable.""" - if isinstance(jax_var, jcore.DropVar): + if isinstance(jax_var, jax_core.DropVar): return True - if isinstance(jax_var, jutil.JaCeVar): + if isinstance(jax_var, util.JaCeVar): return jax_var.name == "_" return False diff --git a/src/jace/util/util.py b/src/jace/util/util.py index 3943743..96bfa20 100644 --- a/src/jace/util/util.py +++ b/src/jace/util/util.py @@ -10,6 +10,8 @@ from collections.abc import Iterable from typing import TypeVar, cast, overload +from jace.util import traits + _T = TypeVar("_T") @@ -27,8 +29,6 @@ def as_sequence(value: _T) -> Iterable[_T]: ... def as_sequence(value: _T | Iterable[_T]) -> Iterable[_T]: - from jace.util.traits import is_non_string_iterable - - if is_non_string_iterable(value): + if traits.is_non_string_iterable(value): return value return cast(Iterable[_T], [value])