Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgraded the sources to 3.10 #295

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import functools

from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -113,7 +112,7 @@ def mha(q, k, v, *,
sm_scale: float = 1.0,
block_q: int = 128,
block_k: int = 128,
num_warps: Optional[int] = None,
num_warps: int | None = None,
num_stages: int = 1,
grid=None,
):
Expand Down
3 changes: 1 addition & 2 deletions examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import functools
import timeit

from typing import Tuple

import jax.numpy as jnp
from jax import random
Expand Down Expand Up @@ -63,7 +62,7 @@ class BlockELL:
blocks: jnp.ndarray # float32[n_rows, n_blocks, *block_size]
blocks_per_row: jnp.ndarray # int32[n_rows, n_blocks]
indices: jnp.ndarray # int32[n_rows, max_num_blocks_per_row, 2]
shape: Tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1])
shape: tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1])

ndim: int = property(lambda self: len(self.shape))
num_blocks = property(lambda self: self.blocks.shape[0])
Expand Down
4 changes: 2 additions & 2 deletions examples/pallas/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def body(k, acc_refs):
accs = for_loop.for_loop(num_k_blocks, body, [acc_i, acc_f, acc_o, acc_g])
bs = [pl.load(b_ref, (idx_n,))
for b_ref in [b_hi_ref, b_hf_ref, b_hg_ref, b_ho_ref]]
acc_i, acc_f, acc_g, acc_o = [acc + b for acc, b in zip(accs, bs)]
acc_i, acc_f, acc_g, acc_o = (acc + b for acc, b in zip(accs, bs))
i_gate, f_gate, o_gate = (
jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o))
cell = jnp.tanh(acc_g)
Expand Down Expand Up @@ -124,7 +124,7 @@ def lstm_cell_reference(weights, x, h, c):
xs = [jnp.dot(x, w) for w in ws]
hs = [jnp.dot(h, u) for u in us]
accs = [x + h for x, h in zip(xs, hs)]
acc_i, acc_f, acc_g, acc_o = [acc + b[None] for acc, b in zip(accs, bs)]
acc_i, acc_f, acc_g, acc_o = (acc + b[None] for acc, b in zip(accs, bs))
i_gate, f_gate, o_gate = (
jax.nn.sigmoid(acc_i), jax.nn.sigmoid(acc_f), jax.nn.sigmoid(acc_o))
cell = jnp.tanh(acc_g)
Expand Down
4 changes: 2 additions & 2 deletions jax_triton/experimental/fusion/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import os

from typing import Any, Tuple
from typing import Any

import jax
from jax import lax
Expand Down Expand Up @@ -204,7 +204,7 @@ def make_elementwise(shape, dtype, *args):
class MatmulElementwise(jax_rewrite.JaxExpression):
x: jax_rewrite.JaxExpression
y: jax_rewrite.JaxExpression
elem_ops: Tuple[core.Primitive]
elem_ops: tuple[core.Primitive]

def match(self, expr, bindings, succeed):
if not isinstance(expr, MatmulElementwise):
Expand Down
25 changes: 13 additions & 12 deletions jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import dataclasses
import itertools as it

from typing import Any, Callable, List, Tuple, Union
from typing import Any
from collections.abc import Callable

from jax._src import core as jax_core
import jax.numpy as jnp
Expand All @@ -35,7 +36,7 @@
class Node(matcher.Pattern, metaclass=abc.ABCMeta):

@abc.abstractproperty
def parents(self) -> List[Node]:
def parents(self) -> list[Node]:
...


Expand All @@ -51,9 +52,9 @@ def map_parents(self, fn: Callable[[Node], Node]) -> Node:
class Eqn(Node):
primitive: jax_core.Primitive
params: jr.Params
invars: List[Node]
shape: Union[Tuple[int, ...], List[Tuple[int, ...]]]
dtype: Union[jnp.dtype, List[jnp.dtype]]
invars: list[Node]
shape: tuple[int, ...] | list[tuple[int, ...]]
dtype: jnp.dtype | list[jnp.dtype]

@property
def parents(self):
Expand All @@ -77,7 +78,7 @@ def match(self, expr, bindings, succeed):

@dataclasses.dataclass(frozen=True, eq=False)
class JaxprVar(Node):
shape: Tuple[int, ...]
shape: tuple[int, ...]
dtype: jnp.dtype

