From dcbdfc5e10ad3d3a3f01dbfce6205dab83bcfb65 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 25 Jul 2024 13:22:54 +0000 Subject: [PATCH] wip --- docs/JAX FP8 matmul tutorial.ipynb | 577 ++++++++++++++++++++++++++++- 1 file changed, 576 insertions(+), 1 deletion(-) diff --git a/docs/JAX FP8 matmul tutorial.ipynb b/docs/JAX FP8 matmul tutorial.ipynb index 93934ad..1f05ee9 100644 --- a/docs/JAX FP8 matmul tutorial.ipynb +++ b/docs/JAX FP8 matmul tutorial.ipynb @@ -1,9 +1,584 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "c29689be-e7f4-40fb-9942-8f8944364239", + "metadata": {}, + "source": [ + "# JAX FP8 (fused) matmul tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "f878aaba-ce22-42d3-89e3-ad7f22b6f75c", + "metadata": {}, + "source": [ + "## FP8 in machine learning quickstart\n", + "\n", + "* Two FP8 datatypes;\n", + "* Why?\n", + "* Papers;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51775bad-18ad-49b7-9371-930b3704a294", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "13fdbeaf-c0e7-4c10-8fe9-9109a49eefe2", + "metadata": {}, + "source": [ + "## FP8 matrix multiplication API" + ] + }, + { + "cell_type": "markdown", + "id": "b8234ebd-6f63-4f71-bb6c-b0d9b7cd2ddc", + "metadata": {}, + "source": [ + "## FP8 matmul in JAX: from simple to complicated!" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "e238ba4d-d749-477b-9ce9-f457a17a75a7", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax_scalify.utils import print_hlo_module" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "2893288d-7f2a-42e1-8541-afeed1d63a85", + "metadata": {}, + "outputs": [], + "source": [ + "# Starting with the most simple matmul!\n", + "def matmul_fn(a_fp8, b_fp8):\n", + " # FP8 x FP8 -> FP8 matmul\n", + " return jax.lax.dot(a_fp8, b_fp8)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "a6142f8d-08ee-4fa6-962f-2b85a1bcecb6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[64,128]{1,0})->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"9fadaa690799c7afd9b2ce5c8db8224a\"}\n", + "\n", + "%wrapped_transpose_computation (param_0: f8e4m3fn[64,128]) -> f8e4m3fn[128,64] {\n", + " %param_0 = f8e4m3fn[64,128]{1,0} parameter(0)\n", + " ROOT %transpose.1.1 = f8e4m3fn[128,64]{1,0} transpose(f8e4m3fn[64,128]{1,0} %param_0), dimensions={1,0}\n", + "}\n", + "\n", + "ENTRY %main.4 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[64,128]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_1.2.0 = f8e4m3fn[64,128]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %wrapped_transpose = f8e4m3fn[128,64]{1,0} fusion(f8e4m3fn[64,128]{1,0} %Arg_1.2.0), kind=kInput, calls=%wrapped_transpose_computation\n", + " %cublas-gemm.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %wrapped_transpose, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " backend_cfg: {\n", + " \"operation_queue_id\": \"0\",\n", + " \"wait_on_operation_queues\": [],\n", + " \"gemm_backend_config\": {\n", + " \"alpha_real\": 1,\n", + " \"alpha_imag\": 0,\n", + " \"beta\": 0,\n", + " \"dot_dimension_numbers\": {\n", + " \"lhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"rhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"lhs_batch_dimensions\": [],\n", + " \"rhs_batch_dimensions\": []\n", + " },\n", + " \"precision_config\": {\n", + " \"operand_precision\": [\n", + " \"DEFAULT\",\n", + " \"DEFAULT\"\n", + " ],\n", + " \"algorithm\": \"ALG_UNSET\"\n", + " },\n", + " \"epilogue\": \"DEFAULT\",\n", + " \"damax_output\": false,\n", + " \"selected_algorithm\": \"2\",\n", + " \"lhs_stride\": \"2048\",\n", + " \"rhs_stride\": \"8192\",\n", + " \"grad_x\": false,\n", + " \"grad_y\": false\n", + " },\n", + " \"force_earliest_schedule\": false\n", + " }\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "a_aval = jax.core.ShapedArray((32, 64), jnp.float8_e4m3fn)\n", + "b_aval = jax.core.ShapedArray((64, 128), jnp.float8_e4m3fn)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn).lower(a_aval, b_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA.\n", + "print_hlo_module(fn_compiled, backend_cfg=True)" + ] + }, + { + "cell_type": "markdown", + "id": "89cc3c24-70bb-4b03-b207-4ac304621579", + "metadata": {}, + "source": [ + "**A couple of things to remark:**\n", + "* XLA generates a `custom-call` to `__cublas$lt$matmul$f8` after recognizing the FP8 matmul.\n", + "* Only last axis reduction is supported by `cublasLtMatmul`, hence why an additional transpose is added.\n", + "* `__cublas$lt$matmul$f8` has 4 additional `f32` arguments (set here to a constant `1`) corresponding to scale factors.\n", + "\n", + "**How to use scaling factors from JAX?**" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "7a1b4871-dc21-497c-a894-ba5d266c8b08", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ERROR: Input dtypes ('float32', 'float8_e4m3fn') have no available implicit dtype promotion path. To avoid unintended promotion, 8-bit floats do not support implicit promotion. If you'd like your inputs to be promoted to another type, you can do so explicitly using e.g. x.astype('float32')\n" + ] + } + ], + "source": [ + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, c_scale):\n", + " # First try: just scale the input.\n", + " a_fp8 = a_fp8 * a_scale\n", + " out = jax.lax.dot(a_fp8, b_fp8.T)\n", + " return out\n", + "\n", + "# `cublasLtMatmul` expecting FP32 scales.\n", + "scale_aval = jax.core.ShapedArray((), jnp.float32)\n", + "try:\n", + " fn_compiled = jax.jit(matmul_fn_with_scale).lower(a_aval, b_aval, scale_aval, scale_aval, scale_aval).compile()\n", + "except Exception as e:\n", + " # Issue: do not support implicit mixed-multiplication FP8 x FP32\n", + " print(\"ERROR:\", e)" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "6abbed6c-2a06-4942-becb-da9c7b7b79e7", + "metadata": {}, + "outputs": [], + "source": [ + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.transpose())\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-448), d_fp32, jnp.float32(448))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce0ad419-d73d-4b7f-b834-3edc8d4ddbf7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "fd42a40d-5e61-4417-b425-8fb251cd2171", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"0ae55bc5ea38f0523b45347dc49424c4\"}\n", + "\n", + "ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " backend_cfg: {\n", + " \"operation_queue_id\": \"0\",\n", + " \"wait_on_operation_queues\": [],\n", + " \"gemm_backend_config\": {\n", + " \"alpha_real\": 1,\n", + " \"alpha_imag\": 0,\n", + " \"beta\": 0,\n", + " \"dot_dimension_numbers\": {\n", + " \"lhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"rhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"lhs_batch_dimensions\": [],\n", + " \"rhs_batch_dimensions\": []\n", + " },\n", + " \"precision_config\": {\n", + " \"operand_precision\": [\n", + " \"DEFAULT\",\n", + " \"DEFAULT\"\n", + " ],\n", + " \"algorithm\": \"ALG_UNSET\"\n", + " },\n", + " \"epilogue\": \"DEFAULT\",\n", + " \"damax_output\": false,\n", + " \"selected_algorithm\": \"2\",\n", + " \"lhs_stride\": \"2048\",\n", + " \"rhs_stride\": \"8192\",\n", + " \"grad_x\": false,\n", + " \"grad_y\": false\n", + " },\n", + " \"force_earliest_schedule\": false\n", + " }\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "a_aval = jax.core.ShapedArray((32, 64), jnp.float8_e4m3fn)\n", + "b_aval = jax.core.ShapedArray((128, 64), jnp.float8_e4m3fn)\n", + "# `cublasLtMatmul` expecting F32 scales.\n", + "scale_aval = jax.core.ShapedArray((), jnp.float32)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a_aval, b_aval, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA.\n", + "print_hlo_module(fn_compiled, backend_cfg=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "id": "17652aac-36ae-41e9-892c-9d58c8e76283", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# cdf(x) = 0.5 * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x**3))\n", + "def gelu(x_arr):\n", + " # sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype)\n", + " two_over_pi = np.array(2 / np.pi).astype(x_arr.dtype)\n", + " sqrt_2_over_pi = np.sqrt(two_over_pi)\n", + " sqrt_2_over_pi = np.array(0.797884583, x_arr.dtype)\n", + " cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))\n", + " return x_arr * cdf" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "a3dfbb39-6867-4de7-a329-c39c824bcedf", + "metadata": {}, + "outputs": [], + "source": [ + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " dtype = jnp.bfloat16\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(dtype) * a_scale.astype(dtype)\n", + " b_fp32 = b_fp8.astype(dtype) * b_scale.astype(dtype)\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.transpose())\n", + "\n", + " print(d_fp32)\n", + "\n", + " # d_fp32 = jax.nn.relu(d_fp32)\n", + " d_fp32 = gelu(d_fp32)\n", + " return d_fp32\n", + "\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(dtype(-448), d_fp32, dtype(448))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "id": "492d5394-4363-49c7-ab97-c4f454a4ddfb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tracedwith\n", + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[])->bf16[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"a07a42366a4e4fa29c13ac7bf8758a5f\"}\n", + "\n", + "%fused_multiply (param_0.10: bf16[32,128]) -> bf16[32,128] {\n", + " %param_0.10 = bf16[32,128]{1,0} parameter(0)\n", + " %multiply.8.3 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %param_0.10, bf16[32,128]{1,0} %param_0.10)\n", + " %multiply.9.3 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %multiply.8.3, bf16[32,128]{1,0} %param_0.10)\n", + " %constant_11_1 = bf16[] constant(0.04468)\n", + " %broadcast.7.1 = bf16[32,128]{1,0} broadcast(bf16[] %constant_11_1), dimensions={}\n", + " %multiply.10.1 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %multiply.9.3, bf16[32,128]{1,0} %broadcast.7.1)\n", + " %add.2.1 = bf16[32,128]{1,0} add(bf16[32,128]{1,0} %param_0.10, bf16[32,128]{1,0} %multiply.10.1)\n", + " %constant_9_1 = bf16[] constant(0.7969)\n", + " %broadcast.9.1 = bf16[32,128]{1,0} broadcast(bf16[] %constant_9_1), dimensions={}\n", + " %multiply.11.1 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %add.2.1, bf16[32,128]{1,0} %broadcast.9.1)\n", + " %convert.10.1 = f32[32,128]{1,0} convert(bf16[32,128]{1,0} %multiply.11.1)\n", + " %tanh.1.7 = f32[32,128]{1,0} tanh(f32[32,128]{1,0} %convert.10.1)\n", + " %convert.11.7 = bf16[32,128]{1,0} convert(f32[32,128]{1,0} %tanh.1.7)\n", + " %constant_7_1 = bf16[] constant(1)\n", + " %broadcast.11.1 = bf16[32,128]{1,0} broadcast(bf16[] %constant_7_1), dimensions={}\n", + " %add.3.5 = bf16[32,128]{1,0} add(bf16[32,128]{1,0} %convert.11.7, bf16[32,128]{1,0} %broadcast.11.1)\n", + " %constant_5_1 = bf16[] constant(0.5)\n", + " %broadcast.13.1 = bf16[32,128]{1,0} broadcast(bf16[] %constant_5_1), dimensions={}\n", + " %multiply.12.3 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %add.3.5, bf16[32,128]{1,0} %broadcast.13.1)\n", + " ROOT %multiply.13.1 = bf16[32,128]{1,0} multiply(bf16[32,128]{1,0} %param_0.10, bf16[32,128]{1,0} %multiply.12.3)\n", + "}\n", + "\n", + "ENTRY %main.32 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[]) -> bf16[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.1.0 = (bf16[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " backend_cfg: {\n", + " \"operation_queue_id\": \"0\",\n", + " \"wait_on_operation_queues\": [],\n", + " \"gemm_backend_config\": {\n", + " \"alpha_real\": 1,\n", + " \"alpha_imag\": 0,\n", + " \"beta\": 0,\n", + " \"dot_dimension_numbers\": {\n", + " \"lhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"rhs_contracting_dimensions\": [\n", + " \"1\"\n", + " ],\n", + " \"lhs_batch_dimensions\": [],\n", + " \"rhs_batch_dimensions\": []\n", + " },\n", + " \"precision_config\": {\n", + " \"operand_precision\": [\n", + " \"DEFAULT\",\n", + " \"DEFAULT\"\n", + " ],\n", + " \"algorithm\": \"ALG_UNSET\"\n", + " },\n", + " \"epilogue\": \"DEFAULT\",\n", + " \"damax_output\": false,\n", + " \"selected_algorithm\": \"3\",\n", + " \"lhs_stride\": \"2048\",\n", + " \"rhs_stride\": \"8192\",\n", + " \"grad_x\": false,\n", + " \"grad_y\": false\n", + " },\n", + " \"force_earliest_schedule\": false\n", + " }\n", + " %get-tuple-element.1 = bf16[32,128]{1,0} get-tuple-element((bf16[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0\n", + " ROOT %loop_multiply_fusion = bf16[32,128]{1,0} fusion(bf16[32,128]{1,0} %get-tuple-element.1), kind=kLoop, calls=%fused_multiply\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "a_aval = jax.core.ShapedArray((32, 64), jnp.float8_e4m3fn)\n", + "b_aval = jax.core.ShapedArray((128, 64), jnp.float8_e4m3fn)\n", + "# `cublasLtMatmul` expecting F32 scales.\n", + "scale_aval = jax.core.ShapedArray((), jnp.float32)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a_aval, b_aval, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA.\n", + "print_hlo_module(fn_compiled, backend_cfg=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "id": "6bad9562-dc5b-4b89-ac22-85dc75498fd3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:bf16[128,64]. let\n", + " b:bf16[128,64] = integer_pow[y=3] a\n", + " c:bf16[128,64] = mul 0.0446777 b\n", + " d:bf16[128,64] = add a c\n", + " e:bf16[128,64] = mul 0.796875 d\n", + " f:bf16[128,64] = tanh e\n", + " g:bf16[128,64] = add 1 f\n", + " h:bf16[128,64] = mul 0.5 g\n", + " i:bf16[128,64] = mul a h\n", + " in (i,) }" + ] + }, + "execution_count": 174, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aval = jax.core.ShapedArray((128, 64), jnp.bfloat16)\n", + "\n", + "jax.make_jaxpr(gelu)(aval)" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "id": "52d70b59-6717-4989-a4b8-5c0f3fd131cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit_gelu attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor<128x64xbf16> {mhlo.layout_mode = \"default\"}) -> (tensor<128x64xbf16> {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n", + " %0 = stablehlo.multiply %arg0, %arg0 : tensor<128x64xbf16>\n", + " %1 = stablehlo.multiply %0, %arg0 : tensor<128x64xbf16>\n", + " %cst = stablehlo.constant dense<4.467770e-02> : tensor\n", + " %2 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<128x64xbf16>\n", + " %3 = stablehlo.multiply %2, %1 : tensor<128x64xbf16>\n", + " %4 = stablehlo.add %arg0, %3 : tensor<128x64xbf16>\n", + " %cst_0 = stablehlo.constant dense<7.968750e-01> : tensor\n", + " %5 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<128x64xbf16>\n", + " %6 = stablehlo.multiply %5, %4 : tensor<128x64xbf16>\n", + " %7 = stablehlo.tanh %6 : tensor<128x64xbf16>\n", + " %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor\n", + " %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<128x64xbf16>\n", + " %9 = stablehlo.add %8, %7 : tensor<128x64xbf16>\n", + " %cst_2 = stablehlo.constant dense<5.000000e-01> : tensor\n", + " %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<128x64xbf16>\n", + " %11 = stablehlo.multiply %10, %9 : tensor<128x64xbf16>\n", + " %12 = stablehlo.multiply %arg0, %11 : tensor<128x64xbf16>\n", + " return %12 : tensor<128x64xbf16>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "print(jax.jit(gelu).lower(aval).as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "0a006bd0-e5e3-4089-a4a8-eceaa87e2664", + "metadata": {}, + "outputs": [], + "source": [ + "# 0.797884583" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "id": "3654fca8-d543-4f38-b383-7644a7028fb4", + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (276308696.py, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m Cell \u001b[0;32mIn[177], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m x = <>[16,32] parameter(0)\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + " x = <>[16,32] parameter(0)\n", + "y = <>[32,16] parameter(1)\n", + "x_bf16 = bf16[16,32] convert(x)\n", + "y_bf16 = bf16[32,16] convert(y)\n", + "x_scale = bf16[] parameter(2)\n", + "y_scale = bf16[] parameter(3)\n", + "x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}\n", + "y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}\n", + "x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)\n", + "y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)\n", + "dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n", + "mul.0 = bf16[16,16] multiply(dot, dot)\n", + "mul.1 = bf16[16,16] multiply(dot, mul.0)\n", + "const.0 = bf16[] constant(0.044715)\n", + "bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}\n", + "mul.2 = bf16[16,16] multiply(mul.1, bcast.0)\n", + "add.0 = bf16[16,16] add(dot, mul.2)\n", + "const.1 = bf16[] constant(0.797884583)\n", + "bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}\n", + "mul.3 = bf16[16,16] multiply(add.0, bcast.1)\n", + "tanh = bf16[16,16] tanh(mul.3)\n", + "const.2 = bf16[] constant(1)\n", + "bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}\n", + "add.2 = bf16[16,16] add(tanh, bcast.2)\n", + "const.3 = bf16[] constant(0.5)\n", + "bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}\n", + "mul.4 = bf16[16,16] multiply(add.2, bcast.3)\n", + "ROOT out = bf16[16,16] multiply(dot, mul.4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8bdaf95-7a83-4cef-ad2b-ce6520bf69c8", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, - "id": "e8d340b1-0527-4014-a82c-2877cc32790a", + "id": "b9a314f6-2e98-41cb-a1e6-c041b0e0b973", "metadata": {}, "outputs": [], "source": []