Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add multi hot sparse categorical crossentropy #163

Merged
merged 6 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/multi_hot_sparse_categorical_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
'''Trains a simple convnet on multi label classification using multi_hot_sparse_categorical_crossentropy

This example demonstrate
1) how to do multi label classification using normal categorical crossentropy
2) when labels are sparse, how to improve performance using multi_hot_sparse_categorical_crossentropy
'''

import time
import random
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.preprocessing.sequence import pad_sequences
from keras import backend as K
"""
Input:
input data is random images of size (32, 32) in channels first data format

Labels:
Tradition labels are in the shape of (num_samples, num_class), for example:
labels = [[0, 1, 1, ..., 0],
[1, 1, 0, ..., 0],
...
[0, 0, 0, ..., 1]]
where len(labels) = num_samples, len(labels[0]) = num_classes

However, when num_classes are very large and labels are very sparse,
we can represent them differently, for example:

There are total 1000 classes, so there will be 10000 different labels.
Each image can belong to at most 5 labels at the same time.
labels = [[1, 2],
[0, 1],
...
[999]]
where labels is a list of list

Special Note:
To deal with different length of sparse labels, we pad them with negative values,
so we can differentiate padding values with normal labels. It will become:
padded_labels = pad_sequeences(labels, value=-1)
padded_labels = [[-1, -1, -1, 1, 2],
[-1, -1, -1, 0, 1],
...
[-1, -1, -1, -1, 999]]
It will have shape (num_samples, 5) which still save space compare to dense labels.
"""

# input image dimensions
img_rows, img_cols = 28, 28
epoch = 5
num_gpus = 4
num_classes = 3000
num_samples = 50000
input_shape = (3, img_rows, img_cols)
batch_size = num_gpus * 32 if num_gpus > 1 else 32
# creating random images of size (28, 28) as training data
x_train = np.random.randint(0, 256, (num_samples, 3, img_rows, img_cols))

# creating dense labels and sparse labels
sparse_labels = []
dense_labels = np.zeros((num_samples, num_classes))
for i in range(0, num_samples):
# each data have at most 5 labels
for j in range(0, 5):
label = random.randint(0, num_classes - 1)
dense_labels[i][label] = 1
# making the number of labels for each data unequal
if random.randint(0, 5) == 1:
break
sparse_label_j = np.where(dense_labels[i] == 1)[0]
sparse_labels.append(sparse_label_j)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] remove blank line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


# construct a simple CNN model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

# use multi gpu
if num_gpus > 1:
model = keras.utils.multi_gpu_model(model, num_gpus)

# use normal categorical crossentropy
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train, dense_labels,
batch_size=batch_size,
epochs=epoch)


# use normal multi_hot_sparse_categorical_crossentropy
model.compile(loss=keras.losses.multi_hot_sparse_categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])

# pad sparse labels into shape length with value -1 to differentiate from normal labels
y_train_pad = pad_sequences(sparse_labels, value=-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not add this padding logic inside multi_hot_sparse_categorical_crossentropy calculating method?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be done at numpy array level, before feeding to the model. Keras can't take a input(y_true) in the form of a list of list, and the list length may vary. (e.g. [[1,2],[0],[0,1,2]])

model.fit(x_train, y_train_pad,
batch_size=batch_size,
epochs=epoch)

# speed reference on two losses
outputs = model.predict(x_train)
start = time.time()
loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(dense_labels))
print("categorical crossentropy loss time per epoch:", time.time() - start)
outputs = model.predict(x_train)
start = time.time()
loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(y_train_pad))
print("multi hot sparse categorical crossentropy loss time per epoch:", time.time() - start)
60 changes: 60 additions & 0 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,66 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
return reshape(KerasSymbol(mx_output), target.shape)


@keras_mxnet_symbol
def multi_hot_sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Calculate Categorical crossentropy with a list of integer targets (multi-labels)

