Skip to content

Commit

Permalink
Merge pull request #4186 from IvyZX:errors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673510178
  • Loading branch information
Flax Authors committed Sep 11, 2024
2 parents 61ea8a6 + e8b0a1f commit b967964
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 24 deletions.
9 changes: 9 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


#################################################
# lazy_init.py errors #
#################################################
Expand Down
17 changes: 0 additions & 17 deletions flax/nnx/errors.py

This file was deleted.

2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import numpy as np

from flax.nnx import (
errors,
reprlib,
tracers,
)
from flax.nnx import graph
from flax.nnx.variables import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')

Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import jax

from flax import nnx
from flax import errors
from flax.nnx import reprlib, tracers
from flax.typing import Missing
import jax.tree_util as jtu
Expand Down Expand Up @@ -235,7 +235,7 @@ def __setattr__(self, name: str, value: Any) -> None:

def _setattr(self, name: str, value: tp.Any):
if not self._trace_state.is_valid():
raise nnx.errors.TraceContextError(
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)

Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, TypeVar

from absl.testing import absltest
from flax import nnx
from flax import nnx, errors
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -39,7 +39,7 @@ def test_trace_level(self):
@jax.jit
def f():
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
):
m.a = 2
Expand Down
5 changes: 3 additions & 2 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl.testing import absltest

from flax import nnx
from flax import errors


class TestRngs(absltest.TestCase):
Expand Down Expand Up @@ -58,7 +59,7 @@ def test_rng_trace_level_constraints(self):
@jax.jit
def f():
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
'Cannot call RngStream from a different trace level',
):
rngs.params()
Expand All @@ -76,7 +77,7 @@ def h():

self.assertIsInstance(rngs1, nnx.Rngs)
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
'Cannot call RngStream from a different trace level',
):
rngs1.params()
Expand Down

0 comments on commit b967964

Please sign in to comment.