From 587d8157af366b70f29ec853361ca250a3c3f29e Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 28 Jun 2024 14:36:35 +0100 Subject: [PATCH] Adding fwd/bwd cast methods compatible with FP8. Allowing cast applying only on forward or backward passes respectively. Making it easier to build explicit FP8 code. --- .../mnist_classifier_from_scratch_fp8.py | 16 +++--- examples/scalify-quickstart.ipynb | 4 +- jax_scalify/core/typing.py | 15 ++--- jax_scalify/ops/__init__.py | 7 ++- jax_scalify/ops/cast.py | 26 +++++++-- jax_scalify/ops/rescaling.py | 46 +++------------ jax_scalify/ops/utils.py | 38 +++++++++++++ tests/ops/test_cast.py | 57 +++++++++++++++++-- 8 files changed, 141 insertions(+), 68 deletions(-) create mode 100644 jax_scalify/ops/utils.py diff --git a/examples/mnist/mnist_classifier_from_scratch_fp8.py b/examples/mnist/mnist_classifier_from_scratch_fp8.py index 78da872..f53aa6e 100644 --- a/examples/mnist/mnist_classifier_from_scratch_fp8.py +++ b/examples/mnist/mnist_classifier_from_scratch_fp8.py @@ -59,18 +59,18 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): def predict(params, inputs, use_fp8=True): - reduce_precision_dtype = jsa.ops.reduce_precision_dtype if use_fp8 else lambda x, d: x - reduce_precision_dtype_grad = jsa.ops.reduce_precision_dtype_grad if use_fp8 else lambda x, d: x + reduce_precision_on_forward = jsa.ops.reduce_precision_on_forward if use_fp8 else lambda x, d: x + reduce_precision_on_backward = jsa.ops.reduce_precision_on_backward if use_fp8 else lambda x, d: x activations = inputs for w, b in params[:-1]: # Forward FP8 casting. - w = reduce_precision_dtype(w, ml_dtypes.float8_e4m3fn) - activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn) + w = reduce_precision_on_forward(w, ml_dtypes.float8_e4m3fn) + activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn) # Matmul outputs = jnp.dot(activations, w) # Backward FP8 casting - outputs = reduce_precision_dtype_grad(outputs, ml_dtypes.float8_e5m2) + outputs = reduce_precision_on_backward(outputs, ml_dtypes.float8_e5m2) # Bias + relu outputs = outputs + b @@ -78,11 +78,11 @@ def predict(params, inputs, use_fp8=True): final_w, final_b = params[-1] # Forward FP8 casting. - # final_w = jsa.ops.reduce_precision_dtype(final_w, ml_dtypes.float8_e4m3fn) - activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn) + # final_w = jsa.ops.reduce_precision_on_forward(final_w, ml_dtypes.float8_e4m3fn) + activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn) logits = jnp.dot(activations, final_w) # Backward FP8 casting - logits = reduce_precision_dtype_grad(logits, ml_dtypes.float8_e5m2) + logits = reduce_precision_on_backward(logits, ml_dtypes.float8_e5m2) logits = logits + final_b diff --git a/examples/scalify-quickstart.ipynb b/examples/scalify-quickstart.ipynb index 4bab499..0cfce92 100644 --- a/examples/scalify-quickstart.ipynb +++ b/examples/scalify-quickstart.ipynb @@ -279,12 +279,12 @@ "source": [ "import ml_dtypes\n", "# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n", - "# Similarly to `dynamic_rescale`, `reduce_precision_dtype(_grad)` are available to cast in forward and backward passes\n", + "# Similarly to `dynamic_rescale`, `reduce_precision_on_forward(_grad)` are available to cast in forward and backward passes\n", "sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n", "\n", "@jsa.scalify\n", "def cast_fn(v):\n", - " return jsa.ops.reduce_precision_dtype(v, ml_dtypes.float8_e4m3fn)\n", + " return jsa.ops.reduce_precision_on_forward(v, ml_dtypes.float8_e4m3fn)\n", "\n", "sc_fp8 = cast_fn(sc)\n", "print(\"Scaled input in FP32:\", sc)\n", diff --git a/jax_scalify/core/typing.py b/jax_scalify/core/typing.py index 08a92cf..e26cf95 100644 --- a/jax_scalify/core/typing.py +++ b/jax_scalify/core/typing.py @@ -4,16 +4,17 @@ # import chex import jax import jax.numpy as jnp -import jaxlib import numpy as np # Type aliasing. To be compatible with JAX 0.3 as well. -if jax.__version_info__[1] > 3: - Array = jax.Array - ArrayTypes = (jax.Array, jax.stages.ArgInfo) -else: - Array = jaxlib.xla_extension.DeviceArray - ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer) +try: + from jax import Array + + ArrayTypes = (Array, jax.stages.ArgInfo) +except ImportError: + from jaxlib.xla_extension import DeviceArray as Array + + ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer) def get_numpy_api(val: Any) -> Any: diff --git a/jax_scalify/ops/__init__.py b/jax_scalify/ops/__init__.py index 6c163c0..7bb0557 100644 --- a/jax_scalify/ops/__init__.py +++ b/jax_scalify/ops/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from .cast import reduce_precision_dtype, reduce_precision_dtype_grad # noqa: F401 +from .cast import ( # noqa: F401 + cast_on_backward, + cast_on_forward, + reduce_precision_on_backward, + reduce_precision_on_forward, +) from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401 from .rescaling import ( # noqa: F401 dynamic_rescale_l1, diff --git a/jax_scalify/ops/cast.py b/jax_scalify/ops/cast.py index 356b509..a66f2aa 100644 --- a/jax_scalify/ops/cast.py +++ b/jax_scalify/ops/cast.py @@ -6,7 +6,7 @@ from jax_scalify.core import Array, DTypeLike -from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd +from .utils import map_on_backward, map_on_forward def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array: @@ -15,11 +15,27 @@ def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array: return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant) -def reduce_precision_dtype(arr: Array, dtype: DTypeLike) -> Array: +def reduce_precision_on_forward(arr: Array, dtype: DTypeLike) -> Array: """`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass).""" - return partial(fn_fwd_identity_bwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr) + return partial(map_on_forward, lambda v: reduce_precision_dtype_base(v, dtype))(arr) -def reduce_precision_dtype_grad(arr: Array, dtype: DTypeLike) -> Array: +def reduce_precision_on_backward(arr: Array, dtype: DTypeLike) -> Array: """`Fake` cast to an ML dtype on the backward pass (no-op on forward pass).""" - return partial(fn_bwd_identity_fwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr) + return partial(map_on_backward, lambda v: reduce_precision_dtype_base(v, dtype))(arr) + + +def cast_on_forward(arr: Array, dtype: DTypeLike) -> Array: + """Cast input array only on the forward pass (no-op on the backward pass). + + Useful for implementation `DenseGeneral` FP8 matmuls. + """ + return partial(map_on_forward, lambda v: jax.lax.convert_element_type(v, dtype))(arr) + + +def cast_on_backward(arr: Array, dtype: DTypeLike) -> Array: + """Cast input array only on the backward pass (no-op on the forward pass). + + Useful for implementation `DenseGeneral` FP8 matmuls. + """ + return partial(map_on_backward, lambda v: jax.lax.convert_element_type(v, dtype))(arr) diff --git a/jax_scalify/ops/rescaling.py b/jax_scalify/ops/rescaling.py index 62f6b44..3debbdb 100644 --- a/jax_scalify/ops/rescaling.py +++ b/jax_scalify/ops/rescaling.py @@ -7,39 +7,7 @@ from jax_scalify.core import ScaledArray, pow2_round, pow2_round_down from jax_scalify.lax import get_data_scale, rebalance - -@partial(jax.custom_vjp, nondiff_argnums=(0,)) -def fn_fwd_identity_bwd(f, arg): - """Function with identity bwd/grad.""" - return f(arg) - - -def fn_fwd_identity_bwd_fwd(f, arg): - return arg, None - - -def fn_fwd_identity_bwd_bwd(f, _, grad): - return (grad,) - - -fn_fwd_identity_bwd.defvjp(fn_fwd_identity_bwd_fwd, fn_fwd_identity_bwd_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(0,)) -def fn_bwd_identity_fwd(f, arg): - """Apply a function on the gradient/backward pass.""" - return arg - - -def fn_bwd_identity_fwd_fwd(f, arg): - return arg, None - - -def fn_bwd_identity_fwd_bwd(f, _, grad): - return (f(grad),) - - -fn_bwd_identity_fwd.defvjp(fn_bwd_identity_fwd_fwd, fn_bwd_identity_fwd_bwd) +from .utils import map_on_backward, map_on_forward def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: @@ -97,11 +65,11 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray: # Dynamic rescale on fwd arrays. -dynamic_rescale_max = partial(fn_fwd_identity_bwd, dynamic_rescale_max_base) -dynamic_rescale_l1 = partial(fn_fwd_identity_bwd, dynamic_rescale_l1_base) -dynamic_rescale_l2 = partial(fn_fwd_identity_bwd, dynamic_rescale_l2_base) +dynamic_rescale_max = partial(map_on_forward, dynamic_rescale_max_base) +dynamic_rescale_l1 = partial(map_on_forward, dynamic_rescale_l1_base) +dynamic_rescale_l2 = partial(map_on_forward, dynamic_rescale_l2_base) # Dynamic rescale on gradients. -dynamic_rescale_max_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_max_base) -dynamic_rescale_l1_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l1_base) -dynamic_rescale_l2_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l2_base) +dynamic_rescale_max_grad = partial(map_on_backward, dynamic_rescale_max_base) +dynamic_rescale_l1_grad = partial(map_on_backward, dynamic_rescale_l1_base) +dynamic_rescale_l2_grad = partial(map_on_backward, dynamic_rescale_l2_base) diff --git a/jax_scalify/ops/utils.py b/jax_scalify/ops/utils.py new file mode 100644 index 0000000..6f8ea62 --- /dev/null +++ b/jax_scalify/ops/utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +from functools import partial + +import jax + + +@partial(jax.custom_vjp, nondiff_argnums=(0,)) +def map_on_forward(f, arg): + """Map a function on a forward pass only. No-op/identity on backward pass.""" + return f(arg) + + +def map_on_forward_fwd(f, arg): + return arg, None + + +def map_on_forward_bwd(f, _, grad): + return (grad,) + + +map_on_forward.defvjp(map_on_forward_fwd, map_on_forward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(0,)) +def map_on_backward(f, arg): + """Map a function on the gradient/backward pass. No-op/identity on forward.""" + return arg + + +def map_on_backward_fwd(f, arg): + return arg, None + + +def map_on_backward_bwd(f, _, grad): + return (f(grad),) + + +map_on_backward.defvjp(map_on_backward_fwd, map_on_backward_bwd) diff --git a/tests/ops/test_cast.py b/tests/ops/test_cast.py index 973b409..4fb55fc 100644 --- a/tests/ops/test_cast.py +++ b/tests/ops/test_cast.py @@ -2,6 +2,8 @@ from functools import partial import chex +import jax +import jax.numpy as jnp import ml_dtypes import numpy as np import numpy.testing as npt @@ -9,18 +11,18 @@ from numpy.typing import NDArray from jax_scalify.core import scaled_array, scalify -from jax_scalify.ops import reduce_precision_dtype +from jax_scalify.ops import cast_on_backward, cast_on_forward, reduce_precision_on_forward -class CastMLDtypeTests(chex.TestCase): +class ReducePrecisionDtypeTests(chex.TestCase): @parameterized.parameters( {"ml_dtype": ml_dtypes.float8_e4m3fn}, {"ml_dtype": ml_dtypes.float8_e5m2}, ) - def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype): + def test__reduce_precision_on_forward__consistent_rounding_down(self, ml_dtype): # Values potentially "problematic" in FP8. values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) - out = reduce_precision_dtype(values, dtype=ml_dtype) + out = reduce_precision_on_forward(values, dtype=ml_dtype) expected_out = values.astype(ml_dtype) assert out.dtype == values.dtype npt.assert_array_equal(out, expected_out) @@ -29,10 +31,53 @@ def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype): {"ml_dtype": ml_dtypes.float8_e4m3fn}, {"ml_dtype": ml_dtypes.float8_e5m2}, ) - def test__reduce_precision_dtype__scalify_compatiblity(self, ml_dtype): + def test__reduce_precision_on_forward__scalify_compatiblity(self, ml_dtype): values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) arr = scaled_array(values, np.float32(1)) - out = scalify(partial(reduce_precision_dtype, dtype=ml_dtype))(arr) + out = scalify(partial(reduce_precision_on_forward, dtype=ml_dtype))(arr) npt.assert_array_equal(out.scale, arr.scale) npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype)) + + +class CastOnForwardBackwardTests(chex.TestCase): + @chex.variants(with_jit=True, without_jit=True) + @parameterized.parameters( + {"dtype": jnp.float16}, + # TODO: uncomment when JAX 0.4+ used + # {"dtype": jnp.float8_e4m3fn}, + # {"dtype": jnp.float8_e5m2}, + ) + def test__cast_on_forward_backward__proper_results(self, dtype): + # Values potentially "problematic" in FP8. + values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) + out_on_fwd = self.variant(partial(cast_on_forward, dtype=dtype))(values) + out_on_bwd = self.variant(partial(cast_on_backward, dtype=dtype))(values) + + assert out_on_fwd.dtype == dtype + assert out_on_bwd.dtype == values.dtype + npt.assert_array_equal(out_on_fwd, jax.lax.convert_element_type(values, dtype)) + npt.assert_array_equal(out_on_bwd, values) + + @parameterized.parameters( + {"dtype": jnp.float16}, + # TODO: uncomment when JAX 0.4+ used + # {"dtype": jnp.float8_e4m3fn}, + # {"dtype": jnp.float8_e5m2}, + ) + def test__cast_on_backward__grad__proper_results(self, dtype): + def fn(val, with_cast): + if with_cast: + val = cast_on_backward(val, dtype=dtype) + val = val * val + return jax.lax.reduce_sum_p.bind(val, axes=(0,)) + + # Values potentially "problematic" in FP8. + values: NDArray[np.float32] = np.array([17, -17, 8, 1, 9, 11, 18], np.float32) + # Backward pass => gradient. + grads = jax.grad(partial(fn, with_cast=True))(values) + grads_ref = jax.grad(partial(fn, with_cast=False))(values) + + assert grads.dtype == dtype + assert grads_ref.dtype == values.dtype + npt.assert_array_equal(grads, jax.lax.convert_element_type(grads_ref, dtype))