diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index c73690e6f62..227c43b2128 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -39,7 +39,7 @@ def call(self, y_true, y_pred): def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None): self.name = name or auto_name(self.__class__.__name__) self.reduction = standardize_reduction(reduction) - self._dtype_policy = dtype_policies.get(dtype) + self._dtype_policy = dtype_policies.get(dtype or backend.floatx()) self._dtype = self._dtype_policy.compute_dtype @property diff --git a/keras/src/losses/loss_test.py b/keras/src/losses/loss_test.py index fd120fcb9e3..3f13bc96725 100644 --- a/keras/src/losses/loss_test.py +++ b/keras/src/losses/loss_test.py @@ -18,6 +18,16 @@ def call(self, y_true, y_pred): class LossTest(testing.TestCase): + def setUp(self): + self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy() + self._floatx = backend.floatx() + return super().setUp() + + def tearDown(self): + dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy) + backend.set_floatx(self._floatx) + return super().tearDown() + def test_squeeze_or_expand(self): x1 = ops.ones((3,)) x2 = ops.ones((3, 1)) @@ -262,3 +272,18 @@ def test_dtype_arg(self): # `dtype` setter should raise AttributeError with self.assertRaises(AttributeError): loss.dtype = "bfloat16" + + def test_default_dtype(self): + y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32") + y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32") + + # Defaults to `keras.config.floatx()` not global `dtype_policy` + dtype_policies.dtype_policy.set_dtype_policy("mixed_float16") + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, "float32") + + backend.set_floatx("float16") + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, backend.floatx()) diff --git a/keras/src/metrics/metric.py b/keras/src/metrics/metric.py index 42c9c5b2f32..b9417ece200 100644 --- a/keras/src/metrics/metric.py +++ b/keras/src/metrics/metric.py @@ -91,7 +91,7 @@ def result(self): def __init__(self, dtype=None, name=None): self.name = name or auto_name(self.__class__.__name__) - self._dtype_policy = dtype_policies.get(dtype) + self._dtype_policy = dtype_policies.get(dtype or backend.floatx()) self._dtype = self._dtype_policy.compute_dtype self._metrics = [] self._variables = [] diff --git a/keras/src/metrics/metric_test.py b/keras/src/metrics/metric_test.py index 292f4bff7ce..673ee4ea0f7 100644 --- a/keras/src/metrics/metric_test.py +++ b/keras/src/metrics/metric_test.py @@ -44,6 +44,16 @@ def reset_state(self): class MetricTest(testing.TestCase): + def setUp(self): + self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy() + self._floatx = backend.floatx() + return super().setUp() + + def tearDown(self): + dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy) + backend.set_floatx(self._floatx) + return super().tearDown() + def test_end_to_end_flow(self): metric = ExampleMetric(name="mse") self.assertEqual(metric.name, "mse") @@ -228,3 +238,20 @@ def test_dtype_arg(self): # `dtype` setter should raise AttributeError with self.assertRaises(AttributeError): metric.dtype = "bfloat16" + + def test_default_dtype(self): + y_true = np.random.random((10, 3)) + y_pred = np.random.random((10, 3)) + + # Defaults to `keras.config.floatx()` not global `dtype_policy` + dtype_policies.dtype_policy.set_dtype_policy("mixed_float16") + metric = ExampleMetric() + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertDType(result, "float32") + + backend.set_floatx("float16") + metric = ExampleMetric() + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertDType(result, backend.floatx())