From 096b848e63072328f8812a45061a239f52ba303c Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Fri, 19 Jan 2024 17:36:07 -0800 Subject: [PATCH] Allow bfloat16 default dtype (#19074) Useful for llms! The tradeoff in precision can often be worth it in memory constrained environments, and unlike float16, does not have the same overflow/underflow issues during training. --- keras/backend/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/backend/config.py b/keras/backend/config.py index 653291eb3f0..8725e27b56f 100644 --- a/keras/backend/config.py +++ b/keras/backend/config.py @@ -20,7 +20,7 @@ def floatx(): """Return the default float type, as a string. - E.g. `'float16'`, `'float32'`, `'float64'`. + E.g. `'bfloat16'`, `'float16'`, `'float32'`, `'float64'`. Returns: String, the current default float type. @@ -45,7 +45,7 @@ def set_floatx(value): `keras.mixed_precision.set_dtype_policy('mixed_float16')`. Args: - value: String; `'float16'`, `'float32'`, or `'float64'`. + value: String; `'bfloat16'`, `'float16'`, `'float32'`, or `'float64'`. Examples: >>> keras.config.floatx() @@ -62,7 +62,7 @@ def set_floatx(value): ValueError: In case of invalid value. """ global _FLOATX - accepted_dtypes = {"float16", "float32", "float64"} + accepted_dtypes = {"bfloat16", "float16", "float32", "float64"} if value not in accepted_dtypes: raise ValueError( f"Unknown `floatx` value: {value}. "