Skip to content

Commit

Permalink
Merge pull request #22699 from sergachev:cudnn_fusion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671395864
  • Loading branch information
jax authors committed Sep 5, 2024
2 parents 2dd13ce + 85d792a commit 8fe99ff
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .fusion import cudnn_fusion
91 changes: 91 additions & 0 deletions jax/_src/cudnn/fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import jax
from jax import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir



def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
del unused_kwargs
return jax_core.jaxpr_as_fun(jaxpr)(*args)


def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
del unused_kwargs
del args
return jaxpr.out_avals


cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
cudnn_fusion_p.multiple_results = True
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)


def call_cudnn_fusion(f, *args, **kwargs):
"""Creates a new cudnn_fusion corresponding to calling
the given function f with args and kwargs."""
jaxpr, out_shapes = jax.make_jaxpr(
functools.partial(f, **kwargs), return_shape=True
)(*args)
flat_args = jax.tree.leaves(args)
out_tree = jax.tree.structure(out_shapes)
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
return jax.tree.unflatten(out_tree, out_flat)


def _cudnn_fusion_stablehlo_lowering(
ctx,
*args,
name,
jaxpr,
):
"""Make cudnn_fusion which calls the implementation function.
Currently this leaks a CallOp since we're using the `core_call_lowering`
function, but this should get cleaned up by DCE easily.
"""
impl = mlir.core_call_lowering(
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
)
call_op = impl[0].owner
called_fn = call_op.attributes["callee"]
cudnn_fusion = hlo.CustomCallOp(
[r.type for r in call_op.results],
call_op.operands,
call_target_name="__cudnn$fusion",
called_computations=ir.ArrayAttr.get([called_fn]),
)
return cudnn_fusion.results


mlir.register_lowering(
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
)


def cudnn_fusion(f):
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
fusions. For example - batch matmul with mixed types and addition:
@cudnn_fusion
def fn(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
"""
return functools.partial(call_cudnn_fusion, f)
14 changes: 14 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,20 @@ py_test(
],
)

jax_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
disable_backends = [
"cpu",
"tpu",
],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
tags = ["multiaccelerator"],
)

exports_files(
[
"api_test.py",
Expand Down
69 changes: 69 additions & 0 deletions tests/cudnn_fusion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest, parameterized
from unittest import SkipTest
from jax._src import test_util as jtu
import jax
import jax.numpy as jnp
from jax._src.cudnn import cudnn_fusion


jax.config.parse_flags_with_absl()


class CudnnFusionTest(jtu.JaxTestCase):
def setUp(self):
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on >= sm80 GPUs")
super().setUp()

@parameterized.parameters(["", "pmap"])
@jtu.run_on_devices("cuda")
def test_cudnn_fusion(self, mode):
batch_size = 2
if mode == "pmap" and jax.device_count() < batch_size:
raise SkipTest("pmap test requires 2 GPUs")

@cudnn_fusion
def comp1(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z

k = jax.random.key(0)
s = batch_size, 16, 16
x = jnp.int8(jax.random.normal(k, shape=s))
y = jnp.bfloat16(jax.random.normal(k, shape=s))
z = jnp.float32(jax.random.normal(k, shape=s))

fn = jax.pmap(comp1) if mode == "pmap" else comp1
jitted = jax.jit(comp1)
lowered = jitted.lower(x, y, z)
stablehlo = lowered.as_text("stablehlo")
self.assertIn("func.func private @comp1", stablehlo)
self.assertIn("__cudnn$fusion", stablehlo)

hlo = lowered.as_text("hlo")
self.assertIn('custom_call_target="__cudnn$fusion"', hlo)
self.assertIn("called_computations=", hlo)

hlo_after_opt = lowered.compile().as_text()
self.assertIn("kind=kCustom", hlo_after_opt)
self.assertIn("plan_id", hlo_after_opt)

self.assertAllClose(jitted(x, y, z), fn(x, y, z))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 8fe99ff

Please sign in to comment.