diff --git a/src/jace/util/debug.py b/src/jace/util/debug.py index 27c57aa..23aee23 100644 --- a/src/jace/util/debug.py +++ b/src/jace/util/debug.py @@ -63,7 +63,7 @@ def run_jax_sdfg( # Canonical SDFGs do not have global memory, so we must transform it. # We will afterwards undo it. - for glob_name in jsdfg.inp_names + jsdfg.out_names: # type: ignore[operator] # concatenation + for glob_name in jsdfg.inp_names + jsdfg.out_names: jsdfg.sdfg.arrays[glob_name].transient = False try: @@ -80,7 +80,7 @@ def run_jax_sdfg( return ret_val finally: - for name in jsdfg.inp_names + jsdfg.out_names: # type: ignore[operator] # concatenation + for name in jsdfg.inp_names + jsdfg.out_names: jsdfg.sdfg.arrays[name].transient = True @@ -94,6 +94,9 @@ def _jace_run( Args: *args: Forwarded to the tracing and final execution of the SDFG. **kwargs: Used to construct the driver. + + Notes: + This function will be removed soon. """ jaxpr = jax.make_jaxpr(fun)(*args) driver = translator.JaxprTranslationDriver(**kwargs) diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index fb4619a..0ea37a1 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -53,10 +53,6 @@ def __eq__(self, other: Any) -> bool: return NotImplemented return id(self) == id(other) - def __post_init__(self) -> None: - if not isinstance(self.shape, tuple): - raise ValueError("The 'shape' member of a 'JaCeVar' must be a tuple.") - def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar | str) -> str: """Returns the name of the Jax variable as a string.