Skip to content

Commit

Permalink
Refactored the gather translator.
Browse files Browse the repository at this point in the history
It is now better confiugured.
  • Loading branch information
philip-paul-mueller committed Sep 24, 2024
1 parent c8b7d86 commit cb600d3
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 177 deletions.
4 changes: 2 additions & 2 deletions src/jace/translator/primitive_translators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .conditions import condition_translator
from .convert_element_type_translator import ConvertElementTypeTranslator
from .copy_translator import copy_translator, device_put_translator
from .gather_translator import GatherTranslator
from .gather_translator import gather_translator
from .iota_translator import IotaTranslator
from .pjit_translator import pjit_translator
from .reshape_translator import reshape_translator
Expand All @@ -30,7 +30,6 @@
"ArithmeticOperationTranslator",
"BroadcastInDimTranslator",
"ConvertElementTypeTranslator",
"GatherTranslator",
"IotaTranslator",
"LogicalOperationTranslator",
"SelectNTranslator",
Expand All @@ -41,6 +40,7 @@
"copy_translator",
"device_put_translator",
"dynamic_slicing_translator",
"gather_translator",
"pjit_translator",
"reshape_translator",
]
333 changes: 158 additions & 175 deletions src/jace/translator/primitive_translators/gather_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import dace
from jax import lax as jax_lax
from typing_extensions import override

from jace import translator, util

Expand All @@ -24,188 +23,172 @@
from jax import core as jax_core


class GatherTranslator(translator.PrimitiveTranslator):
@translator.register_primitive_translator()
@translator.make_primitive_translator("gather")
def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further.
builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface.
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jax_core.JaxprEqn,
eqn_state: dace.SDFGState,
) -> None:
"""
Garther Translator.
Implements the `gather` primitive.
The gather operation extracts patches of a certain size, known as `slice_size`,
from an array, called source or input array. Where these patches starts is
given by another array, the index array.
These primitive is used to implement the `array.at[...].get()` access. In the
end the primitive extracts patches/windows of a certain size, known as
`slice_size`, from an array, which is called source or input array. The start
points of these windows are given by another array, the so called index array.
Args:
builder: The builder object that is active.
in_var_names: The names of the input variables, the first array is
assumed as source array and the second is the index array.
out_var_names: The names of the output variables.
eqn: The equation to translate.
eqn_state: The state in which we put the extraction.
See Also:
https://www.tensorflow.org/xla/operation_semantics#gather
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html
"""

@property
def primitive(self) -> str: # noqa: D102 # No docstring needed.
return "gather"

@override
def __call__( # noqa: PLR0914, PLR0915 # Just ported from the prototype, cleanup postponed.
self,
builder: translator.JaxprTranslationBuilder,
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jax_core.JaxprEqn,
eqn_state: dace.SDFGState,
) -> None:
"""
Performs the gather operation.
Args:
builder: The builder object that is active.
in_var_names: The names of the input variables, the first array is
assumed as source array and the second is the index array.
out_var_names: The names of the output variables.
eqn: The equation to translate.
eqn_state: The state in which we put the extraction.
"""
assert len(eqn.invars) == 2 # noqa: PLR2004 # XLA supports more inputs.

out_name = out_var_names[0]
out_shape = util.get_jax_var_shape(eqn.outvars[0])

src_name = in_var_names[0]
src_shape = util.get_jax_var_shape(eqn.invars[0])

idx_name = in_var_names[1]
idx_shape = util.get_jax_var_shape(eqn.invars[1])

dimension_numbers = eqn.params["dimension_numbers"]
offset_dims: Sequence[int] = dimension_numbers.offset_dims
collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims
start_index_map: Sequence[int] = dimension_numbers.start_index_map
slice_sizes: Sequence[int] = eqn.params["slice_sizes"]
mode: jax_lax.GatherScatterMode = eqn.params["mode"]
assert len(start_index_map) == idx_shape[-1]

if mode != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS:
raise NotImplementedError(f"The mode {mode} is not implemented.")

# Over these dimensions the copy of the patches goes.
batch_dims = tuple(d for d in range(len(out_shape)) if d not in offset_dims)

# Every batch dimension is associated with one dimension of of the index
# array, but there is always one dimension more in the index array. This
# dimension contains the start indexes of the slice, if there is only
# one index that should be loaded is not strictly necessary, but Jax
# (currently adds) it implicitly, probably to make life easier.
if (len(batch_dims) + 1) != len(idx_shape):
raise ValueError(
f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}."
)

