Skip to content

Commit

Permalink
Merge pull request #3435 from kaixih:update_fp8_support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576718576
  • Loading branch information
Flax Authors committed Oct 26, 2023
2 parents 5557649 + d24db61 commit 738078c
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 262 deletions.
6 changes: 1 addition & 5 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@
make_causal_mask as make_causal_mask,
)
from .combinators import Sequential as Sequential
from .fp8_ops import (
compute_scale as fp8_compute_scale,
quantize_dequantize as fp8_quantize_dequantize,
Fp8DenseGeneralOp as Fp8DenseGeneralOp,
)
from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp
from .initializers import (
ones_init as ones_init,
ones as ones,
Expand Down
145 changes: 67 additions & 78 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,39 @@
from functools import partial

from flax.linen import initializers
from flax.linen.module import Module
from flax.linen import module
from jax import custom_vjp
from jax import lax
from jax import numpy as jnp
from jax import random


# Type annotations
Array = jnp.ndarray
Dtype = jnp.dtype
PRNGKey = jnp.ndarray
OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient'

class FP8Helper:
FP8_COLLECTION_NAME: str = "fp8_params"

def get_fp8_max(fp8_dtype, out_dtype):
assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2)
return jnp.finfo(fp8_dtype).max.astype(out_dtype)


def quantize(x, q_dtype, scale, compute_dtype):
# We need to explicitly cast the max value to compute_dtype, otherwise the jax
# dtype promotion will cast the scaled_x to fp32 in the following ops, which
# would violate the fp8-matmul pattern matching.
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
dtype_max = get_fp8_max(q_dtype, compute_dtype)

scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape)

clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)

return clipped_x.astype(q_dtype)


def dequantize(x, dq_dtype, scale):
return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape)


def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = quantize(x, q_dtype, scale, compute_dtype)
return dequantize(qx, x.dtype, scale)


def compute_scale(amax, scale, fp8_max, margin=0):
"""Default function to convert amax to scaling factor."""
# This function copied from the TransformerEngine is used to compute its
Expand All @@ -66,38 +61,44 @@ def compute_scale(amax, scale, fp8_max, margin=0):
sf = jnp.where(exp < 0, 1.0 / sf, sf)
return 1.0 / sf


def compute_scale_and_amax_history(x, q_dtype, scale, amax_history):
dtype_max = get_fp8_max(q_dtype, jnp.float32)

amax_update = jnp.max(jnp.abs(x)).astype(scale.dtype)
new_amax_history = \
jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)

amax_from_history = jnp.max(new_amax_history, axis=0)
new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update)
amax_from_history = jnp.max(new_history, axis=0)
new_scale = compute_scale(amax_from_history, scale, dtype_max)
return new_scale, new_amax_history
return new_scale, new_history


def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
qx = quantize_dequantize(x, q_dtype, scale, compute_dtype)
new_scale, new_amax_history = compute_scale_and_amax_history(
x, q_dtype, scale, amax_history)
return qx, new_scale, new_amax_history
new_scale, new_history = compute_scale_and_amax_history(
x, q_dtype, scale, amax_history
)
return qx, new_scale, new_history


@partial(custom_vjp, nondiff_argnums=(0,))
def in_qdq(compute_dtype, inp, scale, amax_history):
qin, _, _ = qdq_and_return(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype)
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, inp, scale, amax_history):
qin, new_scale, new_amax_history = qdq_and_return(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype)
return qin, (new_scale, new_amax_history)
qin, new_scale, new_history = qdq_and_return(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)


def in_qdq_bwd(compute_dtype, res, g):
new_scale, new_amax_history = res
new_scale, new_history = res
q_g = g
return q_g, new_scale, new_amax_history
return q_g, new_scale, new_history


in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd)

Expand All @@ -106,34 +107,23 @@ def in_qdq_bwd(compute_dtype, res, g):
def out_qdq(compute_dtype, out, scale, amax_history):
return out


def out_qdq_fwd(compute_dtype, out, scale, amax_history):
return out, (scale, amax_history)


def out_qdq_bwd(compute_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_amax_history = qdq_and_return(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype)
return q_g, new_scale, new_amax_history

out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)

