From a23c2bb14580e16a79597262d65a2fe0aa283090 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 10 Sep 2024 21:02:47 -0700 Subject: [PATCH 1/2] `TestCase`'s `assertAllClose` and `assertAlmostEqual` now report the provided error message. (#20248) --- keras/src/testing/test_case.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/testing/test_case.py b/keras/src/testing/test_case.py index d5a8f7d779f..0d930f46dc4 100644 --- a/keras/src/testing/test_case.py +++ b/keras/src/testing/test_case.py @@ -43,7 +43,7 @@ def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): x1 = backend.convert_to_numpy(x1) if not isinstance(x2, np.ndarray): x2 = backend.convert_to_numpy(x2) - np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol) + np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg) def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): try: @@ -62,7 +62,7 @@ def assertAlmostEqual(self, x1, x2, decimal=3, msg=None): x1 = backend.convert_to_numpy(x1) if not isinstance(x2, np.ndarray): x2 = backend.convert_to_numpy(x2) - np.testing.assert_almost_equal(x1, x2, decimal=decimal) + np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg) def assertAllEqual(self, x1, x2, msg=None): self.assertEqual(len(x1), len(x2), msg=msg) From 698cc2f1486b5f80be31ce4fdcdc42a3e6b0fc91 Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Thu, 12 Sep 2024 01:23:19 +0530 Subject: [PATCH 2/2] added validation checks in Group, Layer, Batch Normalization layers (#20246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added validation checks in Group, Layer, Batch Normalization layers for the compute_output_shape function * Update batch_normalization.py * Update group_normalization.py * Update layer_normalization.py --------- Co-authored-by: François Chollet --- FE | 0 .../layers/normalization/batch_normalization.py | 12 ++++++++++++ .../layers/normalization/group_normalization.py | 12 ++++++++++++ .../layers/normalization/layer_normalization.py | 14 +++++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 FE diff --git a/FE b/FE new file mode 100644 index 00000000000..e69de29bb2d diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index 8b0160344ee..67f12cab956 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -219,6 +219,18 @@ def build(self, input_shape): self.built = True def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) return input_shape def call(self, inputs, training=None, mask=None): diff --git a/keras/src/layers/normalization/group_normalization.py b/keras/src/layers/normalization/group_normalization.py index f70fb69f3ed..c547c99a6b9 100644 --- a/keras/src/layers/normalization/group_normalization.py +++ b/keras/src/layers/normalization/group_normalization.py @@ -199,6 +199,18 @@ def _create_broadcast_shape(self, input_shape): return broadcast_shape def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) return input_shape def get_config(self): diff --git a/keras/src/layers/normalization/layer_normalization.py b/keras/src/layers/normalization/layer_normalization.py index 73a8956fd8f..6df9fd1df89 100644 --- a/keras/src/layers/normalization/layer_normalization.py +++ b/keras/src/layers/normalization/layer_normalization.py @@ -117,7 +117,7 @@ def __init__( gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, - **kwargs + **kwargs, ): super().__init__(**kwargs) if isinstance(axis, (list, tuple)): @@ -235,6 +235,18 @@ def _broadcast(v): return ops.cast(outputs, input_dtype) def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) return input_shape def get_config(self):