Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. #22699

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax/_src/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .fusion import cudnn_fusion
91 changes: 91 additions & 0 deletions jax/_src/cudnn/fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import jax
from jax import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir



def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
del unused_kwargs
return jax_core.jaxpr_as_fun(jaxpr)(*args)


def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
del unused_kwargs
del args
return jaxpr.out_avals


cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
cudnn_fusion_p.multiple_results = True
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)


def call_cudnn_fusion(f, *args, **kwargs):
"""Creates a new cudnn_fusion corresponding to calling
the given function f with args and kwargs."""
jaxpr, out_shapes = jax.make_jaxpr(
functools.partial(f, **kwargs), return_shape=True
)(*args)
flat_args = jax.tree.leaves(args)
out_tree = jax.tree.structure(out_shapes)
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
return jax.tree.unflatten(out_tree, out_flat)


def _cudnn_fusion_stablehlo_lowering(
ctx,
*args,
name,
jaxpr,
):
"""Make cudnn_fusion which calls the implementation function.
Currently this leaks a CallOp since we're using the `core_call_lowering`
function, but this should get cleaned up by DCE easily.
"""
impl = mlir.core_call_lowering(
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
)
call_op = impl[0].owner
called_fn = call_op.attributes["callee"]
cudnn_fusion = hlo.CustomCallOp(
[r.type for r in call_op.results],
call_op.operands,
call_target_name="__cudnn$fusion",
called_computations=ir.ArrayAttr.get([called_fn]),
)
return cudnn_fusion.results


mlir.register_lowering(
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
)


def cudnn_fusion(f):
sergachev marked this conversation as resolved.
Show resolved Hide resolved
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
fusions. For example - batch matmul with mixed types and addition:

@cudnn_fusion
def fn(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
"""
return functools.partial(call_cudnn_fusion, f)
14 changes: 14 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,20 @@ py_test(
],
)

jax_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
tags = ["multiaccelerator"],
)

exports_files(
[
"api_test.py",
Expand Down
69 changes: 69 additions & 0 deletions tests/cudnn_fusion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest, parameterized
from unittest import SkipTest
from jax._src import test_util as jtu
import jax
import jax.numpy as jnp
from jax._src.cudnn import cudnn_fusion


jax.config.parse_flags_with_absl()


class CudnnFusionTest(jtu.JaxTestCase):
def setUp(self):
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on >= sm80 GPUs")
super().setUp()

@parameterized.parameters(["", "pmap"])
@jtu.run_on_devices("cuda")
def test_cudnn_fusion(self, mode):
batch_size = 2
if mode == "pmap" and jax.device_count() < batch_size:
raise SkipTest("pmap test requires 2 GPUs")

@cudnn_fusion
def comp1(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z

k = jax.random.key(0)
s = batch_size, 16, 16
x = jnp.int8(jax.random.normal(k, shape=s))
y = jnp.bfloat16(jax.random.normal(k, shape=s))
z = jnp.float32(jax.random.normal(k, shape=s))

fn = jax.pmap(comp1) if mode == "pmap" else comp1
jitted = jax.jit(comp1)
lowered = jitted.lower(x, y, z)
stablehlo = lowered.as_text("stablehlo")
self.assertIn("func.func private @comp1", stablehlo)
self.assertIn("__cudnn$fusion", stablehlo)

hlo = lowered.as_text("hlo")
self.assertIn('custom_call_target="__cudnn$fusion"', hlo)
self.assertIn("called_computations=", hlo)

hlo_after_opt = lowered.compile().as_text()
self.assertIn("kind=kCustom", hlo_after_opt)
self.assertIn("plan_id", hlo_after_opt)

self.assertAllClose(jitted(x, y, z), fn(x, y, z))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())