From 08e739479f2990eab2aeb145d7d9c7e79ef819ad Mon Sep 17 00:00:00 2001 From: aboubezari <126983138+aboubezari@users.noreply.github.com> Date: Wed, 24 Jul 2024 17:27:31 -0700 Subject: [PATCH] Support torch `convert_to_numpy` for all devices (#20042) * Support torch convert_to_numpy for all devices * reformat --- keras/src/backend/torch/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 5f01d57d5b7..0b18ae9b6f1 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -240,7 +240,7 @@ def transform(x): if x.requires_grad: x = x.detach() # Tensor has to be moved to CPU before converting to numpy. - if x.is_cuda or x.is_mps: + if x.device != torch.device("cpu"): x = x.cpu() if x.dtype == torch.bfloat16: # Attempting to call .numpy() on a bfloat16 torch tensor leads