Skip to content

Commit

Permalink
dev(narugo): fix bug on torch 1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Sep 18, 2023
1 parent b539dc7 commit f0ffabe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
3 changes: 3 additions & 0 deletions treetensor/torch/funcs/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps

import torch
from hbutils.testing import vpip
from treevalue import func_treelize as original_func_treelize

from ..tensor import Tensor
Expand All @@ -14,6 +15,8 @@
get_func_from_torch = module_func_loader(torch, Tensor,
[(torch.is_tensor, Tensor)])

_is_torch_2 = vpip('torch') >= '2'


def wrap_for_treelize(*args, **kwargs):
def _decorator(func):
Expand Down
22 changes: 12 additions & 10 deletions treetensor/torch/funcs/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import torch
from hbutils.testing import vpip

from .base import doc_from_base, wrap_for_treelize
from .base import doc_from_base, wrap_for_treelize, _is_torch_2

__all__ = [
'vmap',
]

_is_torch_2 = vpip('torch') >= '2'


@doc_from_base()
@wrap_for_treelize()
def vmap(func, *args, **kwargs):
if _is_torch_2:
if _is_torch_2:
@doc_from_base()
@wrap_for_treelize()
def vmap(func, *args, **kwargs):
return torch.vmap(func, *args, **kwargs)
else:

else:
def vmap(func, *args, **kwargs):
"""
.. warning:
:method:`treetensor.torch.vmap` is not supported for torch 1.x.
"""
raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')

0 comments on commit f0ffabe

Please sign in to comment.