Skip to content

Commit

Permalink
make it work with torch<2.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 19, 2024
1 parent 3618d63 commit 8244ed2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
37 changes: 19 additions & 18 deletions src/para_attn/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,25 @@ def cannot_use_attention_backend(*args, **kwargs):
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
raise RuntimeError("torch.library.custom_op requires PyTorch version >= 2.4.0")

# def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
# def wrap(func):
# return func
# if fn is None:
# return wrap
# return fn

# def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
# def wrap(func):
# return func
# if fn is None:
# return wrap
# return fn

# _torch_custom_op_wrapper = noop_custom_op_wrapper
# _torch_register_fake_wrapper = noop_register_fake_wrapper

def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func

if fn is None:
return wrap
return fn

def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func

if fn is None:
return wrap
return fn

_torch_custom_op_wrapper = noop_custom_op_wrapper
_torch_register_fake_wrapper = noop_register_fake_wrapper


def flash_attention_forward_with_lse(
Expand Down
3 changes: 1 addition & 2 deletions src/para_attn/para_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.overrides import TorchFunctionMode

import para_attn
import para_attn.ops as para_attn_ops
import para_attn.primitives as DP

try:
Expand All @@ -17,8 +18,6 @@
if _templated_ring_attention is not None:
import torch.distributed.tensor.experimental._attention as torch_ring_attention

para_attn_ops = torch.ops.para_attn

__all__ = [
"UnifiedAttnMode",
"RingAttnMode",
Expand Down

0 comments on commit 8244ed2

Please sign in to comment.