Skip to content

Commit

Permalink
Some more corrections.
Browse files Browse the repository at this point in the history
Now let's test if it works.
  • Loading branch information
philip-paul-mueller committed Sep 24, 2024
1 parent cb600d3 commit c29fc0d
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 96 deletions.
50 changes: 24 additions & 26 deletions src/jace/translator/mapped_operation_base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause

"""Module containing all translators related to arithmetic logical operations."""
"""Module implementing the `MappedOperationTranslatorBase` helper class."""

from __future__ import annotations

Expand Down Expand Up @@ -37,8 +37,9 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator):
```
where `__in*` are the connector names of the Tasklet and `__out` is the
output connector. For problems such as this, the SDFG API provides the
`SDFGState.add_mapped_tasklet()` function, however, because it is very low
level and very verbose to use, this class acts as a convenience wrapper around it.
`SDFGState.add_mapped_tasklet()` function. However, because the function
operates on a very low level and is very verbose to use, this class acts
as a convenience wrapper around it.
To use this class a user has to define the abstract `write_tasklet_code()` method.
This function generates the entire code that should be put into the Tasklet,
Expand Down Expand Up @@ -160,8 +161,8 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation object.
"""
out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output
out_rank = len(out_shp)
out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0]))
out_rank = len(out_shape)
if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars):
raise NotImplementedError(
f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! "
Expand All @@ -170,29 +171,26 @@ def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need

# Now we will generate the input Memlets.
tskl_inputs: dict[str, dace.Memlet] = {}
for i, (in_var_name, inp_shp) in enumerate(
for i, (in_var_name, in_shape) in enumerate(
zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars))
):
if in_var_name is None: # Input is a literal: No Memlet needed
continue

if inp_shp == (): # Scalars
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar
continue

# We might have to do broadcasting.
# We ensured that input and output have the same rank (JAX is doing that
# for us). So we must do broadcasting, i.e. replicating that input
# dimension, if its size is 1. We threat the case where the output has
# size 1 in that dimension as broadcasting as well.
dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1]
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(
in_var_name,
", ".join(
("0" if i in dims_to_bcast else it_var)
for i, (it_var, _) in enumerate(tskl_ranges)
),
)
if in_var_name is None:
pass

elif in_shape == ():
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0")

else:
dims_to_bcast = [
dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1
]
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(
in_var_name,
", ".join(
("0" if i in dims_to_bcast else it_var)
for i, (it_var, _) in enumerate(tskl_ranges)
),
)
return tskl_inputs

def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,12 @@ def write_tasklet_code(
in_var_names: Sequence[str | None],
eqn: jax_core.JaxprEqn,
) -> str:
return (
self._bool_tmpl
if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars)
else self._int_tmpl
)
if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars):
return self._bool_tmpl
return self._int_tmpl


# Maps the name of an arithmetic primitives to the code template that is used to
# Maps the name of an arithmetic JAX primitive to the code template that is used to
# generate the body of the mapped tasklet. These are used to instantiate the
# `ArithmeticOperationTranslator` objects.
# fmt: off
Expand Down Expand Up @@ -175,9 +173,9 @@ def write_tasklet_code(
}


# Maps the name of a logical primitive to the two code templates (first the integer
# case and second the boolean case) used to create the body of the mapped tasklet.
# They are used to instantiate the `LogicalOperationTranslator` translators.
# Maps the name of a logical primitive to the two code templates, first the integer
# case and second the boolean case, that are used to create the body of the mapped
# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators.
_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = {
"or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"),
"not": ("__out = ~(__in0)", "__out = not (__in0)"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@translator.register_primitive_translator()
@translator.make_primitive_translator("concatenate")
def concatenate_translator(
builder: translator.JaxprTranslationBuilder, # noqa: ARG001 # Required by the interface.
builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface.
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jax_core.JaxprEqn,
Expand All @@ -36,6 +36,15 @@ def concatenate_translator(
Each source array is copied by its own map, but all maps write to the same
access node.
Args:
builder: The builder object of the translation; unused.
in_var_names: The SDFG variables used an input arguments in order as they
should be concatenated.
out_var_names: Names of SDFG variables that should be used as outputs.
eqn: The equation that should be translated, the concatenation dimensions
is read from the `dimension` parameter.
eqn_state: State into which the nested SDFG should be constructed.
"""
if any(in_var_name is None for in_var_name in in_var_names):
raise NotImplementedError("Concatenate: No literal inputs supported.")
Expand Down
25 changes: 14 additions & 11 deletions src/jace/translator/primitive_translators/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,19 @@ def condition_translator(
Args:
builder: The builder object of the translation.
in_var_names: The SDFG variables used an input arguments. First is the index,
the variable that selects the branch, the remaining ones are passed as
inputs to the branches.
in_var_names: The SDFG variables used an input arguments. First is the
selection variable. The remaining ones are passed to the branches as
inputs.
out_var_names: Names of SDFG variables that should be used as outputs.
eqn: The equation that should be translated.
eqn_state: State into which the nested SDFG should be constructed.
Notes:
According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html)
the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional)
an out of range selector uses the last branch. JaCe conforms to JAX semantic.
After this function the terminal state of the `builder` is unspecific.
- According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html)
the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional)
an out of range selector uses the last branch. JaCe conforms to JAX
semantic.
- After this function the terminal state of the `builder` is unspecific.
"""
if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_:
# XLA explicitly provides a binary form of the primitive
Expand All @@ -61,7 +62,7 @@ def condition_translator(
# integer implementation.
raise NotImplementedError("The boolean conditional primitive is not implemented.")

# To make names in the (nested) SDFG unique we use the name of the equation state
# Used as prefix to give all additional states/variables a unique name.
name_pattern = eqn_state.name

# To avoid special cases promote all symbols to constants.
Expand All @@ -80,9 +81,7 @@ def condition_translator(
literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0]))
selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))"
selection_state = eqn_state

else:
# Promotion of a scalar to a symbol through a state transition.
selection_variable_name = in_var_names[0]
selection_symbol = f"{selection_variable_name}_symb"
selection_state = builder.append_new_state(
Expand All @@ -93,12 +92,15 @@ def condition_translator(
prev_state=eqn_state,
)

# Translate the subbranches, the branches are all connected from `selection_state`.
branch_states: list[dace.SDFGState] = []
for i, branch_jaxpr in enumerate(branches):
branch_pattern = f"{name_pattern}_{{}}_branch_{i}"
branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg"))

# This will update the terminal state only for the first branch
# The first time it is called it will update the builder's terminal state
# but since we will return the join state it will be updated later. But
# until then the terminal state of the builder is invalid.
branch_state = builder.append_new_state(
label=branch_pattern.format("state"),
condition=f"{selection_symbol} == {i}",
Expand All @@ -113,6 +115,7 @@ def condition_translator(
)
branch_states.append(branch_state)

# Connect all branch states to the join state
join_state = builder.add_orphan_state(f"{name_pattern}__join_state")
for branch_state in branch_states:
builder.sdfg.add_edge(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def write_tasklet_code(

if in_dtype == out_dtype:
# JAX sometimes adds conversions which are not needed. In these cases
# we perform a copy.
# make a copy out of it.
# TODO(phimuell): Create a Memlet instead.
return "__out = __in0"

Expand Down
27 changes: 22 additions & 5 deletions src/jace/translator/primitive_translators/copy_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,31 @@ def copy_translator(
builder: translator.JaxprTranslationBuilder,
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jax_core.JaxprEqn, # noqa: ARG001 # Required by the interface.
eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] # Required by the interface.
eqn_state: dace.SDFGState,
) -> None:
"""
Implements the `copy` primitive.
The copy is implemented by creating a memlet between the source and destination.
Args:
builder: The builder object of the translation.
in_var_names: The SDFG variable that acts as source.
out_var_names: The SDFG variable that acts as destination of the copy.
eqn: The equation that should be translated; unused.
eqn_state: State into which the nested SDFG should be constructed.
Todo:
Investigate if operation should expand to a map.
"""
assert in_var_names[0] is not None
eqn_state.add_nedge(
eqn_state.add_read(in_var_names[0]),
eqn_state.add_write(out_var_names[0]),
dace.Memlet.from_array(
in_var_names[0],
builder.arrays[in_var_names[0]], # type: ignore[index] # Guaranteed to be a string
builder.arrays[in_var_names[0]],
),
)

Expand All @@ -60,9 +70,16 @@ def device_put_translator(
Implements the `device_put` primitive.
In JAX this primitive is used to copy data between the host and the device,
in DaCe Memlets can do this. However, because of the way JaCe operates, at
least in the beginning a computation is either fully on the host or on the
device this copy will essentially perform a copying.
in DaCe only memlets can do this. However, because of the way JaCe (currently)
operates (a computation is either fully on the host or on GPU), the `device_put`
primitive essentially decays to a copy.
Args:
builder: The builder object of the translation.
in_var_names: The SDFG variable that acts as source.
out_var_names: The SDFG variable that acts as destination of the copy.
eqn: The equation that should be translated.
eqn_state: State into which the nested SDFG should be constructed.
"""
if not (eqn.params["device"] is None and eqn.params["src"] is None):
raise NotImplementedError(
Expand Down
Loading

0 comments on commit c29fc0d

Please sign in to comment.