From cb600d397aaa8d12a9bb82467457a522aab5d4c2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 12:37:20 +0200 Subject: [PATCH] Refactored the gather translator. It is now better confiugured. --- .../primitive_translators/__init__.py | 4 +- .../gather_translator.py | 333 +++++++++--------- 2 files changed, 160 insertions(+), 177 deletions(-) diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py index 757743e..f019964 100644 --- a/src/jace/translator/primitive_translators/__init__.py +++ b/src/jace/translator/primitive_translators/__init__.py @@ -17,7 +17,7 @@ from .conditions import condition_translator from .convert_element_type_translator import ConvertElementTypeTranslator from .copy_translator import copy_translator, device_put_translator -from .gather_translator import GatherTranslator +from .gather_translator import gather_translator from .iota_translator import IotaTranslator from .pjit_translator import pjit_translator from .reshape_translator import reshape_translator @@ -30,7 +30,6 @@ "ArithmeticOperationTranslator", "BroadcastInDimTranslator", "ConvertElementTypeTranslator", - "GatherTranslator", "IotaTranslator", "LogicalOperationTranslator", "SelectNTranslator", @@ -41,6 +40,7 @@ "copy_translator", "device_put_translator", "dynamic_slicing_translator", + "gather_translator", "pjit_translator", "reshape_translator", ] diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index 4b58e70..8d0f60f 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -13,7 +13,6 @@ import dace from jax import lax as jax_lax -from typing_extensions import override from jace import translator, util @@ -24,188 +23,172 @@ from jax import core as jax_core -class GatherTranslator(translator.PrimitiveTranslator): +@translator.register_primitive_translator() +@translator.make_primitive_translator("gather") +def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: """ - Garther Translator. + Implements the `gather` primitive. - The gather operation extracts patches of a certain size, known as `slice_size`, - from an array, called source or input array. Where these patches starts is - given by another array, the index array. + These primitive is used to implement the `array.at[...].get()` access. In the + end the primitive extracts patches/windows of a certain size, known as + `slice_size`, from an array, which is called source or input array. The start + points of these windows are given by another array, the so called index array. + + Args: + builder: The builder object that is active. + in_var_names: The names of the input variables, the first array is + assumed as source array and the second is the index array. + out_var_names: The names of the output variables. + eqn: The equation to translate. + eqn_state: The state in which we put the extraction. See Also: https://www.tensorflow.org/xla/operation_semantics#gather https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html """ - - @property - def primitive(self) -> str: # noqa: D102 # No docstring needed. - return "gather" - - @override - def __call__( # noqa: PLR0914, PLR0915 # Just ported from the prototype, cleanup postponed. - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - """ - Performs the gather operation. - - Args: - builder: The builder object that is active. - in_var_names: The names of the input variables, the first array is - assumed as source array and the second is the index array. - out_var_names: The names of the output variables. - eqn: The equation to translate. - eqn_state: The state in which we put the extraction. - """ - assert len(eqn.invars) == 2 # noqa: PLR2004 # XLA supports more inputs. - - out_name = out_var_names[0] - out_shape = util.get_jax_var_shape(eqn.outvars[0]) - - src_name = in_var_names[0] - src_shape = util.get_jax_var_shape(eqn.invars[0]) - - idx_name = in_var_names[1] - idx_shape = util.get_jax_var_shape(eqn.invars[1]) - - dimension_numbers = eqn.params["dimension_numbers"] - offset_dims: Sequence[int] = dimension_numbers.offset_dims - collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims - start_index_map: Sequence[int] = dimension_numbers.start_index_map - slice_sizes: Sequence[int] = eqn.params["slice_sizes"] - mode: jax_lax.GatherScatterMode = eqn.params["mode"] - assert len(start_index_map) == idx_shape[-1] - - if mode != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: - raise NotImplementedError(f"The mode {mode} is not implemented.") - - # Over these dimensions the copy of the patches goes. - batch_dims = tuple(d for d in range(len(out_shape)) if d not in offset_dims) - - # Every batch dimension is associated with one dimension of of the index - # array, but there is always one dimension more in the index array. This - # dimension contains the start indexes of the slice, if there is only - # one index that should be loaded is not strictly necessary, but Jax - # (currently adds) it implicitly, probably to make life easier. - if (len(batch_dims) + 1) != len(idx_shape): - raise ValueError( - f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." - ) - - # These are the dimensions (of the input) for which a map index is created. - # Note that we exclude collapsed dimensions here. - src_dim_with_map_idx = tuple( - dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + out_name = out_var_names[0] + out_shape = util.get_jax_var_shape(eqn.outvars[0]) + src_name = in_var_names[0] + src_shape = util.get_jax_var_shape(eqn.invars[0]) + idx_name = in_var_names[1] + idx_shape = util.get_jax_var_shape(eqn.invars[1]) + dimension_numbers = eqn.params["dimension_numbers"] + + if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: + raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") + + # This is the size of the slice window that is copied. Its length equal the rank + # of the source array, dimensions that should not be copied are listed in + # `collapsed_slice_dims`. + slice_sizes: Sequence[int] = eqn.params["slice_sizes"] + collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims + not_collapsed_slice_dims = tuple( + dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + ) + assert len(slice_sizes) == len(src_shape) + + # The batch dimensions are used to iterate through the slice windows, thus access + # the index array, with the exception of the last dimension, see below. + # NOTE: In pure XLA this last dimension might not be present, however, JAX + # adds it and our implementation relies on it. + batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims) + if (len(batch_dims) + 1) != len(idx_shape): + raise ValueError( + f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." ) - assert len(src_dim_with_map_idx) == len(offset_dims) - - # The final map is the composition of two loops. The first map iterates over - # the index array, except the last dimension, and is used to "copy" the - # different patches from the source to the output array. These map parameters - # follow the pattern `__i{out_name}_gather{bd}`, where `bd` is a batch - # dimension. These variables are used to access the index array. - # The second loop performs the actual copy of the slices. For these - # the variables `__i{i}` is used were, these are known as offset - # dimensions. - # What is a bit difficult, that the actual access/dereferencing of the source - # array is done within the tasklet. - - # Access pattern of the source array _within_ the tasklet. - src_access_pattern: list[str] = [] - - # These are the map ranges for the coying of the slicing. - slice_map_ranges: list[tuple[str, str]] = [] - - # Compute the access pattern within the tasklet. - # As a side effect we also compute the map ranges, but only for the slices. - for dim, slice_size in enumerate(slice_sizes): - # Order is important! - if dim not in start_index_map: - # This dimension is fully copied - slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) - src_access_pattern.append(slice_map_ranges[-1][0]) - assert dim in src_dim_with_map_idx - assert slice_size == src_shape[dim] - - elif dim in collapsed_slice_dims: - # This dimension is only partially copied, however, since the - # dimension is collapsed, only a single element is copied that - # comes from the index array. - src_access_pattern.append(f"__gather_{dim}") - - else: - # This dimension is partially copied, but is _not colapsed_, we need - # a map index to copy the range. However, there is also an offset - # that is involved from copying. - slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) - src_access_pattern.append(f"__gather_{dim} + {slice_map_ranges[-1][0]}") - assert dim in src_dim_with_map_idx - assert slice_size <= src_shape[dim] - - # These are the map variable that go over the index array. - patch_loop_vars = tuple(f"__i{out_name}_gather{bd}" for bd in batch_dims) - patch_map_ranges = [ - (map_param, f"0:{patch_loop_bound}") - for map_param, patch_loop_bound in zip(patch_loop_vars, idx_shape[:-1]) - ] - - # Creating the input memlet that allows us to access the source array from - # inside the tasklet and make it accessible through the name `__arr`. At - # this point it is not possible to tell where we access, because we are - # missing a index variables, they will only be accessible inside the - # tasklet (see below), however, we know that we will access only one - # element from the array. - tasklet_inputs: dict[str, dace.Memlet] = { - "__arr": dace.Memlet.simple( - data=src_name, - subset_str=", ".join(f"0:{size}" for size in src_shape), - num_accesses=1, + + # The last dimension is special, as it contains the actual start point for the + # slice window when the dimension is only partially copied. The `start_index_map` + # associates each position element in the last dimension with the corresponding + # dimension of the source array. + start_index_map: Sequence[int] = dimension_numbers.start_index_map + assert len(start_index_map) == idx_shape[-1] + + # The final map has two parts. The first part iterates through all the slice + # windows that are given through the index array (except last dimension). + # If a dimension is not fully copied then the start index of the window is + # given through the elements of the last dimensions of the index array. + # Map variables that are used for this use the pattern `__i{out_name}_gather{bd}`. + # The second loop is used to copy the slice windows themselves, their map + # variables follow the pattern `__i{i}`. + + # Because the offsets of the slice window (which are given by the elements of + # the last dimension of the index array) are variables and not symbols, it + # can not be included in the memlets. Instead we generate an tasklet that + # performs an indirect access and get all elements of the last dimension of the + # index array (with the names `__gather_{dim}`), together with the full source + # array as input. + + # Access pattern of the source array _inside_ the tasklet. + src_access_pattern: list[str] = [] + + # The ranges of the second part implicit loop (the one that copies the windows). + inside_window_map_ranges: list[tuple[str, str]] = [] + + for dim, slice_size in enumerate(slice_sizes): + # Order is important! + if dim not in start_index_map: + # This dimension is fully copied + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(inside_window_map_ranges[-1][0]) + assert dim in not_collapsed_slice_dims + assert dim not in batch_dims + + elif dim in collapsed_slice_dims: + # This dimension is only partially copied, but because it is collapsed, + # only a single element is copied. Thus the offset is only given by the + # index array. + src_access_pattern.append(f"__gather_{dim}") + assert dim in batch_dims + + else: + # This dimension is partially copied, but _not colapsed_. This creates a + # slice index and the offset (of a single element) is given by the static + # start of the window and the current position inside of the window. + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") + assert dim in batch_dims + assert dim in not_collapsed_slice_dims + + # These are the map variables that are associated to the first implicit loop (the + # iteration over the index array, excluding the last dimension). + batch_map_ranges = [ + (f"__i{out_name}_gather{batch_dim}", f"0:{batch_loop_bound}") + for batch_dim, batch_loop_bound in zip(batch_dims, idx_shape[:-1]) + ] + assert len(batch_map_ranges) + len(inside_window_map_ranges) == len(out_shape) + + tasklet_inputs: dict[str, dace.Memlet] = {} + + # We need to pass the full array into the tasklet, however, we know that we + # will read only one element. + tasklet_inputs["__arr"] = dace.Memlet.simple( + data=src_name, + subset_str=", ".join(f"0:{size}" for size in src_shape), + num_accesses=1, + ) + + # The static offset of the slice window, which is given through the elements + # of the last dimensions of the index array, for every element in that dimension + # there is an input. + for i, dim in enumerate(start_index_map): + tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( + data=idx_name, + subset_str=( + ", ".join(batch_loop_var for batch_loop_var, _ in batch_map_ranges) + f", {i}" ), - } - - # Now we are creating the memlets that access the index array. - for i, dim in enumerate(start_index_map): - tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( - data=idx_name, subset_str=(", ".join(patch_loop_vars) + f", {i}") - ) - - # The tasklet code. - tasklet_code = "__out = __arr[" + ", ".join(src_access_pattern) + "]" - - # Now we generate the output memlet. - outpt_access_pattern: list[str] = [] - dim_counter = 0 - for dim in range(len(out_shape)): - if dim in batch_dims: - # This is a batch dimension, thus a loop variable is used for it. - patch_loop_var = patch_loop_vars[batch_dims.index(dim)] - outpt_access_pattern.append(str(patch_loop_var)) - - else: - # This is a dimension for copying the slices. - assert dim_counter <= len(src_dim_with_map_idx) - associated_map_idx = src_dim_with_map_idx[dim_counter] - dim_counter += 1 - outpt_access_pattern.append(f"__i{associated_map_idx}") - assert dim_counter == len(src_dim_with_map_idx) - - tasklet_outputs: dict[str, dace.Memlet] = { - "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(outpt_access_pattern)) - } - assert len(patch_map_ranges) + len(slice_map_ranges) == len(out_shape) - - eqn_state.add_mapped_tasklet( - name=f"_gather_map_{out_name}", - map_ranges=patch_map_ranges + slice_map_ranges, - inputs=tasklet_inputs, - code=tasklet_code, - outputs=tasklet_outputs, - external_edges=True, ) - -_ = translator.register_primitive_translator(GatherTranslator()) + # The output shape is given by the combination of the non collapsed slice sizes + # and the index array (without the last dimension) with some permutation. + # Note that the relative order of slice sizes can not be changed, but they + # might be interleaved with the batch variables. + output_memlet_pattern: list[str] = [] + dim_counter = 0 + for dim in range(len(out_shape)): + if dim in batch_dims: + batch_loop_var = batch_map_ranges[batch_dims.index(dim)][0] + output_memlet_pattern.append(str(batch_loop_var)) + + else: + associated_map_idx = not_collapsed_slice_dims[dim_counter] + dim_counter += 1 + output_memlet_pattern.append(f"__i{associated_map_idx}") + assert dim_counter == len(not_collapsed_slice_dims) + + eqn_state.add_mapped_tasklet( + name=f"_gather_map_{out_name}", + map_ranges=batch_map_ranges + inside_window_map_ranges, + inputs=tasklet_inputs, + code="__out = __arr[" + ", ".join(src_access_pattern) + "]", + outputs={ + "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(output_memlet_pattern)) + }, + external_edges=True, + )