# These are the dimensions (of the input) for which a map index is created.
# Note that we exclude collapsed dimensions here.
src_dim_with_map_idx = tuple(
dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims
out_name = out_var_names[0]
out_shape = util.get_jax_var_shape(eqn.outvars[0])
src_name = in_var_names[0]
src_shape = util.get_jax_var_shape(eqn.invars[0])
idx_name = in_var_names[1]
idx_shape = util.get_jax_var_shape(eqn.invars[1])
dimension_numbers = eqn.params["dimension_numbers"]

if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS:
raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.")

# This is the size of the slice window that is copied. Its length equal the rank
# of the source array, dimensions that should not be copied are listed in
# `collapsed_slice_dims`.
slice_sizes: Sequence[int] = eqn.params["slice_sizes"]
collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims
not_collapsed_slice_dims = tuple(
dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims
)
assert len(slice_sizes) == len(src_shape)

# The batch dimensions are used to iterate through the slice windows, thus access
# the index array, with the exception of the last dimension, see below.
# NOTE: In pure XLA this last dimension might not be present, however, JAX
# adds it and our implementation relies on it.
batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims)
if (len(batch_dims) + 1) != len(idx_shape):
raise ValueError(
f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}."
)
assert len(src_dim_with_map_idx) == len(offset_dims)

# The final map is the composition of two loops. The first map iterates over
# the index array, except the last dimension, and is used to "copy" the
# different patches from the source to the output array. These map parameters
# follow the pattern `__i{out_name}_gather{bd}`, where `bd` is a batch
# dimension. These variables are used to access the index array.
# The second loop performs the actual copy of the slices. For these
# the variables `__i{i}` is used were, these are known as offset
# dimensions.
# What is a bit difficult, that the actual access/dereferencing of the source
# array is done within the tasklet.

# Access pattern of the source array _within_ the tasklet.
src_access_pattern: list[str] = []

# These are the map ranges for the coying of the slicing.
slice_map_ranges: list[tuple[str, str]] = []

# Compute the access pattern within the tasklet.
# As a side effect we also compute the map ranges, but only for the slices.
for dim, slice_size in enumerate(slice_sizes):
# Order is important!
if dim not in start_index_map:
# This dimension is fully copied
slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(slice_map_ranges[-1][0])
assert dim in src_dim_with_map_idx
assert slice_size == src_shape[dim]

elif dim in collapsed_slice_dims:
# This dimension is only partially copied, however, since the
# dimension is collapsed, only a single element is copied that
# comes from the index array.
src_access_pattern.append(f"__gather_{dim}")

else:
# This dimension is partially copied, but is _not colapsed_, we need
# a map index to copy the range. However, there is also an offset
# that is involved from copying.
slice_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(f"__gather_{dim} + {slice_map_ranges[-1][0]}")
assert dim in src_dim_with_map_idx
assert slice_size <= src_shape[dim]

# These are the map variable that go over the index array.
patch_loop_vars = tuple(f"__i{out_name}_gather{bd}" for bd in batch_dims)
patch_map_ranges = [
(map_param, f"0:{patch_loop_bound}")
for map_param, patch_loop_bound in zip(patch_loop_vars, idx_shape[:-1])
]

