Skip to content

Commit

Permalink
Updated the translation context class.
Browse files Browse the repository at this point in the history
It is now less complex.
  • Loading branch information
philip-paul-mueller committed May 6, 2024
1 parent c185444 commit f139a18
Showing 1 changed file with 21 additions and 80 deletions.
101 changes: 21 additions & 80 deletions src/jace/translator/_translation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from collections.abc import MutableMapping, Sequence
from collections.abc import MutableMapping

import dace
from jax import core as jax_core
Expand Down Expand Up @@ -50,13 +50,13 @@ class _TranslationContext:
"""

__slots__ = (
"_sdfg",
"_start_state",
"_terminal_state",
"_jax_name_map",
"_inp_names",
"_out_names",
"_rev_idx",
"sdfg",
"start_state",
"terminal_state",
"jax_name_map",
"inp_names",
"out_names",
"rev_idx",
)

def __init__(
Expand All @@ -73,82 +73,23 @@ def __init__(
if isinstance(name, str) and not util._VALID_SDFG_OBJ_NAME.fullmatch(name):
raise ValueError(f"'{name}' is not a valid SDFG name.")

self._sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}"))
self._start_state: dace.SDFGState = self._sdfg.add_state(
self.sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}"))
self.start_state: dace.SDFGState = self.sdfg.add_state(
label="initial_state", is_start_block=True
)
self._terminal_state: dace.SDFGState = self._start_state
self._jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {}
self._inp_names: tuple[str, ...] = ()
self._out_names: tuple[str, ...] = ()
self._rev_idx: int = rev_idx
self.terminal_state: dace.SDFGState = self.start_state
self.jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {}
self.inp_names: tuple[str, ...] = ()
self.out_names: tuple[str, ...] = ()
self.rev_idx: int = rev_idx

def to_translated_jaxpr_sdfg(self) -> translator.TranslatedJaxprSDFG:
"""Transforms `self` into a `TranslatedJaxprSDFG`."""
return translator.TranslatedJaxprSDFG(
sdfg=self._sdfg,
start_state=self._start_state,
terminal_state=self._terminal_state,
jax_name_map=self._jax_name_map,
inp_names=self._inp_names,
out_names=self._out_names,
sdfg=self.sdfg,
start_state=self.start_state,
terminal_state=self.terminal_state,
jax_name_map=self.jax_name_map,
inp_names=self.inp_names,
out_names=self.out_names,
)

@property
def sdfg(self) -> dace.SDFG:
return self._sdfg

@property
def start_state(self) -> dace.SDFGState:
return self._start_state

@property
def terminal_state(self) -> dace.SDFGState:
return self._terminal_state

@terminal_state.setter
def terminal_state(
self,
new_term_state: dace.SDFGState,
) -> None:
self._terminal_state = new_term_state

@property
def jax_name_map(self) -> MutableMapping[jax_core.Var | util.JaCeVar, str]:
return self._jax_name_map

@property
def inp_names(self) -> tuple[str, ...]:
return self._inp_names

@inp_names.setter
def inp_names(
self,
inp_names: Sequence[str],
) -> None:
if isinstance(inp_names, str):
self._inp_names = (inp_names,)
elif isinstance(inp_names, tuple):
self._inp_names = inp_names
else:
self._inp_names = tuple(inp_names)

@property
def out_names(self) -> tuple[str, ...]:
return self._out_names

@out_names.setter
def out_names(
self,
out_names: Sequence[str],
) -> None:
if isinstance(out_names, str):
self._out_names = (out_names,)
elif isinstance(out_names, tuple):
self._out_names = out_names
else:
self._out_names = tuple(out_names)

@property
def rev_idx(self) -> int:
return self._rev_idx

0 comments on commit f139a18

Please sign in to comment.