From 3ee8dad963fb1f45a26622ffc9e896c505ef164d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 24 Sep 2024 14:37:42 +0200 Subject: [PATCH] Made it run again. --- src/jace/translator/jaxpr_translator_builder.py | 3 +-- .../primitive_translators/broadcast_in_dim_translator.py | 2 +- src/jace/translator/primitive_translators/conditions.py | 4 ++-- .../translator/primitive_translators/gather_translator.py | 3 --- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index c82c277..288593f 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -598,10 +598,9 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: prev_terminal_state, new_sdfg_term_state, ) - self._ctx.validate() - # Modify terminal root state of 'self' self._ctx.terminal_state = new_sdfg_term_state + self._ctx.validate() def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext: """ diff --git a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py index 964a2f6..d8bd388 100644 --- a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py +++ b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py @@ -56,7 +56,7 @@ def make_input_memlets( subset_str = ( ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) if eqn.params["broadcast_dimensions"] - else "0", + else "0" ) return {"__in0": dace.Memlet.simple(in_var_names[0], subset_str)} diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py index 6e37a7a..945baf1 100644 --- a/src/jace/translator/primitive_translators/conditions.py +++ b/src/jace/translator/primitive_translators/conditions.py @@ -79,7 +79,7 @@ def condition_translator( # Make sure that the selection variable is a DaCe symbol. if in_var_names[0] is None: literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) - selection_symbol = f"max({len(branches)}, min(0, {literal_selection_value}))" + selection_symbol = f"min({len(branches)}, max(0, {literal_selection_value}))" selection_state = eqn_state else: selection_variable_name = in_var_names[0] @@ -87,7 +87,7 @@ def condition_translator( selection_state = builder.append_new_state( label=f"{name_pattern}_fork", assignments={ - selection_symbol: f"max({len(branches)}, min(0, {selection_variable_name}[0]))" + selection_symbol: f"min({len(branches)}, max(0, {selection_variable_name}))" }, prev_state=eqn_state, ) diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py index daacb56..4f459d9 100644 --- a/src/jace/translator/primitive_translators/gather_translator.py +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -123,14 +123,12 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(inside_window_map_ranges[-1][0]) assert dim in not_collapsed_slice_dims - assert dim not in batch_dims elif dim in collapsed_slice_dims: # This dimension is only partially copied, but because it is collapsed, # only a single element is copied. Thus the offset is only given by the # what we read from the index array. src_access_pattern.append(f"__gather_{dim}") - assert dim in batch_dims else: # This dimension is partially copied, but _not colapsed_. This the element @@ -138,7 +136,6 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any # current position within the slicing window. inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") - assert dim in batch_dims assert dim in not_collapsed_slice_dims # These are the map variables that are associated to the first implicit loop (the