From d6265bc55a7516563400a4063907fc60d526be7c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 13 Sep 2024 15:03:41 +0200 Subject: [PATCH 1/9] Added all translators from teh development branch. I just copied them and did not do a merge, which is not so nice. Furthermore, the tests are not yet there, in my view it makes sense to first have something that can be checked. --- .../translator/jaxpr_translator_builder.py | 59 +++- .../mapped_operation_base_translator.py | 214 +++++++++++++ src/jace/translator/post_translation.py | 92 ++++++ src/jace/translator/primitive_translator.py | 3 + .../primitive_translators/__init__.py | 35 ++- .../primitive_translators/alu_translator.py | 287 ------------------ .../arithmetic_logical_translators.py | 200 ++++++++++++ .../broadcast_in_dim_translator.py | 67 ++++ .../concatenate_translator.py | 87 ++++++ .../primitive_translators/conditions.py | 182 +++++++++++ .../convert_element_type_translator.py | 85 ++++++ .../primitive_translators/copy_translator.py | 92 ++++++ .../gather_translator.py | 211 +++++++++++++ .../primitive_translators/iota_translator.py | 56 ++++ .../primitive_translators/pjit_translator.py | 147 +++++++++ .../reshape_translator.py | 67 ++++ .../select_n_translator.py | 94 ++++++ .../primitive_translators/slicing.py | 198 ++++++++++++ .../squeeze_translator.py | 69 +++++ src/jace/util/jax_helper.py | 21 ++ 20 files changed, 1973 insertions(+), 293 deletions(-) create mode 100644 src/jace/translator/mapped_operation_base_translator.py delete mode 100644 src/jace/translator/primitive_translators/alu_translator.py create mode 100644 src/jace/translator/primitive_translators/arithmetic_logical_translators.py create mode 100644 src/jace/translator/primitive_translators/broadcast_in_dim_translator.py create mode 100644 src/jace/translator/primitive_translators/concatenate_translator.py create mode 100644 src/jace/translator/primitive_translators/conditions.py create mode 100644 src/jace/translator/primitive_translators/convert_element_type_translator.py create mode 100644 src/jace/translator/primitive_translators/copy_translator.py create mode 100644 src/jace/translator/primitive_translators/gather_translator.py create mode 100644 src/jace/translator/primitive_translators/iota_translator.py create mode 100644 src/jace/translator/primitive_translators/pjit_translator.py create mode 100644 src/jace/translator/primitive_translators/reshape_translator.py create mode 100644 src/jace/translator/primitive_translators/select_n_translator.py create mode 100644 src/jace/translator/primitive_translators/slicing.py create mode 100644 src/jace/translator/primitive_translators/squeeze_translator.py diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index 3d7d04c..9b76407 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -14,6 +14,7 @@ import dace from dace import data as dace_data, properties as dace_properties +from dace.sdfg import propagation as dace_propagation from jax import core as jax_core from jace import util @@ -35,8 +36,11 @@ class JaxprTranslationBuilder: - there are only transient variables inside the SDFG, - it lacks the special `__return` variable, - the `arg_names` parameter is not set, - - for all scalar values a ` Scalar` SDFG variable is used, thus they cannot - be used to return anything. + - for all scalar values a `Scalar` SDFG variable is used, thus they cannot + be used for return values, + - for every transient there is exactly one access node that writes to it, + except the name of the array starts with `__jace_mutable_`, which can + be written to multiple times. For these reasons the SDFG is not directly usable, and further manipulations have to be performed. Especially, DaCe's validation function will fail and @@ -550,6 +554,7 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: translator = self._primitive_translators[primitive_name] # Create the state into which the equation should be translated + prev_terminal_state = self._ctx.terminal_state eqn_state = self.append_new_state( label=f"{primitive_name}_{'_'.join(out_var_names)}", prev_state=None, # forces the creation of a new terminal state @@ -569,8 +574,13 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: if eqn_state is not self._ctx.terminal_state: raise RuntimeError("Inconsistent terminal state was detected.") new_sdfg_term_state = eqn_state - if not self._ctx.validate(): - raise RuntimeError("Detected an invalid SDFG under construction.") + + # Propagate the Memlets through the newly created state machine + self._propagate_memlets_in_new_states( + prev_terminal_state, + new_sdfg_term_state, + ) + self._ctx.validate() # Modify terminal root state of 'self' self._ctx.terminal_state = new_sdfg_term_state @@ -680,6 +690,47 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: return out_var_names + def _propagate_memlets_in_new_states( + self, + prev_terminal_state: dace.SDFGState, + new_terminal_state: dace.SDFGState, + ) -> None: + """ + Propagate the Memlets inside the newly added parts of the state machine. + + This function performs BFS starting at `prev_terminal_state` that is bound + by `new_terminal_state`. + + Args: + prev_terminal_state: Terminal state before the expansion of the + state machine. + new_terminal_state: Terminal state after the expansion. + """ + seen: set[dace.SDFGState] = {prev_terminal_state} + nodes_to_process: list[dace.SDFGState] = [ + edge.dst for edge in self.sdfg.out_edges(prev_terminal_state) + ] + + while nodes_to_process: + currently_processing = nodes_to_process.pop(-1) + if ( + self.sdfg.out_degree(currently_processing) == 0 + and currently_processing != new_terminal_state + ): + raise dace.sdfg.InvalidSDFGError( + f"Found leaf node '{currently_processing}' that is not the terminal node.", + self.sdfg, + self.sdfg.node_id(currently_processing), + ) + + seen.add(currently_processing) + dace_propagation.propagate_memlets_state(self.sdfg, currently_processing) + nodes_to_process.extend( + edge.dst + for edge in self.sdfg.out_edges(currently_processing) + if edge.dst not in seen + ) + @property def _start_state(self) -> dace.SDFGState: return cast(dace.SDFGState, self._ctx.start_state) diff --git a/src/jace/translator/mapped_operation_base_translator.py b/src/jace/translator/mapped_operation_base_translator.py new file mode 100644 index 0000000..9f0f402 --- /dev/null +++ b/src/jace/translator/mapped_operation_base_translator.py @@ -0,0 +1,214 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Module containing all translators related to arithmetic logical operations.""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +import dace +from typing_extensions import final, override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class MappedOperationTranslatorBase(translator.PrimitiveTranslator): + """ + Implements the base for all "mapped base operations". + + A mapped base operation `f` is an operation that has several inputs arrays + that are elementwise combined to a single output array. A prime example for + this would be the addition of two arrays. Essentially it assumes that the + Tasklet code can be written as: + ``` + __out = f(__in0, __in1, __in3, ...) + ``` + where `__in*` are the connector names of the Tasklet and `__out` is the + output connector. For problems such as this, the SDFG API provides the + `SDFGState.add_mapped_tasklet()` function, however, in most cases it can not + be directly used, for various reasons. Thus this class acts like a + convenience wrapper around it. + + To use this class a user has to overwrite the `write_tasklet_code()` function. + This function generates the entire code that should be put into the Tasklet, + include the assignment to `__out`. If needed the translator will perform + literal substitution on the returned code and broadcast the inputs to match + the outputs. + + If needed a subclass can also override the `make_input_memlets()` function + to generate custom input Memlets, such as adding an offset. + + Args: + primitive_name: The name of the primitive `self` should bind to. + + Note: + This class will always generate a mapped Tasklet, even if a scalar is handled. + """ + + def __init__(self, primitive_name: str) -> None: + self._prim_name = primitive_name + + @property + def primitive(self) -> str: + """Returns the primitive that should be translated.""" + return self._prim_name + + @final + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + """ + Create the mapped Tasklet. + + The function will create the map ranges and based on the shape of the + output array. It will then call `make_input_memlets()` to get the input + Memlets. After that it calls `write_tasklet_code()` to get the Tasklet + code and perform literal substitution by forwarding it to + `self.literal_substitution()`. After that it will create the mapped Tasklet. + + Note: + For a description of the arguments see `PrimitiveTranslatorCallable`. + """ + assert len(out_var_names) == 1 + if util.get_jax_var_shape(eqn.outvars[0]) != (): + tskl_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") + for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + tskl_output: dict[str, dace.Memlet] = { + "__out": dace.Memlet.simple( + out_var_names[0], ", ".join(name for name, _ in tskl_ranges) + ) + } + + else: + # If we have a scalar we will generate a Map, but it will be trivial. + tskl_ranges = [("__jace_iterator_SCALAR", "0:1")] + tskl_output = {"__out": dace.Memlet.simple(out_var_names[0], "0")} + + tskl_inputs: dict[str, dace.Memlet] = self.make_input_memlets( + tskl_ranges, in_var_names, eqn + ) + tskl_name = f"{self.primitive}_{out_var_names[0]}" + tskl_code = self.write_tasklet_code(tskl_ranges, in_var_names, eqn) + tskl_code = self.literal_substitution(tskl_code, in_var_names, eqn) + + eqn_state.add_mapped_tasklet( + name=tskl_name, + map_ranges=tskl_ranges, + inputs=tskl_inputs, + code=tskl_code, + outputs=tskl_output, + external_edges=True, + ) + + return eqn_state + + @abstractmethod + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + """ + Return the (Python) code that should be put inside the Tasklet. + + This also includes the assignment statement, i.e. `__out`. + However, the base will do literal substitution on the returned object. + + Args: + tskl_ranges: List of pairs used as map parameter, first element + is the name iteration index of the dimension, second is its range. + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation. + """ + ... + + def make_input_memlets( # noqa: PLR6301 # Subclasses might need them. + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + """ + Generate the input Memlets for the non literal operators of the primitive. + + The returned `dict` maps the input connector of the Tasklet to the Memlet + that is used to connect it to the Map entry node. + + Args: + tskl_ranges: List of pairs used as map parameter, first element + is the name iteration index of the dimension, second is its range + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation object. + """ + out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output + out_rank = len(out_shp) + if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars): + raise NotImplementedError( + f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! " + f"Eqn: {eqn} || {tuple(util.get_jax_var_shape(eqn.outvars[0]))}" + ) + + # Now we will generate the input Memlets. + tskl_inputs: dict[str, dace.Memlet] = {} + for i, (in_var_name, inp_shp) in enumerate( + zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars)) + ): + if in_var_name is None: # Input is a literal: No Memlet needed + continue + + if inp_shp == (): # Scalars + tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar + continue + + # We have to to broadcasting (combine yes and no together) + dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1] + tskl_inputs[f"__in{i}"] = dace.Memlet.simple( + in_var_name, + ", ".join( + ("0" if i in dims_to_bcast else it_var) + for i, (it_var, _) in enumerate(tskl_ranges) + ), + ) + return tskl_inputs + + def literal_substitution( # noqa: PLR6301 # Subclasses might need it. + self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn + ) -> str: + """ + Perform literal substitution on the proto Tasklet code `tskl_code`. + + Args: + tskl_code: The proto Tasklet code with literal. + in_var_names: The list of SDFG variables used as input. + eqn: The equation. + + Note: + It is allowed but not recommended to override this function. + """ + for i, in_var_name in enumerate(in_var_names): + if in_var_name is not None: + continue + t_val = util.get_jax_literal_value(eqn.invars[i]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + return tskl_code diff --git a/src/jace/translator/post_translation.py b/src/jace/translator/post_translation.py index a00b651..9831f35 100644 --- a/src/jace/translator/post_translation.py +++ b/src/jace/translator/post_translation.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: + from dace.sdfg import nodes as dace_nodes + from jace import translator @@ -234,3 +236,93 @@ def finalize_translation_context( if validate: tsdfg.validate() return tsdfg + + +def add_nested_sdfg( + state: dace.SDFGState, + child_ctx: translator.TranslationContext, + parent_ctx: translator.TranslationContext, + in_var_names: Sequence[str], + out_var_names: Sequence[str], +) -> dace_nodes.NestedSDFG: + """ + Adds the SDFG in `child_ctx` as nested SDFG at state `state` in `parent_ctx`. + + The function is a convenience wrapper that operates directly on translation + contexts instead of SDFGs. The function will also create the necessary Memlet + connections. + + Args: + state: The state at which the nested SDFG should be inserted. + Must be part of `parent_ctx`. + child_ctx: The translation context representing the SDFG that should be added. + parent_ctx: The parent SDFG to which `child_ctx` should be added as nested + SDFG in state `state`. + in_var_names: Names of the variables in `parent_ctx` that are used as inputs for + the nested SDFG, must have the same order as `child_ctx.input_names`. + out_var_names: Names of the variables in `parent_ctx` that are used as outputs + for the nested SDFG, must have the same order as `child_ctx.output_names`. + + Returns: + The nested SDFG object. + + Note: + The function will not add `child_ctx` directly as nested SDFG. Instead it + will first pass it to `finalize_translation_context()` and operates on the + return values. This means that `child_ctx` will be modified in place, and + a copy will be added to `parent_ctx`. + It is highly recommended that `state` is empty. + """ + if child_ctx.sdfg.free_symbols: + raise NotImplementedError("Symbol Mapping is not implemented.") + assert not (child_ctx.input_names is None or child_ctx.output_names is None) # Silence mypy + assert len(child_ctx.input_names) == len(in_var_names) + assert len(child_ctx.output_names) == len(out_var_names) + assert state in parent_ctx.sdfg.nodes() + assert not set(in_var_names).intersection(out_var_names) + + if any(input_name.startswith("__jace_mutable_") for input_name in in_var_names): + raise NotImplementedError( + "'__jace_mutable_' variables are not yet handled in 'add_nested_sdfg()'." + ) + if len(set(in_var_names)) != len(in_var_names): + raise ValueError( + f"An input can only be passed once, but { {in_var_name for in_var_name in in_var_names if in_var_names.count(in_var_name) > 1} } were passed multiple times." + ) + if len(set(out_var_names)) != len(out_var_names): + raise NotImplementedError( + f"Tried to write multiple times to variables: { {out_var_name for out_var_name in out_var_names if out_var_names.count(out_var_name) > 1} }." + ) + + final_child_ctx = finalize_translation_context(child_ctx) + nested_sdfg: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=final_child_ctx.sdfg, + parent=parent_ctx.sdfg, + # Bug in DaCe must be a set. + inputs=set(final_child_ctx.input_names), + outputs=set(final_child_ctx.output_names), + ) + + # Now create the connections for the input. + for outer_name, inner_name in zip(in_var_names, final_child_ctx.input_names): + outer_array = parent_ctx.sdfg.arrays[outer_name] + state.add_edge( + state.add_read(outer_name), + None, + nested_sdfg, + inner_name, + dace.Memlet.from_array(outer_name, outer_array), + ) + + # Now we create the output connections. + for outer_name, inner_name in zip(out_var_names, final_child_ctx.output_names): + outer_array = parent_ctx.sdfg.arrays[outer_name] + state.add_edge( + nested_sdfg, + inner_name, + state.add_write(outer_name), + None, + dace.Memlet.from_array(outer_name, outer_array), + ) + + return nested_sdfg diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py index ab84c5d..2000731 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/primitive_translator.py @@ -64,6 +64,9 @@ def __call__( primitive translator was able to fully construct the dataflow graph within `eqn_state`. + After the primitive translator returns, the builder will propagate the + Memlets in all states that were newly created. + A primitive translator has to use the passed input variables, `in_var_names` and must write its output into the variables indicated by `out_var_names`. But it is allowed that a primitive translator diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py index 65f9153..9e2fec0 100644 --- a/src/jace/translator/primitive_translators/__init__.py +++ b/src/jace/translator/primitive_translators/__init__.py @@ -8,7 +8,38 @@ from __future__ import annotations -from .alu_translator import ALUTranslator +from .arithmetic_logical_translators import ( + ArithmeticOperationTranslator, + LogicalOperationTranslator, +) +from .broadcast_in_dim_translator import BroadcastInDimTranslator +from .concatenate_translator import ConcatenateTranslator +from .conditions import condition_translator +from .convert_element_type_translator import ConvertElementTypeTranslator +from .copy_translator import CopyTranslator, DevicePutTranslator +from .gather_translator import GatherTranslator +from .iota_translator import IotaTranslator +from .pjit_translator import PJITTranslator +from .reshape_translator import ReshapeTranslator +from .select_n_translator import SelectNTranslator +from .slicing import SlicingTranslator +from .squeeze_translator import SqueezeTranslator -__all__ = ["ALUTranslator"] +__all__ = [ + "ArithmeticOperationTranslator", + "BroadcastInDimTranslator", + "ConcatenateTranslator", + "ConvertElementTypeTranslator", + "CopyTranslator", + "DevicePutTranslator", + "GatherTranslator", + "IotaTranslator", + "LogicalOperationTranslator", + "PJITTranslator", + "ReshapeTranslator", + "SelectNTranslator", + "SlicingTranslator", + "SqueezeTranslator", + "condition_translator", +] diff --git a/src/jace/translator/primitive_translators/alu_translator.py b/src/jace/translator/primitive_translators/alu_translator.py deleted file mode 100644 index f217924..0000000 --- a/src/jace/translator/primitive_translators/alu_translator.py +++ /dev/null @@ -1,287 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""This module contains the `ALUTranslator` which translates all arithmetic and logic primitives.""" -# ruff: noqa: W505 PLR0912 C901 PLR0914 PLR0915 D417 - -from __future__ import annotations - -from collections.abc import Sequence -from typing import Any, Final, cast - -import dace -import numpy as np -from jax import core as jax_core -from typing_extensions import override - -from jace import translator, util - - -class ALUTranslator(translator.PrimitiveTranslator): - """ - This translator handles all arithmetic and logical operations. - - This translator will be reworked soon, it just exists that the initial PR can do anything at all!! - """ - - def __init__(self, prim_name: str, prim_tmpl: str) -> None: - """Initialize the `ALUTranslator`.""" - self._prim_name = prim_name - self._prim_tmpl = prim_tmpl - - @property - @override - def primitive(self) -> str: - return self._prim_name - - @override - def __call__( - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - """ - Perform the translation. - - Deepening on the shapes of the input the function will either create a Tasklet or a mapped Tasklet. - The translator is able to handle broadcasting with NumPy rules. - The function will always perform the translation inside the provided state. - - Args: - builder: The builder object of the translation. - in_var_names: List of the names of the arrays created inside the SDFG for the inpts or 'None' in case of a literal. - out_var_names: List of the names of the arrays created inside the SDFG for the outputs. - eqn: The JAX equation that is translated. - eqn_state: State into which the primitive's SDFG representation is constructed. - """ - assert self._prim_name == eqn.primitive.name - - # Determine what kind of input we got and how we should proceed. - is_scalar = len(util.get_jax_var_shape(eqn.outvars[0])) == 0 - input_scalars = [len(util.get_jax_var_shape(Inp)) == 0 for i, Inp in enumerate(eqn.invars)] - has_scalars_as_inputs = any(input_scalars) - has_some_literals = any(x is None for x in in_var_names) - inps_same_shape = all( - util.get_jax_var_shape(eqn.invars[0]) == util.get_jax_var_shape(eqn.invars[i]) - for i in range(1, len(eqn.invars)) - ) - - # We will now look which dimensions have to be broadcasted on which operator. - # I.e. in the dimensions in the lists below there will be no map iteration index. - dims_to_bcastl: list[int] = [] - dims_to_bcastr: list[int] = [] - - # Determine if and how we have to broadcast. - if inps_same_shape or is_scalar: - pass - - elif has_some_literals or has_scalars_as_inputs: - # This is essentially an array plus a scalar, that is eitehr a literal or a variable. - assert (not has_some_literals) or all( - util.get_jax_var_shape(invar) == util.get_jax_var_shape(eqn.outvars[0]) - for (invar, x) in zip(eqn.invars, in_var_names, strict=False) - if x is not None - ) - assert (not has_scalars_as_inputs) or all( - util.get_jax_var_shape(invar) in {util.get_jax_var_shape(eqn.outvars[0]), ()} - for (invar, x) in zip(eqn.invars, in_var_names, strict=False) - if x is not None - ) - - else: - # This is the general broadcasting case - # We assume that both inputs and the output have the same rank but different sizes in each dimension. - # It seems that JAX ensures this. - # We further assume that if the size in a dimension differs then one must have size 1. - # This is the size we broadcast over, i.e. conceptually replicated. - out_shps = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output - input_shpl = tuple( - util.get_jax_var_shape(eqn.invars[0]) - ) # Shape of the left/first input - input_shpr = tuple( - util.get_jax_var_shape(eqn.invars[1]) - ) # Shape of the right/second input - - if not ((len(input_shpl) == len(input_shpr)) and (len(out_shps) == len(input_shpr))): - raise NotImplementedError("Can not broadcast over different ranks.") - - for dim, (shp_lft, shp_rgt, out_shp) in enumerate( - zip(input_shpl, input_shpr, out_shps) - ): - if shp_lft == shp_rgt: - assert out_shp == shp_lft - elif shp_lft == 1: - assert shp_rgt == out_shp - dims_to_bcastl.append(dim) - elif shp_rgt == 1: - assert shp_lft == out_shp - dims_to_bcastr.append(dim) - else: - raise ValueError(f"Invalid shapes in dimension {dim} for broadcasting.") - - # Now we create the Tasklet in which the calculation is performed. - tskl_code: str = self._write_tasklet_code(in_var_names, eqn) - tskl_name: str = eqn.primitive.name - tskl_map_ranges: list[tuple[str, str]] = [ - (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) - ] - tskl_output: tuple[str, dace.Memlet] = None # type: ignore[assignment] - tskl_inputs: list[tuple[str, dace.Memlet] | tuple[None, None]] = [] - - # Generate the Memlets for the input. - for i, dims_to_bcast in zip(range(len(in_var_names)), [dims_to_bcastl, dims_to_bcastr]): - if in_var_names[i] is None: # Literal: No input needed. - tskl_inputs.append((None, None)) - continue - if input_scalars[i]: # Scalar - assert len(dims_to_bcast) == 0 - i_memlet = dace.Memlet.simple(in_var_names[i], "0") - else: # Array: We may have to broadcast - inputs_: list[str] = [] - for dim, (map_var, _) in enumerate(tskl_map_ranges): - if dim in dims_to_bcast: - inputs_.append("0") - else: - inputs_.append(map_var) - i_memlet = dace.Memlet.simple(in_var_names[i], ", ".join(inputs_)) - del inputs_ - tskl_inputs.append((f"__in{i}", i_memlet)) - - # Now generate the Memlets for the output - if is_scalar: - tskl_output = ("__out0", dace.Memlet.simple(out_var_names[0], "0")) - else: - tskl_output = ( - "__out0", - dace.Memlet.simple(out_var_names[0], ", ".join([X[0] for X in tskl_map_ranges])), - ) - - if is_scalar: - tskl_tasklet = eqn_state.add_tasklet( - tskl_name, - _list_to_dict(tskl_inputs).keys(), - _list_to_dict([tskl_output]).keys(), - tskl_code, - ) - for in_var, (in_connector, in_memlet) in zip(in_var_names, tskl_inputs, strict=False): - if in_var is None: # So access node for literal - continue - eqn_state.add_edge( - eqn_state.add_read(in_var), None, tskl_tasklet, in_connector, in_memlet - ) - eqn_state.add_edge( - tskl_tasklet, - tskl_output[0], - eqn_state.add_write(out_var_names[0]), - None, - tskl_output[1], - ) - else: - eqn_state.add_mapped_tasklet( - name=tskl_name, - map_ranges=_list_to_dict(tskl_map_ranges), - inputs=_list_to_dict(tskl_inputs), - code=tskl_code, - outputs=_list_to_dict([tskl_output]), - external_edges=True, - ) - - return eqn_state - - def _write_tasklet_code( - self, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn - ) -> str: - """ - This function generates the Tasklet code based on a primitive. - - The function will also perform literal substitution and parameter handling. - - Args: - in_var_names: The list of SDFG variables used as input. - """ - t_code = self._prim_tmpl - - # Now we handle Literal substitution - for i, in_var_name in enumerate(in_var_names): - if in_var_name is not None: - continue - - jax_in_var: jax_core.Literal = cast(jax_core.Literal, eqn.invars[i]) - if util.get_jax_var_shape(jax_in_var) == (): - t_val = jax_in_var.val - if isinstance(t_val, np.ndarray): - t_val = jax_in_var.val.max() # I do not know a better way in that case - t_code = t_code.replace(f"__in{i}", str(t_val)) - else: - raise ValueError( - f"Can not handle the literal case of shape: {util.get_jax_var_shape(jax_in_var)}" - ) - - # Now replace the parameters - if len(eqn.params) != 0: - t_code = t_code.format(**eqn.params) - - return t_code - - -def _list_to_dict(inp: Sequence[tuple[None | Any, Any]]) -> dict[Any, Any]: - """ - This method turns a `list` of pairs into a `dict` and applies a `None` filter. - - The function will only include pairs whose key, i.e. first element is not `None`. - """ - return {k: v for k, v in inp if k is not None} - - -# Contains all the templates for ALU operations. -_ALU_OPS_TASKLET_TEMPLATES: Final[dict[str, str]] = { - # Unary operations - "pos": "__out0 = +(__in0)", - "neg": "__out0 = -(__in0)", - "not": "__out0 = not (__in0)", - "floor": "__out0 = floor(__in0)", - "ceil": "__out0 = ceil(__in0)", - "round": "__out0 = round(__in0)", - "abs": "__out0 = abs(__in0)", - "sign": "__out0 = sign(__in0)", - "sqrt": "__out0 = sqrt(__in0)", - "log": "__out0 = log(__in0)", - "exp": "__out0 = exp(__in0)", - "integer_pow": "__out0 = (__in0)**({y})", # 'y' is a parameter of the primitive - "sin": "__out0 = sin(__in0)", - "asin": "__out0 = asin(__in0)", - "cos": "__out0 = cos(__in0)", - "acos": "__out0 = acos(__in0)", - "tan": "__out0 = tan(__in0)", - "atan": "__out0 = atan(__in0)", - "tanh": "__out0 = tanh(__in0)", - # Binary operations - "add": "__out0 = (__in0)+(__in1)", - "add_any": "__out0 = (__in0)+(__in1)", # No idea what makes `add_any` differ from `add` - "sub": "__out0 = (__in0)-(__in1)", - "mul": "__out0 = (__in0)*(__in1)", - "div": "__out0 = (__in0)/(__in1)", - "rem": "__out0 = (__in0)%(__in1)", - "and": "__out0 = (__in0) and (__in1)", - "or": "__out0 = (__in0) or (__in1)", - "pow": "__out0 = (__in0)**(__in1)", - "ipow": "__out0 = (__in0)**(int(__in1))", - "min": "__out0 = min(__in0, __in1)", - "max": "__out0 = max(__in0, __in1)", - "eq": "__out0 = __in0 == __in1", - "ne": "__out0 = __in0 != __in1", - "ge": "__out0 = __in0 >= __in1", - "gt": "__out0 = __in0 > __in1", - "le": "__out0 = __in0 <= __in1", - "lt": "__out0 = __in0 < __in1", -} - -for prim_name, prim_tmpl in _ALU_OPS_TASKLET_TEMPLATES.items(): - translator.register_primitive_translator(ALUTranslator(prim_name, prim_tmpl)) diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py new file mode 100644 index 0000000..c9c0a35 --- /dev/null +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -0,0 +1,200 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Module containing all translators related to arithmetic and logical operations. + +Todo: + - Hijack Jax to inject a proper modulo operation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ArithmeticOperationTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Translator for all arithmetic operations. + + The class is derived from `MappedOperationTranslatorBase` and overwrites the + `write_tasklet_code()` function for the Tasklet code. + + Args: + prim_name: The name of the primitive that should be handled. + tskl_tmpl: Template used for generating the Tasklet code. + + Note: + - It does not implement the logical operations, they are implemented by + the `LogicalOperationTranslator` class. + - Despite its name this class also provides the comparison operators. + - It does not implement `mod` nor `fmod` as they are translated to some + nested `pjit` implementation by Jax for unknown reasons. + """ + + def __init__(self, prim_name: str, tskl_tmpl: str) -> None: + super().__init__(primitive_name=prim_name) + self._tskl_tmpl = tskl_tmpl + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + """Returns the code for the Tasklet, with all parameters replaced.""" + tskl_code = self._tskl_tmpl + if len(eqn.params) != 0: + tskl_code = tskl_code.format(**eqn.params) + return tskl_code + + +class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Translator for all logical operations. + + The reason why the logical operations are separated from the arithmetic + operations is quite complicated and in fact the whole thing is harder than + it should be. NumPy has two kinds of these operations, i.e. + `logical_{and, or, xor, not}()` and `bitwise_{and, or, xor, not}()`, but Jax + has only a single kind of logical operation, that operate in bitwise mode. + The first idea would be to use `ArithmeticOperationTranslator` with a template + such as `__out = __in0 & __in1` or `__out = ~__in0`. Since DaCe eventually + generates C++ code and C++ has a native bool type, and `true` is guaranteed + to be `1` and `false` equals `0`, this works for all operations except `not`, + as `~true` in C++ is essentially `~1`, which is again `true`! + Thus the `not` primitive must be handled separately. + + The solution to the problem is, to introduce two templates, one used for the + bool context and one used in the integer context. This works because depending + if the `logical_*()` or `bitwise_*()` functions are used the input is either + of type bool or an integer. + + Args: + prim_name: The name of the primitive that should be handled. + int_tmpl: The template used for the integer case. + bool_tmpl: The template used for the bool case. + + Note: + Since it does not make sense to single out `not` and keep the other + logical operations in `ArithmeticOperationTranslator` all of them are + handled by this class. + """ + + def __init__(self, prim_name: str, int_tmpl: str, bool_tmpl: str) -> None: + super().__init__(primitive_name=prim_name) + self._int_tmpl = int_tmpl + self._bool_tmpl = bool_tmpl + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): + return self._bool_tmpl + return self._int_tmpl + + +# Contains the code templates for all supported arithmetic operations. +# fmt: off +_ARITMETIC_OPERATION_TEMPLATES: Final[dict[str, str]] = { + # Unary operations + "pos": "__out = +(__in0)", + "neg": "__out = -(__in0)", + + "floor": "__out = floor(__in0)", + "ceil": "__out = ceil(__in0)", + "round": "__out = round(__in0)", + + "abs": "__out = abs(__in0)", + "sign": "__out = sign(__in0)", + "exp": "__out = exp(__in0)", + "exp2": "__out = exp2(__in0)", + "expm1": "__out = expm1(__in0)", + "log": "__out = log(__in0)", + "log1p": "__out = log1p(__in0)", + "conj": "__out = conj(__in0)", + "sqrt": "__out = sqrt(__in0)", + "cbrt": "__out = cbrt(__in0)", + + "integer_pow": "__out = (__in0)**({y})", # 'y' is a parameter of the primitive + "is_finite": "__out = isfinite(__in0)", + + "sin": "__out = sin(__in0)", + "asin": "__out = asin(__in0)", + "cos": "__out = cos(__in0)", + "acos": "__out = acos(__in0)", + "tan": "__out = tan(__in0)", + "atan": "__out = atan(__in0)", + + "sinh": "__out = sinh(__in0)", + "asinh": "__out = asinh(__in0)", + "cosh": "__out = cosh(__in0)", + "acosh": "__out = acosh(__in0)", + "tanh": "__out = tanh(__in0)", + "atanh": "__out = atanh(__in0)", + + # Binary operations + "add": "__out = (__in0)+(__in1)", + "add_any": "__out = (__in0)+(__in1)", # No idea what makes `add_any` differ from `add` + "sub": "__out = (__in0)-(__in1)", + "mul": "__out = (__in0)*(__in1)", + "div": "__out = (__in0)/(__in1)", + "rem": "__out = (__in0)%(__in1)", + "pow": "__out = (__in0)**(__in1)", + "min": "__out = min((__in0), (__in1))", + "max": "__out = max((__in0), (__in1))", + + "eq": "__out = (__in0) == (__in1)", + "ne": "__out = (__in0) != (__in1)", + "ge": "__out = (__in0) >= (__in1)", + "gt": "__out = (__in0) > (__in1)", + "le": "__out = (__in0) <= (__in1)", + "lt": "__out = (__in0) < (__in1)", + + "atan2": "__out = atan2((__in0), (__in1))", + + "nextafter": "__out = nextafter((__in0), (__in1))", + + # Ternary operations + "clamp": "__out = (__in0 if __in1 < __in0 else (__in1 if __in1 < __in2 else __in2))" +} + + +# Contains the code templates for all logical operations. +# The first one is for the integer case, the second for the bool case. +_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = { + "or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"), + "not": ("__out = ~(__in0)", "__out = not (__in0)"), + "and": ("__out = (__in0) & (__in1)", "__out = (__in0) and (__in1)"), + "xor": ("__out = (__in0) ^ (__in1)", "__out = (__in0) != (__in1)"), +} + + +# Create the arithmetic translators +for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items(): + translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl)) + +# Create the logical translators. +for pname, (itmpl, btmpl) in _LOGICAL_OPERATION_TEMPLATES.items(): + translator.register_primitive_translator(LogicalOperationTranslator(pname, itmpl, btmpl)) diff --git a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py new file mode 100644 index 0000000..7f24160 --- /dev/null +++ b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py @@ -0,0 +1,67 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This implements the `broadcast_in_dim` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class BroadcastInDimTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `broadcast_in_dim` primitive. + + The primitive is implemented through the `MappedOperationTranslatorBase` base. + Essentially it creates a copy, but also creates special Memlets that replicate + the content of the input. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="broadcast_in_dim") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + if in_var_names[0] is None: + return {} + return { + "__in0": dace.Memlet.simple( + in_var_names[0], + ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) + if eqn.params["broadcast_dimensions"] + else "0", + ) + } + + +translator.register_primitive_translator(BroadcastInDimTranslator()) diff --git a/src/jace/translator/primitive_translators/concatenate_translator.py b/src/jace/translator/primitive_translators/concatenate_translator.py new file mode 100644 index 0000000..e8bd144 --- /dev/null +++ b/src/jace/translator/primitive_translators/concatenate_translator.py @@ -0,0 +1,87 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the concatenation primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ConcatenateTranslator(translator.PrimitiveTranslator): + """ + Implements the `concatenate` primitive. + + It is implemented by a series of map that writes to the same access node. + It is probably the largest stretch of "written once" in the entire core. + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring needed. + return "concatenate" + + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + if any(in_var_name is None for in_var_name in in_var_names): + raise NotImplementedError("Concatenate: No literal inputs supported.") + + # Dimension along we concatenate. + cat_dim = eqn.params["dimension"] + + # Offset counter for write back. + already_copied = 0 + + # This is the access node we use for the output + # Is inside a dict for input to `add_mapped_tasklet()`. + output_nodes = {out_var_names[0]: eqn_state.add_write(out_var_names[0])} + + # Now going over each input and copying the input in the correct location + # of the output array. + for i, in_var_name in enumerate(in_var_names): + input_shape = util.get_jax_var_shape(eqn.invars[i]) + + tskl_range = [(f"__dim{d}", f"0:{dim_size}") for d, dim_size in enumerate(input_shape)] + tskl_input_access = [it_var for it_var, _ in tskl_range] + + tskl_output_access = tskl_input_access.copy() + tskl_output_access[cat_dim] = f"{tskl_output_access[cat_dim]} + {already_copied}" + + eqn_state.add_mapped_tasklet( + f"_concatenate_{out_var_names[0]}_{in_var_name}", + map_ranges=tskl_range, + inputs={"__in": dace.Memlet.simple(in_var_name, ", ".join(tskl_input_access))}, + code="__out = __in", + outputs={ + "__out": dace.Memlet.simple(out_var_names[0], ",".join(tskl_output_access)) + }, + output_nodes=output_nodes, + external_edges=True, + ) + + # Update the counter that we have copied + already_copied += input_shape[cat_dim] + + +_ = translator.register_primitive_translator(ConcatenateTranslator()) diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py new file mode 100644 index 0000000..d291016 --- /dev/null +++ b/src/jace/translator/primitive_translators/conditions.py @@ -0,0 +1,182 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements all conditions that are supported in JAX.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import dace + +from jace import translator, util +from jace.translator import post_translation as ptranslation +from jace.translator.primitive_translators import pjit_translator as pjit + + +if TYPE_CHECKING: + from jax._src import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("cond") +def condition_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> dace.SDFGState: + """ + Implements the translation of the `cond` primitive, i.e. a scalar if. + + XLA, JAX' backend, supports two versions, one in which the selector, i.e. the + variable indicating which branch should be executed is an integer or a boolean. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variables used an input arguments. First is the index, + the variable that selects the branch, the remaining ones are passed as + inputs to the branches. + out_var_names: Names of SDFG variables that should be used as outputs. + eqn: The equation that should be translated. + eqn_state: State into which the nested SDFG should be constructed. + + Returns: + Because of the nature of this primitive, the translator has to construct + new states and will return the new SDFG state that serves as terminal state. + + Note: + The implementation assumes that the selector, i.e. the variables indicating + which branch should be taken is inside its bound. + """ + if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: + return _cond_primitive_boolean_impl( + builder=builder, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) + return _cond_primitive_multi_switch_impl( + builder=builder, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) + + +def _cond_primitive_multi_switch_impl( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> dace.SDFGState: + """ + Implements the integer version of the conditional primitive. + + For arguments see `ConditionTranslator`. + + This [version](https://openxla.org/xla/operation_semantics#conditional) is + essentially a C switch statement without a default branch. + """ + # To make names in the SDFG unique we use the name of the equation state + name_pattern = eqn_state.name + + # Promote all inputs to the branches to variables, this are all except the first + # which is the selection variable. + branch_input_variable_names: list[str] = pjit._promote_literals_to_constants( + builder=builder, + var_names=in_var_names[1:], + jax_vars=eqn.invars[1:], + name_pattern=name_pattern, + ) + + if in_var_names[0] is None: + # The selection variable is a literal, so we will now pretend it is a symbol. + # This also means that we do not need a state transition to promote the + # variable to a symbol. + selection_symbol = str(util.get_jax_literal_value(eqn.invars[0])) + selection_state = eqn_state + + else: + # The selection variable is an input. + # For the implementation of the condition we need to promote the selection + # variable to a symbol, for which we need an interstate edge. + # As a side effect it will update the terminal state. + selection_variable_name = in_var_names[0] + selection_symbol = f"{selection_variable_name}_symb" + + selection_state = builder.append_new_state( + label=f"{name_pattern}_fork", + assignments={selection_symbol: selection_variable_name}, + prev_state=eqn_state, + ) + + # Now iterate through all branches, translate them and integrate them + # for each branch we will generate a dedicated state. + branch_states: list[dace.SDFGState] = [] + for i, branch_jaxpr in enumerate(eqn.params["branches"]): + branch_pattern = f"{name_pattern}_{{}}_branch_{i}" + branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg")) + + # This will update the terminal state only the first time. + branch_state = builder.append_new_state( + label=branch_pattern.format("state"), + condition=f"{selection_symbol} == {i}", + prev_state=selection_state, + ) + + # Integrating it. + ptranslation.add_nested_sdfg( + state=branch_state, + child_ctx=branch_ctx, + parent_ctx=builder._ctx, + in_var_names=branch_input_variable_names, + out_var_names=out_var_names, + ) + branch_states.append(branch_state) + + # Now we have to generate a join state that will serve as new terminal state. + # We append it to the first branch state, which is the current terminal state. + assert builder._terminal_sdfg_state is branch_states[0] + terminal_state = builder.append_new_state( + label=f"{name_pattern}_join", + prev_state=branch_states[0], + ) + for branch_state in branch_states[1:]: + builder.sdfg.add_edge( + branch_state, + terminal_state, + dace.sdfg.InterstateEdge(), + ) + + # We return it, because otherwise the builder will assume that `eqn_state` was used. + return terminal_state + + +def _cond_primitive_boolean_impl( + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] + in_var_names: Sequence[str | None], # noqa: ARG001 [unused-function-argument] + out_var_names: Sequence[str], # noqa: ARG001 [unused-function-argument] + eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] + eqn_state: dace.SDFGState, # noqa: ARG001 [unused-function-argument] +) -> dace.SDFGState: + """ + Implements the case the selector of the primitive is a bool. + + XLA explicitly provides this + [form of the primitive](https://openxla.org/xla/operation_semantics#conditional) + JAX however, does not seem to use it and instead forwards it to the integer + implementation. + JaCe will not implement it and instead generate an error. + """ + # NOTE: This is mostly to notice if JAX decided to implement that branch. + raise NotImplementedError("The boolean conditional primitive is not implemented.") diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py new file mode 100644 index 0000000..ee05a2a --- /dev/null +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -0,0 +1,85 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the translator for the `convert_element_type` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ConvertElementTypeTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `convert_element_type` primitive. + + The primitive will expand to a "copy Map", however, the Tasklet will not + simply copy the input to the output, but also perform type conversion. + However, in cases where the input type is the same as the output type, + the Tasklet will just be a copy Tasklet, that can then be removed by DaCe. + + Note: + This translator ignores the `new_dtype` and `weak_type` parameters of + the equation and only performs the casting based on the type of the fields. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="convert_element_type") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if in_var_names[0] is None: + raise NotImplementedError("'convert_element_type' is not supported for literals.") + + in_dtype = util.get_jax_var_dtype(eqn.invars[0]).type + in_dtype_s: str = in_dtype.__name__ + out_dtype = util.get_jax_var_dtype(eqn.outvars[0]).type + out_dtype_s: str = out_dtype.__name__ + + # This is the base of the template that we use for conversion. You should notice + # that the Tasklet `__out = __in0` will fail, see commit `f5aabc3` of the + # prototype. Thus we have to do it in this way. + conv_code = "__in0" + + if in_dtype == out_dtype: + # For some reason Jax sometimes adds conversions where no are needed. In + # these cases we explicitly create a copy Tasklet, which is trivial and can + # be removed by DaCe. + # TODO(phimuell): Create a Memlet instead. + return f"__out = {conv_code}" + + if in_dtype_s.startswith("bool"): + # Interestingly `__out = int(__in0)` will not work. + conv_code = f"(1 if {conv_code} else 0)" + if out_dtype_s.startswith("bool"): + conv_code = f"dace.bool_({conv_code})" + elif hasattr(dace.dtypes, out_dtype_s): + conv_code = f"dace.{out_dtype_s}({conv_code})" + else: + raise NotImplementedError( + f"Cannot convert '{in_dtype}' to '{out_dtype}' as this type is not known to DaCe." + ) + return f"__out = {conv_code}" + + +_ = translator.register_primitive_translator(ConvertElementTypeTranslator()) diff --git a/src/jace/translator/primitive_translators/copy_translator.py b/src/jace/translator/primitive_translators/copy_translator.py new file mode 100644 index 0000000..6de5ab9 --- /dev/null +++ b/src/jace/translator/primitive_translators/copy_translator.py @@ -0,0 +1,92 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the translator related to data movement.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class CopyTranslator(translator.PrimitiveTranslator): + """ + Implements the `copy` primitive. + + The translator is implemented by using a Memlet. + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring needed. + return "copy" + + def __call__( # noqa: D102 # No docstring + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, # noqa: ARG002 + eqn_state: dace.SDFGState, + ) -> None: + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet.from_array( + in_var_names[0], + builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string + ), + ) + + +class DevicePutTranslator(CopyTranslator): + """ + Implements the `device_put` primitive. + + In Jax this primitive is used to copy data between the host and the device, + in DaCe Memlets can do this. However, because of the way JaCe operates, at + least in the beginning a computation is either fully on the host or on the + device this copy will essentially perform a copying. + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring + return "device_put" + + @override + def __call__( # No docstring + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + if not (eqn.params["device"] is None and eqn.params["src"] is None): + raise NotImplementedError( + f"Can only copy on the host, but not from {eqn.params['src']} to {eqn.params['device']}." + ) + return super().__call__( + builder=builder, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) + + +_ = translator.register_primitive_translator(CopyTranslator()) +_ = translator.register_primitive_translator(DevicePutTranslator()) diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py new file mode 100644 index 0000000..343ee15 --- /dev/null +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -0,0 +1,211 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the translator for the `gather` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from jax import lax as jax_lax +from typing_extensions import override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class GatherTranslator(translator.PrimitiveTranslator): + """ + Garther Translator. + + The gather operation extracts patches of a certain size, known as `slice_size`, + from an array, called source or input array. Where these patches starts is + given by another array, the index array. + + See Also: + https://www.tensorflow.org/xla/operation_semantics#gather + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring needed. + return "gather" + + @override + def __call__( # noqa: PLR0914, PLR0915 # Just ported from the prototype, cleanup postponed. + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + """ + Performs the gather operation. + + Args: + builder: The builder object that is active. + in_var_names: The names of the input variables, the first array is + assumed as source array and the second is the index array. + out_var_names: The names of the output variables. + eqn: The equation to translate. + eqn_state: The state in which we put the extraction. + """ + assert len(eqn.invars) == 2 # noqa: PLR2004 # XLA supports more inputs. + + out_name = out_var_names[0] + out_shape = util.get_jax_var_shape(eqn.outvars[0]) + + src_name = in_var_names[0] + src_shape = util.get_jax_var_shape(eqn.invars[0]) + + idx_name = in_var_names[1] + idx_shape = util.get_jax_var_shape(eqn.invars[1]) + + dimension_numbers = eqn.params["dimension_numbers"] + offset_dims: Sequence[int] = dimension_numbers.offset_dims + collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims + start_index_map: Sequence[int] = dimension_numbers.start_index_map + slice_sizes: Sequence[int] = eqn.params["slice_sizes"] + mode: jax_lax.GatherScatterMode = eqn.params["mode"] + assert len(start_index_map) == idx_shape[-1] + + if mode != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: + raise NotImplementedError(f"The mode {mode} is not implemented.") + + # Over these dimensions the copy of the patches goes. + batch_dims = tuple(d for d in range(len(out_shape)) if d not in offset_dims) + + # Every batch dimension is associated with one dimension of of the index + # array, but there is always one dimension more in the index array. This + # dimension contains the start indexes of the slice, if there is only + # one index that should be loaded is not strictly necessary, but Jax + # (currently adds) it implicitly, probably to make life easier. + if (len(batch_dims) + 1) != len(idx_shape): + raise ValueError( + f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." + ) + + # These are the dimensions (of the input) for which a map index is created. + # Note that we exclude collapsed dimensions here. + src_dim_with_map_idx = tuple( + dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + ) + assert len(src_dim_with_map_idx) == len(offset_dims) + + # The final map is the composition of two loops. The first map iterates over + # the index array, except the last dimension, and is used to "copy" the + # different patches from the source to the output array. These map parameters + # follow the pattern `__i{out_name}_gather{bd}`, where `bd` is a batch + # dimension. These variables are used to access the index array. + # The second loop performs the actual copy of the slices. For these + # the variables `__i{i}` is used were, these are known as offset + # dimensions. + # What is a bit difficult, that the actual access/dereferencing of the source + # array is done within the tasklet. + + # Access pattern of the source array _within_ the tasklet. + src_access_pattern: list[str] = [] + + # These are the map ranges for the coying of the slicing. + slice_map_ranges: list[tuple[str, str]] = [] + + # Compute the access pattern within the tasklet. + # As a side effect we also compute the map ranges, but only for the slices. + for dim, slice_size in enumerate(slice_sizes): + # Order is important! + if dim not in start_index_map: + # This dimension is fully copied + slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(slice_map_ranges[-1][0]) + assert dim in src_dim_with_map_idx + assert slice_size == src_shape[dim] + + elif dim in collapsed_slice_dims: + # This dimension is only partially copied, however, since the + # dimension is collapsed, only a single element is copied that + # comes from the index array. + src_access_pattern.append(f"__gather_{dim}") + + else: + # This dimension is partially copied, but is _not colapsed_, we need + # a map index to copy the range. However, there is also an offset + # that is involved from copying. + slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(f"__gather_{dim} + {slice_map_ranges[-1][0]}") + assert dim in src_dim_with_map_idx + assert slice_size <= src_shape[dim] + + # These are the map variable that go over the index array. + patch_loop_vars = tuple(f"__i{out_name}_gather{bd}" for bd in batch_dims) + patch_map_ranges = [ + (map_param, f"0:{patch_loop_bound}") + for map_param, patch_loop_bound in zip(patch_loop_vars, idx_shape[:-1]) + ] + + # Creating the input memlet that allows us to access the source array from + # inside the tasklet and make it accessible through the name `__arr`. At + # this point it is not possible to tell where we access, because we are + # missing a index variables, they will only be accessible inside the + # tasklet (see below), however, we know that we will access only one + # element from the array. + tasklet_inputs: dict[str, dace.Memlet] = { + "__arr": dace.Memlet.simple( + data=src_name, + subset_str=", ".join(f"0:{size}" for size in src_shape), + num_accesses=1, + ), + } + + # Now we are creating the memlets that access the index array. + for i, dim in enumerate(start_index_map): + tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( + data=idx_name, subset_str=(", ".join(patch_loop_vars) + f", {i}") + ) + + # The tasklet code. + tasklet_code = "__out = __arr[" + ", ".join(src_access_pattern) + "]" + + # Now we generate the output memlet. + outpt_access_pattern: list[str] = [] + dim_counter = 0 + for dim in range(len(out_shape)): + if dim in batch_dims: + # This is a batch dimension, thus a loop variable is used for it. + patch_loop_var = patch_loop_vars[batch_dims.index(dim)] + outpt_access_pattern.append(str(patch_loop_var)) + + else: + # This is a dimension for copying the slices. + assert dim_counter <= len(src_dim_with_map_idx) + associated_map_idx = src_dim_with_map_idx[dim_counter] + dim_counter += 1 + outpt_access_pattern.append(f"__i{associated_map_idx}") + assert dim_counter == len(src_dim_with_map_idx) + + tasklet_outputs: dict[str, dace.Memlet] = { + "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(outpt_access_pattern)) + } + assert len(patch_map_ranges) + len(slice_map_ranges) == len(out_shape) + + eqn_state.add_mapped_tasklet( + name=f"_gather_map_{out_name}", + map_ranges=patch_map_ranges + slice_map_ranges, + inputs=tasklet_inputs, + code=tasklet_code, + outputs=tasklet_outputs, + external_edges=True, + ) + + +_ = translator.register_primitive_translator(GatherTranslator()) diff --git a/src/jace/translator/primitive_translators/iota_translator.py b/src/jace/translator/primitive_translators/iota_translator.py new file mode 100644 index 0000000..ce0d99f --- /dev/null +++ b/src/jace/translator/primitive_translators/iota_translator.py @@ -0,0 +1,56 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This implements the `iota` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from typing_extensions import override + +from jace import translator +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import dace + from jax import core as jax_core + + +class IotaTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `iota` primitive. + + Essentially, a very general `jnp.arange()` function. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="iota") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return f"__out = {tskl_ranges[eqn.params['dimension']][0]}" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + return {} + + +translator.register_primitive_translator(IotaTranslator()) diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py new file mode 100644 index 0000000..59bfd7e --- /dev/null +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -0,0 +1,147 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the `pjit` translator, i.e. nested Jaxpr expressions.""" + +from __future__ import annotations + +import re +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from jax._src import sharding_impls as jax_sharding # noqa: PLC2701 [import-private-name] + +from jace import translator, util +from jace.translator import post_translation as ptranslation + + +if TYPE_CHECKING: + import dace + from jax._src import core as jax_core + + +def _promote_literals_to_constants( + builder: translator.JaxprTranslationBuilder, + var_names: Sequence[str | None], + jax_vars: Sequence[jax_core.Atom], + name_pattern: str, +) -> list[str]: + """ + Promotes all literals in `var_names` to DaCe constants and add them to the SDFG. + + The function assumes that `var_names` are the SDFG variables equivalents of + `jax_vars`, as by convention `None` indicates a literal. The function will create + a constant for each literal and return `var_names` cleared of all literals. + For naming the variables the function will use `name_pattern`. + + Args: + builder: The builder that is used for translation. + var_names: Names of the SDFG variables, `None` indicates a literal. + jax_vars: The JAX variables, in the same order than `var_names`. + name_pattern: A pattern to generate a unique name for the variables. + + Todo: + Is a constant the right idea or should we generate a symbol? + """ + promoted_var_names: list[str] = [] + for i, var_name in enumerate(var_names): + if var_name is None: + promoted_var_name = f"__const_{name_pattern}_literal_promotion_{i}" + jax_var = jax_vars[i] + promoted_jace_var = util.JaCeVar.from_atom( + jax_var=jax_var, + name=promoted_var_name, + ) + builder.add_array(promoted_jace_var) + builder.sdfg.add_constant( + promoted_var_name, + util.get_jax_literal_value(jax_var), + builder.arrays[promoted_var_name], + ) + + else: + # Already an SDFG variable, so nothing to do. + promoted_var_name = var_name + promoted_var_names.append(promoted_var_name) + return promoted_var_names + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("pjit") +def PJITTranslator( # noqa: N802 [invalid-function-name] + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `pjit` translator that handles nested Jaxpr. + + `pjit` primitives in JAX represents nested calls, for example the body of a scan + is inside a nested Jaxpr. However, `pjit` is used to indicate that a computation + should be done on the device or on sharded memory. + + However, due to the current state and working of JaCe, this aspect is essentially + ignored and the computation is always inlined. + + In case an input is a literal the translator will create a constant for it. + + Args: + builder: The builder object of the translation. + in_var_names: Names of the SDFG variables that should be used as inputs + inside the parent SDFG. + out_var_names: Names of SDFG variables that should be used as outputs + inside the parent SDFG. + eqn: The equation that contains the `pjit` primitive. + eqn_state: State into which the nested SDFG should be constructed. + """ + params: dict[str, Any] = eqn.params + nested_jaxpr: jax_core.ClosedJaxpr = params["jaxpr"] + in_shardings = params["in_shardings"] + out_shardings = params["out_shardings"] + _ = params["donated_invars"] # Always ignored + _ = params["keep_unused"] + _ = params["inline"] + + if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings): + raise NotImplementedError("Currently 'pjit' does not support sharding in its input.") + if not all(out_sharding is jax_sharding.UNSPECIFIED for out_sharding in out_shardings): + raise NotImplementedError("Currently 'pjit' does not support sharding in its output.") + + # TODO(phimuell): Controlflow region and name + pjit_name = params["name"] + + # TODO(phimuell): Controlflow region and name + # They will introduce a feature like that to address them in optimizations. + pjit_name = params["name"] + + # Name in SDFG must be unique, thus we mangle it, furthermore, we have to clean it. + sdfg_name = f"pjit_{re.subn('[^a-zA-Z0-9_]', '_', pjit_name)[0]}__{'_'.join(out_var_names)}" + + # Ensure that all inputs are SDFG variables + final_input_names = _promote_literals_to_constants( + builder=builder, + var_names=in_var_names, + jax_vars=eqn.invars, + name_pattern=sdfg_name, + ) + + # Now get the translated SDFG. + nested_context: translator.TranslationContext = builder.translate_jaxpr( + jaxpr=nested_jaxpr, + name=sdfg_name, + ) + + # Now lets add the nested SDFG + ptranslation.add_nested_sdfg( + state=eqn_state, + child_ctx=nested_context, + parent_ctx=builder._ctx, + in_var_names=final_input_names, + out_var_names=out_var_names, + ) diff --git a/src/jace/translator/primitive_translators/reshape_translator.py b/src/jace/translator/primitive_translators/reshape_translator.py new file mode 100644 index 0000000..241cc94 --- /dev/null +++ b/src/jace/translator/primitive_translators/reshape_translator.py @@ -0,0 +1,67 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the translator for the `reshape` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ReshapeTranslator(translator.PrimitiveTranslator): + """ + Implements the `reshape` primitive. + + The current implementation uses a Memlet for this and essentially acts as + an optimization barrier. Furthermore the Jax primitive also has the optional + `dimensions` parameters which allows to permute the input, this is not + supported. + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring needed. + return "reshape" + + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + """ + Performs the reshaping. + + Currently a copy using a Memlet is performed. + """ + if eqn.params["dimensions"] is not None: + raise NotImplementedError("Currently 'dimensions' must be 'None'.") + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet( + data=in_var_names[0], + subset=", ".join(f"0:{size}" for size in util.get_jax_var_shape(eqn.invars[0])), + other_subset=", ".join(f"0:{size}" for size in eqn.params["new_sizes"]), + ), + ) + + +translator.register_primitive_translator(ReshapeTranslator()) diff --git a/src/jace/translator/primitive_translators/select_n_translator.py b/src/jace/translator/primitive_translators/select_n_translator.py new file mode 100644 index 0000000..51b27b3 --- /dev/null +++ b/src/jace/translator/primitive_translators/select_n_translator.py @@ -0,0 +1,94 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements `select_n`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SelectNTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `select_n` primitive. + + The `select_n` primitive is a generalization of `np.where`, that can take an + arbitrary number of branches, which are selected by an integer predicate. + The behaviour is undefined if the predicate is out of bound. + + Note: + For a better understanding this function renames its input connectors. + The first one, which is the predicate, is renamed to `__cond` and the + others are renamed again to `__in{i}`, starting with zero. + + Todo: + Implement the primitive as a nested SDFG. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="select_n") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if len(in_var_names) == 3: # noqa: PLR2004 # `3` is not magic. + # This order is correct, since `False` is interpreted as `0`, which means + # the first case. DaCe seems to have some problems with bools and integer + # casting around, so we handle the bool case explicitly here. + # See also `ConvertElementTypeTranslator`. + return "__out = __in1 if __cond else __in0" + + return "\n".join( + ["if __cond == 0: __out = __in0"] + + [f"elif __cond == {i}: __out = __in{i}" for i in range(1, len(in_var_names) - 1)] + ) + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + return { + f"__in{i - 1}" if i else "__cond": dace.Memlet.simple( + in_var_name, ", ".join(f"{it_idx}" for it_idx, _ in tskl_ranges) + ) + for i, in_var_name in enumerate(in_var_names) + if in_var_name + } + + @override + def literal_substitution( + self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn + ) -> str: + assert in_var_names[0] # Condition can never be a literal. + for i, in_var_name in enumerate(in_var_names[1:]): + if in_var_name is not None: + continue + t_val = util.get_jax_literal_value(eqn.invars[i + 1]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + return tskl_code + + +translator.register_primitive_translator(SelectNTranslator()) diff --git a/src/jace/translator/primitive_translators/slicing.py b/src/jace/translator/primitive_translators/slicing.py new file mode 100644 index 0000000..ae4f167 --- /dev/null +++ b/src/jace/translator/primitive_translators/slicing.py @@ -0,0 +1,198 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements slicing.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SlicingTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `slice` primitive. + + This is the classical slicing operation which extracts a fixed sized window + from a fixed initial position. The slicing is implemented using a partial copy. + + Note: + Slices are essentially optimization barriers as they can not be fused + with Maps before them. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="slice") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + """We have to add the offsets to the Memlet accesses.""" + strides: Sequence[int] = ( + ((1,) * len(tskl_ranges)) if eqn.params["strides"] is None else eqn.params["strides"] + ) + start_indices: Sequence[int] = eqn.params["start_indices"] # Fist index to slice + return { + "__in0": dace.Memlet.simple( + in_var_names[0], + ", ".join( + f"{start_index} + {it_idx} * {stride}" + for (it_idx, _), start_index, stride in zip(tskl_ranges, start_indices, strides) + ), + ) + } + + +class DynamicSlicingTranslator(translator.PrimitiveTranslator): + """ + Implements the `dynamic_slice` primitive. + + [Dynamic slicing](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) + performs a slicing of a _fixed_ window, but the start of the window is + not fix, instead it is passed by variables. Furthermore, (as it is in Jax), + if the window would overrun the start indexes are adjusted. + + Todo: + - Prevent that the modified start indexes are promoted to symbols, + to ensure mergability. + """ + + @property + def primitive(self) -> str: # noqa: D102 # No docstring needed. + return "dynamic_slice" + + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + assert in_var_names[0] + assert len(in_var_names) == len(util.get_jax_var_shape(eqn.invars[0])) + 1 + + # This is the sizes of the slice window. + window_sizes: Sequence[int] = eqn.params["slice_sizes"] + + # Maps the variable name, that stores the start index of the window in one + # dimensions to the access node, that holds the value. The variable name + # is also used as dynamic range offset. + # Only present if the index is not a literal. + in_access: dict[str, dace.nodes.AccessNode] = {} + + # Name of the variable from where we get the start index of the window + # or the value itself, if it is a literal; in the order of the dimension. + # If the value is `None` then the literal was not yet processed. + window_start_indices: list[str | None] = list(in_var_names[1:]) + + # We will always adapt the start indexes and not check if it is needed. + for dim, (window_start_index, dim_size, window_size) in enumerate( + zip(window_start_indices, util.get_jax_var_shape(eqn.invars[0]), window_sizes) + ): + if window_start_index is None: + # Jax does not adjust the literals on its own + raw_window_start = int(util.get_jax_literal_value(eqn.invars[dim + 1])) # type: ignore[arg-type] # type confusion + adjusted_window_start = min(dim_size, raw_window_start + window_size) - window_size + window_start_indices[dim] = str(adjusted_window_start) + continue + + # We do not use a symbol for the start of the window but a Tasklet, as + # a symbol would need an interstage edge, which is an optimization barrier. + tasklet = dace.nodes.Tasklet( + label=f"adjustment_of_slice_start_{window_start_index}_for_{out_var_names[0]}", + inputs={"unadjusted_start_idx": None}, + outputs={"adjusted_start_idx": None}, + code=f"adjusted_start_idx = min(unadjusted_start_idx + {window_size}, {dim_size}) - {window_size}", + ) + new_start_idx_var_name = builder.add_array( + eqn.invars[dim + 1], name_prefix="__jace_adapted_start_idx_" + ) + new_start_idx_acc = eqn_state.add_access(new_start_idx_var_name) + + eqn_state.add_edge( + eqn_state.add_read(window_start_index), + None, + tasklet, + "unadjusted_start_idx", + dace.Memlet.simple(window_start_index, "0"), + ) + eqn_state.add_edge( + tasklet, + "adjusted_start_idx", + new_start_idx_acc, + None, + dace.Memlet.simple(new_start_idx_var_name, "0"), + ) + # Update the name of the start index, and store the access + # node for later use. + window_start_indices[dim] = new_start_idx_var_name + in_access[new_start_idx_var_name] = new_start_idx_acc + + tskl_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + + memlet_accesses: list[str] = [] + + for (it_var, _), offset_symbol_name in zip(tskl_ranges, window_start_indices): + assert offset_symbol_name is not None + memlet_accesses.append(f"{it_var} + {offset_symbol_name}") + + tskl_input = dace.Memlet.simple(in_var_names[0], ", ".join(memlet_accesses)) + tskl_output = dace.Memlet.simple( + out_var_names[0], ", ".join(name for name, _ in tskl_ranges) + ) + _, map_entry, _ = eqn_state.add_mapped_tasklet( + name=f"{self.primitive}_{out_var_names[0]}", + map_ranges=tskl_ranges, + inputs={"__in": tskl_input}, + code="__out = __in", + outputs={"__out": tskl_output}, + external_edges=True, + ) + + # Creating the inputs for the dynamic map ranges. We have to use the same + # access nodes as above, to ensure a single order of computation. + for window_start_index_name, windows_start_access_node in in_access.items(): + eqn_state.add_edge( + windows_start_access_node, + None, + map_entry, + window_start_index_name, + dace.Memlet.simple(window_start_index_name, "0"), + ) + map_entry.add_in_connector(window_start_index_name) + + +translator.register_primitive_translator(SlicingTranslator()) +translator.register_primitive_translator(DynamicSlicingTranslator()) diff --git a/src/jace/translator/primitive_translators/squeeze_translator.py b/src/jace/translator/primitive_translators/squeeze_translator.py new file mode 100644 index 0000000..de6f1f4 --- /dev/null +++ b/src/jace/translator/primitive_translators/squeeze_translator.py @@ -0,0 +1,69 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the `squeeze` primitive.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SqueezeTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `squeeze` primitive. + + The primitives allows to remove dimensions of size one. Essentially + equivalent to `np.squeeze` and the inverse to `np.expand_dims()`. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="squeeze") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + dims_to_delete: Sequence[str] = eqn.params["dimensions"] + in_rank: int = len(util.get_jax_var_shape(eqn.invars[0])) + cnt = itertools.count(0) + return { + "__in0": dace.Memlet.simple( + in_var_names[0], + ", ".join( + "0" if dim in dims_to_delete else tskl_ranges[next(cnt)][0] + for dim in range(in_rank) + ), + ) + } + + +translator.register_primitive_translator(SqueezeTranslator()) diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index bc2de21..7c9f2f0 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -81,6 +81,27 @@ def __eq__(self, other: Any) -> bool: return NotImplemented return id(self) == id(other) + @classmethod + def from_atom( + cls, + jax_var: jax_core.Atom, + name: str | None, + ) -> JaCeVar: + """ + Generates a `JaCeVar` from the JAX variable `jax_var`. + + If `jax_var` is a literal its value is ignored. + + Args: + jax_var: The variable to process. + name: The optional name of the variable. + """ + return cls( + shape=get_jax_var_shape(jax_var), + dtype=get_jax_var_dtype(jax_var), + name=name, + ) + def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: """Returns the name of `jax_var` as a string.""" From 25616ae13f12a4d4d6574e9740b632496b5e757e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 13 Sep 2024 15:20:35 +0200 Subject: [PATCH 2/9] Modified the condition primitive. Before it was implementewd as a switch, for the case JAX would use the bool overload of XLA. The check for this was now essentially moved inside the function. --- .../primitive_translators/conditions.py | 58 ++----------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index d291016..38ba2c2 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -52,41 +52,15 @@ def condition_translator( new states and will return the new SDFG state that serves as terminal state. Note: - The implementation assumes that the selector, i.e. the variables indicating - which branch should be taken is inside its bound. + This function essentially implements a C `switch` statement, however, there + is no default branch. """ if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: - return _cond_primitive_boolean_impl( - builder=builder, - in_var_names=in_var_names, - out_var_names=out_var_names, - eqn=eqn, - eqn_state=eqn_state, - ) - return _cond_primitive_multi_switch_impl( - builder=builder, - in_var_names=in_var_names, - out_var_names=out_var_names, - eqn=eqn, - eqn_state=eqn_state, - ) - + # XLA explicitly provides this [form of the primitive](https://openxla.org/xla/operation_semantics#conditional) + # JAX however, does not seem to use it at the moment and instead forwards it + # to the integer implementation. + raise NotImplementedError("The boolean conditional primitive is not implemented.") -def _cond_primitive_multi_switch_impl( - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, -) -> dace.SDFGState: - """ - Implements the integer version of the conditional primitive. - - For arguments see `ConditionTranslator`. - - This [version](https://openxla.org/xla/operation_semantics#conditional) is - essentially a C switch statement without a default branch. - """ # To make names in the SDFG unique we use the name of the equation state name_pattern = eqn_state.name @@ -160,23 +134,3 @@ def _cond_primitive_multi_switch_impl( # We return it, because otherwise the builder will assume that `eqn_state` was used. return terminal_state - - -def _cond_primitive_boolean_impl( - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] - in_var_names: Sequence[str | None], # noqa: ARG001 [unused-function-argument] - out_var_names: Sequence[str], # noqa: ARG001 [unused-function-argument] - eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] - eqn_state: dace.SDFGState, # noqa: ARG001 [unused-function-argument] -) -> dace.SDFGState: - """ - Implements the case the selector of the primitive is a bool. - - XLA explicitly provides this - [form of the primitive](https://openxla.org/xla/operation_semantics#conditional) - JAX however, does not seem to use it and instead forwards it to the integer - implementation. - JaCe will not implement it and instead generate an error. - """ - # NOTE: This is mostly to notice if JAX decided to implement that branch. - raise NotImplementedError("The boolean conditional primitive is not implemented.") From 24d97fb47ee759fd601c29c9191ade23fbc0b01e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 13 Sep 2024 15:23:18 +0200 Subject: [PATCH 3/9] Nobody needs that thest in this form. --- tests/test_subtranslator_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_subtranslator_helper.py b/tests/test_subtranslator_helper.py index a4c4ad9..52672b0 100644 --- a/tests/test_subtranslator_helper.py +++ b/tests/test_subtranslator_helper.py @@ -75,7 +75,7 @@ def fake_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 def test_are_subtranslators_imported(): """Tests if something is inside the list of subtranslators.""" # Must be adapted if new primitives are implemented. - assert len(get_registered_primitive_translators()) == 37 + assert len(get_registered_primitive_translators()) > 0 @pytest.mark.usefixtures("no_builtin_translators") From c8b7d86e5f7a6158a7e02768ec1ac6627a538817 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 09:00:01 +0200 Subject: [PATCH 4/9] First batch of Enrique's suggestions, but still not done. --- .../translator/jaxpr_translator_builder.py | 29 ++- .../mapped_operation_base_translator.py | 50 ++--- src/jace/translator/post_translation.py | 51 +++++- src/jace/translator/primitive_translator.py | 2 +- .../primitive_translators/__init__.py | 21 ++- .../arithmetic_logical_translators.py | 54 +++--- .../broadcast_in_dim_translator.py | 23 +-- .../concatenate_translator.py | 99 ++++------ .../primitive_translators/conditions.py | 80 ++++---- .../convert_element_type_translator.py | 31 ++-- .../primitive_translators/copy_translator.py | 99 +++++----- .../gather_translator.py | 2 +- .../primitive_translators/iota_translator.py | 2 +- .../primitive_translators/pjit_translator.py | 74 ++------ .../reshape_translator.py | 66 +++---- .../select_n_translator.py | 22 +-- .../primitive_translators/slicing.py | 171 ++++++++---------- .../squeeze_translator.py | 2 +- 18 files changed, 401 insertions(+), 477 deletions(-) diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index 9b76407..c82c277 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -37,10 +37,10 @@ class JaxprTranslationBuilder: - it lacks the special `__return` variable, - the `arg_names` parameter is not set, - for all scalar values a `Scalar` SDFG variable is used, thus they cannot - be used for return values, + be used for returning values, - for every transient there is exactly one access node that writes to it, - except the name of the array starts with `__jace_mutable_`, which can - be written to multiple times. + except if the name of the array starts with `__jace_mutable_`, in which case + it can be written to multiple times. For these reasons the SDFG is not directly usable, and further manipulations have to be performed. Especially, DaCe's validation function will fail and @@ -179,6 +179,24 @@ def append_new_state( self._ctx.terminal_state = new_state return new_state + def add_orphan_state( + self, + label: str, + ) -> dace.SDFGState: + """ + Add a new orphan state to the SDFG. + + The state is not connected to any other state, nor it is the new start state. + Except you know what you are doing you should not use this function and + instead use `self.append_new_state()`. + + Args: + label: The name of the state. + """ + if not self.is_allocated(): + raise RuntimeError("Builder is not allocated.") + return self._ctx.sdfg.add_state(label=label, is_start_block=False) + @property def arrays(self) -> Mapping[str, dace_data.Data]: """ @@ -712,7 +730,7 @@ def _propagate_memlets_in_new_states( ] while nodes_to_process: - currently_processing = nodes_to_process.pop(-1) + currently_processing = nodes_to_process.pop() if ( self.sdfg.out_degree(currently_processing) == 0 and currently_processing != new_terminal_state @@ -790,7 +808,7 @@ def __init__(self, name: str | None, jaxpr: jax_core.ClosedJaxpr) -> None: self.terminal_state = self.start_state self.jaxpr = jaxpr - def validate(self) -> bool: + def validate(self) -> None: """ Validate internal state of `self`. @@ -829,4 +847,3 @@ def validate(self) -> bool: self.sdfg, None, ) - return True diff --git a/src/jace/translator/mapped_operation_base_translator.py b/src/jace/translator/mapped_operation_base_translator.py index 9f0f402..17a5c35 100644 --- a/src/jace/translator/mapped_operation_base_translator.py +++ b/src/jace/translator/mapped_operation_base_translator.py @@ -37,11 +37,10 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator): ``` where `__in*` are the connector names of the Tasklet and `__out` is the output connector. For problems such as this, the SDFG API provides the - `SDFGState.add_mapped_tasklet()` function, however, in most cases it can not - be directly used, for various reasons. Thus this class acts like a - convenience wrapper around it. + `SDFGState.add_mapped_tasklet()` function, however, because it is very low + level and very verbose to use, this class acts as a convenience wrapper around it. - To use this class a user has to overwrite the `write_tasklet_code()` function. + To use this class a user has to define the abstract `write_tasklet_code()` method. This function generates the entire code that should be put into the Tasklet, include the assignment to `__out`. If needed the translator will perform literal substitution on the returned code and broadcast the inputs to match @@ -51,7 +50,7 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator): to generate custom input Memlets, such as adding an offset. Args: - primitive_name: The name of the primitive `self` should bind to. + primitive_name: The name of the primitive `self` should bind to. Note: This class will always generate a mapped Tasklet, even if a scalar is handled. @@ -78,7 +77,7 @@ def __call__( """ Create the mapped Tasklet. - The function will create the map ranges and based on the shape of the + The function will create the map ranges based on the shape of the output array. It will then call `make_input_memlets()` to get the input Memlets. After that it calls `write_tasklet_code()` to get the Tasklet code and perform literal substitution by forwarding it to @@ -88,7 +87,7 @@ def __call__( For a description of the arguments see `PrimitiveTranslatorCallable`. """ assert len(out_var_names) == 1 - if util.get_jax_var_shape(eqn.outvars[0]) != (): + if util.get_jax_var_shape(eqn.outvars[0]): tskl_ranges: list[tuple[str, str]] = [ (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) @@ -130,20 +129,20 @@ def write_tasklet_code( eqn: jax_core.JaxprEqn, ) -> str: """ - Return the (Python) code that should be put inside the Tasklet. + Return the Python code that should be put inside the Tasklet. This also includes the assignment statement, i.e. `__out`. However, the base will do literal substitution on the returned object. Args: - tskl_ranges: List of pairs used as map parameter, first element + tskl_ranges: List of pairs used as map parameter, first element is the name iteration index of the dimension, second is its range. - in_var_names: The list of SDFG variables used as input, `None` if literal. - eqn: The equation. + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation. """ ... - def make_input_memlets( # noqa: PLR6301 # Subclasses might need them. + def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need them. self, tskl_ranges: Sequence[tuple[str, str]], in_var_names: Sequence[str | None], @@ -156,10 +155,10 @@ def make_input_memlets( # noqa: PLR6301 # Subclasses might need them. that is used to connect it to the Map entry node. Args: - tskl_ranges: List of pairs used as map parameter, first element + tskl_ranges: List of pairs used as map parameter, first element is the name iteration index of the dimension, second is its range - in_var_names: The list of SDFG variables used as input, `None` if literal. - eqn: The equation object. + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation object. """ out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output out_rank = len(out_shp) @@ -181,7 +180,11 @@ def make_input_memlets( # noqa: PLR6301 # Subclasses might need them. tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar continue - # We have to to broadcasting (combine yes and no together) + # We might have to do broadcasting. + # We ensured that input and output have the same rank (JAX is doing that + # for us). So we must do broadcasting, i.e. replicating that input + # dimension, if its size is 1. We threat the case where the output has + # size 1 in that dimension as broadcasting as well. dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1] tskl_inputs[f"__in{i}"] = dace.Memlet.simple( in_var_name, @@ -192,23 +195,22 @@ def make_input_memlets( # noqa: PLR6301 # Subclasses might need them. ) return tskl_inputs - def literal_substitution( # noqa: PLR6301 # Subclasses might need it. + def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it. self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn ) -> str: """ Perform literal substitution on the proto Tasklet code `tskl_code`. Args: - tskl_code: The proto Tasklet code with literal. - in_var_names: The list of SDFG variables used as input. - eqn: The equation. + tskl_code: The proto Tasklet code with literal. + in_var_names: The list of SDFG variables used as input. + eqn: The equation. Note: It is allowed but not recommended to override this function. """ for i, in_var_name in enumerate(in_var_names): - if in_var_name is not None: - continue - t_val = util.get_jax_literal_value(eqn.invars[i]) - tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + if in_var_name is None: + t_val = util.get_jax_literal_value(eqn.invars[i]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) return tskl_code diff --git a/src/jace/translator/post_translation.py b/src/jace/translator/post_translation.py index 9831f35..6b27b4f 100644 --- a/src/jace/translator/post_translation.py +++ b/src/jace/translator/post_translation.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from dace.sdfg import nodes as dace_nodes + from jax import core as jax_core from jace import translator @@ -271,7 +272,8 @@ def add_nested_sdfg( will first pass it to `finalize_translation_context()` and operates on the return values. This means that `child_ctx` will be modified in place, and a copy will be added to `parent_ctx`. - It is highly recommended that `state` is empty. + It is highly recommended that `state` is empty, this makes subsequent + inlining of the nested SDFG simpler. """ if child_ctx.sdfg.free_symbols: raise NotImplementedError("Symbol Mapping is not implemented.") @@ -298,7 +300,6 @@ def add_nested_sdfg( nested_sdfg: dace_nodes.NestedSDFG = state.add_nested_sdfg( sdfg=final_child_ctx.sdfg, parent=parent_ctx.sdfg, - # Bug in DaCe must be a set. inputs=set(final_child_ctx.input_names), outputs=set(final_child_ctx.output_names), ) @@ -326,3 +327,49 @@ def add_nested_sdfg( ) return nested_sdfg + + +def promote_literals_to_constants( + builder: translator.JaxprTranslationBuilder, + var_names: Sequence[str | None], + jax_vars: Sequence[jax_core.Atom], + name_pattern: str, +) -> list[str]: + """ + Promotes all literals in `var_names` to DaCe constants and add them to the SDFG. + + The function assumes that `var_names` are the SDFG variables equivalents of + `jax_vars`, as by convention `None` indicates a literal. The function will create + a constant for each literal and return `var_names` cleared of all literals. + For naming the variables the function will use `name_pattern`. + + Args: + builder: The builder that is used for translation. + var_names: Names of the SDFG variables, `None` indicates a literal. + jax_vars: The JAX variables, in the same order than `var_names`. + name_pattern: A pattern to generate a unique name for the variables. + + Todo: + Is a constant the right idea or should we generate a symbol? + """ + promoted_var_names: list[str] = [] + for i, var_name in enumerate(var_names): + if var_name is None: + promoted_var_name = f"__const_{name_pattern}_literal_promotion_{i}" + jax_var = jax_vars[i] + promoted_jace_var = util.JaCeVar.from_atom( + jax_var=jax_var, + name=promoted_var_name, + ) + builder.add_array(promoted_jace_var) + builder.sdfg.add_constant( + promoted_var_name, + util.get_jax_literal_value(jax_var), + builder.arrays[promoted_var_name], + ) + + else: + # Already an SDFG variable, so nothing to do. + promoted_var_name = var_name + promoted_var_names.append(promoted_var_name) + return promoted_var_names diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py index 2000731..71aa067 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/primitive_translator.py @@ -77,7 +77,7 @@ def __call__( Args: builder: The builder object of the translation. in_var_names: List of the names of the arrays created inside the - SDFG for the inpts or `None` in case of a literal. + SDFG for the inputs or `None` in case of a literal. out_var_names: List of the names of the arrays created inside the SDFG for the outputs. eqn: The JAX primitive that should be translated. diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py index 9e2fec0..757743e 100644 --- a/src/jace/translator/primitive_translators/__init__.py +++ b/src/jace/translator/primitive_translators/__init__.py @@ -13,33 +13,34 @@ LogicalOperationTranslator, ) from .broadcast_in_dim_translator import BroadcastInDimTranslator -from .concatenate_translator import ConcatenateTranslator +from .concatenate_translator import concatenate_translator from .conditions import condition_translator from .convert_element_type_translator import ConvertElementTypeTranslator -from .copy_translator import CopyTranslator, DevicePutTranslator +from .copy_translator import copy_translator, device_put_translator from .gather_translator import GatherTranslator from .iota_translator import IotaTranslator -from .pjit_translator import PJITTranslator -from .reshape_translator import ReshapeTranslator +from .pjit_translator import pjit_translator +from .reshape_translator import reshape_translator from .select_n_translator import SelectNTranslator -from .slicing import SlicingTranslator +from .slicing import SlicingTranslator, dynamic_slicing_translator from .squeeze_translator import SqueezeTranslator __all__ = [ "ArithmeticOperationTranslator", "BroadcastInDimTranslator", - "ConcatenateTranslator", "ConvertElementTypeTranslator", - "CopyTranslator", - "DevicePutTranslator", "GatherTranslator", "IotaTranslator", "LogicalOperationTranslator", - "PJITTranslator", - "ReshapeTranslator", "SelectNTranslator", "SlicingTranslator", "SqueezeTranslator", + "concatenate_translator", "condition_translator", + "copy_translator", + "device_put_translator", + "dynamic_slicing_translator", + "pjit_translator", + "reshape_translator", ] diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py index c9c0a35..667e1ac 100644 --- a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -6,7 +6,7 @@ # SPDX-License-Identifier: BSD-3-Clause """ -Module containing all translators related to arithmetic and logical operations. +Primitive translators related to all arithmetic, logical and comparison operations. Todo: - Hijack Jax to inject a proper modulo operation. @@ -31,21 +31,14 @@ class ArithmeticOperationTranslator(mapped_base.MappedOperationTranslatorBase): """ - Translator for all arithmetic operations. - - The class is derived from `MappedOperationTranslatorBase` and overwrites the - `write_tasklet_code()` function for the Tasklet code. + Translator for all arithmetic operations and comparisons. Args: - prim_name: The name of the primitive that should be handled. - tskl_tmpl: Template used for generating the Tasklet code. + prim_name: The name of the primitive that should be handled. + tskl_tmpl: Template used for generating the Tasklet code. Note: - - It does not implement the logical operations, they are implemented by - the `LogicalOperationTranslator` class. - - Despite its name this class also provides the comparison operators. - - It does not implement `mod` nor `fmod` as they are translated to some - nested `pjit` implementation by Jax for unknown reasons. + Logical and bitwise operations are implemented by `LogicalOperationTranslator`. """ def __init__(self, prim_name: str, tskl_tmpl: str) -> None: @@ -60,10 +53,7 @@ def write_tasklet_code( eqn: jax_core.JaxprEqn, ) -> str: """Returns the code for the Tasklet, with all parameters replaced.""" - tskl_code = self._tskl_tmpl - if len(eqn.params) != 0: - tskl_code = tskl_code.format(**eqn.params) - return tskl_code + return self._tskl_tmpl.format(**eqn.params) class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): @@ -82,15 +72,15 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): as `~true` in C++ is essentially `~1`, which is again `true`! Thus the `not` primitive must be handled separately. - The solution to the problem is, to introduce two templates, one used for the + The solution to the problem is to introduce two templates, one used for the bool context and one used in the integer context. This works because depending if the `logical_*()` or `bitwise_*()` functions are used the input is either of type bool or an integer. Args: - prim_name: The name of the primitive that should be handled. - int_tmpl: The template used for the integer case. - bool_tmpl: The template used for the bool case. + prim_name: The name of the primitive that should be handled. + int_tmpl: The template used for the integer case. + bool_tmpl: The template used for the bool case. Note: Since it does not make sense to single out `not` and keep the other @@ -110,12 +100,16 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): - return self._bool_tmpl - return self._int_tmpl + return ( + self._bool_tmpl + if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars) + else self._int_tmpl + ) -# Contains the code templates for all supported arithmetic operations. +# Maps the name of an arithmetic primitives to the code template that is used to +# generate the body of the mapped tasklet. These are used to instantiate the +# `ArithmeticOperationTranslator` objects. # fmt: off _ARITMETIC_OPERATION_TEMPLATES: Final[dict[str, str]] = { # Unary operations @@ -177,24 +171,24 @@ def write_tasklet_code( "nextafter": "__out = nextafter((__in0), (__in1))", # Ternary operations - "clamp": "__out = (__in0 if __in1 < __in0 else (__in1 if __in1 < __in2 else __in2))" + "clamp": "__out = ((__in0) if (__in1) < (__in0) else ((__in1) if (__in1) < (__in2) else (__in2)))" } -# Contains the code templates for all logical operations. -# The first one is for the integer case, the second for the bool case. +# Maps the name of a logical primitive to the two code templates (first the integer +# case and second the boolean case) used to create the body of the mapped tasklet. +# They are used to instantiate the `LogicalOperationTranslator` translators. _LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = { "or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"), "not": ("__out = ~(__in0)", "__out = not (__in0)"), "and": ("__out = (__in0) & (__in1)", "__out = (__in0) and (__in1)"), "xor": ("__out = (__in0) ^ (__in1)", "__out = (__in0) != (__in1)"), } +# fmt: on -# Create the arithmetic translators +# Instantiate the arithmetic and logical translators from the templates. for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items(): translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl)) - -# Create the logical translators. for pname, (itmpl, btmpl) in _LOGICAL_OPERATION_TEMPLATES.items(): translator.register_primitive_translator(LogicalOperationTranslator(pname, itmpl, btmpl)) diff --git a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py index 7f24160..964a2f6 100644 --- a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py +++ b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""This implements the `broadcast_in_dim` primitive.""" +"""Primitive translator for broadcasting operations.""" from __future__ import annotations @@ -28,9 +28,8 @@ class BroadcastInDimTranslator(mapped_base.MappedOperationTranslatorBase): """ Implements the `broadcast_in_dim` primitive. - The primitive is implemented through the `MappedOperationTranslatorBase` base. - Essentially it creates a copy, but also creates special Memlets that replicate - the content of the input. + Essentially creates a copy tasklet, however, the memlets are made in such a + way that some dimensions are replicated. """ def __init__(self) -> None: @@ -52,16 +51,14 @@ def make_input_memlets( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> dict[str, dace.Memlet]: - if in_var_names[0] is None: + if in_var_names[0] is None: # Broadcast a literal (scalar) to a matrix. return {} - return { - "__in0": dace.Memlet.simple( - in_var_names[0], - ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) - if eqn.params["broadcast_dimensions"] - else "0", - ) - } + subset_str = ( + ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) + if eqn.params["broadcast_dimensions"] + else "0", + ) + return {"__in0": dace.Memlet.simple(in_var_names[0], subset_str)} translator.register_primitive_translator(BroadcastInDimTranslator()) diff --git a/src/jace/translator/primitive_translators/concatenate_translator.py b/src/jace/translator/primitive_translators/concatenate_translator.py index e8bd144..1b5f679 100644 --- a/src/jace/translator/primitive_translators/concatenate_translator.py +++ b/src/jace/translator/primitive_translators/concatenate_translator.py @@ -5,14 +5,13 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the concatenation primitive.""" +"""Primitive translator for concatenation operations.""" from __future__ import annotations from typing import TYPE_CHECKING import dace -from typing_extensions import override from jace import translator, util @@ -23,65 +22,45 @@ from jax import core as jax_core -class ConcatenateTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("concatenate") +def concatenate_translator( + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ Implements the `concatenate` primitive. - It is implemented by a series of map that writes to the same access node. - It is probably the largest stretch of "written once" in the entire core. + Each source array is copied by its own map, but all maps write to the same + access node. """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "concatenate" - - @override - def __call__( - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - if any(in_var_name is None for in_var_name in in_var_names): - raise NotImplementedError("Concatenate: No literal inputs supported.") - - # Dimension along we concatenate. - cat_dim = eqn.params["dimension"] - - # Offset counter for write back. - already_copied = 0 - - # This is the access node we use for the output - # Is inside a dict for input to `add_mapped_tasklet()`. - output_nodes = {out_var_names[0]: eqn_state.add_write(out_var_names[0])} - - # Now going over each input and copying the input in the correct location - # of the output array. - for i, in_var_name in enumerate(in_var_names): - input_shape = util.get_jax_var_shape(eqn.invars[i]) - - tskl_range = [(f"__dim{d}", f"0:{dim_size}") for d, dim_size in enumerate(input_shape)] - tskl_input_access = [it_var for it_var, _ in tskl_range] - - tskl_output_access = tskl_input_access.copy() - tskl_output_access[cat_dim] = f"{tskl_output_access[cat_dim]} + {already_copied}" - - eqn_state.add_mapped_tasklet( - f"_concatenate_{out_var_names[0]}_{in_var_name}", - map_ranges=tskl_range, - inputs={"__in": dace.Memlet.simple(in_var_name, ", ".join(tskl_input_access))}, - code="__out = __in", - outputs={ - "__out": dace.Memlet.simple(out_var_names[0], ",".join(tskl_output_access)) - }, - output_nodes=output_nodes, - external_edges=True, - ) - - # Update the counter that we have copied - already_copied += input_shape[cat_dim] - - -_ = translator.register_primitive_translator(ConcatenateTranslator()) + if any(in_var_name is None for in_var_name in in_var_names): + raise NotImplementedError("Concatenate: No literal inputs supported.") + + # Access node that is used by all maps. + output_nodes = {out_var_names[0]: eqn_state.add_write(out_var_names[0])} + + cat_dim = eqn.params["dimension"] + cat_offset = 0 + for i, in_var_name in enumerate(in_var_names): + input_shape = util.get_jax_var_shape(eqn.invars[i]) + + tskl_range = [(f"__dim{d}", f"0:{dim_size}") for d, dim_size in enumerate(input_shape)] + tskl_input_access = [it_var for it_var, _ in tskl_range] + + tskl_output_access = tskl_input_access.copy() + tskl_output_access[cat_dim] = f"{tskl_output_access[cat_dim]} + {cat_offset}" + + eqn_state.add_mapped_tasklet( + f"_concatenate_{out_var_names[0]}_{in_var_name}", + map_ranges=tskl_range, + inputs={"__in": dace.Memlet.simple(in_var_name, ", ".join(tskl_input_access))}, + code="__out = __in", + outputs={"__out": dace.Memlet.simple(out_var_names[0], ",".join(tskl_output_access))}, + output_nodes=output_nodes, + external_edges=True, + ) + cat_offset += input_shape[cat_dim] diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 38ba2c2..4f7363d 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements all conditions that are supported in JAX.""" +"""Primitive translator for condition operations, i.e. scalar `if` and `switch`.""" from __future__ import annotations @@ -16,7 +16,6 @@ from jace import translator, util from jace.translator import post_translation as ptranslation -from jace.translator.primitive_translators import pjit_translator as pjit if TYPE_CHECKING: @@ -33,10 +32,12 @@ def condition_translator( eqn_state: dace.SDFGState, ) -> dace.SDFGState: """ - Implements the translation of the `cond` primitive, i.e. a scalar if. + Implements the translation of scalar conditional branches. - XLA, JAX' backend, supports two versions, one in which the selector, i.e. the - variable indicating which branch should be executed is an integer or a boolean. + This translator handles both `jax.lax.cond()` and `jax.lax.switch()` cases. + The sub expression of the branches are each translated into a separate nested + SDFG, each located in their own state. These state are then connected to the + joint state which is returned. Args: builder: The builder object of the translation. @@ -47,68 +48,62 @@ def condition_translator( eqn: The equation that should be translated. eqn_state: State into which the nested SDFG should be constructed. - Returns: - Because of the nature of this primitive, the translator has to construct - new states and will return the new SDFG state that serves as terminal state. - - Note: - This function essentially implements a C `switch` statement, however, there - is no default branch. + Notes: + According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) + the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) + an out of range selector uses the last branch. JaCe conforms to JAX semantic. + After this function the terminal state of the `builder` is unspecific. """ if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: - # XLA explicitly provides this [form of the primitive](https://openxla.org/xla/operation_semantics#conditional) - # JAX however, does not seem to use it at the moment and instead forwards it - # to the integer implementation. + # XLA explicitly provides a binary form of the primitive + # (https://openxla.org/xla/operation_semantics#conditional) JAX however, + # does not seem to use it at the moment and instead forwards it to the + # integer implementation. raise NotImplementedError("The boolean conditional primitive is not implemented.") - # To make names in the SDFG unique we use the name of the equation state + # To make names in the (nested) SDFG unique we use the name of the equation state name_pattern = eqn_state.name - # Promote all inputs to the branches to variables, this are all except the first - # which is the selection variable. - branch_input_variable_names: list[str] = pjit._promote_literals_to_constants( + # To avoid special cases promote all symbols to constants. + branch_input_variable_names: list[str] = ptranslation.promote_literals_to_constants( builder=builder, var_names=in_var_names[1:], jax_vars=eqn.invars[1:], name_pattern=name_pattern, ) + # expressions of the branches. + branches: list[jax_core.ClosedJaxpr] = eqn.params["branches"] + + # Make sure that the selection variable is a DaCe symbol. if in_var_names[0] is None: - # The selection variable is a literal, so we will now pretend it is a symbol. - # This also means that we do not need a state transition to promote the - # variable to a symbol. - selection_symbol = str(util.get_jax_literal_value(eqn.invars[0])) + literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) + selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))" selection_state = eqn_state else: - # The selection variable is an input. - # For the implementation of the condition we need to promote the selection - # variable to a symbol, for which we need an interstate edge. - # As a side effect it will update the terminal state. + # Promotion of a scalar to a symbol through a state transition. selection_variable_name = in_var_names[0] selection_symbol = f"{selection_variable_name}_symb" - selection_state = builder.append_new_state( label=f"{name_pattern}_fork", - assignments={selection_symbol: selection_variable_name}, + assignments={ + selection_symbol: f"max({len(branches)}, min(0, {selection_variable_name}[0]))" + }, prev_state=eqn_state, ) - # Now iterate through all branches, translate them and integrate them - # for each branch we will generate a dedicated state. branch_states: list[dace.SDFGState] = [] - for i, branch_jaxpr in enumerate(eqn.params["branches"]): + for i, branch_jaxpr in enumerate(branches): branch_pattern = f"{name_pattern}_{{}}_branch_{i}" branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg")) - # This will update the terminal state only the first time. + # This will update the terminal state only for the first branch branch_state = builder.append_new_state( label=branch_pattern.format("state"), condition=f"{selection_symbol} == {i}", prev_state=selection_state, ) - - # Integrating it. ptranslation.add_nested_sdfg( state=branch_state, child_ctx=branch_ctx, @@ -118,19 +113,12 @@ def condition_translator( ) branch_states.append(branch_state) - # Now we have to generate a join state that will serve as new terminal state. - # We append it to the first branch state, which is the current terminal state. - assert builder._terminal_sdfg_state is branch_states[0] - terminal_state = builder.append_new_state( - label=f"{name_pattern}_join", - prev_state=branch_states[0], - ) - for branch_state in branch_states[1:]: + join_state = builder.add_orphan_state(f"{name_pattern}__join_state") + for branch_state in branch_states: builder.sdfg.add_edge( branch_state, - terminal_state, + join_state, dace.sdfg.InterstateEdge(), ) - # We return it, because otherwise the builder will assume that `eqn_state` was used. - return terminal_state + return join_state diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py index ee05a2a..118f4e3 100644 --- a/src/jace/translator/primitive_translators/convert_element_type_translator.py +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the translator for the `convert_element_type` primitive.""" +"""Primitive translator for type casting operations.""" from __future__ import annotations @@ -28,14 +28,12 @@ class ConvertElementTypeTranslator(mapped_base.MappedOperationTranslatorBase): """ Implements the `convert_element_type` primitive. - The primitive will expand to a "copy Map", however, the Tasklet will not - simply copy the input to the output, but also perform type conversion. - However, in cases where the input type is the same as the output type, - the Tasklet will just be a copy Tasklet, that can then be removed by DaCe. + The primitive is implemented as a copy operation. However, the tasklet body + will perform the type conversion operation. Note: - This translator ignores the `new_dtype` and `weak_type` parameters of - the equation and only performs the casting based on the type of the fields. + The type to cast to id inferred from the output variable and the `new_dtype` + parameter of the equation is ignored. """ def __init__(self) -> None: @@ -56,20 +54,19 @@ def write_tasklet_code( out_dtype = util.get_jax_var_dtype(eqn.outvars[0]).type out_dtype_s: str = out_dtype.__name__ - # This is the base of the template that we use for conversion. You should notice - # that the Tasklet `__out = __in0` will fail, see commit `f5aabc3` of the - # prototype. Thus we have to do it in this way. - conv_code = "__in0" - if in_dtype == out_dtype: - # For some reason Jax sometimes adds conversions where no are needed. In - # these cases we explicitly create a copy Tasklet, which is trivial and can - # be removed by DaCe. + # JAX sometimes adds conversions which are not needed. In these cases + # we perform a copy. # TODO(phimuell): Create a Memlet instead. - return f"__out = {conv_code}" + return "__out = __in0" + + # A simple copy tasklet `__out = __in0` and rely on the implicit type + # conversion of the C++ compiler, is not enough. Due to a bug in DaCe + # (see https://github.com/spcl/dace/issues/1665) this conversion might be + # lost, thus we have to perform the conversion explicitly in the tasklet. + conv_code = "__in0" if in_dtype_s.startswith("bool"): - # Interestingly `__out = int(__in0)` will not work. conv_code = f"(1 if {conv_code} else 0)" if out_dtype_s.startswith("bool"): conv_code = f"dace.bool_({conv_code})" diff --git a/src/jace/translator/primitive_translators/copy_translator.py b/src/jace/translator/primitive_translators/copy_translator.py index 6de5ab9..5cc0d3c 100644 --- a/src/jace/translator/primitive_translators/copy_translator.py +++ b/src/jace/translator/primitive_translators/copy_translator.py @@ -5,14 +5,13 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the translator related to data movement.""" +"""Primitive translators related to data movement operations.""" from __future__ import annotations from typing import TYPE_CHECKING import dace -from typing_extensions import override from jace import translator @@ -23,70 +22,56 @@ from jax import core as jax_core -class CopyTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("copy") +def copy_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, # noqa: ARG001 # Required by the interface. + eqn_state: dace.SDFGState, +) -> None: """ Implements the `copy` primitive. - The translator is implemented by using a Memlet. + Todo: + Investigate if operation should expand to a map. """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "copy" - - def __call__( # noqa: D102 # No docstring - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, # noqa: ARG002 - eqn_state: dace.SDFGState, - ) -> None: - eqn_state.add_nedge( - eqn_state.add_read(in_var_names[0]), - eqn_state.add_write(out_var_names[0]), - dace.Memlet.from_array( - in_var_names[0], - builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string - ), - ) - - -class DevicePutTranslator(CopyTranslator): + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet.from_array( + in_var_names[0], + builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string + ), + ) + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("device_put") +def device_put_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ Implements the `device_put` primitive. - In Jax this primitive is used to copy data between the host and the device, + In JAX this primitive is used to copy data between the host and the device, in DaCe Memlets can do this. However, because of the way JaCe operates, at least in the beginning a computation is either fully on the host or on the device this copy will essentially perform a copying. """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring - return "device_put" - - @override - def __call__( # No docstring - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - if not (eqn.params["device"] is None and eqn.params["src"] is None): - raise NotImplementedError( - f"Can only copy on the host, but not from {eqn.params['src']} to {eqn.params['device']}." - ) - return super().__call__( - builder=builder, - in_var_names=in_var_names, - out_var_names=out_var_names, - eqn=eqn, - eqn_state=eqn_state, + if not (eqn.params["device"] is None and eqn.params["src"] is None): + raise NotImplementedError( + f"Can only copy on the host, but not from {eqn.params['src']} to {eqn.params['device']}." ) - - -_ = translator.register_primitive_translator(CopyTranslator()) -_ = translator.register_primitive_translator(DevicePutTranslator()) + copy_translator( + builder=builder, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 343ee15..4b58e70 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the translator for the `gather` primitive.""" +"""Primitive translator for indexing operations.""" from __future__ import annotations diff --git a/src/jace/translator/primitive_translators/iota_translator.py b/src/jace/translator/primitive_translators/iota_translator.py index ce0d99f..035caf7 100644 --- a/src/jace/translator/primitive_translators/iota_translator.py +++ b/src/jace/translator/primitive_translators/iota_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""This implements the `iota` primitive.""" +"""Primitive translator for the `iota` primitive.""" from __future__ import annotations diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py index 59bfd7e..b3b9d97 100644 --- a/src/jace/translator/primitive_translators/pjit_translator.py +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the `pjit` translator, i.e. nested Jaxpr expressions.""" +"""Primitive translator related handling nested Jaxpr operations.""" from __future__ import annotations @@ -15,7 +15,7 @@ from jax._src import sharding_impls as jax_sharding # noqa: PLC2701 [import-private-name] -from jace import translator, util +from jace import translator from jace.translator import post_translation as ptranslation @@ -24,55 +24,9 @@ from jax._src import core as jax_core -def _promote_literals_to_constants( - builder: translator.JaxprTranslationBuilder, - var_names: Sequence[str | None], - jax_vars: Sequence[jax_core.Atom], - name_pattern: str, -) -> list[str]: - """ - Promotes all literals in `var_names` to DaCe constants and add them to the SDFG. - - The function assumes that `var_names` are the SDFG variables equivalents of - `jax_vars`, as by convention `None` indicates a literal. The function will create - a constant for each literal and return `var_names` cleared of all literals. - For naming the variables the function will use `name_pattern`. - - Args: - builder: The builder that is used for translation. - var_names: Names of the SDFG variables, `None` indicates a literal. - jax_vars: The JAX variables, in the same order than `var_names`. - name_pattern: A pattern to generate a unique name for the variables. - - Todo: - Is a constant the right idea or should we generate a symbol? - """ - promoted_var_names: list[str] = [] - for i, var_name in enumerate(var_names): - if var_name is None: - promoted_var_name = f"__const_{name_pattern}_literal_promotion_{i}" - jax_var = jax_vars[i] - promoted_jace_var = util.JaCeVar.from_atom( - jax_var=jax_var, - name=promoted_var_name, - ) - builder.add_array(promoted_jace_var) - builder.sdfg.add_constant( - promoted_var_name, - util.get_jax_literal_value(jax_var), - builder.arrays[promoted_var_name], - ) - - else: - # Already an SDFG variable, so nothing to do. - promoted_var_name = var_name - promoted_var_names.append(promoted_var_name) - return promoted_var_names - - @translator.register_primitive_translator() @translator.make_primitive_translator("pjit") -def PJITTranslator( # noqa: N802 [invalid-function-name] +def pjit_translator( builder: translator.JaxprTranslationBuilder, in_var_names: Sequence[str | None], out_var_names: Sequence[str], @@ -82,13 +36,9 @@ def PJITTranslator( # noqa: N802 [invalid-function-name] """ Implements the `pjit` translator that handles nested Jaxpr. - `pjit` primitives in JAX represents nested calls, for example the body of a scan - is inside a nested Jaxpr. However, `pjit` is used to indicate that a computation - should be done on the device or on sharded memory. - - However, due to the current state and working of JaCe, this aspect is essentially - ignored and the computation is always inlined. - + `pjit` primitives in JAX represents nested calls, for example the branches of a + conditional are nested Jaxpr. However, in JAX `pjit` is also used to indicate that + a computation should be done on the device or on sharded memory. In case an input is a literal the translator will create a constant for it. Args: @@ -99,6 +49,10 @@ def PJITTranslator( # noqa: N802 [invalid-function-name] inside the parent SDFG. eqn: The equation that contains the `pjit` primitive. eqn_state: State into which the nested SDFG should be constructed. + + Note: + The translator ignores the `donated_invars`, the `keep_unused` and the + `inline` parameter and let's DaCe handle it. """ params: dict[str, Any] = eqn.params nested_jaxpr: jax_core.ClosedJaxpr = params["jaxpr"] @@ -116,22 +70,18 @@ def PJITTranslator( # noqa: N802 [invalid-function-name] # TODO(phimuell): Controlflow region and name pjit_name = params["name"] - # TODO(phimuell): Controlflow region and name - # They will introduce a feature like that to address them in optimizations. - pjit_name = params["name"] - # Name in SDFG must be unique, thus we mangle it, furthermore, we have to clean it. sdfg_name = f"pjit_{re.subn('[^a-zA-Z0-9_]', '_', pjit_name)[0]}__{'_'.join(out_var_names)}" # Ensure that all inputs are SDFG variables - final_input_names = _promote_literals_to_constants( + final_input_names = ptranslation.promote_literals_to_constants( builder=builder, var_names=in_var_names, jax_vars=eqn.invars, name_pattern=sdfg_name, ) - # Now get the translated SDFG. + # Translate the nested expression nested_context: translator.TranslationContext = builder.translate_jaxpr( jaxpr=nested_jaxpr, name=sdfg_name, diff --git a/src/jace/translator/primitive_translators/reshape_translator.py b/src/jace/translator/primitive_translators/reshape_translator.py index 241cc94..1bcbc5a 100644 --- a/src/jace/translator/primitive_translators/reshape_translator.py +++ b/src/jace/translator/primitive_translators/reshape_translator.py @@ -5,14 +5,13 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the translator for the `reshape` primitive.""" +"""Primitive translator for reshaping operations.""" from __future__ import annotations from typing import TYPE_CHECKING import dace -from typing_extensions import override from jace import translator, util @@ -23,45 +22,30 @@ from jax import core as jax_core -class ReshapeTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("reshape") +def reshape_translator( + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ - Implements the `reshape` primitive. + Implements the `reshape` primitive, through a memlet. - The current implementation uses a Memlet for this and essentially acts as - an optimization barrier. Furthermore the Jax primitive also has the optional - `dimensions` parameters which allows to permute the input, this is not - supported. + Note: + The optional `dimensions` parameters which allows to permute the input + is not supported. """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "reshape" - - @override - def __call__( - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - """ - Performs the reshaping. - - Currently a copy using a Memlet is performed. - """ - if eqn.params["dimensions"] is not None: - raise NotImplementedError("Currently 'dimensions' must be 'None'.") - eqn_state.add_nedge( - eqn_state.add_read(in_var_names[0]), - eqn_state.add_write(out_var_names[0]), - dace.Memlet( - data=in_var_names[0], - subset=", ".join(f"0:{size}" for size in util.get_jax_var_shape(eqn.invars[0])), - other_subset=", ".join(f"0:{size}" for size in eqn.params["new_sizes"]), - ), - ) - - -translator.register_primitive_translator(ReshapeTranslator()) + if eqn.params["dimensions"] is not None: + raise NotImplementedError("Currently 'dimensions' must be 'None'.") + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet( + data=in_var_names[0], + subset=", ".join(f"0:{size}" for size in util.get_jax_var_shape(eqn.invars[0])), + other_subset=", ".join(f"0:{size}" for size in eqn.params["new_sizes"]), + ), + ) diff --git a/src/jace/translator/primitive_translators/select_n_translator.py b/src/jace/translator/primitive_translators/select_n_translator.py index 51b27b3..0b9a0d1 100644 --- a/src/jace/translator/primitive_translators/select_n_translator.py +++ b/src/jace/translator/primitive_translators/select_n_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements `select_n`.""" +"""Primitive translator for select operations, i.e. generalized `np.where()`.""" from __future__ import annotations @@ -29,16 +29,13 @@ class SelectNTranslator(mapped_base.MappedOperationTranslatorBase): Implements the `select_n` primitive. The `select_n` primitive is a generalization of `np.where`, that can take an - arbitrary number of branches, which are selected by an integer predicate. + arbitrary number of cases, which are selected by an integer predicate. The behaviour is undefined if the predicate is out of bound. Note: For a better understanding this function renames its input connectors. The first one, which is the predicate, is renamed to `__cond` and the others are renamed again to `__in{i}`, starting with zero. - - Todo: - Implement the primitive as a nested SDFG. """ def __init__(self) -> None: @@ -51,11 +48,9 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - if len(in_var_names) == 3: # noqa: PLR2004 # `3` is not magic. - # This order is correct, since `False` is interpreted as `0`, which means - # the first case. DaCe seems to have some problems with bools and integer - # casting around, so we handle the bool case explicitly here. - # See also `ConvertElementTypeTranslator`. + if len(in_var_names) == 3: # noqa: PLR2004 # Ternary conditional expression. + # The order is correct, since `False` is interpreted as `0`, + # which means "the first case". return "__out = __in1 if __cond else __in0" return "\n".join( @@ -84,10 +79,9 @@ def literal_substitution( ) -> str: assert in_var_names[0] # Condition can never be a literal. for i, in_var_name in enumerate(in_var_names[1:]): - if in_var_name is not None: - continue - t_val = util.get_jax_literal_value(eqn.invars[i + 1]) - tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + if in_var_name is None: + t_val = util.get_jax_literal_value(eqn.invars[i + 1]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) return tskl_code diff --git a/src/jace/translator/primitive_translators/slicing.py b/src/jace/translator/primitive_translators/slicing.py index ae4f167..c53c3d0 100644 --- a/src/jace/translator/primitive_translators/slicing.py +++ b/src/jace/translator/primitive_translators/slicing.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements slicing.""" +"""Primitive translators for slicing operations.""" from __future__ import annotations @@ -28,12 +28,13 @@ class SlicingTranslator(mapped_base.MappedOperationTranslatorBase): """ Implements the `slice` primitive. - This is the classical slicing operation which extracts a fixed sized window - from a fixed initial position. The slicing is implemented using a partial copy. + The `slice` primitive represents the static case of slicing, i.e. a fixed + window starting from a fixed starting point. + The slicing is implemented by performing a partial copy. Note: Slices are essentially optimization barriers as they can not be fused - with Maps before them. + with Maps _before_ them. """ def __init__(self) -> None: @@ -55,7 +56,6 @@ def make_input_memlets( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> dict[str, dace.Memlet]: - """We have to add the offsets to the Memlet accesses.""" strides: Sequence[int] = ( ((1,) * len(tskl_ranges)) if eqn.params["strides"] is None else eqn.params["strides"] ) @@ -64,76 +64,71 @@ def make_input_memlets( "__in0": dace.Memlet.simple( in_var_names[0], ", ".join( - f"{start_index} + {it_idx} * {stride}" + f"{start_index} + ({it_idx} * {stride})" for (it_idx, _), start_index, stride in zip(tskl_ranges, start_indices, strides) ), ) } -class DynamicSlicingTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("dynamic_slice") +def dynamic_slicing_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ Implements the `dynamic_slice` primitive. - [Dynamic slicing](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) - performs a slicing of a _fixed_ window, but the start of the window is - not fix, instead it is passed by variables. Furthermore, (as it is in Jax), - if the window would overrun the start indexes are adjusted. + Dynamic slicing (see: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) + performs a slicing of a _fixed_ window, but the start of the window is defined + through some input variables. Furthermore, if the window would overrun the + start indexes are adjusted. Todo: - Prevent that the modified start indexes are promoted to symbols, to ensure mergability. """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "dynamic_slice" - - @override - def __call__( - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - assert in_var_names[0] - assert len(in_var_names) == len(util.get_jax_var_shape(eqn.invars[0])) + 1 - - # This is the sizes of the slice window. - window_sizes: Sequence[int] = eqn.params["slice_sizes"] - - # Maps the variable name, that stores the start index of the window in one - # dimensions to the access node, that holds the value. The variable name - # is also used as dynamic range offset. - # Only present if the index is not a literal. - in_access: dict[str, dace.nodes.AccessNode] = {} - - # Name of the variable from where we get the start index of the window - # or the value itself, if it is a literal; in the order of the dimension. - # If the value is `None` then the literal was not yet processed. - window_start_indices: list[str | None] = list(in_var_names[1:]) - - # We will always adapt the start indexes and not check if it is needed. - for dim, (window_start_index, dim_size, window_size) in enumerate( - zip(window_start_indices, util.get_jax_var_shape(eqn.invars[0]), window_sizes) - ): - if window_start_index is None: - # Jax does not adjust the literals on its own - raw_window_start = int(util.get_jax_literal_value(eqn.invars[dim + 1])) # type: ignore[arg-type] # type confusion - adjusted_window_start = min(dim_size, raw_window_start + window_size) - window_size - window_start_indices[dim] = str(adjusted_window_start) - continue - - # We do not use a symbol for the start of the window but a Tasklet, as - # a symbol would need an interstage edge, which is an optimization barrier. + assert in_var_names[0] + assert len(in_var_names) == len(util.get_jax_var_shape(eqn.invars[0])) + 1 + + window_sizes: Sequence[int] = eqn.params["slice_sizes"] + + # Maps the variable name, that stores the _adjusted_ start index of the window + # of a dimension to the access node that holds the value. Needed to ensure the + # correct order of computation. + in_access: dict[str, dace.nodes.AccessNode] = {} + + # Name of the variables (DaCe arrays) from where we get the start index of the + # window or the value itself if it is a literal (`None` means not yet processed). + # The first input argument is always the array we slice from. + window_start_indices: list[str | None] = list(in_var_names[1:]) + + for dim, (window_start_index, dim_size, window_size) in enumerate( + zip(window_start_indices, util.get_jax_var_shape(eqn.invars[0]), window_sizes) + ): + if window_start_index is None: + # The start is a literal value. + # Jax does not adjust the literals on its own so we have to do it. + raw_window_start = int(util.get_jax_literal_value(eqn.invars[dim + 1])) # type: ignore[arg-type] # type confusion + adjusted_window_start = min(dim_size, raw_window_start + window_size) - window_size + window_start_indices[dim] = str(adjusted_window_start) + + else: tasklet = dace.nodes.Tasklet( label=f"adjustment_of_slice_start_{window_start_index}_for_{out_var_names[0]}", inputs={"unadjusted_start_idx": None}, outputs={"adjusted_start_idx": None}, code=f"adjusted_start_idx = min(unadjusted_start_idx + {window_size}, {dim_size}) - {window_size}", ) + # Name of the variable holding the (adjusted) start of the window. + # It is important that this name is also used for the dynamic map range + # symbols created below. This prevents some errors if DaCe promotes them + # to symbols and does not handle the DMR correctly. + # (see https://github.com/spcl/dace/issues/1665) new_start_idx_var_name = builder.add_array( eqn.invars[dim + 1], name_prefix="__jace_adapted_start_idx_" ) @@ -153,46 +148,40 @@ def __call__( None, dace.Memlet.simple(new_start_idx_var_name, "0"), ) - # Update the name of the start index, and store the access - # node for later use. window_start_indices[dim] = new_start_idx_var_name in_access[new_start_idx_var_name] = new_start_idx_acc - tskl_ranges: list[tuple[str, str]] = [ - (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) - ] - - memlet_accesses: list[str] = [] - - for (it_var, _), offset_symbol_name in zip(tskl_ranges, window_start_indices): - assert offset_symbol_name is not None - memlet_accesses.append(f"{it_var} + {offset_symbol_name}") - - tskl_input = dace.Memlet.simple(in_var_names[0], ", ".join(memlet_accesses)) - tskl_output = dace.Memlet.simple( - out_var_names[0], ", ".join(name for name, _ in tskl_ranges) + tskl_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + tskl_input = dace.Memlet.simple( + in_var_names[0], + ", ".join( + f"{it_var} + {offset_symbol_name}" + for (it_var, _), offset_symbol_name in zip(tskl_ranges, window_start_indices) + ), + ) + tskl_output = dace.Memlet.simple(out_var_names[0], ", ".join(name for name, _ in tskl_ranges)) + _, map_entry, _ = eqn_state.add_mapped_tasklet( + name=f"dynamic_slice_{out_var_names[0]}", + map_ranges=tskl_ranges, + inputs={"__in": tskl_input}, + code="__out = __in", + outputs={"__out": tskl_output}, + external_edges=True, + ) + + # Create the dynamic ranges, i.e. read the start indexes for the window + # from variable and create symbols out of it, without an interstate edge. + for window_start_index_name, windows_start_access_node in in_access.items(): + eqn_state.add_edge( + windows_start_access_node, + None, + map_entry, + window_start_index_name, + dace.Memlet.simple(window_start_index_name, "0"), ) - _, map_entry, _ = eqn_state.add_mapped_tasklet( - name=f"{self.primitive}_{out_var_names[0]}", - map_ranges=tskl_ranges, - inputs={"__in": tskl_input}, - code="__out = __in", - outputs={"__out": tskl_output}, - external_edges=True, - ) - - # Creating the inputs for the dynamic map ranges. We have to use the same - # access nodes as above, to ensure a single order of computation. - for window_start_index_name, windows_start_access_node in in_access.items(): - eqn_state.add_edge( - windows_start_access_node, - None, - map_entry, - window_start_index_name, - dace.Memlet.simple(window_start_index_name, "0"), - ) - map_entry.add_in_connector(window_start_index_name) + map_entry.add_in_connector(window_start_index_name) translator.register_primitive_translator(SlicingTranslator()) -translator.register_primitive_translator(DynamicSlicingTranslator()) diff --git a/src/jace/translator/primitive_translators/squeeze_translator.py b/src/jace/translator/primitive_translators/squeeze_translator.py index de6f1f4..dbaa548 100644 --- a/src/jace/translator/primitive_translators/squeeze_translator.py +++ b/src/jace/translator/primitive_translators/squeeze_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements the `squeeze` primitive.""" +"""Primitive translator for squeezing (the removal of size 1 dimensions) operations.""" from __future__ import annotations From cb600d397aaa8d12a9bb82467457a522aab5d4c2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 12:37:20 +0200 Subject: [PATCH 5/9] Refactored the gather translator. It is now better confiugured. --- .../primitive_translators/__init__.py | 4 +- .../gather_translator.py | 333 +++++++++--------- 2 files changed, 160 insertions(+), 177 deletions(-) diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py index 757743e..f019964 100644 --- a/src/jace/translator/primitive_translators/__init__.py +++ b/src/jace/translator/primitive_translators/__init__.py @@ -17,7 +17,7 @@ from .conditions import condition_translator from .convert_element_type_translator import ConvertElementTypeTranslator from .copy_translator import copy_translator, device_put_translator -from .gather_translator import GatherTranslator +from .gather_translator import gather_translator from .iota_translator import IotaTranslator from .pjit_translator import pjit_translator from .reshape_translator import reshape_translator @@ -30,7 +30,6 @@ "ArithmeticOperationTranslator", "BroadcastInDimTranslator", "ConvertElementTypeTranslator", - "GatherTranslator", "IotaTranslator", "LogicalOperationTranslator", "SelectNTranslator", @@ -41,6 +40,7 @@ "copy_translator", "device_put_translator", "dynamic_slicing_translator", + "gather_translator", "pjit_translator", "reshape_translator", ] diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 4b58e70..8d0f60f 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -13,7 +13,6 @@ import dace from jax import lax as jax_lax -from typing_extensions import override from jace import translator, util @@ -24,188 +23,172 @@ from jax import core as jax_core -class GatherTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("gather") +def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ - Garther Translator. + Implements the `gather` primitive. - The gather operation extracts patches of a certain size, known as `slice_size`, - from an array, called source or input array. Where these patches starts is - given by another array, the index array. + These primitive is used to implement the `array.at[...].get()` access. In the + end the primitive extracts patches/windows of a certain size, known as + `slice_size`, from an array, which is called source or input array. The start + points of these windows are given by another array, the so called index array. + + Args: + builder: The builder object that is active. + in_var_names: The names of the input variables, the first array is + assumed as source array and the second is the index array. + out_var_names: The names of the output variables. + eqn: The equation to translate. + eqn_state: The state in which we put the extraction. See Also: https://www.tensorflow.org/xla/operation_semantics#gather https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "gather" - - @override - def __call__( # noqa: PLR0914, PLR0915 # Just ported from the prototype, cleanup postponed. - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - """ - Performs the gather operation. - - Args: - builder: The builder object that is active. - in_var_names: The names of the input variables, the first array is - assumed as source array and the second is the index array. - out_var_names: The names of the output variables. - eqn: The equation to translate. - eqn_state: The state in which we put the extraction. - """ - assert len(eqn.invars) == 2 # noqa: PLR2004 # XLA supports more inputs. - - out_name = out_var_names[0] - out_shape = util.get_jax_var_shape(eqn.outvars[0]) - - src_name = in_var_names[0] - src_shape = util.get_jax_var_shape(eqn.invars[0]) - - idx_name = in_var_names[1] - idx_shape = util.get_jax_var_shape(eqn.invars[1]) - - dimension_numbers = eqn.params["dimension_numbers"] - offset_dims: Sequence[int] = dimension_numbers.offset_dims - collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims - start_index_map: Sequence[int] = dimension_numbers.start_index_map - slice_sizes: Sequence[int] = eqn.params["slice_sizes"] - mode: jax_lax.GatherScatterMode = eqn.params["mode"] - assert len(start_index_map) == idx_shape[-1] - - if mode != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: - raise NotImplementedError(f"The mode {mode} is not implemented.") - - # Over these dimensions the copy of the patches goes. - batch_dims = tuple(d for d in range(len(out_shape)) if d not in offset_dims) - - # Every batch dimension is associated with one dimension of of the index - # array, but there is always one dimension more in the index array. This - # dimension contains the start indexes of the slice, if there is only - # one index that should be loaded is not strictly necessary, but Jax - # (currently adds) it implicitly, probably to make life easier. - if (len(batch_dims) + 1) != len(idx_shape): - raise ValueError( - f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." - ) - - # These are the dimensions (of the input) for which a map index is created. - # Note that we exclude collapsed dimensions here. - src_dim_with_map_idx = tuple( - dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + out_name = out_var_names[0] + out_shape = util.get_jax_var_shape(eqn.outvars[0]) + src_name = in_var_names[0] + src_shape = util.get_jax_var_shape(eqn.invars[0]) + idx_name = in_var_names[1] + idx_shape = util.get_jax_var_shape(eqn.invars[1]) + dimension_numbers = eqn.params["dimension_numbers"] + + if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: + raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") + + # This is the size of the slice window that is copied. Its length equal the rank + # of the source array, dimensions that should not be copied are listed in + # `collapsed_slice_dims`. + slice_sizes: Sequence[int] = eqn.params["slice_sizes"] + collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims + not_collapsed_slice_dims = tuple( + dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + ) + assert len(slice_sizes) == len(src_shape) + + # The batch dimensions are used to iterate through the slice windows, thus access + # the index array, with the exception of the last dimension, see below. + # NOTE: In pure XLA this last dimension might not be present, however, JAX + # adds it and our implementation relies on it. + batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims) + if (len(batch_dims) + 1) != len(idx_shape): + raise ValueError( + f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." ) - assert len(src_dim_with_map_idx) == len(offset_dims) - - # The final map is the composition of two loops. The first map iterates over - # the index array, except the last dimension, and is used to "copy" the - # different patches from the source to the output array. These map parameters - # follow the pattern `__i{out_name}_gather{bd}`, where `bd` is a batch - # dimension. These variables are used to access the index array. - # The second loop performs the actual copy of the slices. For these - # the variables `__i{i}` is used were, these are known as offset - # dimensions. - # What is a bit difficult, that the actual access/dereferencing of the source - # array is done within the tasklet. - - # Access pattern of the source array _within_ the tasklet. - src_access_pattern: list[str] = [] - - # These are the map ranges for the coying of the slicing. - slice_map_ranges: list[tuple[str, str]] = [] - - # Compute the access pattern within the tasklet. - # As a side effect we also compute the map ranges, but only for the slices. - for dim, slice_size in enumerate(slice_sizes): - # Order is important! - if dim not in start_index_map: - # This dimension is fully copied - slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) - src_access_pattern.append(slice_map_ranges[-1][0]) - assert dim in src_dim_with_map_idx - assert slice_size == src_shape[dim] - - elif dim in collapsed_slice_dims: - # This dimension is only partially copied, however, since the - # dimension is collapsed, only a single element is copied that - # comes from the index array. - src_access_pattern.append(f"__gather_{dim}") - - else: - # This dimension is partially copied, but is _not colapsed_, we need - # a map index to copy the range. However, there is also an offset - # that is involved from copying. - slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) - src_access_pattern.append(f"__gather_{dim} + {slice_map_ranges[-1][0]}") - assert dim in src_dim_with_map_idx - assert slice_size <= src_shape[dim] - - # These are the map variable that go over the index array. - patch_loop_vars = tuple(f"__i{out_name}_gather{bd}" for bd in batch_dims) - patch_map_ranges = [ - (map_param, f"0:{patch_loop_bound}") - for map_param, patch_loop_bound in zip(patch_loop_vars, idx_shape[:-1]) - ] - - # Creating the input memlet that allows us to access the source array from - # inside the tasklet and make it accessible through the name `__arr`. At - # this point it is not possible to tell where we access, because we are - # missing a index variables, they will only be accessible inside the - # tasklet (see below), however, we know that we will access only one - # element from the array. - tasklet_inputs: dict[str, dace.Memlet] = { - "__arr": dace.Memlet.simple( - data=src_name, - subset_str=", ".join(f"0:{size}" for size in src_shape), - num_accesses=1, + + # The last dimension is special, as it contains the actual start point for the + # slice window when the dimension is only partially copied. The `start_index_map` + # associates each position element in the last dimension with the corresponding + # dimension of the source array. + start_index_map: Sequence[int] = dimension_numbers.start_index_map + assert len(start_index_map) == idx_shape[-1] + + # The final map has two parts. The first part iterates through all the slice + # windows that are given through the index array (except last dimension). + # If a dimension is not fully copied then the start index of the window is + # given through the elements of the last dimensions of the index array. + # Map variables that are used for this use the pattern `__i{out_name}_gather{bd}`. + # The second loop is used to copy the slice windows themselves, their map + # variables follow the pattern `__i{i}`. + + # Because the offsets of the slice window (which are given by the elements of + # the last dimension of the index array) are variables and not symbols, it + # can not be included in the memlets. Instead we generate an tasklet that + # performs an indirect access and get all elements of the last dimension of the + # index array (with the names `__gather_{dim}`), together with the full source + # array as input. + + # Access pattern of the source array _inside_ the tasklet. + src_access_pattern: list[str] = [] + + # The ranges of the second part implicit loop (the one that copies the windows). + inside_window_map_ranges: list[tuple[str, str]] = [] + + for dim, slice_size in enumerate(slice_sizes): + # Order is important! + if dim not in start_index_map: + # This dimension is fully copied + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(inside_window_map_ranges[-1][0]) + assert dim in not_collapsed_slice_dims + assert dim not in batch_dims + + elif dim in collapsed_slice_dims: + # This dimension is only partially copied, but because it is collapsed, + # only a single element is copied. Thus the offset is only given by the + # index array. + src_access_pattern.append(f"__gather_{dim}") + assert dim in batch_dims + + else: + # This dimension is partially copied, but _not colapsed_. This creates a + # slice index and the offset (of a single element) is given by the static + # start of the window and the current position inside of the window. + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") + assert dim in batch_dims + assert dim in not_collapsed_slice_dims + + # These are the map variables that are associated to the first implicit loop (the + # iteration over the index array, excluding the last dimension). + batch_map_ranges = [ + (f"__i{out_name}_gather{batch_dim}", f"0:{batch_loop_bound}") + for batch_dim, batch_loop_bound in zip(batch_dims, idx_shape[:-1]) + ] + assert len(batch_map_ranges) + len(inside_window_map_ranges) == len(out_shape) + + tasklet_inputs: dict[str, dace.Memlet] = {} + + # We need to pass the full array into the tasklet, however, we know that we + # will read only one element. + tasklet_inputs["__arr"] = dace.Memlet.simple( + data=src_name, + subset_str=", ".join(f"0:{size}" for size in src_shape), + num_accesses=1, + ) + + # The static offset of the slice window, which is given through the elements + # of the last dimensions of the index array, for every element in that dimension + # there is an input. + for i, dim in enumerate(start_index_map): + tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( + data=idx_name, + subset_str=( + ", ".join(batch_loop_var for batch_loop_var, _ in batch_map_ranges) + f", {i}" ), - } - - # Now we are creating the memlets that access the index array. - for i, dim in enumerate(start_index_map): - tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( - data=idx_name, subset_str=(", ".join(patch_loop_vars) + f", {i}") - ) - - # The tasklet code. - tasklet_code = "__out = __arr[" + ", ".join(src_access_pattern) + "]" - - # Now we generate the output memlet. - outpt_access_pattern: list[str] = [] - dim_counter = 0 - for dim in range(len(out_shape)): - if dim in batch_dims: - # This is a batch dimension, thus a loop variable is used for it. - patch_loop_var = patch_loop_vars[batch_dims.index(dim)] - outpt_access_pattern.append(str(patch_loop_var)) - - else: - # This is a dimension for copying the slices. - assert dim_counter <= len(src_dim_with_map_idx) - associated_map_idx = src_dim_with_map_idx[dim_counter] - dim_counter += 1 - outpt_access_pattern.append(f"__i{associated_map_idx}") - assert dim_counter == len(src_dim_with_map_idx) - - tasklet_outputs: dict[str, dace.Memlet] = { - "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(outpt_access_pattern)) - } - assert len(patch_map_ranges) + len(slice_map_ranges) == len(out_shape) - - eqn_state.add_mapped_tasklet( - name=f"_gather_map_{out_name}", - map_ranges=patch_map_ranges + slice_map_ranges, - inputs=tasklet_inputs, - code=tasklet_code, - outputs=tasklet_outputs, - external_edges=True, ) - -_ = translator.register_primitive_translator(GatherTranslator()) + # The output shape is given by the combination of the non collapsed slice sizes + # and the index array (without the last dimension) with some permutation. + # Note that the relative order of slice sizes can not be changed, but they + # might be interleaved with the batch variables. + output_memlet_pattern: list[str] = [] + dim_counter = 0 + for dim in range(len(out_shape)): + if dim in batch_dims: + batch_loop_var = batch_map_ranges[batch_dims.index(dim)][0] + output_memlet_pattern.append(str(batch_loop_var)) + + else: + associated_map_idx = not_collapsed_slice_dims[dim_counter] + dim_counter += 1 + output_memlet_pattern.append(f"__i{associated_map_idx}") + assert dim_counter == len(not_collapsed_slice_dims) + + eqn_state.add_mapped_tasklet( + name=f"_gather_map_{out_name}", + map_ranges=batch_map_ranges + inside_window_map_ranges, + inputs=tasklet_inputs, + code="__out = __arr[" + ", ".join(src_access_pattern) + "]", + outputs={ + "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(output_memlet_pattern)) + }, + external_edges=True, + ) From c29fc0dbf0377e67baea172d282e6b1da0995c7a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 14:15:02 +0200 Subject: [PATCH 6/9] Some more corrections. Now let's test if it works. --- .../mapped_operation_base_translator.py | 50 +++++++-------- .../arithmetic_logical_translators.py | 16 ++--- .../concatenate_translator.py | 11 +++- .../primitive_translators/conditions.py | 25 ++++---- .../convert_element_type_translator.py | 2 +- .../primitive_translators/copy_translator.py | 27 ++++++-- .../gather_translator.py | 64 ++++++++++--------- .../primitive_translators/pjit_translator.py | 17 +++-- .../reshape_translator.py | 18 +++++- .../select_n_translator.py | 2 +- 10 files changed, 136 insertions(+), 96 deletions(-) diff --git a/src/jace/translator/mapped_operation_base_translator.py b/src/jace/translator/mapped_operation_base_translator.py index 17a5c35..508ad13 100644 --- a/src/jace/translator/mapped_operation_base_translator.py +++ b/src/jace/translator/mapped_operation_base_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Module containing all translators related to arithmetic logical operations.""" +"""Module implementing the `MappedOperationTranslatorBase` helper class.""" from __future__ import annotations @@ -37,8 +37,9 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator): ``` where `__in*` are the connector names of the Tasklet and `__out` is the output connector. For problems such as this, the SDFG API provides the - `SDFGState.add_mapped_tasklet()` function, however, because it is very low - level and very verbose to use, this class acts as a convenience wrapper around it. + `SDFGState.add_mapped_tasklet()` function. However, because the function + operates on a very low level and is very verbose to use, this class acts + as a convenience wrapper around it. To use this class a user has to define the abstract `write_tasklet_code()` method. This function generates the entire code that should be put into the Tasklet, @@ -160,8 +161,8 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need in_var_names: The list of SDFG variables used as input, `None` if literal. eqn: The equation object. """ - out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output - out_rank = len(out_shp) + out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0])) + out_rank = len(out_shape) if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars): raise NotImplementedError( f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! " @@ -170,29 +171,26 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need # Now we will generate the input Memlets. tskl_inputs: dict[str, dace.Memlet] = {} - for i, (in_var_name, inp_shp) in enumerate( + for i, (in_var_name, in_shape) in enumerate( zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars)) ): - if in_var_name is None: # Input is a literal: No Memlet needed - continue - - if inp_shp == (): # Scalars - tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar - continue - - # We might have to do broadcasting. - # We ensured that input and output have the same rank (JAX is doing that - # for us). So we must do broadcasting, i.e. replicating that input - # dimension, if its size is 1. We threat the case where the output has - # size 1 in that dimension as broadcasting as well. - dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1] - tskl_inputs[f"__in{i}"] = dace.Memlet.simple( - in_var_name, - ", ".join( - ("0" if i in dims_to_bcast else it_var) - for i, (it_var, _) in enumerate(tskl_ranges) - ), - ) + if in_var_name is None: + pass + + elif in_shape == (): + tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") + + else: + dims_to_bcast = [ + dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1 + ] + tskl_inputs[f"__in{i}"] = dace.Memlet.simple( + in_var_name, + ", ".join( + ("0" if i in dims_to_bcast else it_var) + for i, (it_var, _) in enumerate(tskl_ranges) + ), + ) return tskl_inputs def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it. diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py index 667e1ac..7cf321f 100644 --- a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -100,14 +100,12 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - return ( - self._bool_tmpl - if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars) - else self._int_tmpl - ) + if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): + return self._bool_tmpl + return self._int_tmpl -# Maps the name of an arithmetic primitives to the code template that is used to +# Maps the name of an arithmetic JAX primitive to the code template that is used to # generate the body of the mapped tasklet. These are used to instantiate the # `ArithmeticOperationTranslator` objects. # fmt: off @@ -175,9 +173,9 @@ def write_tasklet_code( } -# Maps the name of a logical primitive to the two code templates (first the integer -# case and second the boolean case) used to create the body of the mapped tasklet. -# They are used to instantiate the `LogicalOperationTranslator` translators. +# Maps the name of a logical primitive to the two code templates, first the integer +# case and second the boolean case, that are used to create the body of the mapped +# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators. _LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = { "or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"), "not": ("__out = ~(__in0)", "__out = not (__in0)"), diff --git a/src/jace/translator/primitive_translators/concatenate_translator.py b/src/jace/translator/primitive_translators/concatenate_translator.py index 1b5f679..b327bde 100644 --- a/src/jace/translator/primitive_translators/concatenate_translator.py +++ b/src/jace/translator/primitive_translators/concatenate_translator.py @@ -25,7 +25,7 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("concatenate") def concatenate_translator( - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, @@ -36,6 +36,15 @@ def concatenate_translator( Each source array is copied by its own map, but all maps write to the same access node. + + Args: + builder: The builder object of the translation; unused. + in_var_names: The SDFG variables used an input arguments in order as they + should be concatenated. + out_var_names: Names of SDFG variables that should be used as outputs. + eqn: The equation that should be translated, the concatenation dimensions + is read from the `dimension` parameter. + eqn_state: State into which the nested SDFG should be constructed. """ if any(in_var_name is None for in_var_name in in_var_names): raise NotImplementedError("Concatenate: No literal inputs supported.") diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 4f7363d..6e37a7a 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -41,18 +41,19 @@ def condition_translator( Args: builder: The builder object of the translation. - in_var_names: The SDFG variables used an input arguments. First is the index, - the variable that selects the branch, the remaining ones are passed as - inputs to the branches. + in_var_names: The SDFG variables used an input arguments. First is the + selection variable. The remaining ones are passed to the branches as + inputs. out_var_names: Names of SDFG variables that should be used as outputs. eqn: The equation that should be translated. eqn_state: State into which the nested SDFG should be constructed. Notes: - According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) - the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) - an out of range selector uses the last branch. JaCe conforms to JAX semantic. - After this function the terminal state of the `builder` is unspecific. + - According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) + the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) + an out of range selector uses the last branch. JaCe conforms to JAX + semantic. + - After this function the terminal state of the `builder` is unspecific. """ if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: # XLA explicitly provides a binary form of the primitive @@ -61,7 +62,7 @@ def condition_translator( # integer implementation. raise NotImplementedError("The boolean conditional primitive is not implemented.") - # To make names in the (nested) SDFG unique we use the name of the equation state + # Used as prefix to give all additional states/variables a unique name. name_pattern = eqn_state.name # To avoid special cases promote all symbols to constants. @@ -80,9 +81,7 @@ def condition_translator( literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))" selection_state = eqn_state - else: - # Promotion of a scalar to a symbol through a state transition. selection_variable_name = in_var_names[0] selection_symbol = f"{selection_variable_name}_symb" selection_state = builder.append_new_state( @@ -93,12 +92,15 @@ def condition_translator( prev_state=eqn_state, ) + # Translate the subbranches, the branches are all connected from `selection_state`. branch_states: list[dace.SDFGState] = [] for i, branch_jaxpr in enumerate(branches): branch_pattern = f"{name_pattern}_{{}}_branch_{i}" branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg")) - # This will update the terminal state only for the first branch + # The first time it is called it will update the builder's terminal state + # but since we will return the join state it will be updated later. But + # until then the terminal state of the builder is invalid. branch_state = builder.append_new_state( label=branch_pattern.format("state"), condition=f"{selection_symbol} == {i}", @@ -113,6 +115,7 @@ def condition_translator( ) branch_states.append(branch_state) + # Connect all branch states to the join state join_state = builder.add_orphan_state(f"{name_pattern}__join_state") for branch_state in branch_states: builder.sdfg.add_edge( diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py index 118f4e3..a9f179c 100644 --- a/src/jace/translator/primitive_translators/convert_element_type_translator.py +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -56,7 +56,7 @@ def write_tasklet_code( if in_dtype == out_dtype: # JAX sometimes adds conversions which are not needed. In these cases - # we perform a copy. + # make a copy out of it. # TODO(phimuell): Create a Memlet instead. return "__out = __in0" diff --git a/src/jace/translator/primitive_translators/copy_translator.py b/src/jace/translator/primitive_translators/copy_translator.py index 5cc0d3c..9e0d2d1 100644 --- a/src/jace/translator/primitive_translators/copy_translator.py +++ b/src/jace/translator/primitive_translators/copy_translator.py @@ -28,21 +28,31 @@ def copy_translator( builder: translator.JaxprTranslationBuilder, in_var_names: Sequence[str | None], out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, # noqa: ARG001 # Required by the interface. + eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] # Required by the interface. eqn_state: dace.SDFGState, ) -> None: """ Implements the `copy` primitive. + The copy is implemented by creating a memlet between the source and destination. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated; unused. + eqn_state: State into which the nested SDFG should be constructed. + Todo: Investigate if operation should expand to a map. """ + assert in_var_names[0] is not None eqn_state.add_nedge( eqn_state.add_read(in_var_names[0]), eqn_state.add_write(out_var_names[0]), dace.Memlet.from_array( in_var_names[0], - builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string + builder.arrays[in_var_names[0]], ), ) @@ -60,9 +70,16 @@ def device_put_translator( Implements the `device_put` primitive. In JAX this primitive is used to copy data between the host and the device, - in DaCe Memlets can do this. However, because of the way JaCe operates, at - least in the beginning a computation is either fully on the host or on the - device this copy will essentially perform a copying. + in DaCe only memlets can do this. However, because of the way JaCe (currently) + operates (a computation is either fully on the host or on GPU), the `device_put` + primitive essentially decays to a copy. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated. + eqn_state: State into which the nested SDFG should be constructed. """ if not (eqn.params["device"] is None and eqn.params["src"] is None): raise NotImplementedError( diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 8d0f60f..daacb56 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -26,7 +26,7 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("gather") def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further. - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, @@ -64,8 +64,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") # This is the size of the slice window that is copied. Its length equal the rank - # of the source array, dimensions that should not be copied are listed in - # `collapsed_slice_dims`. + # of the source array, dimensions that are excluded from copying are listed + # in `collapsed_slice_dims`. slice_sizes: Sequence[int] = eqn.params["slice_sizes"] collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims not_collapsed_slice_dims = tuple( @@ -73,34 +73,38 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any ) assert len(slice_sizes) == len(src_shape) - # The batch dimensions are used to iterate through the slice windows, thus access - # the index array, with the exception of the last dimension, see below. - # NOTE: In pure XLA this last dimension might not be present, however, JAX - # adds it and our implementation relies on it. + # The batch dimensions are used to iterate through the different slice windows + # (not inside them) thus they access the index array, with the exception of the + # last dimension, see below. + # NOTE: In pure XLA this last dimension is in certain cases optional, however, + # JAX adds it and our implementation relies on it. batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims) if (len(batch_dims) + 1) != len(idx_shape): raise ValueError( f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." ) - # The last dimension is special, as it contains the actual start point for the - # slice window when the dimension is only partially copied. The `start_index_map` - # associates each position element in the last dimension with the corresponding + # The last dimension of the index array is special, as it contains the actual + # start point for the slice windows when the dimension is only partially copied. + # Thus the last dimension must be seen as a list of start indexes and the other + # dimensions are used to enumerate the slice windows. The `start_index_map` + # associates each position in the last dimension with the corresponding # dimension of the source array. start_index_map: Sequence[int] = dimension_numbers.start_index_map assert len(start_index_map) == idx_shape[-1] - # The final map has two parts. The first part iterates through all the slice - # windows that are given through the index array (except last dimension). - # If a dimension is not fully copied then the start index of the window is - # given through the elements of the last dimensions of the index array. - # Map variables that are used for this use the pattern `__i{out_name}_gather{bd}`. - # The second loop is used to copy the slice windows themselves, their map - # variables follow the pattern `__i{i}`. + # The iteration variable of the final map can be divided into two parts or + # categories. The first part iterates through all the slice windows that are + # given through the index array. If a dimension is not fully copied then the + # start index of the window is given through the elements of the last dimensions + # of the index array. Map variables that are used for this use the pattern + # `__i{out_name}_gather{bd}`. The second kind of variables are used to copy the + # content of the slice windows themselves, these map variables follow the + # pattern `__i{i}`. # Because the offsets of the slice window (which are given by the elements of - # the last dimension of the index array) are variables and not symbols, it - # can not be included in the memlets. Instead we generate an tasklet that + # the last dimension of the index array) are variables and not symbols, they + # can not be included in the memlets. Instead we generate a tasklet that # performs an indirect access and get all elements of the last dimension of the # index array (with the names `__gather_{dim}`), together with the full source # array as input. @@ -108,7 +112,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any # Access pattern of the source array _inside_ the tasklet. src_access_pattern: list[str] = [] - # The ranges of the second part implicit loop (the one that copies the windows). + # The map variables and their ranges of the second part implicit loop; the one + # that copy the content inside the window. inside_window_map_ranges: list[tuple[str, str]] = [] for dim, slice_size in enumerate(slice_sizes): @@ -123,14 +128,14 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any elif dim in collapsed_slice_dims: # This dimension is only partially copied, but because it is collapsed, # only a single element is copied. Thus the offset is only given by the - # index array. + # what we read from the index array. src_access_pattern.append(f"__gather_{dim}") assert dim in batch_dims else: - # This dimension is partially copied, but _not colapsed_. This creates a - # slice index and the offset (of a single element) is given by the static - # start of the window and the current position inside of the window. + # This dimension is partially copied, but _not colapsed_. This the element + # that is read depends on the (static) offset of this window and the + # current position within the slicing window. inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") assert dim in batch_dims @@ -154,9 +159,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any num_accesses=1, ) - # The static offset of the slice window, which is given through the elements - # of the last dimensions of the index array, for every element in that dimension - # there is an input. + # The static offsets of the slice window, are given through the elements of the + # last dimensions of the index array. for i, dim in enumerate(start_index_map): tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( data=idx_name, @@ -165,10 +169,10 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any ), ) - # The output shape is given by the combination of the non collapsed slice sizes + # The output shape is given by the combination of the not collapsed slice sizes # and the index array (without the last dimension) with some permutation. - # Note that the relative order of slice sizes can not be changed, but they - # might be interleaved with the batch variables. + # While the relative order of slice window does not change, `start_index_map` + # already applied a permutation, it might be interleaved with batch dimensions. output_memlet_pattern: list[str] = [] dim_counter = 0 for dim in range(len(out_shape)): diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py index b3b9d97..43bc3ea 100644 --- a/src/jace/translator/primitive_translators/pjit_translator.py +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -11,7 +11,7 @@ import re from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from jax._src import sharding_impls as jax_sharding # noqa: PLC2701 [import-private-name] @@ -54,13 +54,12 @@ def pjit_translator( The translator ignores the `donated_invars`, the `keep_unused` and the `inline` parameter and let's DaCe handle it. """ - params: dict[str, Any] = eqn.params - nested_jaxpr: jax_core.ClosedJaxpr = params["jaxpr"] - in_shardings = params["in_shardings"] - out_shardings = params["out_shardings"] - _ = params["donated_invars"] # Always ignored - _ = params["keep_unused"] - _ = params["inline"] + nested_jaxpr: jax_core.ClosedJaxpr = eqn.params["jaxpr"] + in_shardings = eqn.params["in_shardings"] + out_shardings = eqn.params["out_shardings"] + _ = eqn.params["donated_invars"] # Always ignored + _ = eqn.params["keep_unused"] + _ = eqn.params["inline"] if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings): raise NotImplementedError("Currently 'pjit' does not support sharding in its input.") @@ -68,7 +67,7 @@ def pjit_translator( raise NotImplementedError("Currently 'pjit' does not support sharding in its output.") # TODO(phimuell): Controlflow region and name - pjit_name = params["name"] + pjit_name = eqn.params["name"] # Name in SDFG must be unique, thus we mangle it, furthermore, we have to clean it. sdfg_name = f"pjit_{re.subn('[^a-zA-Z0-9_]', '_', pjit_name)[0]}__{'_'.join(out_var_names)}" diff --git a/src/jace/translator/primitive_translators/reshape_translator.py b/src/jace/translator/primitive_translators/reshape_translator.py index 1bcbc5a..79b9bb0 100644 --- a/src/jace/translator/primitive_translators/reshape_translator.py +++ b/src/jace/translator/primitive_translators/reshape_translator.py @@ -25,17 +25,29 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("reshape") def reshape_translator( - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, eqn_state: dace.SDFGState, ) -> None: """ - Implements the `reshape` primitive, through a memlet. + Implements the `reshape` primitive. + + The function creates a memlet between the input (old shape) and output (final + shape). Because of this, it is best if both arrays do not have any paddings. + + Args: + builder: The builder object of the translation. + in_var_names: Name of the SDFG variable of the source array, + with the old shape. + out_var_names: Name of SDFG variable that acts as destination, + with the new shape. + eqn: The equation that contains the `pjit` primitive. + eqn_state: State into which the nested SDFG should be constructed. Note: - The optional `dimensions` parameters which allows to permute the input + The optional `dimensions` parameters, which allows to permute the input, is not supported. """ if eqn.params["dimensions"] is not None: diff --git a/src/jace/translator/primitive_translators/select_n_translator.py b/src/jace/translator/primitive_translators/select_n_translator.py index 0b9a0d1..aa96922 100644 --- a/src/jace/translator/primitive_translators/select_n_translator.py +++ b/src/jace/translator/primitive_translators/select_n_translator.py @@ -48,7 +48,7 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - if len(in_var_names) == 3: # noqa: PLR2004 # Ternary conditional expression. + if len(in_var_names) == 3: # noqa: PLR2004 [magic-value-comparison] # Ternary conditional expression. # The order is correct, since `False` is interpreted as `0`, # which means "the first case". return "__out = __in1 if __cond else __in0" From 846a34512c1ba674c040c41ce28cf5a349703dec Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 14:37:42 +0200 Subject: [PATCH 7/9] Fixed some errors. --- src/jace/translator/jaxpr_translator_builder.py | 3 +-- .../primitive_translators/broadcast_in_dim_translator.py | 2 +- src/jace/translator/primitive_translators/conditions.py | 4 ++-- .../translator/primitive_translators/gather_translator.py | 3 --- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index c82c277..288593f 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -598,10 +598,9 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: prev_terminal_state, new_sdfg_term_state, ) - self._ctx.validate() - # Modify terminal root state of 'self' self._ctx.terminal_state = new_sdfg_term_state + self._ctx.validate() def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext: """ diff --git a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py index 964a2f6..d8bd388 100644 --- a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py +++ b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py @@ -56,7 +56,7 @@ def make_input_memlets( subset_str = ( ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) if eqn.params["broadcast_dimensions"] - else "0", + else "0" ) return {"__in0": dace.Memlet.simple(in_var_names[0], subset_str)} diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 6e37a7a..945baf1 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -79,7 +79,7 @@ def condition_translator( # Make sure that the selection variable is a DaCe symbol. if in_var_names[0] is None: literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) - selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))" + selection_symbol = f"min({len(branches)}, max(0, {literal_selection_value}))" selection_state = eqn_state else: selection_variable_name = in_var_names[0] @@ -87,7 +87,7 @@ def condition_translator( selection_state = builder.append_new_state( label=f"{name_pattern}_fork", assignments={ - selection_symbol: f"max({len(branches)}, min(0, {selection_variable_name}[0]))" + selection_symbol: f"min({len(branches)}, max(0, {selection_variable_name}))" }, prev_state=eqn_state, ) diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index daacb56..4f459d9 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -123,14 +123,12 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(inside_window_map_ranges[-1][0]) assert dim in not_collapsed_slice_dims - assert dim not in batch_dims elif dim in collapsed_slice_dims: # This dimension is only partially copied, but because it is collapsed, # only a single element is copied. Thus the offset is only given by the # what we read from the index array. src_access_pattern.append(f"__gather_{dim}") - assert dim in batch_dims else: # This dimension is partially copied, but _not colapsed_. This the element @@ -138,7 +136,6 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any # current position within the slicing window. inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") - assert dim in batch_dims assert dim in not_collapsed_slice_dims # These are the map variables that are associated to the first implicit loop (the From 2c7b3c8819daaf74bc721c43a2fe08e2ac51b5ec Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 26 Sep 2024 15:35:16 +0200 Subject: [PATCH 8/9] Applied Enriques primarly fixes. --- .../translator/jaxpr_translator_builder.py | 21 +--------- .../arithmetic_logical_translators.py | 40 ++++++++++++------- .../primitive_translators/conditions.py | 2 +- .../primitive_translators/slicing.py | 2 +- 4 files changed, 30 insertions(+), 35 deletions(-) diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index 288593f..3e48964 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -179,24 +179,6 @@ def append_new_state( self._ctx.terminal_state = new_state return new_state - def add_orphan_state( - self, - label: str, - ) -> dace.SDFGState: - """ - Add a new orphan state to the SDFG. - - The state is not connected to any other state, nor it is the new start state. - Except you know what you are doing you should not use this function and - instead use `self.append_new_state()`. - - Args: - label: The name of the state. - """ - if not self.is_allocated(): - raise RuntimeError("Builder is not allocated.") - return self._ctx.sdfg.add_state(label=label, is_start_block=False) - @property def arrays(self) -> Mapping[str, dace_data.Data]: """ @@ -520,7 +502,8 @@ def _allocate_translation_ctx( @property def _ctx(self) -> TranslationContext: """Returns the currently active translation context.""" - assert len(self._ctx_stack) != 0, "No context is active." + if not self.is_allocated(): + raise RuntimeError("The context is not allocated.") return self._ctx_stack[-1] def _clear_translation_ctx(self) -> TranslationContext | None: diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py index 7cf321f..28f9a3a 100644 --- a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -79,8 +79,8 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): Args: prim_name: The name of the primitive that should be handled. - int_tmpl: The template used for the integer case. - bool_tmpl: The template used for the bool case. + bitwise_tmpl: The template used for the bitwise case. + logical_tmpl: The template used for the logical case. Note: Since it does not make sense to single out `not` and keep the other @@ -88,10 +88,10 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): handled by this class. """ - def __init__(self, prim_name: str, int_tmpl: str, bool_tmpl: str) -> None: + def __init__(self, prim_name: str, bitwise_tmpl: str, logical_tmpl: str) -> None: super().__init__(primitive_name=prim_name) - self._int_tmpl = int_tmpl - self._bool_tmpl = bool_tmpl + self._bitwise_tmpl = bitwise_tmpl + self._logical_tmpl = logical_tmpl @override def write_tasklet_code( @@ -101,8 +101,8 @@ def write_tasklet_code( eqn: jax_core.JaxprEqn, ) -> str: if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): - return self._bool_tmpl - return self._int_tmpl + return self._logical_tmpl + return self._bitwise_tmpl # Maps the name of an arithmetic JAX primitive to the code template that is used to @@ -176,11 +176,23 @@ def write_tasklet_code( # Maps the name of a logical primitive to the two code templates, first the integer # case and second the boolean case, that are used to create the body of the mapped # tasklet. They are used to instantiate the `LogicalOperationTranslator` translators. -_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = { - "or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"), - "not": ("__out = ~(__in0)", "__out = not (__in0)"), - "and": ("__out = (__in0) & (__in1)", "__out = (__in0) and (__in1)"), - "xor": ("__out = (__in0) ^ (__in1)", "__out = (__in0) != (__in1)"), +_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, dict[str, str]]] = { + "or": { + "bitwise_tmpl": "__out = (__in0) | (__in1)", + "logical_tmpl": "__out = (__in0) or (__in1)", + }, + "not": { + "bitwise_tmpl": "__out = ~(__in0)", + "logical_tmpl": "__out = not (__in0)", + }, + "and": { + "bitwise_tmpl": "__out = (__in0) & (__in1)", + "logical_tmpl": "__out = (__in0) and (__in1)", + }, + "xor": { + "bitwise_tmpl": "__out = (__in0) ^ (__in1)", + "logical_tmpl": "__out = (__in0) != (__in1)", + }, } # fmt: on @@ -188,5 +200,5 @@ def write_tasklet_code( # Instantiate the arithmetic and logical translators from the templates. for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items(): translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl)) -for pname, (itmpl, btmpl) in _LOGICAL_OPERATION_TEMPLATES.items(): - translator.register_primitive_translator(LogicalOperationTranslator(pname, itmpl, btmpl)) +for pname, ptmpl in _LOGICAL_OPERATION_TEMPLATES.items(): # type: ignore[assignment] # Type confusion + translator.register_primitive_translator(LogicalOperationTranslator(pname, **ptmpl)) # type: ignore[arg-type] # Type confusion diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 945baf1..e13920b 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -116,7 +116,7 @@ def condition_translator( branch_states.append(branch_state) # Connect all branch states to the join state - join_state = builder.add_orphan_state(f"{name_pattern}__join_state") + join_state = builder._ctx.sdfg.add_state(label=f"{name_pattern}__join_state") for branch_state in branch_states: builder.sdfg.add_edge( branch_state, diff --git a/src/jace/translator/primitive_translators/slicing.py b/src/jace/translator/primitive_translators/slicing.py index c53c3d0..6d9ae26 100644 --- a/src/jace/translator/primitive_translators/slicing.py +++ b/src/jace/translator/primitive_translators/slicing.py @@ -57,7 +57,7 @@ def make_input_memlets( eqn: jax_core.JaxprEqn, ) -> dict[str, dace.Memlet]: strides: Sequence[int] = ( - ((1,) * len(tskl_ranges)) if eqn.params["strides"] is None else eqn.params["strides"] + eqn.params["strides"] if eqn.params["strides"] else ((1,) * len(tskl_ranges)) ) start_indices: Sequence[int] = eqn.params["start_indices"] # Fist index to slice return { From 18bdee9326ba2307100d08badb45f9f7fdd12442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 26 Sep 2024 16:27:38 +0200 Subject: [PATCH 9/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Enrique González Paredes --- .../primitive_translators/convert_element_type_translator.py | 2 +- .../translator/primitive_translators/gather_translator.py | 2 +- src/jace/translator/primitive_translators/pjit_translator.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py index a9f179c..e1fb8e5 100644 --- a/src/jace/translator/primitive_translators/convert_element_type_translator.py +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -32,7 +32,7 @@ class ConvertElementTypeTranslator(mapped_base.MappedOperationTranslatorBase): will perform the type conversion operation. Note: - The type to cast to id inferred from the output variable and the `new_dtype` + The type to cast to is inferred from the output variable and the `new_dtype` parameter of the equation is ignored. """ diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 4f459d9..51f5730 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -63,7 +63,7 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") - # This is the size of the slice window that is copied. Its length equal the rank + # This is the size of the slice window that is copied. Its length is the rank # of the source array, dimensions that are excluded from copying are listed # in `collapsed_slice_dims`. slice_sizes: Sequence[int] = eqn.params["slice_sizes"] diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py index 43bc3ea..95cb3d4 100644 --- a/src/jace/translator/primitive_translators/pjit_translator.py +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -57,9 +57,7 @@ def pjit_translator( nested_jaxpr: jax_core.ClosedJaxpr = eqn.params["jaxpr"] in_shardings = eqn.params["in_shardings"] out_shardings = eqn.params["out_shardings"] - _ = eqn.params["donated_invars"] # Always ignored - _ = eqn.params["keep_unused"] - _ = eqn.params["inline"] + # "donated_invars", "keep_unused", "inline" parameters are just ignored if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings): raise NotImplementedError("Currently 'pjit' does not support sharding in its input.")