diff --git a/keras/src/layers/preprocessing/category_encoding.py b/keras/src/layers/preprocessing/category_encoding.py index 4a9fead2152..183debf4990 100644 --- a/keras/src/layers/preprocessing/category_encoding.py +++ b/keras/src/layers/preprocessing/category_encoding.py @@ -131,12 +131,12 @@ def compute_output_shape(self, input_shape): return (self.num_tokens,) if self.output_mode == "one_hot": if input_shape[-1] != 1: - return tuple(input_shape + (self.num_tokens,)) + return tuple(input_shape) + (self.num_tokens,) elif len(input_shape) == 1: - return tuple(input_shape + (self.num_tokens,)) + return tuple(input_shape) + (self.num_tokens,) else: - return tuple(input_shape[:-1] + (self.num_tokens,)) - return tuple(input_shape[:-1] + (self.num_tokens,)) + return tuple(input_shape[:-1]) + (self.num_tokens,) + return tuple(input_shape[:-1]) + (self.num_tokens,) def compute_output_spec(self, inputs, count_weights=None): output_shape = self.compute_output_shape(inputs.shape)