Skip to content

Commit

Permalink
revert from_or_to_torch_tensor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fpzh2011 committed Jun 27, 2024
1 parent 62bcb09 commit b5d2773
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions python/oneflow/utils/tensor/from_or_to_torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,7 @@ def from_torch(torch_tensor):
except:
print_error_msg()
assert isinstance(torch_tensor, torch.Tensor)
# return flow.from_dlpack(torch.to_dlpack(torch_tensor))
dtype = flow.float16
if torch_tensor.dtype == torch.int64:
dtype = flow.int64
elif torch_tensor.dtype != torch.float16:
print(torch_tensor.dtype)
return flow.tensor(
torch_tensor.cpu().numpy(), device=flow.device("npu"), dtype=dtype
).reshape([x for x in torch_tensor.shape])
return flow.from_dlpack(torch.to_dlpack(torch_tensor))


def to_torch(flow_tensor):
Expand Down Expand Up @@ -112,12 +104,4 @@ def to_torch(flow_tensor):
"WARNING: `to_torch` received a global tensor. A PyTorch CPU tensor which is a copy of its data will be returned."
)
return torch.from_numpy(flow_tensor.numpy())
# return torch.from_dlpack(flow.to_dlpack(flow_tensor))
dtype = torch.float16
if flow_tensor.dtype == flow.int64:
dtype = torch.int64
elif flow_tensor.dtype != flow.float16:
print(flow_tensor.dtype)
return torch.tensor(flow_tensor.numpy(), device="npu", dtype=dtype).reshape(
[x for x in flow_tensor.shape]
)
return torch.from_dlpack(flow.to_dlpack(flow_tensor))

0 comments on commit b5d2773

Please sign in to comment.