Skip to content

Commit

Permalink
Revert to the original default behavior of dtype (keras-team#20014)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Jul 19, 2024
1 parent 556b34f commit 8ca7cd8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
2 changes: 1 addition & 1 deletion keras/src/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def call(self, y_true, y_pred):
def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None):
self.name = name or auto_name(self.__class__.__name__)
self.reduction = standardize_reduction(reduction)
self._dtype_policy = dtype_policies.get(dtype)
self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
self._dtype = self._dtype_policy.compute_dtype

@property
Expand Down
25 changes: 25 additions & 0 deletions keras/src/losses/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def call(self, y_true, y_pred):


class LossTest(testing.TestCase):
def setUp(self):
self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy()
self._floatx = backend.floatx()
return super().setUp()

def tearDown(self):
dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy)
backend.set_floatx(self._floatx)
return super().tearDown()

def test_squeeze_or_expand(self):
x1 = ops.ones((3,))
x2 = ops.ones((3, 1))
Expand Down Expand Up @@ -262,3 +272,18 @@ def test_dtype_arg(self):
# `dtype` setter should raise AttributeError
with self.assertRaises(AttributeError):
loss.dtype = "bfloat16"

def test_default_dtype(self):
y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32")
y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32")

# Defaults to `keras.config.floatx()` not global `dtype_policy`
dtype_policies.dtype_policy.set_dtype_policy("mixed_float16")
loss_fn = ExampleLoss()
loss = loss_fn(y_true, y_pred)
self.assertDType(loss, "float32")

backend.set_floatx("float16")
loss_fn = ExampleLoss()
loss = loss_fn(y_true, y_pred)
self.assertDType(loss, backend.floatx())
2 changes: 1 addition & 1 deletion keras/src/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def result(self):

def __init__(self, dtype=None, name=None):
self.name = name or auto_name(self.__class__.__name__)
self._dtype_policy = dtype_policies.get(dtype)
self._dtype_policy = dtype_policies.get(dtype or backend.floatx())
self._dtype = self._dtype_policy.compute_dtype
self._metrics = []
self._variables = []
Expand Down
27 changes: 27 additions & 0 deletions keras/src/metrics/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def reset_state(self):


class MetricTest(testing.TestCase):
def setUp(self):
self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy()
self._floatx = backend.floatx()
return super().setUp()

def tearDown(self):
dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy)
backend.set_floatx(self._floatx)
return super().tearDown()

def test_end_to_end_flow(self):
metric = ExampleMetric(name="mse")
self.assertEqual(metric.name, "mse")
Expand Down Expand Up @@ -228,3 +238,20 @@ def test_dtype_arg(self):
# `dtype` setter should raise AttributeError
with self.assertRaises(AttributeError):
metric.dtype = "bfloat16"

def test_default_dtype(self):
y_true = np.random.random((10, 3))
y_pred = np.random.random((10, 3))

# Defaults to `keras.config.floatx()` not global `dtype_policy`
dtype_policies.dtype_policy.set_dtype_policy("mixed_float16")
metric = ExampleMetric()
metric.update_state(y_true, y_pred)
result = metric.result()
self.assertDType(result, "float32")

backend.set_floatx("float16")
metric = ExampleMetric()
metric.update_state(y_true, y_pred)
result = metric.result()
self.assertDType(result, backend.floatx())

0 comments on commit 8ca7cd8

Please sign in to comment.