diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 52d403cd0131..775b7c2ea898 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -132,7 +132,7 @@ def build_kernel( if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: - raise ValueError("Transpose only supported for only happen for 16bit types") + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") if swizzle not in {32, 64, 128}: raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}")