Skip to content

Commit

Permalink
Adding fwd/bwd cast methods compatible with FP8.
Browse files Browse the repository at this point in the history
Allowing cast applying only on forward or backward passes respectively.
Making it easier to build explicit FP8 code.
  • Loading branch information
balancap committed Jun 28, 2024
1 parent c327a2f commit d06a8c4
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 67 deletions.
16 changes: 8 additions & 8 deletions examples/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,30 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):


def predict(params, inputs, use_fp8=True):
reduce_precision_dtype = jsa.ops.reduce_precision_dtype if use_fp8 else lambda x, d: x
reduce_precision_dtype_grad = jsa.ops.reduce_precision_dtype_grad if use_fp8 else lambda x, d: x
reduce_precision_on_forward = jsa.ops.reduce_precision_on_forward if use_fp8 else lambda x, d: x
reduce_precision_on_backward = jsa.ops.reduce_precision_on_backward if use_fp8 else lambda x, d: x

activations = inputs
for w, b in params[:-1]:
# Forward FP8 casting.
w = reduce_precision_dtype(w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
w = reduce_precision_on_forward(w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn)
# Matmul
outputs = jnp.dot(activations, w)
# Backward FP8 casting
outputs = reduce_precision_dtype_grad(outputs, ml_dtypes.float8_e5m2)
outputs = reduce_precision_on_backward(outputs, ml_dtypes.float8_e5m2)

# Bias + relu
outputs = outputs + b
activations = jnp.maximum(outputs, 0)

final_w, final_b = params[-1]
# Forward FP8 casting.
# final_w = jsa.ops.reduce_precision_dtype(final_w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
# final_w = jsa.ops.reduce_precision_on_forward(final_w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn)
logits = jnp.dot(activations, final_w)
# Backward FP8 casting
logits = reduce_precision_dtype_grad(logits, ml_dtypes.float8_e5m2)
logits = reduce_precision_on_backward(logits, ml_dtypes.float8_e5m2)

logits = logits + final_b

Expand Down
4 changes: 2 additions & 2 deletions examples/scalify-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@
"source": [
"import ml_dtypes\n",
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n",
"# Similarly to `dynamic_rescale`, `reduce_precision_dtype(_grad)` are available to cast in forward and backward passes\n",
"# Similarly to `dynamic_rescale`, `reduce_precision_on_forward(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n",
"\n",
"@jsa.scalify\n",
"def cast_fn(v):\n",
" return jsa.ops.reduce_precision_dtype(v, ml_dtypes.float8_e4m3fn)\n",
" return jsa.ops.reduce_precision_on_forward(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
"sc_fp8 = cast_fn(sc)\n",
"print(\"Scaled input in FP32:\", sc)\n",
Expand Down
13 changes: 7 additions & 6 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
# import chex
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np

# Type aliasing. To be compatible with JAX 0.3 as well.
if jax.__version_info__[1] > 3:
Array = jax.Array
ArrayTypes = (jax.Array, jax.stages.ArgInfo)
if jax.__version_info__[1] < 3:
from jaxlib.xla_extension import DeviceArray as Array

ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)
else:
Array = jaxlib.xla_extension.DeviceArray
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer)
from jax import Array

ArrayTypes = (Array, jax.stages.ArgInfo)


def get_numpy_api(val: Any) -> Any:
Expand Down
7 changes: 6 additions & 1 deletion jax_scalify/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .cast import reduce_precision_dtype, reduce_precision_dtype_grad # noqa: F401
from .cast import ( # noqa: F401
cast_on_backward,
cast_on_forward,
reduce_precision_on_backward,
reduce_precision_on_forward,
)
from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401
from .rescaling import ( # noqa: F401
dynamic_rescale_l1,
Expand Down
26 changes: 21 additions & 5 deletions jax_scalify/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from jax_scalify.core import Array, DTypeLike

from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd
from .utils import map_on_backward, map_on_forward


def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
Expand All @@ -15,11 +15,27 @@ def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant)


