From c29fc0dbf0377e67baea172d282e6b1da0995c7a Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 14:15:02 +0200 Subject: [PATCH] Some more corrections. Now let's test if it works. --- .../mapped_operation_base_translator.py | 50 +++++++-------- .../arithmetic_logical_translators.py | 16 ++--- .../concatenate_translator.py | 11 +++- .../primitive_translators/conditions.py | 25 ++++---- .../convert_element_type_translator.py | 2 +- .../primitive_translators/copy_translator.py | 27 ++++++-- .../gather_translator.py | 64 ++++++++++--------- .../primitive_translators/pjit_translator.py | 17 +++-- .../reshape_translator.py | 18 +++++- .../select_n_translator.py | 2 +- 10 files changed, 136 insertions(+), 96 deletions(-) diff --git a/src/jace/translator/mapped_operation_base_translator.py b/src/jace/translator/mapped_operation_base_translator.py index 17a5c35..508ad13 100644 --- a/src/jace/translator/mapped_operation_base_translator.py +++ b/src/jace/translator/mapped_operation_base_translator.py @@ -5,7 +5,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Module containing all translators related to arithmetic logical operations.""" +"""Module implementing the `MappedOperationTranslatorBase` helper class.""" from __future__ import annotations @@ -37,8 +37,9 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator): ``` where `__in*` are the connector names of the Tasklet and `__out` is the output connector. For problems such as this, the SDFG API provides the - `SDFGState.add_mapped_tasklet()` function, however, because it is very low - level and very verbose to use, this class acts as a convenience wrapper around it. + `SDFGState.add_mapped_tasklet()` function. However, because the function + operates on a very low level and is very verbose to use, this class acts + as a convenience wrapper around it. To use this class a user has to define the abstract `write_tasklet_code()` method. This function generates the entire code that should be put into the Tasklet, @@ -160,8 +161,8 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need in_var_names: The list of SDFG variables used as input, `None` if literal. eqn: The equation object. """ - out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output - out_rank = len(out_shp) + out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0])) + out_rank = len(out_shape) if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars): raise NotImplementedError( f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! " @@ -170,29 +171,26 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need # Now we will generate the input Memlets. tskl_inputs: dict[str, dace.Memlet] = {} - for i, (in_var_name, inp_shp) in enumerate( + for i, (in_var_name, in_shape) in enumerate( zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars)) ): - if in_var_name is None: # Input is a literal: No Memlet needed - continue - - if inp_shp == (): # Scalars - tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar - continue - - # We might have to do broadcasting. - # We ensured that input and output have the same rank (JAX is doing that - # for us). So we must do broadcasting, i.e. replicating that input - # dimension, if its size is 1. We threat the case where the output has - # size 1 in that dimension as broadcasting as well. - dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1] - tskl_inputs[f"__in{i}"] = dace.Memlet.simple( - in_var_name, - ", ".join( - ("0" if i in dims_to_bcast else it_var) - for i, (it_var, _) in enumerate(tskl_ranges) - ), - ) + if in_var_name is None: + pass + + elif in_shape == (): + tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") + + else: + dims_to_bcast = [ + dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1 + ] + tskl_inputs[f"__in{i}"] = dace.Memlet.simple( + in_var_name, + ", ".join( + ("0" if i in dims_to_bcast else it_var) + for i, (it_var, _) in enumerate(tskl_ranges) + ), + ) return tskl_inputs def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it. diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py index 667e1ac..7cf321f 100644 --- a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -100,14 +100,12 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - return ( - self._bool_tmpl - if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars) - else self._int_tmpl - ) + if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): + return self._bool_tmpl + return self._int_tmpl -# Maps the name of an arithmetic primitives to the code template that is used to +# Maps the name of an arithmetic JAX primitive to the code template that is used to # generate the body of the mapped tasklet. These are used to instantiate the # `ArithmeticOperationTranslator` objects. # fmt: off @@ -175,9 +173,9 @@ def write_tasklet_code( } -# Maps the name of a logical primitive to the two code templates (first the integer -# case and second the boolean case) used to create the body of the mapped tasklet. -# They are used to instantiate the `LogicalOperationTranslator` translators. +# Maps the name of a logical primitive to the two code templates, first the integer +# case and second the boolean case, that are used to create the body of the mapped +# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators. _LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = { "or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"), "not": ("__out = ~(__in0)", "__out = not (__in0)"), diff --git a/src/jace/translator/primitive_translators/concatenate_translator.py b/src/jace/translator/primitive_translators/concatenate_translator.py index 1b5f679..b327bde 100644 --- a/src/jace/translator/primitive_translators/concatenate_translator.py +++ b/src/jace/translator/primitive_translators/concatenate_translator.py @@ -25,7 +25,7 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("concatenate") def concatenate_translator( - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, @@ -36,6 +36,15 @@ def concatenate_translator( Each source array is copied by its own map, but all maps write to the same access node. + + Args: + builder: The builder object of the translation; unused. + in_var_names: The SDFG variables used an input arguments in order as they + should be concatenated. + out_var_names: Names of SDFG variables that should be used as outputs. + eqn: The equation that should be translated, the concatenation dimensions + is read from the `dimension` parameter. + eqn_state: State into which the nested SDFG should be constructed. """ if any(in_var_name is None for in_var_name in in_var_names): raise NotImplementedError("Concatenate: No literal inputs supported.") diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 4f7363d..6e37a7a 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -41,18 +41,19 @@ def condition_translator( Args: builder: The builder object of the translation. - in_var_names: The SDFG variables used an input arguments. First is the index, - the variable that selects the branch, the remaining ones are passed as - inputs to the branches. + in_var_names: The SDFG variables used an input arguments. First is the + selection variable. The remaining ones are passed to the branches as + inputs. out_var_names: Names of SDFG variables that should be used as outputs. eqn: The equation that should be translated. eqn_state: State into which the nested SDFG should be constructed. Notes: - According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) - the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) - an out of range selector uses the last branch. JaCe conforms to JAX semantic. - After this function the terminal state of the `builder` is unspecific. + - According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) + the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) + an out of range selector uses the last branch. JaCe conforms to JAX + semantic. + - After this function the terminal state of the `builder` is unspecific. """ if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: # XLA explicitly provides a binary form of the primitive @@ -61,7 +62,7 @@ def condition_translator( # integer implementation. raise NotImplementedError("The boolean conditional primitive is not implemented.") - # To make names in the (nested) SDFG unique we use the name of the equation state + # Used as prefix to give all additional states/variables a unique name. name_pattern = eqn_state.name # To avoid special cases promote all symbols to constants. @@ -80,9 +81,7 @@ def condition_translator( literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))" selection_state = eqn_state - else: - # Promotion of a scalar to a symbol through a state transition. selection_variable_name = in_var_names[0] selection_symbol = f"{selection_variable_name}_symb" selection_state = builder.append_new_state( @@ -93,12 +92,15 @@ def condition_translator( prev_state=eqn_state, ) + # Translate the subbranches, the branches are all connected from `selection_state`. branch_states: list[dace.SDFGState] = [] for i, branch_jaxpr in enumerate(branches): branch_pattern = f"{name_pattern}_{{}}_branch_{i}" branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg")) - # This will update the terminal state only for the first branch + # The first time it is called it will update the builder's terminal state + # but since we will return the join state it will be updated later. But + # until then the terminal state of the builder is invalid. branch_state = builder.append_new_state( label=branch_pattern.format("state"), condition=f"{selection_symbol} == {i}", @@ -113,6 +115,7 @@ def condition_translator( ) branch_states.append(branch_state) + # Connect all branch states to the join state join_state = builder.add_orphan_state(f"{name_pattern}__join_state") for branch_state in branch_states: builder.sdfg.add_edge( diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py index 118f4e3..a9f179c 100644 --- a/src/jace/translator/primitive_translators/convert_element_type_translator.py +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -56,7 +56,7 @@ def write_tasklet_code( if in_dtype == out_dtype: # JAX sometimes adds conversions which are not needed. In these cases - # we perform a copy. + # make a copy out of it. # TODO(phimuell): Create a Memlet instead. return "__out = __in0" diff --git a/src/jace/translator/primitive_translators/copy_translator.py b/src/jace/translator/primitive_translators/copy_translator.py index 5cc0d3c..9e0d2d1 100644 --- a/src/jace/translator/primitive_translators/copy_translator.py +++ b/src/jace/translator/primitive_translators/copy_translator.py @@ -28,21 +28,31 @@ def copy_translator( builder: translator.JaxprTranslationBuilder, in_var_names: Sequence[str | None], out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, # noqa: ARG001 # Required by the interface. + eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] # Required by the interface. eqn_state: dace.SDFGState, ) -> None: """ Implements the `copy` primitive. + The copy is implemented by creating a memlet between the source and destination. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated; unused. + eqn_state: State into which the nested SDFG should be constructed. + Todo: Investigate if operation should expand to a map. """ + assert in_var_names[0] is not None eqn_state.add_nedge( eqn_state.add_read(in_var_names[0]), eqn_state.add_write(out_var_names[0]), dace.Memlet.from_array( in_var_names[0], - builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string + builder.arrays[in_var_names[0]], ), ) @@ -60,9 +70,16 @@ def device_put_translator( Implements the `device_put` primitive. In JAX this primitive is used to copy data between the host and the device, - in DaCe Memlets can do this. However, because of the way JaCe operates, at - least in the beginning a computation is either fully on the host or on the - device this copy will essentially perform a copying. + in DaCe only memlets can do this. However, because of the way JaCe (currently) + operates (a computation is either fully on the host or on GPU), the `device_put` + primitive essentially decays to a copy. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated. + eqn_state: State into which the nested SDFG should be constructed. """ if not (eqn.params["device"] is None and eqn.params["src"] is None): raise NotImplementedError( diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 8d0f60f..daacb56 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -26,7 +26,7 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("gather") def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further. - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, @@ -64,8 +64,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") # This is the size of the slice window that is copied. Its length equal the rank - # of the source array, dimensions that should not be copied are listed in - # `collapsed_slice_dims`. + # of the source array, dimensions that are excluded from copying are listed + # in `collapsed_slice_dims`. slice_sizes: Sequence[int] = eqn.params["slice_sizes"] collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims not_collapsed_slice_dims = tuple( @@ -73,34 +73,38 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any ) assert len(slice_sizes) == len(src_shape) - # The batch dimensions are used to iterate through the slice windows, thus access - # the index array, with the exception of the last dimension, see below. - # NOTE: In pure XLA this last dimension might not be present, however, JAX - # adds it and our implementation relies on it. + # The batch dimensions are used to iterate through the different slice windows + # (not inside them) thus they access the index array, with the exception of the + # last dimension, see below. + # NOTE: In pure XLA this last dimension is in certain cases optional, however, + # JAX adds it and our implementation relies on it. batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims) if (len(batch_dims) + 1) != len(idx_shape): raise ValueError( f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." ) - # The last dimension is special, as it contains the actual start point for the - # slice window when the dimension is only partially copied. The `start_index_map` - # associates each position element in the last dimension with the corresponding + # The last dimension of the index array is special, as it contains the actual + # start point for the slice windows when the dimension is only partially copied. + # Thus the last dimension must be seen as a list of start indexes and the other + # dimensions are used to enumerate the slice windows. The `start_index_map` + # associates each position in the last dimension with the corresponding # dimension of the source array. start_index_map: Sequence[int] = dimension_numbers.start_index_map assert len(start_index_map) == idx_shape[-1] - # The final map has two parts. The first part iterates through all the slice - # windows that are given through the index array (except last dimension). - # If a dimension is not fully copied then the start index of the window is - # given through the elements of the last dimensions of the index array. - # Map variables that are used for this use the pattern `__i{out_name}_gather{bd}`. - # The second loop is used to copy the slice windows themselves, their map - # variables follow the pattern `__i{i}`. + # The iteration variable of the final map can be divided into two parts or + # categories. The first part iterates through all the slice windows that are + # given through the index array. If a dimension is not fully copied then the + # start index of the window is given through the elements of the last dimensions + # of the index array. Map variables that are used for this use the pattern + # `__i{out_name}_gather{bd}`. The second kind of variables are used to copy the + # content of the slice windows themselves, these map variables follow the + # pattern `__i{i}`. # Because the offsets of the slice window (which are given by the elements of - # the last dimension of the index array) are variables and not symbols, it - # can not be included in the memlets. Instead we generate an tasklet that + # the last dimension of the index array) are variables and not symbols, they + # can not be included in the memlets. Instead we generate a tasklet that # performs an indirect access and get all elements of the last dimension of the # index array (with the names `__gather_{dim}`), together with the full source # array as input. @@ -108,7 +112,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any # Access pattern of the source array _inside_ the tasklet. src_access_pattern: list[str] = [] - # The ranges of the second part implicit loop (the one that copies the windows). + # The map variables and their ranges of the second part implicit loop; the one + # that copy the content inside the window. inside_window_map_ranges: list[tuple[str, str]] = [] for dim, slice_size in enumerate(slice_sizes): @@ -123,14 +128,14 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any elif dim in collapsed_slice_dims: # This dimension is only partially copied, but because it is collapsed, # only a single element is copied. Thus the offset is only given by the - # index array. + # what we read from the index array. src_access_pattern.append(f"__gather_{dim}") assert dim in batch_dims else: - # This dimension is partially copied, but _not colapsed_. This creates a - # slice index and the offset (of a single element) is given by the static - # start of the window and the current position inside of the window. + # This dimension is partially copied, but _not colapsed_. This the element + # that is read depends on the (static) offset of this window and the + # current position within the slicing window. inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") assert dim in batch_dims @@ -154,9 +159,8 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any num_accesses=1, ) - # The static offset of the slice window, which is given through the elements - # of the last dimensions of the index array, for every element in that dimension - # there is an input. + # The static offsets of the slice window, are given through the elements of the + # last dimensions of the index array. for i, dim in enumerate(start_index_map): tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( data=idx_name, @@ -165,10 +169,10 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any ), ) - # The output shape is given by the combination of the non collapsed slice sizes + # The output shape is given by the combination of the not collapsed slice sizes # and the index array (without the last dimension) with some permutation. - # Note that the relative order of slice sizes can not be changed, but they - # might be interleaved with the batch variables. + # While the relative order of slice window does not change, `start_index_map` + # already applied a permutation, it might be interleaved with batch dimensions. output_memlet_pattern: list[str] = [] dim_counter = 0 for dim in range(len(out_shape)): diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py index b3b9d97..43bc3ea 100644 --- a/src/jace/translator/primitive_translators/pjit_translator.py +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -11,7 +11,7 @@ import re from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from jax._src import sharding_impls as jax_sharding # noqa: PLC2701 [import-private-name] @@ -54,13 +54,12 @@ def pjit_translator( The translator ignores the `donated_invars`, the `keep_unused` and the `inline` parameter and let's DaCe handle it. """ - params: dict[str, Any] = eqn.params - nested_jaxpr: jax_core.ClosedJaxpr = params["jaxpr"] - in_shardings = params["in_shardings"] - out_shardings = params["out_shardings"] - _ = params["donated_invars"] # Always ignored - _ = params["keep_unused"] - _ = params["inline"] + nested_jaxpr: jax_core.ClosedJaxpr = eqn.params["jaxpr"] + in_shardings = eqn.params["in_shardings"] + out_shardings = eqn.params["out_shardings"] + _ = eqn.params["donated_invars"] # Always ignored + _ = eqn.params["keep_unused"] + _ = eqn.params["inline"] if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings): raise NotImplementedError("Currently 'pjit' does not support sharding in its input.") @@ -68,7 +67,7 @@ def pjit_translator( raise NotImplementedError("Currently 'pjit' does not support sharding in its output.") # TODO(phimuell): Controlflow region and name - pjit_name = params["name"] + pjit_name = eqn.params["name"] # Name in SDFG must be unique, thus we mangle it, furthermore, we have to clean it. sdfg_name = f"pjit_{re.subn('[^a-zA-Z0-9_]', '_', pjit_name)[0]}__{'_'.join(out_var_names)}" diff --git a/src/jace/translator/primitive_translators/reshape_translator.py b/src/jace/translator/primitive_translators/reshape_translator.py index 1bcbc5a..79b9bb0 100644 --- a/src/jace/translator/primitive_translators/reshape_translator.py +++ b/src/jace/translator/primitive_translators/reshape_translator.py @@ -25,17 +25,29 @@ @translator.register_primitive_translator() @translator.make_primitive_translator("reshape") def reshape_translator( - builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. in_var_names: Sequence[str | None], out_var_names: Sequence[str], eqn: jax_core.JaxprEqn, eqn_state: dace.SDFGState, ) -> None: """ - Implements the `reshape` primitive, through a memlet. + Implements the `reshape` primitive. + + The function creates a memlet between the input (old shape) and output (final + shape). Because of this, it is best if both arrays do not have any paddings. + + Args: + builder: The builder object of the translation. + in_var_names: Name of the SDFG variable of the source array, + with the old shape. + out_var_names: Name of SDFG variable that acts as destination, + with the new shape. + eqn: The equation that contains the `pjit` primitive. + eqn_state: State into which the nested SDFG should be constructed. Note: - The optional `dimensions` parameters which allows to permute the input + The optional `dimensions` parameters, which allows to permute the input, is not supported. """ if eqn.params["dimensions"] is not None: diff --git a/src/jace/translator/primitive_translators/select_n_translator.py b/src/jace/translator/primitive_translators/select_n_translator.py index 0b9a0d1..aa96922 100644 --- a/src/jace/translator/primitive_translators/select_n_translator.py +++ b/src/jace/translator/primitive_translators/select_n_translator.py @@ -48,7 +48,7 @@ def write_tasklet_code( in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn, ) -> str: - if len(in_var_names) == 3: # noqa: PLR2004 # Ternary conditional expression. + if len(in_var_names) == 3: # noqa: PLR2004 [magic-value-comparison] # Ternary conditional expression. # The order is correct, since `False` is interpreted as `0`, # which means "the first case". return "__out = __in1 if __cond else __in0"