# Creating the input memlet that allows us to access the source array from
# inside the tasklet and make it accessible through the name `__arr`. At
# this point it is not possible to tell where we access, because we are
# missing a index variables, they will only be accessible inside the
# tasklet (see below), however, we know that we will access only one
# element from the array.
tasklet_inputs: dict[str, dace.Memlet] = {
"__arr": dace.Memlet.simple(
data=src_name,
subset_str=", ".join(f"0:{size}" for size in src_shape),
num_accesses=1,

# The last dimension is special, as it contains the actual start point for the
# slice window when the dimension is only partially copied. The `start_index_map`
# associates each position element in the last dimension with the corresponding
# dimension of the source array.
start_index_map: Sequence[int] = dimension_numbers.start_index_map
assert len(start_index_map) == idx_shape[-1]

# The final map has two parts. The first part iterates through all the slice
# windows that are given through the index array (except last dimension).
# If a dimension is not fully copied then the start index of the window is
# given through the elements of the last dimensions of the index array.
# Map variables that are used for this use the pattern `__i{out_name}_gather{bd}`.
# The second loop is used to copy the slice windows themselves, their map
# variables follow the pattern `__i{i}`.

# Because the offsets of the slice window (which are given by the elements of
# the last dimension of the index array) are variables and not symbols, it
# can not be included in the memlets. Instead we generate an tasklet that
# performs an indirect access and get all elements of the last dimension of the
# index array (with the names `__gather_{dim}`), together with the full source
# array as input.

# Access pattern of the source array _inside_ the tasklet.
src_access_pattern: list[str] = []

# The ranges of the second part implicit loop (the one that copies the windows).
inside_window_map_ranges: list[tuple[str, str]] = []

for dim, slice_size in enumerate(slice_sizes):
# Order is important!
if dim not in start_index_map:
# This dimension is fully copied
inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(inside_window_map_ranges[-1][0])
assert dim in not_collapsed_slice_dims
assert dim not in batch_dims

elif dim in collapsed_slice_dims:
# This dimension is only partially copied, but because it is collapsed,
# only a single element is copied. Thus the offset is only given by the
# index array.
src_access_pattern.append(f"__gather_{dim}")
assert dim in batch_dims

else:
# This dimension is partially copied, but _not colapsed_. This creates a
# slice index and the offset (of a single element) is given by the static
# start of the window and the current position inside of the window.
inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}")
assert dim in batch_dims
assert dim in not_collapsed_slice_dims

# These are the map variables that are associated to the first implicit loop (the
# iteration over the index array, excluding the last dimension).
batch_map_ranges = [
(f"__i{out_name}_gather{batch_dim}", f"0:{batch_loop_bound}")
for batch_dim, batch_loop_bound in zip(batch_dims, idx_shape[:-1])
]
assert len(batch_map_ranges) + len(inside_window_map_ranges) == len(out_shape)

tasklet_inputs: dict[str, dace.Memlet] = {}

# We need to pass the full array into the tasklet, however, we know that we
# will read only one element.
tasklet_inputs["__arr"] = dace.Memlet.simple(
data=src_name,
subset_str=", ".join(f"0:{size}" for size in src_shape),
num_accesses=1,
)

# The static offset of the slice window, which is given through the elements
# of the last dimensions of the index array, for every element in that dimension
# there is an input.
for i, dim in enumerate(start_index_map):
tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple(
data=idx_name,
subset_str=(
", ".join(batch_loop_var for batch_loop_var, _ in batch_map_ranges) + f", {i}"
),
}

# Now we are creating the memlets that access the index array.
for i, dim in enumerate(start_index_map):
tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple(
data=idx_name, subset_str=(", ".join(patch_loop_vars) + f", {i}")
)

# The tasklet code.
tasklet_code = "__out = __arr[" + ", ".join(src_access_pattern) + "]"

# Now we generate the output memlet.
outpt_access_pattern: list[str] = []
dim_counter = 0
for dim in range(len(out_shape)):
if dim in batch_dims:
# This is a batch dimension, thus a loop variable is used for it.
patch_loop_var = patch_loop_vars[batch_dims.index(dim)]
outpt_access_pattern.append(str(patch_loop_var))

else:
# This is a dimension for copying the slices.
assert dim_counter <= len(src_dim_with_map_idx)
associated_map_idx = src_dim_with_map_idx[dim_counter]
dim_counter += 1
outpt_access_pattern.append(f"__i{associated_map_idx}")
assert dim_counter == len(src_dim_with_map_idx)

tasklet_outputs: dict[str, dace.Memlet] = {
"__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(outpt_access_pattern))
}
assert len(patch_map_ranges) + len(slice_map_ranges) == len(out_shape)

eqn_state.add_mapped_tasklet(
name=f"_gather_map_{out_name}",
map_ranges=patch_map_ranges + slice_map_ranges,
inputs=tasklet_inputs,
code=tasklet_code,
outputs=tasklet_outputs,
external_edges=True,
)


_ = translator.register_primitive_translator(GatherTranslator())
# The output shape is given by the combination of the non collapsed slice sizes
# and the index array (without the last dimension) with some permutation.
# Note that the relative order of slice sizes can not be changed, but they
# might be interleaved with the batch variables.
output_memlet_pattern: list[str] = []
dim_counter = 0
for dim in range(len(out_shape)):
if dim in batch_dims:
batch_loop_var = batch_map_ranges[batch_dims.index(dim)][0]
output_memlet_pattern.append(str(batch_loop_var))

else:
associated_map_idx = not_collapsed_slice_dims[dim_counter]
dim_counter += 1
output_memlet_pattern.append(f"__i{associated_map_idx}")
assert dim_counter == len(not_collapsed_slice_dims)

eqn_state.add_mapped_tasklet(
name=f"_gather_map_{out_name}",
map_ranges=batch_map_ranges + inside_window_map_ranges,
inputs=tasklet_inputs,
code="__out = __arr[" + ", ".join(src_access_pattern) + "]",
outputs={
"__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(output_memlet_pattern))
},
external_edges=True,
)

0 comments on commit cb600d3

Please sign in to comment.