Skip to content

Commit

Permalink
Merge branch 'master' of github.com:keras-team/keras
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 11, 2024
2 parents c20c69a + 698cc2f commit 359ef98
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 3 deletions.
Empty file added FE
Empty file.
12 changes: 12 additions & 0 deletions keras/src/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def call(self, inputs, training=None, mask=None):
Expand Down
12 changes: 12 additions & 0 deletions keras/src/layers/normalization/group_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def _create_broadcast_shape(self, input_shape):
return broadcast_shape

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
Expand Down
14 changes: 13 additions & 1 deletion keras/src/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
if isinstance(axis, (list, tuple)):
Expand Down Expand Up @@ -235,6 +235,18 @@ def _broadcast(v):
return ops.cast(outputs, input_dtype)

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
Expand Down
4 changes: 2 additions & 2 deletions keras/src/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
x1 = backend.convert_to_numpy(x1)
if not isinstance(x2, np.ndarray):
x2 = backend.convert_to_numpy(x2)
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg)

def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
try:
Expand All @@ -62,7 +62,7 @@ def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
x1 = backend.convert_to_numpy(x1)
if not isinstance(x2, np.ndarray):
x2 = backend.convert_to_numpy(x2)
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg)

def assertAllEqual(self, x1, x2, msg=None):
self.assertEqual(len(x1), len(x2), msg=msg)
Expand Down

0 comments on commit 359ef98

Please sign in to comment.