diff --git a/python/oneflow/utils/tensor/from_or_to_torch_tensor.py b/python/oneflow/utils/tensor/from_or_to_torch_tensor.py index efd697270db..c1e6cbab3b2 100644 --- a/python/oneflow/utils/tensor/from_or_to_torch_tensor.py +++ b/python/oneflow/utils/tensor/from_or_to_torch_tensor.py @@ -62,15 +62,7 @@ def from_torch(torch_tensor): except: print_error_msg() assert isinstance(torch_tensor, torch.Tensor) - # return flow.from_dlpack(torch.to_dlpack(torch_tensor)) - dtype = flow.float16 - if torch_tensor.dtype == torch.int64: - dtype = flow.int64 - elif torch_tensor.dtype != torch.float16: - print(torch_tensor.dtype) - return flow.tensor( - torch_tensor.cpu().numpy(), device=flow.device("npu"), dtype=dtype - ).reshape([x for x in torch_tensor.shape]) + return flow.from_dlpack(torch.to_dlpack(torch_tensor)) def to_torch(flow_tensor): @@ -112,12 +104,4 @@ def to_torch(flow_tensor): "WARNING: `to_torch` received a global tensor. A PyTorch CPU tensor which is a copy of its data will be returned." ) return torch.from_numpy(flow_tensor.numpy()) - # return torch.from_dlpack(flow.to_dlpack(flow_tensor)) - dtype = torch.float16 - if flow_tensor.dtype == flow.int64: - dtype = torch.int64 - elif flow_tensor.dtype != flow.float16: - print(flow_tensor.dtype) - return torch.tensor(flow_tensor.numpy(), device="npu", dtype=dtype).reshape( - [x for x in flow_tensor.shape] - ) + return torch.from_dlpack(flow.to_dlpack(flow_tensor))