def match(self, expr, bindings, succeed):
Expand Down Expand Up @@ -131,7 +132,7 @@ def from_literal(cls, var: jax_core.Literal) -> Literal:
@dataclasses.dataclass(eq=False)
class Part(Node):
index: int
shape: Tuple[int, ...]
shape: tuple[int, ...]
dtype: jnp.dtype
parent: Node

Expand All @@ -153,9 +154,9 @@ def map_parents(self, fn):

@dataclasses.dataclass(eq=True)
class JaxprGraph(matcher.Pattern):
constvars: List[Node]
invars: List[Node]
outvars: List[Node]
constvars: list[Node]
invars: list[Node]
outvars: list[Node]

def get_nodes(self):
nodes = set(self.outvars)
Expand All @@ -167,7 +168,7 @@ def get_nodes(self):
queue.append(p)
return nodes

def get_children(self, node) -> List[Node]:
def get_children(self, node) -> list[Node]:
nodes = self.get_nodes()
return [n for n in nodes if node in n.parents]

Expand Down Expand Up @@ -274,7 +275,7 @@ def to_jaxpr(self) -> jax_core.Jaxpr:
outvars = [env[n] for n in self.outvars]
return jax_core.Jaxpr(constvars, invars, outvars, eqns, jax_core.no_effects)

def toposort(self) -> List[Node]:
def toposort(self) -> list[Node]:
node_stack = list(self.outvars)
child_counts = {}
while node_stack:
Expand Down
4 changes: 2 additions & 2 deletions jax_triton/experimental/fusion/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Contains lowering passes for jaxprs to pallas."""
import functools

from typing import Any, Dict
from typing import Any

import jax
from jax import api_util
Expand Down Expand Up @@ -317,7 +317,7 @@ def read(v: core.Atom) -> Any:
def write(v: Var, val: Any) -> None:
env[v] = val

env: Dict[Var, Any] = {}
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand Down
35 changes: 18 additions & 17 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import pprint
import tempfile
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
from typing import Any, Protocol, Union
from collections.abc import Callable, Sequence
import zlib
from functools import partial

Expand Down Expand Up @@ -102,11 +103,11 @@
jnp.dtype("bool"): "B",
}

Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]
Grid = Union[int, tuple[int], tuple[int, int], tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[dict[str, Any]], Grid]]


def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]:
def normalize_grid(grid: GridOrLambda, metaparams) -> tuple[int, int, int]:
if callable(grid):
grid = grid(metaparams)
if isinstance(grid, int):
Expand Down Expand Up @@ -186,8 +187,8 @@ class CompilationResult:
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]
ttgir: str | None
llir: str | None

def compile_ttir_inplace(
ttir,
Expand Down Expand Up @@ -375,7 +376,7 @@ def get_or_create_triton_kernel(
enable_fp_fusion,
metaparams,
dump: bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
) -> tuple[triton_kernel_call_lib.TritonKernel, Any]:
if num_warps is None:
num_warps = 4
if num_stages is None:
Expand Down Expand Up @@ -730,7 +731,7 @@ def prune_configs(configs, named_args, **kwargs):
class ShapeDtype(Protocol):

@property
def shape(self) -> Tuple[int, ...]:
def shape(self) -> tuple[int, ...]:
...

@property
Expand All @@ -739,21 +740,21 @@ def dtype(self) -> np.dtype:


def triton_call(
*args: Union[jax.Array, bool, int, float, np.float32],
*args: jax.Array | bool | int | float | np.float32,
kernel: triton.JITFunction,
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
out_shape: ShapeDtype | Sequence[ShapeDtype],
grid: GridOrLambda,
name: str = "",
custom_call_target_name: str = "triton_kernel_call",
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_warps: int | None = None,
num_stages: int | None = None,
num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple.
compute_capability: Optional[int] = None,
compute_capability: int | None = None,
enable_fp_fusion: bool = True,
input_output_aliases: Optional[Dict[int, int]] = None,
zeroed_outputs: Union[
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
] = (),
input_output_aliases: dict[int, int] | None = None,
zeroed_outputs: (
Sequence[int] | Callable[[dict[str, Any]], Sequence[int]]
) = (),
debug: bool = False,
serialized_metadata: bytes = b"",
**metaparams: Any,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies = [
"absl-py>=1.4.0",
"jax>=0.4.31",
"triton>=3.0",
"setuptools", # triton seems to need this when installing itself.
]

[project.optional-dependencies]
Expand Down