From e8b0a1fd741eb380fa6ae11fba1c91091b10e5f8 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Tue, 10 Sep 2024 15:52:06 -0700 Subject: [PATCH] Merge nnx.errors to flax.errors --- flax/errors.py | 9 +++++++++ flax/nnx/errors.py | 17 ----------------- flax/nnx/object.py | 2 +- flax/nnx/variables.py | 4 ++-- tests/nnx/module_test.py | 4 ++-- tests/nnx/rngs_test.py | 5 +++-- 6 files changed, 17 insertions(+), 24 deletions(-) delete mode 100644 flax/nnx/errors.py diff --git a/flax/errors.py b/flax/errors.py index 7284c6e3fb..b2ecd1be69 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -64,6 +64,15 @@ def __reduce__(self): return (FlaxError, (str(self),)) +################################################# +# NNX errors # +################################################# + + +class TraceContextError(FlaxError): + pass + + ################################################# # lazy_init.py errors # ################################################# diff --git a/flax/nnx/errors.py b/flax/nnx/errors.py deleted file mode 100644 index 41c7d4fab5..0000000000 --- a/flax/nnx/errors.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TraceContextError(Exception): - pass diff --git a/flax/nnx/object.py b/flax/nnx/object.py index 9e14155108..f2714ff7fd 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -25,13 +25,13 @@ import numpy as np from flax.nnx import ( - errors, reprlib, tracers, ) from flax.nnx import graph from flax.nnx.variables import Variable, VariableState from flax.typing import Key +from flax import errors G = tp.TypeVar('G', bound='Object') diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index 76805477f5..645e6aac68 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -22,7 +22,7 @@ import jax -from flax import nnx +from flax import errors from flax.nnx import reprlib, tracers from flax.typing import Missing import jax.tree_util as jtu @@ -259,7 +259,7 @@ def __setattr__(self, name: str, value: Any) -> None: def _setattr(self, name: str, value: tp.Any): if not self._trace_state.is_valid(): - raise nnx.errors.TraceContextError( + raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index d5aeae08cd..a3f7bf8c22 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -17,7 +17,7 @@ from typing import Any, TypeVar from absl.testing import absltest -from flax import nnx +from flax import nnx, errors import jax import jax.numpy as jnp import numpy as np @@ -39,7 +39,7 @@ def test_trace_level(self): @jax.jit def f(): with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, "Cannot mutate 'Dict' from different trace level", ): m.a = 2 diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 0e42918264..eeb65ccaed 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -21,6 +21,7 @@ from absl.testing import absltest from flax import nnx +from flax import errors class TestRngs(absltest.TestCase): @@ -58,7 +59,7 @@ def test_rng_trace_level_constraints(self): @jax.jit def f(): with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs.params() @@ -76,7 +77,7 @@ def h(): self.assertIsInstance(rngs1, nnx.Rngs) with self.assertRaisesRegex( - nnx.errors.TraceContextError, + errors.TraceContextError, 'Cannot call RngStream from a different trace level', ): rngs1.params()