Skip to content

Commit

Permalink
[Fix] Fixing an error triggered by the operator any (#369)
Browse files Browse the repository at this point in the history
Closes #367 

Also fixed an error mentioned
[here](CentML/hidet#341 (comment)),
which was encountered immediately after fixing the error described in
#367
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Jul 27, 2024
1 parent 95d95a4 commit 6a4c2e5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,9 +1762,11 @@ def torch_all(input):


@register_function(torch.all)
def torch_all_v2(input, dim, keepdim=False, *, out=None):
def torch_all_v2(input, dim: Union[int, Sequence[int]], keepdim=False, *, out=None):
if out is not None:
raise NotImplementedError("hidet: does not support torch.all(..., out=...)")
if isinstance(dim, int):
dim = (dim,)
return ops.all(input, axis=dim, keepdims=keepdim)


Expand All @@ -1790,9 +1792,11 @@ def torch_argmin(x, dim: Int = None, keepdim: bool = False):


@register_function(torch.any)
def torch_any_v1(input: Tensor, dim, keepdim=False, *, out=None) -> Tensor:
def torch_any_v1(input: Tensor, dim: Union[int, Sequence[int]], keepdim=False, *, out=None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.any(..., out=...)")
if isinstance(dim, int):
dim = (dim,)
return ops.any(input, axis=dim, keepdims=keepdim)


Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .floats import float32, float16
from .integer import int8, uint8
from .integer_subbyte import int4b, uint4b
from .boolean import boolean


class VectorType(DataType):
Expand Down Expand Up @@ -116,6 +117,7 @@ def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
(float16, 2): float16x2,
(int8, 4): int8x4,
(uint8, 4): uint8x4,
(boolean, 4): int8x4,
}
if (base_dtype, num_lanes) in table:
return table[(base_dtype, num_lanes)]
Expand Down
12 changes: 12 additions & 0 deletions tests/frontends/torch/test_torch_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ def test_mean(shape):
atol=1e-5,
rtol=1e-5,
)


@pytest.mark.parametrize(
'shape, dim', [[[2, 4], -1], [[128, 3, 4], 0], [[128, 3, 4], 2], [[72, 5, 64], -1], [[67, 128, 233], 1]]
)
def test_torch_any(shape, dim):
check_module(FunctionalModule(op=lambda x: torch.any(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0)


@pytest.mark.parametrize('shape, dim', [[[2, 3], -1]])
def test_all(shape, dim):
check_module(FunctionalModule(op=lambda x: torch.all(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0)

0 comments on commit 6a4c2e5

Please sign in to comment.