Skip to content

Commit

Permalink
Fixed some errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Sep 24, 2024
1 parent c29fc0d commit 846a345
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 8 deletions.
3 changes: 1 addition & 2 deletions src/jace/translator/jaxpr_translator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,9 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None:
prev_terminal_state,
new_sdfg_term_state,
)
self._ctx.validate()

# Modify terminal root state of 'self'
self._ctx.terminal_state = new_sdfg_term_state
self._ctx.validate()

def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_input_memlets(
subset_str = (
", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"])
if eqn.params["broadcast_dimensions"]
else "0",
else "0"
)
return {"__in0": dace.Memlet.simple(in_var_names[0], subset_str)}

Expand Down
4 changes: 2 additions & 2 deletions src/jace/translator/primitive_translators/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def condition_translator(
# Make sure that the selection variable is a DaCe symbol.
if in_var_names[0] is None:
literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0]))
selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))"
selection_symbol = f"min({len(branches)}, max(0, {literal_selection_value}))"
selection_state = eqn_state
else:
selection_variable_name = in_var_names[0]
selection_symbol = f"{selection_variable_name}_symb"
selection_state = builder.append_new_state(
label=f"{name_pattern}_fork",
assignments={
selection_symbol: f"max({len(branches)}, min(0, {selection_variable_name}[0]))"
selection_symbol: f"min({len(branches)}, max(0, {selection_variable_name}))"
},
prev_state=eqn_state,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,19 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any
inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(inside_window_map_ranges[-1][0])
assert dim in not_collapsed_slice_dims
assert dim not in batch_dims

elif dim in collapsed_slice_dims:
# This dimension is only partially copied, but because it is collapsed,
# only a single element is copied. Thus the offset is only given by the
# what we read from the index array.
src_access_pattern.append(f"__gather_{dim}")
assert dim in batch_dims

else:
# This dimension is partially copied, but _not colapsed_. This the element
# that is read depends on the (static) offset of this window and the
# current position within the slicing window.
inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}"))
src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}")
assert dim in batch_dims
assert dim in not_collapsed_slice_dims

# These are the map variables that are associated to the first implicit loop (the
Expand Down

0 comments on commit 846a345

Please sign in to comment.