Skip to content

Commit

Permalink
Add fastpath for TF for count and multihot encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 11, 2024
1 parent b91dade commit c20c69a
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 3 deletions.
2 changes: 1 addition & 1 deletion keras/src/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def clone_function(layer):
config["seed"] = 1337
return layer.__class__.from_config(config)
new_model = clone_model(model)
new_model = clone_model(model, clone_function=clone_function)
```
Using a `call_function` to add a `Dropout` layer after each `Dense` layer
Expand Down
24 changes: 22 additions & 2 deletions keras/src/utils/numerical_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.utils import tf_utils


@keras_export("keras.utils.normalize")
Expand Down Expand Up @@ -119,7 +120,8 @@ def encode_categorical_inputs(
dtype: the dtype of the output, unless `count_weights` is not `None`.
sparse: whether the output should be sparse for backends supporting it.
count_weights: weights to apply if `output_mode` is `"count"`.
backend_module: the backend to use instead of the curren one.
backend_module: the backend to use instead of the current one.
Returns: the encoded inputs.
"""
backend_module = backend_module or backend
Expand All @@ -131,8 +133,26 @@ def encode_categorical_inputs(

# In all cases, we should uprank scalar input to a single sample.
if rank_of_inputs == 0:
# We need to update `rank_of_inputs` if necessary.
inputs = backend_module.numpy.expand_dims(inputs, -1)
rank_of_inputs = 1

if (
backend_module.__name__.endswith("tensorflow")
and rank_of_inputs <= 2
and output_mode in ("multi_hot", "count")
):
# TF only fastpath. Uses bincount; faster. Doesn't work for rank 3+.
try:
return tf_utils.tf_encode_categorical_inputs(
inputs,
output_mode,
depth,
dtype=dtype,
sparse=sparse,
count_weights=count_weights,
)
except ValueError:
pass

if output_mode == "multi_hot":
return backend_module.nn.multi_hot(
Expand Down
113 changes: 113 additions & 0 deletions keras/src/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,116 @@ def ensure_tensor(inputs, dtype=None):
if dtype is not None and inputs.dtype != dtype:
inputs = tf.cast(inputs, dtype)
return inputs


def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None):
"""Apply binary or count encoding to an input and return a sparse tensor."""
result = tf.sparse.bincount(
inputs,
weights=count_weights,
minlength=depth,
maxlength=depth,
axis=-1,
binary_output=binary_output,
)
result = tf.cast(result, dtype)
if inputs.shape.rank == 1:
output_shape = (depth,)
else:
batch_size = tf.shape(result)[0]
output_shape = (batch_size, depth)
result = tf.SparseTensor(
indices=result.indices, values=result.values, dense_shape=output_shape
)
return result


def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None):
"""Apply binary or count encoding to an input."""
result = tf.math.bincount(
inputs,
weights=count_weights,
minlength=depth,
maxlength=depth,
dtype=dtype,
axis=-1,
binary_output=binary_output,
)
if inputs.shape.rank == 1:
result.set_shape(tf.TensorShape((depth,)))
else:
batch_size = inputs.shape.as_list()[0]
result.set_shape(tf.TensorShape((batch_size, depth)))
return result


def expand_dims(inputs, axis):
"""Expand dims on sparse, ragged, or dense tensors."""
if isinstance(inputs, tf.SparseTensor):
return tf.sparse.expand_dims(inputs, axis)
return tf.expand_dims(inputs, axis)


def tf_encode_categorical_inputs(
inputs,
output_mode,
depth,
dtype="float32",
sparse=False,
count_weights=None,
idf_weights=None,
):
"""Encodes categorical inputs according to output_mode.
Faster method that relies on bincount.
"""

if output_mode == "int":
return tf.identity(tf.cast(inputs, dtype))

original_shape = inputs.shape
# In all cases, we should uprank scalar input to a single sample.
if inputs.shape.rank == 0:
inputs = expand_dims(inputs, -1)
# One hot will unprank only if the final output dimension is not already 1.
if output_mode == "one_hot":
if inputs.shape[-1] != 1:
inputs = expand_dims(inputs, -1)

if inputs.shape.rank > 2:
raise ValueError(
"When output_mode is not `'int'`, maximum supported output rank "
f"is 2. Received output_mode {output_mode} and input shape "
f"{original_shape}, "
f"which would result in output rank {inputs.shape.rank}."
)

binary_output = output_mode in ("multi_hot", "one_hot")
if sparse:
bincounts = sparse_bincount(
inputs, depth, binary_output, dtype, count_weights
)
else:
bincounts = dense_bincount(
inputs, depth, binary_output, dtype, count_weights
)

bincounts = tf.cast(bincounts, dtype)
if output_mode != "tf_idf":
return bincounts

if idf_weights is None:
raise ValueError(
"When output mode is `'tf_idf'`, idf_weights must be provided. "
f"Received: output_mode={output_mode} and idf_weights={idf_weights}"
)

if sparse:
value_weights = tf.gather(idf_weights, bincounts.indices[:, -1])
return tf.SparseTensor(
bincounts.indices,
value_weights * bincounts.values,
bincounts.dense_shape,
)
else:
return tf.multiply(bincounts, idf_weights)

0 comments on commit c20c69a

Please sign in to comment.