Skip to content

Commit

Permalink
Removed the functionality that would allow multiple translators for a…
Browse files Browse the repository at this point in the history
… single primitive.
  • Loading branch information
philip-paul-mueller committed Apr 23, 2024
1 parent ee9d7b3 commit 54b55e1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 289 deletions.
160 changes: 7 additions & 153 deletions src/jace/translator/jace_subtranslator_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any, Final, final
from typing import TYPE_CHECKING, Any

import dace
from jax import core as jcore
Expand Down Expand Up @@ -36,40 +36,13 @@ class JaCeSubTranslatorInterface:
In the end this implements the delegation pattern.
A subtranslator uses its `get_handled_primitives()` function to indicate
for which Jax primitives it want to register. It is important that a
subtranslator can register for as many primitive it wants. At the same
time, it is possible that multiple subtranslators have registered for a
single primitive.
If multiple subtranslator have registered for the same primitive they
will be ordered by driver. There are two ways how a subtranslator can
influence this order. The first one is by implementing `get_priority()`,
the driver will then put them in ascending order.
I.e. the lower its priority the earlier a subtranslator is checked.
However, if a subtranslator returns the special value
`JaCeSubTranslatorInterface.DEFAULT_PRIORITY` it will be always put at the
end, in unspecific order if multiple translator are involved.
The second possibility is to override the '__lt__()' function,
and establish a strict weak order. If a subtranslator overrides this
function it should also override `get_priority()` to return `NotImplemented`.
To decide which subtranslator should be used for a specific equation
the driver will use their 'can_translate_jaxeqn()' function.
The first subtranslator that returns 'True' will then be used.
Todo:
Also come up with a way how to avoid that instances are allowed to access
some private members of the driver; Possibly by composition.
Come up with a better way of ordering; maybe introduce fixed priority level.
And then allows to sort them according to `__lt__()` within the level.
for which Jax primitives it want to register. It is important that there
is no limits on the number of primitives a subtranslator can register itself.
However, only one subtranslator can be registered for a primitive.
"""

__slots__ = ()

# Default value for the priority of primitive translators.
DEFAULT_PRIORITY: Final = int("1" * 64, base=2)

def __init__(
self,
*args: Any,
Expand All @@ -84,11 +57,7 @@ def get_handled_primitives(self) -> Collection[str] | str:
"""Returns the names of all Jax primitives that `self` is able to handle.
There is no limit on the number of primitives for which a subtranslator
can register. It is possible that several translators can be registered
for the same name.
See Also:
`self.can_translate_jaxeqn()` and `self.get_priority()`.
can register.
Notes:
In case a string is returned it is interpreted as 1 element collection.
Expand All @@ -97,35 +66,6 @@ def get_handled_primitives(self) -> Collection[str] | str:
"Class '{type(self).__name__}' does not implement 'get_handled_primitives()'."
)

def can_translate_jaxeqn(
self,
driver: JaxprTranslationDriver,
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jcore.JaxprEqn,
) -> bool:
"""Tests if `self` is able to translate the Jax primitive passed as `eqn`.
This function is used by the driver to determine which of the subtranslators,
that have registered for a certain type of primitive, should be used.
For a more detailed description of the arguments see
`self.translate_jaxeqn()` function.
Args:
driver: The driver object of the translation.
in_var_names: Names of the SDFG variables used as inputs for the primitive.
out_var_names: Names of the SDFG variables used as outputs for the primitive.
eqn: The `jcore.JaxprEqn` instance that is currently being handled.
Notes:
In case there is only one subtranslator registered for a certain primitive,
it is unspecific if this function will be called at all `self.translate_jaxeqn()`.
This function will never be called for a primitive for which it has not registered itself.
"""
raise NotImplementedError(
"Class '{type(self).__name__}' does not implement 'can_translate_jaxeqn()'."
)