# Arguments
target: An integer tensor.
output: A tensor resulting from a softmax
(unless `from_logits` is True, in which
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.

# Returns
Output tensor.

# Example:
```
# refer to examples/multi_hot_sparse_categorical_crossentropy.py
# for a multi-label classification problem with 3 classes
# target with multi labels in normal categorical crossentropy
>>>target_dense = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]])
# output from softmax remains the same format
>>>output = np.array([[0.1, 0.4, 0.5],
>>> [0.4, 0.2, 0.4],
>>> [0.7, 0.2, 0.1]])
# target with multi labels in multi_hot categorical crossentropy
>>>y_true_np2 = np.array([[1, 2], [0, 2],[0]])
```
"""
output_dimensions = list(range(ndim(output)))
if axis != -1 and axis not in output_dimensions:
raise ValueError(
'{}{}{}'.format(
'Unexpected channels axis {}. '.format(axis),
'Expected to be -1 or one of the axes of `output`, ',
'which has {} dimensions.'.format(len(int_shape(output)))))

mx_output = output.symbol
# scale predictions so that the class probabilities of each sample sum to 1
if from_logits:
mx_output = mx.sym.softmax(mx_output, axis=axis)
else:
mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output,
axis=axis,
keepdims=True))
# clip to prevent NaN's and Inf's
mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. Please document these steps so it will be easier later for maintenance.


# using control flow ops to iterate output and take target (true label)
_step = lambda data, _: (mx.sym.take(data[0], data[1]), [])
data = [mx_output, target.symbol]
outputs, _ = mx.symbol.contrib.foreach(_step, data, [])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a MXNet contrib API do we have an issue to track that this needs to be updated when the stable version is released for foreach?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have CI to test MXNet API changes (deprecated or breaking change). It's the same for contrib ops and other normal ops. If it's moved to mx.sym, it will give a deprecated warning. we should be able to catch it. created #173


# calculate loss
# check if target is larger than 0, remove padded labels (-1)
outputs = - mx.sym.sum(mx.sym.broadcast_greater_equal(target.symbol, mx.sym.zeros((1, 1))) *
mx.sym.log(outputs), axis=axis)
return KerasSymbol(outputs)


@keras_mxnet_symbol
def binary_crossentropy(target, output, from_logits=False):
"""Binary crossentropy between an output tensor and a target tensor.
Expand Down
18 changes: 17 additions & 1 deletion keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ def handle_metrics(metrics, weights=None):
metric_fn = metrics_module.sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.sparse_categorical_crossentropy
elif self.loss_functions[i] == losses.multi_hot_sparse_categorical_crossentropy:
# case: multi hot sparse categorical accuracy/crossentropy
# with sparse list of integer targets
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.multi_hot_sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.multi_hot_sparse_categorical_crossentropy
else:
# case: categorical accuracy/crossentropy
if metric in ('accuracy', 'acc'):
Expand Down Expand Up @@ -778,13 +785,22 @@ def _standardize_user_data(self, x,
else:
feed_output_shapes.append(output_shape)

check_last_layer_shape = True
# multi_hot_sparse_categorical_crossentropy only available in mxnet backend
if K.backend() == 'mxnet':
for loss_fn in self.loss_functions:
if loss_fn is losses.multi_hot_sparse_categorical_crossentropy:
# does not check the last layer shape when multi_hot_sparse_categorical_crossentropy \
# is used, since we reduce the dimension of sparse labels.
check_last_layer_shape = False
# Standardize the outputs.
y = standardize_input_data(
y,
feed_output_names,
feed_output_shapes,
check_batch_axis=False, # Don't enforce the batch size.
exception_prefix='target')
exception_prefix='target',
check_last_layer_shape=check_last_layer_shape)

# Generate sample-wise weight values given the `sample_weight` and
# `class_weight` arguments.
Expand Down
8 changes: 7 additions & 1 deletion keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def standardize_input_data(data,
names,
shapes=None,
check_batch_axis=True,
exception_prefix=''):
exception_prefix='',
check_last_layer_shape=True):
"""Normalizes inputs and targets provided by users.

Users may pass data as a list of arrays, dictionary of arrays,
Expand Down Expand Up @@ -130,6 +131,11 @@ def standardize_input_data(data,
shape = shape[1:]
for dim, ref_dim in zip(data_shape, shape):
if ref_dim != dim and ref_dim:
# ignore shape differencew in last layer only if loss is
# multi_hot_sparse_categorical_crossentropy,
# last layer can only be dense or activation layer
if not check_last_layer_shape and names[i].startswith(("dense", "activation")):
continue
raise ValueError(
'Error when checking ' + exception_prefix +
': expected ' + names[i] + ' to have shape ' +
Expand Down
7 changes: 7 additions & 0 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def sparse_categorical_crossentropy(y_true, y_pred):
return K.sparse_categorical_crossentropy(y_true, y_pred)


def multi_hot_sparse_categorical_crossentropy(y_true, y_pred):
if K.backend() != 'mxnet':
raise NotImplementedError('multi_hot_sparse_categorical_crossentropy '
'is only available in MXNet backend')
return K.multi_hot_sparse_categorical_crossentropy(y_true, y_pred)


def binary_crossentropy(y_true, y_pred):
return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)

Expand Down
7 changes: 7 additions & 0 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import multi_hot_sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
Expand All @@ -39,6 +40,12 @@ def sparse_categorical_accuracy(y_true, y_pred):
K.floatx())


def multi_hot_sparse_categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.max(y_true, axis=-1),
K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
K.floatx())


def top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)

Expand Down
16 changes: 16 additions & 0 deletions tests/keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ def test_sparse_categorical_crossentropy_4d():
assert np.isclose(expected_loss, np.mean(loss))


def test_multi_hot_sparse_categorical_crossentropy():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test for multi_hot_sparse_categorical_accuracy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, see changes at tests/keras/metrics_test.py

y_true_np = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]])
y_pred_np = np.array([[0.1, 0.4, 0.5],
[0.4, 0.2, 0.4],
[0.7, 0.2, 0.1]])
y_true_np2 = np.array([[1, 2], [0, 2], [0]])
loss = K.eval(losses.categorical_crossentropy(K.variable(y_true_np), K.variable(y_pred_np)))

# pad labels to have the same size, use -1 to differentiate from normal class labels
y_true_np2 = keras.preprocessing.sequence.pad_sequences(y_true_np2, value=-1)
y_pred2 = K.variable(y_pred_np)
y_true2 = K.variable(y_true_np2)
loss2 = K.eval(losses.multi_hot_sparse_categorical_crossentropy(y_true2, y_pred2))
assert np.allclose(loss, loss2)


class MSE_MAE_loss:
"""Loss function with internal state, for testing serialization code."""
def __init__(self, mse_fraction):
Expand Down
1 change: 1 addition & 0 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
all_sparse_metrics = [
metrics.sparse_categorical_accuracy,
metrics.sparse_categorical_crossentropy,
metrics.multi_hot_sparse_categorical_accuracy
]


Expand Down