Skip to content

Commit

Permalink
Allow bfloat16 default dtype (#19074)
Browse files Browse the repository at this point in the history
Useful for llms! The tradeoff in precision can often be worth it in
memory constrained environments, and unlike float16, does not have
the same overflow/underflow issues during training.
  • Loading branch information
mattdangerw committed Jan 20, 2024
1 parent 2feb430 commit 096b848
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions keras/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def floatx():
"""Return the default float type, as a string.
E.g. `'float16'`, `'float32'`, `'float64'`.
E.g. `'bfloat16'`, `'float16'`, `'float32'`, `'float64'`.
Returns:
String, the current default float type.
Expand All @@ -45,7 +45,7 @@ def set_floatx(value):
`keras.mixed_precision.set_dtype_policy('mixed_float16')`.
Args:
value: String; `'float16'`, `'float32'`, or `'float64'`.
value: String; `'bfloat16'`, `'float16'`, `'float32'`, or `'float64'`.
Examples:
>>> keras.config.floatx()
Expand All @@ -62,7 +62,7 @@ def set_floatx(value):
ValueError: In case of invalid value.
"""
global _FLOATX
accepted_dtypes = {"float16", "float32", "float64"}
accepted_dtypes = {"bfloat16", "float16", "float32", "float64"}
if value not in accepted_dtypes:
raise ValueError(
f"Unknown `floatx` value: {value}. "
Expand Down

0 comments on commit 096b848

Please sign in to comment.