From f4dcd809b3891f161cf8c000b2e203001a5f5d55 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 28 Jun 2024 13:54:16 +0100 Subject: [PATCH] Renaming `cast_ml_dtype` methods to `reduce_precision_dtype`. Avoiding confusion with proper `cast/astype` methods when using FP8 hardware. --- .../mnist/mnist_classifier_from_scratch_fp8.py | 16 ++++++++-------- examples/scalify-quickstart.ipynb | 4 ++-- jax_scalify/ops/__init__.py | 2 +- jax_scalify/ops/ml_dtypes.py | 10 +++++----- tests/ops/test_ml_dtypes.py | 10 +++++----- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/mnist/mnist_classifier_from_scratch_fp8.py b/examples/mnist/mnist_classifier_from_scratch_fp8.py index 0de9219..78da872 100644 --- a/examples/mnist/mnist_classifier_from_scratch_fp8.py +++ b/examples/mnist/mnist_classifier_from_scratch_fp8.py @@ -59,18 +59,18 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): def predict(params, inputs, use_fp8=True): - cast_ml_dtype = jsa.ops.cast_ml_dtype if use_fp8 else lambda x, d: x - cast_ml_dtype_grad = jsa.ops.cast_ml_dtype_grad if use_fp8 else lambda x, d: x + reduce_precision_dtype = jsa.ops.reduce_precision_dtype if use_fp8 else lambda x, d: x + reduce_precision_dtype_grad = jsa.ops.reduce_precision_dtype_grad if use_fp8 else lambda x, d: x activations = inputs for w, b in params[:-1]: # Forward FP8 casting. - w = cast_ml_dtype(w, ml_dtypes.float8_e4m3fn) - activations = cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn) + w = reduce_precision_dtype(w, ml_dtypes.float8_e4m3fn) + activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn) # Matmul outputs = jnp.dot(activations, w) # Backward FP8 casting - outputs = cast_ml_dtype_grad(outputs, ml_dtypes.float8_e5m2) + outputs = reduce_precision_dtype_grad(outputs, ml_dtypes.float8_e5m2) # Bias + relu outputs = outputs + b @@ -78,11 +78,11 @@ def predict(params, inputs, use_fp8=True): final_w, final_b = params[-1] # Forward FP8 casting. - # final_w = jsa.ops.cast_ml_dtype(final_w, ml_dtypes.float8_e4m3fn) - activations = cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn) + # final_w = jsa.ops.reduce_precision_dtype(final_w, ml_dtypes.float8_e4m3fn) + activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn) logits = jnp.dot(activations, final_w) # Backward FP8 casting - logits = cast_ml_dtype_grad(logits, ml_dtypes.float8_e5m2) + logits = reduce_precision_dtype_grad(logits, ml_dtypes.float8_e5m2) logits = logits + final_b diff --git a/examples/scalify-quickstart.ipynb b/examples/scalify-quickstart.ipynb index 32781ba..4bab499 100644 --- a/examples/scalify-quickstart.ipynb +++ b/examples/scalify-quickstart.ipynb @@ -279,12 +279,12 @@ "source": [ "import ml_dtypes\n", "# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n", - "# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n", + "# Similarly to `dynamic_rescale`, `reduce_precision_dtype(_grad)` are available to cast in forward and backward passes\n", "sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n", "\n", "@jsa.scalify\n", "def cast_fn(v):\n", - " return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n", + " return jsa.ops.reduce_precision_dtype(v, ml_dtypes.float8_e4m3fn)\n", "\n", "sc_fp8 = cast_fn(sc)\n", "print(\"Scaled input in FP32:\", sc)\n", diff --git a/jax_scalify/ops/__init__.py b/jax_scalify/ops/__init__.py index ea132b8..8283991 100644 --- a/jax_scalify/ops/__init__.py +++ b/jax_scalify/ops/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401 -from .ml_dtypes import cast_ml_dtype, cast_ml_dtype_grad # noqa: F401 +from .ml_dtypes import reduce_precision_dtype, reduce_precision_dtype_grad # noqa: F401 from .rescaling import ( # noqa: F401 dynamic_rescale_l1, dynamic_rescale_l1_grad, diff --git a/jax_scalify/ops/ml_dtypes.py b/jax_scalify/ops/ml_dtypes.py index 0766e6e..356b509 100644 --- a/jax_scalify/ops/ml_dtypes.py +++ b/jax_scalify/ops/ml_dtypes.py @@ -9,17 +9,17 @@ from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd -def cast_ml_dtype_base(arr: Array, dtype: DTypeLike) -> Array: +def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array: """`Fake` cast to an ML dtype (e.g. FP8), using JAX LAX `reduce_precision` operator.""" info = ml_dtypes.finfo(dtype) return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant) -def cast_ml_dtype(arr: Array, dtype: DTypeLike) -> Array: +def reduce_precision_dtype(arr: Array, dtype: DTypeLike) -> Array: """`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass).""" - return partial(fn_fwd_identity_bwd, lambda v: cast_ml_dtype_base(v, dtype))(arr) + return partial(fn_fwd_identity_bwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr) -def cast_ml_dtype_grad(arr: Array, dtype: DTypeLike) -> Array: +def reduce_precision_dtype_grad(arr: Array, dtype: DTypeLike) -> Array: """`Fake` cast to an ML dtype on the backward pass (no-op on forward pass).""" - return partial(fn_bwd_identity_fwd, lambda v: cast_ml_dtype_base(v, dtype))(arr) + return partial(fn_bwd_identity_fwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr) diff --git a/tests/ops/test_ml_dtypes.py b/tests/ops/test_ml_dtypes.py index 75cee9d..973b409 100644 --- a/tests/ops/test_ml_dtypes.py +++ b/tests/ops/test_ml_dtypes.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from jax_scalify.core import scaled_array, scalify -from jax_scalify.ops import cast_ml_dtype +from jax_scalify.ops import reduce_precision_dtype class CastMLDtypeTests(chex.TestCase): @@ -17,10 +17,10 @@ class CastMLDtypeTests(chex.TestCase): {"ml_dtype": ml_dtypes.float8_e4m3fn}, {"ml_dtype": ml_dtypes.float8_e5m2}, ) - def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype): + def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype): # Values potentially "problematic" in FP8. values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) - out = cast_ml_dtype(values, dtype=ml_dtype) + out = reduce_precision_dtype(values, dtype=ml_dtype) expected_out = values.astype(ml_dtype) assert out.dtype == values.dtype npt.assert_array_equal(out, expected_out) @@ -29,10 +29,10 @@ def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype): {"ml_dtype": ml_dtypes.float8_e4m3fn}, {"ml_dtype": ml_dtypes.float8_e5m2}, ) - def test__cast_ml_dtype__scalify_compatiblity(self, ml_dtype): + def test__reduce_precision_dtype__scalify_compatiblity(self, ml_dtype): values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) arr = scaled_array(values, np.float32(1)) - out = scalify(partial(cast_ml_dtype, dtype=ml_dtype))(arr) + out = scalify(partial(reduce_precision_dtype, dtype=ml_dtype))(arr) npt.assert_array_equal(out.scale, arr.scale) npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype))