Skip to content

Commit

Permalink
Fixing some import names.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed May 3, 2024
1 parent e2b496f commit c185444
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 123 deletions.
2 changes: 0 additions & 2 deletions src/jace/translator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
from __future__ import annotations

from .jaxpr_translator_driver import JaxprTranslationDriver
from .primitive_translator import PrimitiveTranslator
from .translated_jaxpr_sdfg import TranslatedJaxprSDFG


__all__ = [
"PrimitiveTranslator",
"JaxprTranslationDriver",
"TranslatedJaxprSDFG",
]
14 changes: 7 additions & 7 deletions src/jace/translator/_translation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from collections.abc import MutableMapping, Sequence

import dace
from jax import core as jcore
from jax import core as jax_core

from jace import translator as jtrans, util as jutil
from jace import translator, util


class _TranslationContext:
Expand Down Expand Up @@ -70,22 +70,22 @@ def __init__(
rev_idx: The revision index of the context.
name: Name of the SDFG object.
"""
if isinstance(name, str) and not jutil._VALID_SDFG_OBJ_NAME.fullmatch(name):
if isinstance(name, str) and not util._VALID_SDFG_OBJ_NAME.fullmatch(name):
raise ValueError(f"'{name}' is not a valid SDFG name.")

self._sdfg: dace.SDFG = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}"))
self._start_state: dace.SDFGState = self._sdfg.add_state(
label="initial_state", is_start_block=True
)
self._terminal_state: dace.SDFGState = self._start_state
self._jax_name_map: MutableMapping[jcore.Var | jutil.JaCeVar, str] = {}
self._jax_name_map: MutableMapping[jax_core.Var | util.JaCeVar, str] = {}
self._inp_names: tuple[str, ...] = ()
self._out_names: tuple[str, ...] = ()
self._rev_idx: int = rev_idx

def to_translated_jaxpr_sdfg(self) -> jtrans.TranslatedJaxprSDFG:
def to_translated_jaxpr_sdfg(self) -> translator.TranslatedJaxprSDFG:
"""Transforms `self` into a `TranslatedJaxprSDFG`."""
return jtrans.TranslatedJaxprSDFG(
return translator.TranslatedJaxprSDFG(
sdfg=self._sdfg,
start_state=self._start_state,
terminal_state=self._terminal_state,
Expand Down Expand Up @@ -114,7 +114,7 @@ def terminal_state(
self._terminal_state = new_term_state

@property
def jax_name_map(self) -> MutableMapping[jcore.Var | jutil.JaCeVar, str]:
def jax_name_map(self) -> MutableMapping[jax_core.Var | util.JaCeVar, str]:
return self._jax_name_map

@property
Expand Down
Loading

0 comments on commit c185444

Please sign in to comment.