Skip to content

Commit

Permalink
pass test
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Sep 23, 2024
1 parent 920099e commit c9b0fc5
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,41 @@ def _parse_dot_inputs(*args, **kwargs):
x = jnp.asarray(x, comp_dtype)
return x, k, dimension_numbers, comp_dtype

# Convenience wrappers for the quantize-dot-dequantize
def q_dot_dq(
lhs,
rhs,
lhs_scale,
rhs_scale,
out_grad_scale,
lhs_amax_history,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision=None,
preferred_element_type=None
):
q_lhs, new_lhs_scale = in_q(
compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history
)

y = one_sided_q_dot_dq(
lhs,
q_lhs,
new_lhs_scale, # actualy new lhs scale
rhs,
rhs_scale,
out_grad_scale,
rhs_amax_history,
out_grad_amax_history,
compute_dtype,
dimension_numbers,
precision,
preferred_element_type
)
return y # type: ignore


class Fp8DotGeneralBase(module.Module):
amax_history_length: int = 1024
Expand Down Expand Up @@ -419,23 +454,20 @@ def __call__(self, *args, **kwargs):
x, k, dimension_numbers, comp_dtype = _parse_dot_inputs(
*args, **kwargs
)

q_x, new_input_scale = in_q(
comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value)

y = one_sided_q_dot_dq(
y = q_dot_dq(
x,
q_x,
new_input_scale, # actualy new lhs scale
k,
self.input_scale.value,
self.kernel_scale.value,
self.output_grad_scale.value,
self.input_amax_history.value,
self.kernel_amax_history.value,
self.output_grad_amax_history.value,
comp_dtype,
dimension_numbers,
preferred_element_type=x.dtype
preferred_element_type=x.dtype,
)

return y # type: ignore

def one_sided_q_dot_dq_impl(
Expand Down

0 comments on commit c9b0fc5

Please sign in to comment.