def translate_jaxeqn(
self,
driver: JaxprTranslationDriver,
Expand All @@ -152,7 +92,7 @@ def translate_jaxeqn(
`translator.get_terminal_sdfg_state() is eqn_state` holds.
Then the subtranslator is called. Usually a subtranslator should
construct the dataflow graph inside it. It is allowed that the
construct the dataflow graph inside `eqn_state`. It is allowed that the
subtranslators creates more states if needed, but this state machine
has to have a single terminal state, which must be returned
and reachable from `eqn_state`.
Expand Down Expand Up @@ -185,68 +125,14 @@ def translate_jaxeqn(
"Class '{type(self).__name__}' does not implement 'translate_jaxeqn()'."
)

def get_priority(self) -> int:
"""Returns the priority of this translator.
The value returned by this function is used by the driver to order the
subtranslators that have registered for the same primitive.
The _smaller_ the value the earlier it is checked.
See Also:
`self.can_translate_jaxeqn()` and `self.get_handled_primitives()`.
Notes:
By default the function returns `self.DEFAULT_PRIORITY`, which is
handled specially, i.e. it is put at the end.
If a subtranslator instead overrides `__lt__()` this function
should return `NotImplemented`.
"""
return self.DEFAULT_PRIORITY

def has_default_priority(self) -> bool:
"""Checks if `self` has default priority.
Notes:
It is allowed, but not advised to override this function.
However, it has to be consistent with `self.get_priority()`.
"""
try:
x = self.get_priority()
except NotImplementedError:
return False
if x is NotImplemented:
return False
return x == self.DEFAULT_PRIORITY

def __lt__(
self,
other: JaCeSubTranslatorInterface,
) -> bool:
"""Tests if `self` should be checked before `other` in the selection process.
As outlined in the class description this is the second possibility to
influence the order of the subtranslator. This function should return
`True` if `self` should be checked for applicability _before_ `other`.
Notes:
If this function is overridden `get_priority()` should return `NotImplemented`.
This function is never called if either `self` or `other` have default priority.
"""
return self.get_priority() < other.get_priority()

def __eq__(
self,
other: Any,
) -> bool:
"""Tests if two subtranslators are equal.
The default implementation checks if `self` and `other` have the same
type. However, if the behaviour of a subtranslator strongly depend on
its configuration this function should be overridden.
Notes:
If you override this function you should also override
`self.__hash__()` to make the two consistent.
type.
"""
if not isinstance(other, JaCeSubTranslatorInterface):
return NotImplemented
Expand All @@ -258,37 +144,5 @@ def __hash__(self) -> int:
The default implementation return a hash that is based on the class.
Thus all instances of a particular subtranslator will have the same
hash value.
Notes:
If you override this function you should also override
`self.__eq__()` to make the two consistent.
"""
return id(self.__class__)

@final
def __ne__(
self,
other: Any,
) -> bool:
return NotImplemented

@final
def __le__(
self,
other: Any,
) -> bool:
return NotImplemented

@final
def __ge__(
self,
other: Any,
) -> bool:
return NotImplemented

@final
def __gt__(
self,
other: Any,
) -> bool:
return NotImplemented
74 changes: 18 additions & 56 deletions src/jace/translator/jaxpr_translator_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ def __init__(

# Contains all the subtranslators that we need.
# They are partitioned by the names of the primitive they have registered for.
# Inside a partition they are ordered by priority, lowest first, more important.
# This member is allocated by '_init_sub_translators()' and remains allocated
# during the lifetime of the object.
self._sub_translators: dict[str, list[jtrans.JaCeSubTranslatorInterface]] = None # type: ignore[assignment]
self._sub_translators: dict[str, jtrans.JaCeSubTranslatorInterface] = None # type: ignore[assignment]

# The SDFG object that we are currently constructing.
# Only allocated during an ongoing translation.
Expand Down Expand Up @@ -923,36 +922,29 @@ def _allocate_translation_ctx(

def _init_sub_translators(
self,
kwargs: Mapping[str, Any],
subtrans_args: Mapping[str, Any],
) -> JaxprTranslationDriver:
"""This function initializes the subtranslator.
The function forwards `kwargs` to the constructor of the subtranslators.
However, it will remove all arguments starting with an underscore.
"""
if isinstance(self._sub_translators, dict):
raise RuntimeError("Tried to allocate the internal subtranslators twice.")
assert self._sub_translators is None # type: ignore[unreachable]
assert self._sub_translators is None

kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
subtrans_args = {k: v for k, v in subtrans_args.items() if not k.startswith("_")} # type: ignore[unreachable]

# First we will create all subtranslators and partition them.
subtranslators: dict[str, list[jtrans.JaCeSubTranslatorInterface]] = {}
for subtranslator_cls in jtsubt._get_subtranslators_cls():
subtranslator: jtrans.JaCeSubTranslatorInterface = subtranslator_cls(**kwargs)
sub_translators: dict[str, jtrans.JaCeSubTranslatorInterface] = {}
for sub_translator_cls in jtsubt._get_subtranslators_cls():
sub_translator: jtrans.JaCeSubTranslatorInterface = sub_translator_cls(**subtrans_args)
handled_primitives: Iterable[str] = jutil.ensure_iterability(
subtranslator.get_handled_primitives()
sub_translator.get_handled_primitives()
)

# Now add the subtranslator to the primitives it requests, we will sort them later into the correct order.
for handled_primitive in handled_primitives:
subtranslators.setdefault(handled_primitive, []).append(subtranslator)
if handled_primitive in sub_translators:
raise RuntimeError(f"Multiple sub_translators for '{handled_primitive}' found.")
sub_translators[handled_primitive] = sub_translator
self._sub_translators = sub_translators

# Now we order the subtranslators for the primitives.
self._sub_translators = {
prim_name: jtrutil.sort_subtranslators(primSubTranslators)
for prim_name, primSubTranslators in subtranslators.items()
}
return self

def _clear_translation_ctx(self) -> JaxprTranslationDriver:
Expand Down Expand Up @@ -992,41 +984,16 @@ def _clear_translation_ctx(self) -> JaxprTranslationDriver:

def _find_sub_translator_for(
self,
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jcore.JaxprEqn,
) -> jtrans.JaCeSubTranslatorInterface:
"""Returns the appropriate subtranslator for equation `eqn`.
The subtranslators are checked for applicability in the order of their priority.
The fist one that accepts the translation will be taken.
Notes:
The arguments are the same as for `JaCeSubTranslatorInterface.can_translate_jaxeqn()`.
"""
"""Returns the appropriate subtranslator for equation `eqn`."""
assert self._sub_translators is not None

prim_name: str = eqn.primitive.name
if prim_name not in self._sub_translators:
raise NotImplementedError(f"No subtranslators known to handle '{prim_name}'.")
subtranslator_canidates = self._sub_translators[prim_name]
assert len(subtranslator_canidates) > 0

subtranslator: jtrans.JaCeSubTranslatorInterface = None # type: ignore[assignment]
if len(subtranslator_canidates) == 1:
subtranslator = next(iter(subtranslator_canidates))
assert subtranslator.can_translate_jaxeqn(
in_var_names=in_var_names, out_var_names=out_var_names, driver=self, eqn=eqn
)
else:
for subtranslatorCanidate in subtranslator_canidates:
if subtranslatorCanidate.can_translate_jaxeqn(
driver=self, in_var_names=in_var_names, out_var_names=out_var_names, eqn=eqn
):
subtranslator = subtranslatorCanidate
else:
raise NotImplementedError(f"No subtranslator found for handling '{eqn}'.")
return subtranslator

return self._sub_translators[prim_name]

def _translate_single_eqn(
self,
Expand All @@ -1044,7 +1011,6 @@ def _translate_single_eqn(
Returns:
The SDFG names that were used as input and output are returned.
The inputs might contain `None` which indicates that that input was a Jax Literal.
For more information see `JaCeSubTranslatorInterface.can_translate_jaxeqn()`.
Notes:
While `jaxpr` must be a `ClosedJaxpr`, `eqn` must come from the unclosed instance.
Expand All @@ -1070,24 +1036,20 @@ def _translate_single_eqn(
)

# Find the subtranslator
subtranslator: jtrans.JaCeSubTranslatorInterface = self._find_sub_translator_for(
in_var_names=in_var_names,
out_var_names=out_var_names,
eqn=eqn,
)
subtranslator: jtrans.JaCeSubTranslatorInterface = self._find_sub_translator_for(eqn)

# Create the state into which the equation should be translated
last_term_state: dace.SDFGState = self.get_terminal_sdfg_state() # noqa: F841 # Will be used later
eqn_state = self.append_new_state(
label=f"{eqn.primitive.name}_{out_var_names[0]}",
prev_state=None,
prev_state=None, # forces terminal state
)

# Now perform the actual translation of the equation.
new_sdfg_term_state = subtranslator.translate_jaxeqn(
driver=self,
in_var_names=in_var_names,
out_var_names=out_var_names, # Might be modified by subtranslator!
out_var_names=out_var_names, # Might be modified by the subtranslator!
eqn=eqn,
eqn_state=eqn_state,
)
Expand Down
Loading

0 comments on commit 54b55e1

Please sign in to comment.