def reduce_precision_dtype(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_on_forward(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass)."""
return partial(fn_fwd_identity_bwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)
return partial(map_on_forward, lambda v: reduce_precision_dtype_base(v, dtype))(arr)


def reduce_precision_dtype_grad(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_on_backward(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype on the backward pass (no-op on forward pass)."""
return partial(fn_bwd_identity_fwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)
return partial(map_on_backward, lambda v: reduce_precision_dtype_base(v, dtype))(arr)


def cast_on_forward(arr: Array, dtype: DTypeLike) -> Array:
"""Cast input array only on the forward pass (no-op on the backward pass).
Useful for implementation `DenseGeneral` FP8 matmuls.
"""
return partial(map_on_forward, lambda v: jax.lax.convert_element_type(v, dtype))(arr)


def cast_on_backward(arr: Array, dtype: DTypeLike) -> Array:
"""Cast input array only on the backward pass (no-op on the forward pass).
Useful for implementation `DenseGeneral` FP8 matmuls.
"""
return partial(map_on_backward, lambda v: jax.lax.convert_element_type(v, dtype))(arr)
46 changes: 7 additions & 39 deletions jax_scalify/ops/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,7 @@
from jax_scalify.core import ScaledArray, pow2_round, pow2_round_down
from jax_scalify.lax import get_data_scale, rebalance


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def fn_fwd_identity_bwd(f, arg):
"""Function with identity bwd/grad."""
return f(arg)


def fn_fwd_identity_bwd_fwd(f, arg):
return arg, None


def fn_fwd_identity_bwd_bwd(f, _, grad):
return (grad,)


fn_fwd_identity_bwd.defvjp(fn_fwd_identity_bwd_fwd, fn_fwd_identity_bwd_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def fn_bwd_identity_fwd(f, arg):
"""Apply a function on the gradient/backward pass."""
return arg


def fn_bwd_identity_fwd_fwd(f, arg):
return arg, None


def fn_bwd_identity_fwd_bwd(f, _, grad):
return (f(grad),)


fn_bwd_identity_fwd.defvjp(fn_bwd_identity_fwd_fwd, fn_bwd_identity_fwd_bwd)
from .utils import map_on_backward, map_on_forward


def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
Expand Down Expand Up @@ -97,11 +65,11 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:


# Dynamic rescale on fwd arrays.
dynamic_rescale_max = partial(fn_fwd_identity_bwd, dynamic_rescale_max_base)
dynamic_rescale_l1 = partial(fn_fwd_identity_bwd, dynamic_rescale_l1_base)
dynamic_rescale_l2 = partial(fn_fwd_identity_bwd, dynamic_rescale_l2_base)
dynamic_rescale_max = partial(map_on_forward, dynamic_rescale_max_base)
dynamic_rescale_l1 = partial(map_on_forward, dynamic_rescale_l1_base)
dynamic_rescale_l2 = partial(map_on_forward, dynamic_rescale_l2_base)

# Dynamic rescale on gradients.
dynamic_rescale_max_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_max_base)
dynamic_rescale_l1_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l1_base)
dynamic_rescale_l2_grad = partial(fn_bwd_identity_fwd, dynamic_rescale_l2_base)
dynamic_rescale_max_grad = partial(map_on_backward, dynamic_rescale_max_base)
dynamic_rescale_l1_grad = partial(map_on_backward, dynamic_rescale_l1_base)
dynamic_rescale_l2_grad = partial(map_on_backward, dynamic_rescale_l2_base)
38 changes: 38 additions & 0 deletions jax_scalify/ops/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from functools import partial

import jax


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def map_on_forward(f, arg):
"""Map a function on a forward pass only. No-op/identity on backward pass."""
return f(arg)


def map_on_forward_fwd(f, arg):
return arg, None


def map_on_forward_bwd(f, _, grad):
return (grad,)


map_on_forward.defvjp(map_on_forward_fwd, map_on_forward_bwd)


@partial(jax.custom_vjp, nondiff_argnums=(0,))
def map_on_backward(f, arg):
"""Map a function on the gradient/backward pass. No-op/identity on forward."""
return arg


def map_on_backward_fwd(f, arg):
return arg, None


def map_on_backward_bwd(f, _, grad):
return (f(grad),)


map_on_backward.defvjp(map_on_backward_fwd, map_on_backward_bwd)
53 changes: 47 additions & 6 deletions tests/ops/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@
from functools import partial

import chex
import jax
import ml_dtypes
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized
from numpy.typing import NDArray

from jax_scalify.core import scaled_array, scalify
from jax_scalify.ops import reduce_precision_dtype
from jax_scalify.ops import cast_on_backward, cast_on_forward, reduce_precision_on_forward


class CastMLDtypeTests(chex.TestCase):
class ReducePrecisionDtypeTests(chex.TestCase):
@parameterized.parameters(
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype):
def test__reduce_precision_on_forward__consistent_rounding_down(self, ml_dtype):
# Values potentially "problematic" in FP8.
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
out = reduce_precision_dtype(values, dtype=ml_dtype)
out = reduce_precision_on_forward(values, dtype=ml_dtype)
expected_out = values.astype(ml_dtype)
assert out.dtype == values.dtype
npt.assert_array_equal(out, expected_out)
Expand All @@ -29,10 +30,50 @@ def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype):
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__reduce_precision_dtype__scalify_compatiblity(self, ml_dtype):
def test__reduce_precision_on_forward__scalify_compatiblity(self, ml_dtype):
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
arr = scaled_array(values, np.float32(1))
out = scalify(partial(reduce_precision_dtype, dtype=ml_dtype))(arr)
out = scalify(partial(reduce_precision_on_forward, dtype=ml_dtype))(arr)

npt.assert_array_equal(out.scale, arr.scale)
npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype))


class CastOnForwardBackwardTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters(
{"dtype": ml_dtypes.float8_e4m3fn},
{"dtype": ml_dtypes.float8_e5m2},
)
def test__cast_on_forward_backward__proper_results(self, dtype):
# Values potentially "problematic" in FP8.
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
out_on_fwd = self.variant(partial(cast_on_forward, dtype=dtype))(values)
out_on_bwd = self.variant(partial(cast_on_backward, dtype=dtype))(values)

assert out_on_fwd.dtype == dtype
assert out_on_bwd.dtype == values.dtype
npt.assert_array_equal(out_on_fwd, jax.lax.convert_element_type(values, dtype))
npt.assert_array_equal(out_on_bwd, values)

@parameterized.parameters(
{"dtype": np.float16},
{"dtype": ml_dtypes.float8_e4m3fn},
{"dtype": ml_dtypes.float8_e5m2},
)
def test__cast_on_backward__grad__proper_results(self, dtype):
def fn(val, with_cast):
if with_cast:
val = cast_on_backward(val, dtype=dtype)
val = val * val
return jax.lax.reduce_sum_p.bind(val, axes=(0,))

# Values potentially "problematic" in FP8.
values: NDArray[np.float32] = np.array([17, -17, 8, 1, 9, 11, 18], np.float32)
# Backward pass => gradient.
grads = jax.grad(partial(fn, with_cast=True))(values)
grads_ref = jax.grad(partial(fn, with_cast=False))(values)

assert grads.dtype == dtype
assert grads_ref.dtype == values.dtype
npt.assert_array_equal(grads, jax.lax.convert_element_type(grads_ref, dtype))

0 comments on commit d06a8c4

Please sign in to comment.