diff --git a/src/jace/translator/jaxpr_translator_driver.py b/src/jace/translator/jaxpr_translator_driver.py index 6ce6e4c..2fac9c4 100644 --- a/src/jace/translator/jaxpr_translator_driver.py +++ b/src/jace/translator/jaxpr_translator_driver.py @@ -104,7 +104,6 @@ def translate_jaxpr( name: str | None = None, reserved_names: str | Collection[str] | None = None, allow_empty_jaxpr: bool = False, - **kwargs: Any, ) -> translator.TranslatedJaxprSDFG: """Perform the translation of a Jaxpr into a SDFG. @@ -136,10 +135,6 @@ def translate_jaxpr( if not jax.config.read("jax_enable_x64"): raise NotImplementedError("The translation only works if 'jax_enable_x64' is enabled.") - # The point of this flag is, that one can have the translator, but still have access - # the the function of self, such as `add_array()` (is needed in later stages). - _clear_translation_ctx: bool = kwargs.pop("_clear_translation_ctx", True) - # NOTE: If `self` is already allocated, i.e. has an ongoing translation process, # the `_allocate_translation_ctx()` function will start a new context. # Thus the driver will start to translate a second (nested) SDFG. @@ -158,8 +153,7 @@ def translate_jaxpr( ) # Note that `self` and `jsdfg` still share the same underlying memory, i.e. context. jsdfg: translator.TranslatedJaxprSDFG = self._translate_jaxpr_internal(jaxpr) - if _clear_translation_ctx: - self._clear_translation_ctx() + self._clear_translation_ctx() return jsdfg