Skip to content

Commit

Permalink
Renaming cast_ml_dtype methods to reduce_precision_dtype.
Browse files Browse the repository at this point in the history
Avoiding confusion with proper `cast/astype` methods when using FP8 hardware.
  • Loading branch information
balancap committed Jun 28, 2024
1 parent 8e73c15 commit d0009f5
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
16 changes: 8 additions & 8 deletions examples/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,30 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):


def predict(params, inputs, use_fp8=True):
cast_ml_dtype = jsa.ops.cast_ml_dtype if use_fp8 else lambda x, d: x
cast_ml_dtype_grad = jsa.ops.cast_ml_dtype_grad if use_fp8 else lambda x, d: x
reduce_precision_dtype = jsa.ops.reduce_precision_dtype if use_fp8 else lambda x, d: x
reduce_precision_dtype_grad = jsa.ops.reduce_precision_dtype_grad if use_fp8 else lambda x, d: x

activations = inputs
for w, b in params[:-1]:
# Forward FP8 casting.
w = cast_ml_dtype(w, ml_dtypes.float8_e4m3fn)
activations = cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn)
w = reduce_precision_dtype(w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
# Matmul
outputs = jnp.dot(activations, w)
# Backward FP8 casting
outputs = cast_ml_dtype_grad(outputs, ml_dtypes.float8_e5m2)
outputs = reduce_precision_dtype_grad(outputs, ml_dtypes.float8_e5m2)

# Bias + relu
outputs = outputs + b
activations = jnp.maximum(outputs, 0)

final_w, final_b = params[-1]
# Forward FP8 casting.
# final_w = jsa.ops.cast_ml_dtype(final_w, ml_dtypes.float8_e4m3fn)
activations = cast_ml_dtype(activations, ml_dtypes.float8_e4m3fn)
# final_w = jsa.ops.reduce_precision_dtype(final_w, ml_dtypes.float8_e4m3fn)
activations = reduce_precision_dtype(activations, ml_dtypes.float8_e4m3fn)
logits = jnp.dot(activations, final_w)
# Backward FP8 casting
logits = cast_ml_dtype_grad(logits, ml_dtypes.float8_e5m2)
logits = reduce_precision_dtype_grad(logits, ml_dtypes.float8_e5m2)

logits = logits + final_b

Expand Down
4 changes: 2 additions & 2 deletions examples/scalify-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@
"source": [
"import ml_dtypes\n",
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n",
"# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n",
"# Similarly to `dynamic_rescale`, `reduce_precision_dtype(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(2))\n",
"\n",
"@jsa.scalify\n",
"def cast_fn(v):\n",
" return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n",
" return jsa.ops.reduce_precision_dtype(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
"sc_fp8 = cast_fn(sc)\n",
"print(\"Scaled input in FP32:\", sc)\n",
Expand Down
2 changes: 1 addition & 1 deletion jax_scalify/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .cast import reduce_precision_dtype, reduce_precision_dtype_grad # noqa: F401
from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401
from .ml_dtypes import cast_ml_dtype, cast_ml_dtype_grad # noqa: F401
from .rescaling import ( # noqa: F401
dynamic_rescale_l1,
dynamic_rescale_l1_grad,
Expand Down
10 changes: 5 additions & 5 deletions jax_scalify/ops/ml_dtypes.py → jax_scalify/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd


def cast_ml_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype (e.g. FP8), using JAX LAX `reduce_precision` operator."""
info = ml_dtypes.finfo(dtype)
return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant)


def cast_ml_dtype(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_dtype(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass)."""
return partial(fn_fwd_identity_bwd, lambda v: cast_ml_dtype_base(v, dtype))(arr)
return partial(fn_fwd_identity_bwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)


def cast_ml_dtype_grad(arr: Array, dtype: DTypeLike) -> Array:
def reduce_precision_dtype_grad(arr: Array, dtype: DTypeLike) -> Array:
"""`Fake` cast to an ML dtype on the backward pass (no-op on forward pass)."""
return partial(fn_bwd_identity_fwd, lambda v: cast_ml_dtype_base(v, dtype))(arr)
return partial(fn_bwd_identity_fwd, lambda v: reduce_precision_dtype_base(v, dtype))(arr)
10 changes: 5 additions & 5 deletions tests/ops/test_ml_dtypes.py → tests/ops/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
from numpy.typing import NDArray

from jax_scalify.core import scaled_array, scalify
from jax_scalify.ops import cast_ml_dtype
from jax_scalify.ops import reduce_precision_dtype


class CastMLDtypeTests(chex.TestCase):
@parameterized.parameters(
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype):
def test__reduce_precision_dtype__consistent_rounding_down(self, ml_dtype):
# Values potentially "problematic" in FP8.
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
out = cast_ml_dtype(values, dtype=ml_dtype)
out = reduce_precision_dtype(values, dtype=ml_dtype)
expected_out = values.astype(ml_dtype)
assert out.dtype == values.dtype
npt.assert_array_equal(out, expected_out)
Expand All @@ -29,10 +29,10 @@ def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype):
{"ml_dtype": ml_dtypes.float8_e4m3fn},
{"ml_dtype": ml_dtypes.float8_e5m2},
)
def test__cast_ml_dtype__scalify_compatiblity(self, ml_dtype):
def test__reduce_precision_dtype__scalify_compatiblity(self, ml_dtype):
values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16)
arr = scaled_array(values, np.float32(1))
out = scalify(partial(cast_ml_dtype, dtype=ml_dtype))(arr)
out = scalify(partial(reduce_precision_dtype, dtype=ml_dtype))(arr)

npt.assert_array_equal(out.scale, arr.scale)
npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype))

0 comments on commit d0009f5

Please sign in to comment.