From 85d792a92d07f5600d7796d57019fcad58228a59 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 24 Jul 2024 16:31:03 +0200 Subject: [PATCH] Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. --- jax/_src/cudnn/__init__.py | 2 + jax/_src/cudnn/fusion.py | 91 ++++++++++++++++++++++++++++++++++++++ tests/BUILD | 14 ++++++ tests/cudnn_fusion_test.py | 69 +++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+) create mode 100644 jax/_src/cudnn/fusion.py create mode 100644 tests/cudnn_fusion_test.py diff --git a/jax/_src/cudnn/__init__.py b/jax/_src/cudnn/__init__.py index 862a661e24b9..23d1fa28ff43 100644 --- a/jax/_src/cudnn/__init__.py +++ b/jax/_src/cudnn/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .fusion import cudnn_fusion diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py new file mode 100644 index 000000000000..8a13399e3d63 --- /dev/null +++ b/jax/_src/cudnn/fusion.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import jax +from jax import core as jax_core +from jax.interpreters import mlir +from jax.interpreters.mlir import hlo +from jax.interpreters.mlir import ir + + + +def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs): + del unused_kwargs + return jax_core.jaxpr_as_fun(jaxpr)(*args) + + +def _custom_abstract_eval(*args, jaxpr, **unused_kwargs): + del unused_kwargs + del args + return jaxpr.out_avals + + +cudnn_fusion_p = jax_core.Primitive("cudnn_fusion") +cudnn_fusion_p.multiple_results = True +cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval) +cudnn_fusion_p.def_impl(_cudnn_fusion_impl) + + +def call_cudnn_fusion(f, *args, **kwargs): + """Creates a new cudnn_fusion corresponding to calling + the given function f with args and kwargs.""" + jaxpr, out_shapes = jax.make_jaxpr( + functools.partial(f, **kwargs), return_shape=True + )(*args) + flat_args = jax.tree.leaves(args) + out_tree = jax.tree.structure(out_shapes) + out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr) + return jax.tree.unflatten(out_tree, out_flat) + + +def _cudnn_fusion_stablehlo_lowering( + ctx, + *args, + name, + jaxpr, +): + """Make cudnn_fusion which calls the implementation function. + Currently this leaks a CallOp since we're using the `core_call_lowering` + function, but this should get cleaned up by DCE easily. + """ + impl = mlir.core_call_lowering( + ctx, *args, name=name + ".impl", call_jaxpr=jaxpr + ) + call_op = impl[0].owner + called_fn = call_op.attributes["callee"] + cudnn_fusion = hlo.CustomCallOp( + [r.type for r in call_op.results], + call_op.operands, + call_target_name="__cudnn$fusion", + called_computations=ir.ArrayAttr.get([called_fn]), + ) + return cudnn_fusion.results + + +mlir.register_lowering( + cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda" + ) + + +def cudnn_fusion(f): + """Makes a function become a cuDNN kernel. Relies on XLA's handling of + custom fusions with __cudnn$fusion backend. Currently limited to GEMM + fusions. For example - batch matmul with mixed types and addition: + + @cudnn_fusion + def fn(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + """ + return functools.partial(call_cudnn_fusion, f) diff --git a/tests/BUILD b/tests/BUILD index 45743d306fd6..b624b6bef3ac 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1523,6 +1523,20 @@ py_test( ], ) +jax_test( + name = "cudnn_fusion_test", + srcs = ["cudnn_fusion_test.py"], + disable_backends = [ + "cpu", + "tpu", + ], + enable_configs = [ + "gpu_a100", + "gpu_h100", + ], + tags = ["multiaccelerator"], +) + exports_files( [ "api_test.py", diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py new file mode 100644 index 000000000000..e70ba12361a2 --- /dev/null +++ b/tests/cudnn_fusion_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest, parameterized +from unittest import SkipTest +from jax._src import test_util as jtu +import jax +import jax.numpy as jnp +from jax._src.cudnn import cudnn_fusion + + +jax.config.parse_flags_with_absl() + + +class CudnnFusionTest(jtu.JaxTestCase): + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on >= sm80 GPUs") + super().setUp() + + @parameterized.parameters(["", "pmap"]) + @jtu.run_on_devices("cuda") + def test_cudnn_fusion(self, mode): + batch_size = 2 + if mode == "pmap" and jax.device_count() < batch_size: + raise SkipTest("pmap test requires 2 GPUs") + + @cudnn_fusion + def comp1(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + + k = jax.random.key(0) + s = batch_size, 16, 16 + x = jnp.int8(jax.random.normal(k, shape=s)) + y = jnp.bfloat16(jax.random.normal(k, shape=s)) + z = jnp.float32(jax.random.normal(k, shape=s)) + + fn = jax.pmap(comp1) if mode == "pmap" else comp1 + jitted = jax.jit(comp1) + lowered = jitted.lower(x, y, z) + stablehlo = lowered.as_text("stablehlo") + self.assertIn("func.func private @comp1", stablehlo) + self.assertIn("__cudnn$fusion", stablehlo) + + hlo = lowered.as_text("hlo") + self.assertIn('custom_call_target="__cudnn$fusion"', hlo) + self.assertIn("called_computations=", hlo) + + hlo_after_opt = lowered.compile().as_text() + self.assertIn("kind=kCustom", hlo_after_opt) + self.assertIn("plan_id", hlo_after_opt) + + self.assertAllClose(jitted(x, y, z), fn(x, y, z)) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())