Skip to content

Commit

Permalink
Added pl.CompilerParams subclass for Mosaic GPU
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671066741
  • Loading branch information
superbobry authored and jax authors committed Sep 4, 2024
1 parent 3672b63 commit a8a55e0
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 23 deletions.
12 changes: 8 additions & 4 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import functools
import itertools
import threading
from typing import Any, ClassVar, Hashable, Union
from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable
import warnings

import jax
Expand Down Expand Up @@ -66,10 +66,14 @@ def __repr__(self):
SEMAPHORE_INTERPRET_DTYPE = jnp.int16


@dataclasses.dataclass(frozen=True)
class CompilerParams:
@runtime_checkable
class CompilerParams(Protocol):
"""Base class for compiler parameters."""
PLATFORM: ClassVar[str] = "unspecified"
PLATFORM: ClassVar[str]

# Subclasses must be dataclasses.
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import dataclasses
import enum
import functools
from typing import Any, ClassVar, Hashable
from typing import Any, ClassVar, Hashable, Literal

import jax
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import util
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
import jax.numpy as jnp
import numpy as np

map, unsafe_map = util.safe_map, map
Expand Down Expand Up @@ -68,7 +68,7 @@ class TPUCompilerParams(pallas_core.CompilerParams):
device_type: The device type to compile for.
"""
PLATFORM: ClassVar[str] = "mosaic"
dimension_semantics: Sequence[str] | None = None
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None
allow_input_fusion: Sequence[bool] | None = None
vmem_limit_bytes: int | None = None
collective_id: int | None = None
Expand Down
19 changes: 19 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,34 @@

"""Contains GPU-specific Pallas abstractions."""

from collections.abc import Sequence
import dataclasses
import enum
from typing import ClassVar, Literal
from jax import core as jax_core
from jax._src.pallas import core as pallas_core
import jax.numpy as jnp

AbstractMemoryRef = pallas_core.AbstractMemoryRef


@dataclasses.dataclass(frozen=True)
class GPUCompilerParams(pallas_core.CompilerParams):
"""Mosaic GPU compiler parameters.
Attributes:
dimension_semantics: A list of dimension semantics for each grid
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1


class GPUMemorySpace(enum.Enum):
GMEM = "gmem"
SMEM = "smem"
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import dataclasses
import functools
import math
from typing import Any, Literal, TypedDict, cast
from typing import Any, cast

import jax
from jax._src import core as jax_core
Expand Down Expand Up @@ -152,11 +152,6 @@ def _eval_index_map(
return tuple(result)


class Params(TypedDict, total=False):
num_stages: int
dimension_semantics: Sequence[Literal["sequential", "parallel"]]


def lower_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
Expand Down Expand Up @@ -199,7 +194,7 @@ def lower_jaxpr_to_module(
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)

params = Params(**compiler_params.get("mosaic_gpu", {}))
params = compiler_params.get("mosaic_gpu", {})
num_stages = params.get("num_stages", 1)
dimension_semantics = params.get(
"dimension_semantics", ["parallel"] * len(grid_mapping.grid)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def pallas_call(
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "triton"]:
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
compiler_params = {
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax._src.deprecations import register as _register_deprecation
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import no_block_spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""PagedAttention TPU kernel."""

from collections.abc import Sequence
import functools
from typing import Literal

import jax
from jax import lax
Expand Down Expand Up @@ -516,6 +518,7 @@ def paged_attention(
)
q_dtype_for_kernel_launch = q.dtype

dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
if inline_seq_dim:
kernel = paged_flash_attention_kernel_inline_seq_dim
grid = (
Expand All @@ -525,7 +528,7 @@ def paged_attention(
if megacore_mode == "kv_head"
else num_kv_heads,
)
dimension_sematics = ("parallel", "arbitrary", "arbitrary")
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
else:
kernel = paged_flash_attention_kernel
grid = (
Expand All @@ -536,7 +539,7 @@ def paged_attention(
else num_kv_heads,
pages_per_sequence // pages_per_compute_block,
) # type: ignore
dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")

if k_scales_pages is not None and v_scales_pages is not None:
in_specs = [
Expand Down Expand Up @@ -641,7 +644,7 @@ def paged_attention(
scratch_shapes=scratch_shapes,
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=dimension_sematics),
dimension_semantics=dimension_semantics),
out_shape=[
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
Expand Down
9 changes: 4 additions & 5 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,15 @@ def kernel(x_ref, o_ref):

@parameterized.product(num_stages=[1, 2, 3])
def test_add_one_grid_pipelined(self, num_stages):

@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=dict(
mosaic_gpu=dict(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
),
grid=(2, 1),
)
Expand Down

0 comments on commit a8a55e0

Please sign in to comment.