def fp8_dot_general(lhs, rhs, dimension_numbers, precision, compute_dtype,
lhs_scale, lhs_amax_history, rhs_scale, rhs_amax_history,
dout_scale, dout_amax_history):
"""Perform dot_general. """
q_g, new_scale, new_history = qdq_and_return(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history

lhs_qdq = in_qdq(compute_dtype, lhs, lhs_scale, lhs_amax_history)

rhs_qdq = in_qdq(compute_dtype, rhs, rhs_scale, rhs_amax_history)

output_qdq = lax.dot_general(lhs_qdq, rhs_qdq, dimension_numbers, precision)

out = out_qdq(compute_dtype, output_qdq, dout_scale, dout_amax_history)

return out
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


class Fp8DenseGeneralOp(Module):
class Fp8DotGeneralOp(module.Module):
amax_history_length: int = 1024

def setup(self) -> None:
Expand All @@ -151,47 +141,46 @@ def setup(self) -> None:
)

self.input_amax_history = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'input_amax_history',
*amax_history_args)
OVERWRITE_WITH_GRADIENT, 'input_amax_history', *amax_history_args)
self.kernel_amax_history = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'kernel_amax_history',
*amax_history_args)
OVERWRITE_WITH_GRADIENT, 'kernel_amax_history', *amax_history_args)
self.output_grad_amax_history = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'output_grad_amax_history',
*amax_history_args)
OVERWRITE_WITH_GRADIENT, 'output_grad_amax_history', *amax_history_args)

self.input_scale = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'input_scale',
*scale_args)
OVERWRITE_WITH_GRADIENT, 'input_scale', *scale_args)
self.kernel_scale = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'kernel_scale',
*scale_args)
OVERWRITE_WITH_GRADIENT, 'kernel_scale', *scale_args)
self.output_grad_scale = self.variable(
FP8Helper.FP8_COLLECTION_NAME,
'output_grad_scale',
*scale_args)
OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args)


def __call__(self, *args, **kwargs) -> Array:
def __call__(self, *args, **kwargs) -> jnp.ndarray:

assert len(args) == 3
inputs = args[0]
kernel = args[1]
x = args[0]
k = args[1]
dimension_numbers = args[2]
precision = kwargs['precision']
comp_dtype = kernel.dtype
inputs = jnp.asarray(inputs, comp_dtype)

out = fp8_dot_general(inputs, kernel, dimension_numbers, precision,
comp_dtype, self.input_scale.value,
self.input_amax_history.value,
self.kernel_scale.value, self.kernel_amax_history.value,
self.output_grad_scale.value,
self.output_grad_amax_history.value)
return out

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)

x_qdq = in_qdq(
comp_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision)
y = out_qdq(
comp_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value
)

return y

66 changes: 26 additions & 40 deletions flax/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from flax import core
from flax import struct
from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT
import optax


Expand Down Expand Up @@ -71,8 +72,27 @@ def apply_gradients(self, *, grads, **kwargs):
and `opt_state` updated by applying `grads`, and additional attributes
replaced as specified by `kwargs`.
"""
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
if OVERWRITE_WITH_GRADIENT in grads:
grads_with_opt = grads['params']
params_with_opt = self.params['params']
else:
grads_with_opt = grads
params_with_opt = self.params

updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
new_params_with_opt = optax.apply_updates(params_with_opt, updates)

# As implied by the OWG name, the gradients are used directly to update the
# parameters.
if OVERWRITE_WITH_GRADIENT in grads:
new_params = {
'params': new_params_with_opt,
OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT]
}
else:
new_params = new_params_with_opt
return self.replace(
step=self.step + 1,
params=new_params,
Expand All @@ -83,45 +103,11 @@ def apply_gradients(self, *, grads, **kwargs):
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
"""Creates a new instance with `step=0` and initialized `opt_state`."""
opt_state = tx.init(params)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)

class Fp8TrainState(TrainState):
"""Customized train state for Fp8."""

def apply_gradients(self, *, grads, **kwargs):
assert 'fp8_params' in grads
updates, new_opt_state = self.tx.update(grads['params'], self.opt_state,
self.params['params'])
new_non_fp8_params = optax.apply_updates(self.params['params'], updates)

# self.param is structured as
# {'param': {'kernel:...,'}, 'fp8_params': {...}}. For the fp8 variables
# in the fp8-params collection, we will simply replace them with their
# grads, because their grads are actually new values defined in the
# custom_vjp functions.
new_params = {'params': new_non_fp8_params,
'fp8_params': grads['fp8_params']}

return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
# We exclude OWG params when present because they do not need opt states.
params_with_opt = (
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
)

@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
assert 'fp8_params' in params
opt_state = tx.init(params['params'])

opt_state = tx.init(params_with_opt)
return cls(
step=0,
apply_fn=apply_fn,
Expand Down
Loading

0 comments on commit 738078c

Please sign in to comment.