Skip to content

Commit

Permalink
Merge pull request #306 from superbobry:maint-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688121597
  • Loading branch information
Google-ML-Automation committed Oct 21, 2024
2 parents c7b8cd5 + be5c9c2 commit cd49678
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
4 changes: 2 additions & 2 deletions jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def shape(self):

def match(self, expr, bindings, succeed):
if not isinstance(expr, Literal):
return []
return [] # noqa: B901
yield from matcher.matcher((self.value, self.dtype))((expr.value,
expr.dtype), bindings, succeed)

Expand All @@ -138,7 +138,7 @@ class Part(Node):

def match(self, expr, bindings, succeed):
if not isinstance(expr, Part):
return []
return [] # noqa: B901
yield from matcher.matcher((self.index, self.shape, self.dtype, self.parent))((
expr.index, expr.shape, expr.dtype, expr.parent), bindings, succeed)

Expand Down
3 changes: 1 addition & 2 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import triton.language as tl
from triton.runtime import autotuner
import triton._C.libtriton as _triton
from triton._C.libtriton import ir as tl_ir
import triton.backends.nvidia.compiler as cb

CAN_USE_TRITON = True
Expand Down Expand Up @@ -375,7 +374,7 @@ def get_or_create_triton_kernel(
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

constants = {k: v for k, v in metaparams.items()}
constants = dict(metaparams)
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})

Expand Down
36 changes: 30 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,34 @@ target-version = "py310"

[tool.ruff.lint]
ignore = [
# Do not assign a `lambda` expression, use a `def`
"E731",
# Module level import not at top of file
"E402",
# Ambiguous variable name
"E741",
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Unnecessary dict comprehension for iterable
"C420",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
# Zip without explicit strict parameter
"B905",
]
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E101",
"E112",
"E113",
"E115",
"E117",
"E225",
"E227",
"E228",
]

0 comments on commit cd49678

Please sign in to comment.