Skip to content

Commit

Permalink
Fix output shape computation for CategoryEncoding when called with a …
Browse files Browse the repository at this point in the history
…list shape
  • Loading branch information
fchollet committed Sep 9, 2024
1 parent 723f8a0 commit b91dade
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras/src/layers/preprocessing/category_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def compute_output_shape(self, input_shape):
return (self.num_tokens,)
if self.output_mode == "one_hot":
if input_shape[-1] != 1:
return tuple(input_shape + (self.num_tokens,))
return tuple(input_shape) + (self.num_tokens,)
elif len(input_shape) == 1:
return tuple(input_shape + (self.num_tokens,))
return tuple(input_shape) + (self.num_tokens,)
else:
return tuple(input_shape[:-1] + (self.num_tokens,))
return tuple(input_shape[:-1] + (self.num_tokens,))
return tuple(input_shape[:-1]) + (self.num_tokens,)
return tuple(input_shape[:-1]) + (self.num_tokens,)

def compute_output_spec(self, inputs, count_weights=None):
output_shape = self.compute_output_shape(inputs.shape)
Expand Down

0 comments on commit b91dade

Please sign in to comment.