diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 5f01d57d5b7..0b18ae9b6f1 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -240,7 +240,7 @@ def transform(x): if x.requires_grad: x = x.detach() # Tensor has to be moved to CPU before converting to numpy. - if x.is_cuda or x.is_mps: + if x.device != torch.device("cpu"): x = x.cpu() if x.dtype == torch.bfloat16: # Attempting to call .numpy() on a bfloat16 torch tensor leads