Skip to content

Commit

Permalink
[Pallas/MGPU] Skip outgoing TMA when the output is being revisited
Browse files Browse the repository at this point in the history
Otherwise we end up with programs that race on writes to the same GMEM location.

PiperOrigin-RevId: 679111363
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent 3c25da2 commit 3580394
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 126 deletions.
6 changes: 3 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class GPUCompilerParams(pallas_core.CompilerParams):
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 2 and cannot be lower.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
max_concurrent_steps: int = 2


class GPUMemorySpace(enum.Enum):
Expand Down
Loading

0 comments on commit 3580394

Please sign in to comment.