From c9b0fc5be1249b95f7c04687a9908e0eb62056f7 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 15:26:58 -0700 Subject: [PATCH] pass test --- flax/linen/fp8_ops.py | 48 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 324603b20b..a28ba7034a 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -350,6 +350,41 @@ def _parse_dot_inputs(*args, **kwargs): x = jnp.asarray(x, comp_dtype) return x, k, dimension_numbers, comp_dtype +# Convenience wrappers for the quantize-dot-dequantize +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + q_lhs, new_lhs_scale = in_q( + compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history + ) + + y = one_sided_q_dot_dq( + lhs, + q_lhs, + new_lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type + ) + return y # type: ignore + class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 @@ -419,23 +454,20 @@ def __call__(self, *args, **kwargs): x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( *args, **kwargs ) - - q_x, new_input_scale = in_q( - comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) - - y = one_sided_q_dot_dq( + y = q_dot_dq( x, - q_x, - new_input_scale, # actualy new lhs scale k, + self.input_scale.value, self.kernel_scale.value, self.output_grad_scale.value, + self.input_amax_history.value, self.kernel_amax_history.value, self.output_grad_amax_history.value, comp_dtype, dimension_numbers, - preferred_element_type=x.dtype + preferred_element_type=x.dtype, ) + return y # type: ignore def one_sided_q_dot_dq_impl(