diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index d56943738..4bedbc773 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1762,9 +1762,11 @@ def torch_all(input): @register_function(torch.all) -def torch_all_v2(input, dim, keepdim=False, *, out=None): +def torch_all_v2(input, dim: Union[int, Sequence[int]], keepdim=False, *, out=None): if out is not None: raise NotImplementedError("hidet: does not support torch.all(..., out=...)") + if isinstance(dim, int): + dim = (dim,) return ops.all(input, axis=dim, keepdims=keepdim) @@ -1790,9 +1792,11 @@ def torch_argmin(x, dim: Int = None, keepdim: bool = False): @register_function(torch.any) -def torch_any_v1(input: Tensor, dim, keepdim=False, *, out=None) -> Tensor: +def torch_any_v1(input: Tensor, dim: Union[int, Sequence[int]], keepdim=False, *, out=None) -> Tensor: if out is not None: raise NotImplementedError("hidet: does not support torch.any(..., out=...)") + if isinstance(dim, int): + dim = (dim,) return ops.any(input, axis=dim, keepdims=keepdim) diff --git a/python/hidet/ir/dtypes/vector.py b/python/hidet/ir/dtypes/vector.py index 6d0ac6da4..496e07fb9 100644 --- a/python/hidet/ir/dtypes/vector.py +++ b/python/hidet/ir/dtypes/vector.py @@ -14,6 +14,7 @@ from .floats import float32, float16 from .integer import int8, uint8 from .integer_subbyte import int4b, uint4b +from .boolean import boolean class VectorType(DataType): @@ -116,6 +117,7 @@ def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType: (float16, 2): float16x2, (int8, 4): int8x4, (uint8, 4): uint8x4, + (boolean, 4): int8x4, } if (base_dtype, num_lanes) in table: return table[(base_dtype, num_lanes)] diff --git a/tests/frontends/torch/test_torch_reduce.py b/tests/frontends/torch/test_torch_reduce.py index 66aed0df4..124a99ff9 100644 --- a/tests/frontends/torch/test_torch_reduce.py +++ b/tests/frontends/torch/test_torch_reduce.py @@ -125,3 +125,15 @@ def test_mean(shape): atol=1e-5, rtol=1e-5, ) + + +@pytest.mark.parametrize( + 'shape, dim', [[[2, 4], -1], [[128, 3, 4], 0], [[128, 3, 4], 2], [[72, 5, 64], -1], [[67, 128, 233], 1]] +) +def test_torch_any(shape, dim): + check_module(FunctionalModule(op=lambda x: torch.any(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0) + + +@pytest.mark.parametrize('shape, dim', [[[2, 3], -1]]) +def test_all(shape, dim): + check_module(FunctionalModule(op=lambda x: torch.all(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0)