Skip to content

Commit

Permalink
remove flax dependency in test.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Nov 4, 2024
1 parent 10ad8cb commit 8e353b3
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions tests/fused_attention_stablehlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
MaskType,
AttentionLayout,
)
from flax.linen.fp8_ops import qdq, quantize

config.parse_flags_with_absl()
Array = jnp.ndarray
Expand All @@ -48,7 +47,27 @@

fp8_metas = {name: jnp.ones((1, 1, 1, 1), dtype=jnp.float32) for name in fp8_meta_names}

cast_to_representable = partial(qdq, scale=jnp.ones((1,)), compute_dtype=jnp.bfloat16)

def quantize_to_fp8(x, q_dtype, scale, compute_dtype):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2,
jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz)
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)

def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype):
qx = quantize_to_fp8(x, q_dtype, scale, compute_dtype)
out = qx.astype(x.dtype) * jnp.broadcast_to(scale.astype(x.dtype), qx.shape)
return out

cast_to_representable = partial(
quantize_dequantize_fp8, scale=jnp.ones((1,)), compute_dtype=jnp.bfloat16
)

quantize = partial(quantize_to_fp8, scale=jnp.ones((1,)))

def sdpa_train(query: Array,
key: Array,
Expand Down Expand Up @@ -617,10 +636,10 @@ def test_sdpa_fp8(self, batch_size: int, seq_len: int, num_heads: int,
value = cast_to_representable(value_h, jnp.float8_e4m3fn)
grad = cast_to_representable(grad_h, jnp.float8_e4m3fn)

query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
grad_quantized = quantize(grad, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)
grad_quantized = quantize(grad, jnp.float8_e4m3fn, jnp.float32)

sdpa_train_fp8_p = partial(sdpa_train_fp8, scale=scale, mask_type=mask_type)
jitted_sdpa_train_fp8 = jax.jit(sdpa_train_fp8_p)
Expand Down Expand Up @@ -666,9 +685,9 @@ def test_sdpa_fp8_inference(self, batch_size: int, seq_len: int, num_heads: int,
key = cast_to_representable(key_h, jnp.float8_e4m3fn)
value = cast_to_representable(value_h, jnp.float8_e4m3fn)

query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32)
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)

def dot_product_attention_fp8(query, key, value, fp8_metas):
f_p = partial(
Expand Down

0 comments on commit 8e353b3

Please sign in to comment.