diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index 92a5cdbe7..71d5a8a28 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -219,6 +219,11 @@ def tensor_masked_fill_(self: Tensor, mask: Tensor, value: float) -> Tensor: @register_method(torch.Tensor.repeat) def tensor_repeat(self: Tensor, *sizes: int) -> Tensor: + if len(self.shape) < len(sizes): + shape = [1] * (len(sizes) - len(self.shape)) + list(self.shape) + x = ops.reshape(self, shape) + return ops.tile(x, sizes) + return ops.tile(self, sizes) diff --git a/python/hidet/graph/frontend/torch/registry.py b/python/hidet/graph/frontend/torch/registry.py index 4f9202643..13c525964 100644 --- a/python/hidet/graph/frontend/torch/registry.py +++ b/python/hidet/graph/frontend/torch/registry.py @@ -150,7 +150,10 @@ def param(self, name: str, optional=False, steal=False) -> Optional[HidetTensor] if steal: del self.torch_params[name] setattr(self.mod, name, None) - self.hidet_params[name] = tensor_from_torch(torch_param) + if torch_param.is_contiguous(): + self.hidet_params[name] = tensor_from_torch(torch_param) + else: + self.hidet_params[name] = tensor_from_torch(torch_param.contiguous()) del torch_param torch.cuda.empty_cache() return self.hidet_params[name] diff --git a/python/hidet/graph/ops/arithmetic.py b/python/hidet/graph/ops/arithmetic.py index dc3aa0c5f..5349fd8c4 100644 --- a/python/hidet/graph/ops/arithmetic.py +++ b/python/hidet/graph/ops/arithmetic.py @@ -1003,6 +1003,14 @@ def where(cond: Tensor, x: Union[Tensor, PyScalar], y: Union[Tensor, PyScalar]) if cond.dtype != dtypes.boolean: raise ValueError('The condition tensor must have dtype "bool", but got {}'.format(cond.dtype.name)) if isinstance(x, Tensor) and isinstance(y, Tensor): + import hidet.ir.primitives.math + + out_dtype = hidet.ir.primitives.math.type_infer_func([x.dtype, y.dtype]) + if x.dtype != out_dtype: + x = x.to(dtype=out_dtype) + if y.dtype != out_dtype: + y = y.to(dtype=out_dtype) + return WhereOp(cond, x, y).outputs[0] elif isinstance(x, Tensor) and isinstance(y, (int, float, complex)): return WhereTensorScalarOp(cond, x=x, y=y).outputs[0] diff --git a/tests/frontends/torch/test_torch_arithmetic.py b/tests/frontends/torch/test_torch_arithmetic.py new file mode 100644 index 000000000..b89517bcc --- /dev/null +++ b/tests/frontends/torch/test_torch_arithmetic.py @@ -0,0 +1,28 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from hidet.testing.torch_utils import FunctionalModule, check_module + + +@pytest.mark.parametrize('a_shape', [[1, 3, 64], [10, 10], [11, 13], [1, 2, 3]]) +@pytest.mark.parametrize('sizes', [[1, 2, 3], [2, 3, 4, 5, 6, 8]]) +def test_tensor_repear(a_shape, sizes): + def tensor_repeat(tensor, sizes): + return tensor.repeat(*sizes) + + check_module(FunctionalModule(op=tensor_repeat), args=[torch.randn(a_shape), sizes], atol=0, rtol=0) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index 9425b4987..38e563ef5 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -12,6 +12,7 @@ import math import pytest import hidet +import torch import numpy as np from hidet import ops @@ -156,5 +157,19 @@ def test_cast_from_fp16(a_shape): check_unary(a_shape, np.float16, np.uint64, lambda x: ops.cast(x, "uint64")) +@pytest.mark.parametrize("a_shape", unary_op_shapes) +@pytest.mark.parametrize( + "a_dtype, b_dtype", [['float16', 'float32'], ['int32', 'float32'], ['int8', 'int32'], ['int32', 'float16']] +) +def test_where(a_shape, a_dtype, b_dtype): + a = hidet.randn(a_shape, dtype=a_dtype) + b = hidet.randn(a_shape, dtype=b_dtype) + c = hidet.ops.where(a > 0.5, a, b) + + c_torch = torch.where(a.torch() > 0.5, a.torch(), b.torch()) + assert str(c.dtype).split('.')[1] == str(c_torch.dtype).split('.')[1] + np.testing.assert_allclose(c.torch(), c_torch) + + if __name__ == '__main__': pytest.main([__file__])