Skip to content

Commit

Permalink
Small updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed May 13, 2024
1 parent 75d0823 commit d39ba9b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/jace/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_jax_sdfg(

# Canonical SDFGs do not have global memory, so we must transform it.
# We will afterwards undo it.
for glob_name in jsdfg.inp_names + jsdfg.out_names: # type: ignore[operator] # concatenation
for glob_name in jsdfg.inp_names + jsdfg.out_names:
jsdfg.sdfg.arrays[glob_name].transient = False

try:
Expand All @@ -80,7 +80,7 @@ def run_jax_sdfg(
return ret_val

finally:
for name in jsdfg.inp_names + jsdfg.out_names: # type: ignore[operator] # concatenation
for name in jsdfg.inp_names + jsdfg.out_names:
jsdfg.sdfg.arrays[name].transient = True


Expand All @@ -94,6 +94,9 @@ def _jace_run(
Args:
*args: Forwarded to the tracing and final execution of the SDFG.
**kwargs: Used to construct the driver.
Notes:
This function will be removed soon.
"""
jaxpr = jax.make_jaxpr(fun)(*args)
driver = translator.JaxprTranslationDriver(**kwargs)
Expand Down
4 changes: 0 additions & 4 deletions src/jace/util/jax_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented
return id(self) == id(other)

def __post_init__(self) -> None:
if not isinstance(self.shape, tuple):
raise ValueError("The 'shape' member of a 'JaCeVar' must be a tuple.")


def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar | str) -> str:
"""Returns the name of the Jax variable as a string.
Expand Down

0 comments on commit d39ba9b

Please sign in to comment.