From f139a180a7e15f373602f2d1f8f445fc870a8df0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 6 May 2024 07:19:58 +0200 Subject: [PATCH] Updated the translation context class. It is now less complex. --- src/jace/translator/_translation_context.py | 101 ++++---------------- 1 file changed, 21 insertions(+), 80 deletions(-) diff --git a/src/jace/translator/_translation_context.py b/src/jace/translator/_translation_context.py index 6d9c43f..15dd47f 100644 --- a/src/jace/translator/_translation_context.py +++ b/src/jace/translator/_translation_context.py @@ -9,7 +9,7 @@ from __future__ import annotations -from collections.abc import MutableMapping, Sequence +from collections.abc import MutableMapping import dace from jax import core as jax_core @@ -50,13 +50,13 @@ class _TranslationContext: """ __slots__ = ( - "_sdfg", - "_start_state", - "_terminal_state", - "_jax_name_map", - "_inp_names", - "_out_names", - "_rev_idx", + "sdfg", + "start_state", + "terminal_state", + "jax_name_map", + "inp_names", + "out_names", + "rev_idx", ) def __init__( @@ -73,82 +73,23 @@ def __init__( if isinstance(name, str) and not util._VALID_SDFG_OBJ_NAME.fullmatch(name): raise ValueError(f"'{name}' is not a valid SDFG name.") - self._sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}")) - self._start_state: dace.SDFGState = self._sdfg.add_state( + self.sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}")) + self.start_state: dace.SDFGState = self.sdfg.add_state( label="initial_state", is_start_block=True ) - self._terminal_state: dace.SDFGState = self._start_state - self._jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {} - self._inp_names: tuple[str, ...] = () - self._out_names: tuple[str, ...] = () - self._rev_idx: int = rev_idx + self.terminal_state: dace.SDFGState = self.start_state + self.jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {} + self.inp_names: tuple[str, ...] = () + self.out_names: tuple[str, ...] = () + self.rev_idx: int = rev_idx def to_translated_jaxpr_sdfg(self) -> translator.TranslatedJaxprSDFG: """Transforms `self` into a `TranslatedJaxprSDFG`.""" return translator.TranslatedJaxprSDFG( - sdfg=self._sdfg, - start_state=self._start_state, - terminal_state=self._terminal_state, - jax_name_map=self._jax_name_map, - inp_names=self._inp_names, - out_names=self._out_names, + sdfg=self.sdfg, + start_state=self.start_state, + terminal_state=self.terminal_state, + jax_name_map=self.jax_name_map, + inp_names=self.inp_names, + out_names=self.out_names, ) - - @property - def sdfg(self) -> dace.SDFG: - return self._sdfg - - @property - def start_state(self) -> dace.SDFGState: - return self._start_state - - @property - def terminal_state(self) -> dace.SDFGState: - return self._terminal_state - - @terminal_state.setter - def terminal_state( - self, - new_term_state: dace.SDFGState, - ) -> None: - self._terminal_state = new_term_state - - @property - def jax_name_map(self) -> MutableMapping[jax_core.Var | util.JaCeVar, str]: - return self._jax_name_map - - @property - def inp_names(self) -> tuple[str, ...]: - return self._inp_names - - @inp_names.setter - def inp_names( - self, - inp_names: Sequence[str], - ) -> None: - if isinstance(inp_names, str): - self._inp_names = (inp_names,) - elif isinstance(inp_names, tuple): - self._inp_names = inp_names - else: - self._inp_names = tuple(inp_names) - - @property - def out_names(self) -> tuple[str, ...]: - return self._out_names - - @out_names.setter - def out_names( - self, - out_names: Sequence[str], - ) -> None: - if isinstance(out_names, str): - self._out_names = (out_names,) - elif isinstance(out_names, tuple): - self._out_names = out_names - else: - self._out_names = tuple(out_names) - - @property - def rev_idx(self) -> int: - return self._rev_idx