diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 03bfd28d0b9a..01b027d386c3 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -23,7 +23,7 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Union +from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable import warnings import jax @@ -66,10 +66,14 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 -@dataclasses.dataclass(frozen=True) -class CompilerParams: +@runtime_checkable +class CompilerParams(Protocol): """Base class for compiler parameters.""" - PLATFORM: ClassVar[str] = "unspecified" + PLATFORM: ClassVar[str] + + # Subclasses must be dataclasses. + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index e549ee05e770..61b1dc435e72 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,14 +19,14 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Hashable +from typing import Any, ClassVar, Hashable, Literal import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import util -import jax.numpy as jnp from jax._src.pallas import core as pallas_core +import jax.numpy as jnp import numpy as np map, unsafe_map = util.safe_map, map @@ -68,7 +68,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): device_type: The device type to compile for. """ PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: Sequence[str] | None = None + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None allow_input_fusion: Sequence[bool] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fd06a9829644..6619a9acfd02 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -14,8 +14,10 @@ """Contains GPU-specific Pallas abstractions.""" +from collections.abc import Sequence import dataclasses import enum +from typing import ClassVar, Literal from jax import core as jax_core from jax._src.pallas import core as pallas_core import jax.numpy as jnp @@ -23,6 +25,23 @@ AbstractMemoryRef = pallas_core.AbstractMemoryRef +@dataclasses.dataclass(frozen=True) +class GPUCompilerParams(pallas_core.CompilerParams): + """Mosaic GPU compiler parameters. + + Attributes: + dimension_semantics: A list of dimension semantics for each grid + dimension of the kernel. Either "parallel" for dimensions that can + execute in any order, or "sequential" for dimensions that must be + executed sequentially. + num_stages: The number of pipline stages in the kernel. Defaults to 1, + meaning no pipelining is done. + """ + PLATFORM: ClassVar[str] = "mosaic_gpu" + dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None + num_stages: int = 1 + + class GPUMemorySpace(enum.Enum): GMEM = "gmem" SMEM = "smem" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index fb0119025d4c..88cd4545ae7b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -20,7 +20,7 @@ import dataclasses import functools import math -from typing import Any, Literal, TypedDict, cast +from typing import Any, cast import jax from jax._src import core as jax_core @@ -152,11 +152,6 @@ def _eval_index_map( return tuple(result) -class Params(TypedDict, total=False): - num_stages: int - dimension_semantics: Sequence[Literal["sequential", "parallel"]] - - def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -199,7 +194,7 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) - params = Params(**compiler_params.get("mosaic_gpu", {})) + params = compiler_params.get("mosaic_gpu", {}) num_stages = params.get("num_stages", 1) dimension_semantics = params.get( "dimension_semantics", ["parallel"] * len(grid_mapping.grid) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 5bbf37dc3663..d3b0ea8ca080 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1291,7 +1291,7 @@ def pallas_call( if compiler_params is None: compiler_params = {} if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "triton"]: + if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: raise ValueError( f"Unknown platform in compiler params: {compiler_params.PLATFORM}") compiler_params = { diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 9a768ed53e75..832f7b7d1184 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -21,6 +21,7 @@ from jax._src.deprecations import register as _register_deprecation from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec +from jax._src.pallas.core import CompilerParams from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import IndexingMode from jax._src.pallas.core import no_block_spec diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index cd811a874385..eb1e11df17da 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -14,7 +14,9 @@ """PagedAttention TPU kernel.""" +from collections.abc import Sequence import functools +from typing import Literal import jax from jax import lax @@ -516,6 +518,7 @@ def paged_attention( ) q_dtype_for_kernel_launch = q.dtype + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] if inline_seq_dim: kernel = paged_flash_attention_kernel_inline_seq_dim grid = ( @@ -525,7 +528,7 @@ def paged_attention( if megacore_mode == "kv_head" else num_kv_heads, ) - dimension_sematics = ("parallel", "arbitrary", "arbitrary") + dimension_semantics = ("parallel", "arbitrary", "arbitrary") else: kernel = paged_flash_attention_kernel grid = ( @@ -536,7 +539,7 @@ def paged_attention( else num_kv_heads, pages_per_sequence // pages_per_compute_block, ) # type: ignore - dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore + dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary") if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ @@ -641,7 +644,7 @@ def paged_attention( scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_sematics), + dimension_semantics=dimension_semantics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bdf6396cae5c..80cfd04c44d6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -80,16 +80,15 @@ def kernel(x_ref, o_ref): @parameterized.product(num_stages=[1, 2, 3]) def test_add_one_grid_pipelined(self, num_stages): + @functools.partial( pl.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), - compiler_params=dict( - mosaic_gpu=dict( - dimension_semantics=["parallel", "sequential"], - num_stages=num_stages, - ), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + num_stages=num_stages, ), grid=(2, 1), )