diff --git a/examples/multi_hot_sparse_categorical_crossentropy.py b/examples/multi_hot_sparse_categorical_crossentropy.py new file mode 100644 index 00000000000..b4602bbb226 --- /dev/null +++ b/examples/multi_hot_sparse_categorical_crossentropy.py @@ -0,0 +1,121 @@ +'''Trains a simple convnet on multi label classification using multi_hot_sparse_categorical_crossentropy + +This example demonstrate +1) how to do multi label classification using normal categorical crossentropy +2) when labels are sparse, how to improve performance using multi_hot_sparse_categorical_crossentropy +''' + +import time +import random +import numpy as np +import keras +from keras.models import Sequential +from keras.layers import Dense, Dropout, Flatten +from keras.layers import Conv2D, MaxPooling2D +from keras.preprocessing.sequence import pad_sequences +from keras import backend as K +""" +Input: +input data is random images of size (32, 32) in channels first data format + +Labels: +Tradition labels are in the shape of (num_samples, num_class), for example: +labels = [[0, 1, 1, ..., 0], + [1, 1, 0, ..., 0], + ... + [0, 0, 0, ..., 1]] +where len(labels) = num_samples, len(labels[0]) = num_classes + +However, when num_classes are very large and labels are very sparse, +we can represent them differently, for example: + +There are total 1000 classes, so there will be 10000 different labels. +Each image can belong to at most 5 labels at the same time. +labels = [[1, 2], + [0, 1], + ... + [999]] +where labels is a list of list + +Special Note: +To deal with different length of sparse labels, we pad them with negative values, +so we can differentiate padding values with normal labels. It will become: +padded_labels = pad_sequeences(labels, value=-1) +padded_labels = [[-1, -1, -1, 1, 2], + [-1, -1, -1, 0, 1], + ... + [-1, -1, -1, -1, 999]] +It will have shape (num_samples, 5) which still save space compare to dense labels. +""" + +# input image dimensions +img_rows, img_cols = 28, 28 +epoch = 5 +num_gpus = 4 +num_classes = 3000 +num_samples = 50000 +input_shape = (3, img_rows, img_cols) +batch_size = num_gpus * 32 if num_gpus > 1 else 32 +# creating random images of size (28, 28) as training data +x_train = np.random.randint(0, 256, (num_samples, 3, img_rows, img_cols)) + +# creating dense labels and sparse labels +sparse_labels = [] +dense_labels = np.zeros((num_samples, num_classes)) +for i in range(0, num_samples): + # each data have at most 5 labels + for j in range(0, 5): + label = random.randint(0, num_classes - 1) + dense_labels[i][label] = 1 + # making the number of labels for each data unequal + if random.randint(0, 5) == 1: + break + sparse_label_j = np.where(dense_labels[i] == 1)[0] + sparse_labels.append(sparse_label_j) + +# construct a simple CNN model +model = Sequential() +model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=input_shape)) +model.add(Conv2D(64, (3, 3), activation='relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) +model.add(Flatten()) +model.add(Dense(128, activation='relu')) +model.add(Dropout(0.5)) +model.add(Dense(num_classes, activation='softmax')) + +# use multi gpu +if num_gpus > 1: + model = keras.utils.multi_gpu_model(model, num_gpus) + +# use normal categorical crossentropy +model.compile(loss=keras.losses.categorical_crossentropy, + optimizer=keras.optimizers.Adadelta(), + metrics=['accuracy']) +model.fit(x_train, dense_labels, + batch_size=batch_size, + epochs=epoch) + + +# use normal multi_hot_sparse_categorical_crossentropy +model.compile(loss=keras.losses.multi_hot_sparse_categorical_crossentropy, + optimizer=keras.optimizers.Adadelta(), + metrics=['accuracy']) + +# pad sparse labels into shape length with value -1 to differentiate from normal labels +y_train_pad = pad_sequences(sparse_labels, value=-1) +model.fit(x_train, y_train_pad, + batch_size=batch_size, + epochs=epoch) + +# speed reference on two losses +outputs = model.predict(x_train) +start = time.time() +loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(dense_labels)) +print("categorical crossentropy loss time per epoch:", time.time() - start) +outputs = model.predict(x_train) +start = time.time() +loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(y_train_pad)) +print("multi hot sparse categorical crossentropy loss time per epoch:", time.time() - start) diff --git a/keras/backend/mxnet_backend.py b/keras/backend/mxnet_backend.py index cf406306344..4c4e7192664 100644 --- a/keras/backend/mxnet_backend.py +++ b/keras/backend/mxnet_backend.py @@ -2988,6 +2988,71 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): return reshape(KerasSymbol(mx_output), target.shape) +@keras_mxnet_symbol +def multi_hot_sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + """Calculate Categorical crossentropy with a list of integer targets (multi-labels) + + # Arguments + target: An integer tensor. + output: A tensor resulting from a softmax + (unless `from_logits` is True, in which + case `output` is expected to be the logits). + from_logits: Boolean, whether `output` is the + result of a softmax, or is a tensor of logits. + + # Returns + Output tensor. + + # Example: + ``` + # refer to examples/multi_hot_sparse_categorical_crossentropy.py + # for a multi-label classification problem with 3 classes + # target with multi labels in normal categorical crossentropy + >>>target_dense = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]]) + # output from softmax remains the same format + >>>output = np.array([[0.1, 0.4, 0.5], + >>> [0.4, 0.2, 0.4], + >>> [0.7, 0.2, 0.1]]) + # target with multi labels in multi_hot categorical crossentropy + >>>y_true_np2 = np.array([[1, 2], [0, 2],[0]]) + ``` + """ + # TODO: remove version check after mxnet 1.3.1 stable release + if mx.__version__ != '1.3.1': + raise NotImplementedError('MXNet Backend: multi_hot_sparse_categorical_crossentropy only' + 'works with MXNet 1.3.1 or newer, please upgrade MXNet using:' + 'pip install --upgrade mxnet --pre') + output_dimensions = list(range(ndim(output))) + if axis != -1 and axis not in output_dimensions: + raise ValueError( + '{}{}{}'.format( + 'Unexpected channels axis {}. '.format(axis), + 'Expected to be -1 or one of the axes of `output`, ', + 'which has {} dimensions.'.format(len(int_shape(output))))) + + mx_output = output.symbol + # scale predictions so that the class probabilities of each sample sum to 1 + if from_logits: + mx_output = mx.sym.softmax(mx_output, axis=axis) + else: + mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output, + axis=axis, + keepdims=True)) + # clip to prevent NaN's and Inf's + mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) + + # using control flow ops to iterate output and take target (true label) + _step = lambda data, _: (mx.sym.take(data[0], data[1]), []) + data = [mx_output, target.symbol] + outputs, _ = mx.symbol.contrib.foreach(_step, data, []) + + # calculate loss + # check if target is larger than 0, remove padded labels (-1) + outputs = - mx.sym.sum(mx.sym.broadcast_greater_equal(target.symbol, mx.sym.zeros((1, 1))) * + mx.sym.log(outputs), axis=axis) + return KerasSymbol(outputs) + + @keras_mxnet_symbol def binary_crossentropy(target, output, from_logits=False): """Binary crossentropy between an output tensor and a target tensor. diff --git a/keras/engine/training.py b/keras/engine/training.py index 3d8d14a3966..feb082f80b1 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -381,6 +381,13 @@ def handle_metrics(metrics, weights=None): metric_fn = metrics_module.sparse_categorical_accuracy elif metric in ('crossentropy', 'ce'): metric_fn = metrics_module.sparse_categorical_crossentropy + elif self.loss_functions[i] == losses.multi_hot_sparse_categorical_crossentropy: + # case: multi hot sparse categorical accuracy/crossentropy + # with sparse list of integer targets + if metric in ('accuracy', 'acc'): + metric_fn = metrics_module.multi_hot_sparse_categorical_accuracy + elif metric in ('crossentropy', 'ce'): + metric_fn = metrics_module.multi_hot_sparse_categorical_crossentropy else: # case: categorical accuracy/crossentropy if metric in ('accuracy', 'acc'): @@ -778,13 +785,22 @@ def _standardize_user_data(self, x, else: feed_output_shapes.append(output_shape) + check_last_layer_shape = True + # multi_hot_sparse_categorical_crossentropy only available in mxnet backend + if K.backend() == 'mxnet': + for loss_fn in self.loss_functions: + if loss_fn is losses.multi_hot_sparse_categorical_crossentropy: + # does not check the last layer shape when multi_hot_sparse_categorical_crossentropy \ + # is used, since we reduce the dimension of sparse labels. + check_last_layer_shape = False # Standardize the outputs. y = standardize_input_data( y, feed_output_names, feed_output_shapes, check_batch_axis=False, # Don't enforce the batch size. - exception_prefix='target') + exception_prefix='target', + check_last_layer_shape=check_last_layer_shape) # Generate sample-wise weight values given the `sample_weight` and # `class_weight` arguments. diff --git a/keras/engine/training_utils.py b/keras/engine/training_utils.py index bc133c39ecd..ca36e881a08 100644 --- a/keras/engine/training_utils.py +++ b/keras/engine/training_utils.py @@ -32,7 +32,8 @@ def standardize_input_data(data, names, shapes=None, check_batch_axis=True, - exception_prefix=''): + exception_prefix='', + check_last_layer_shape=True): """Normalizes inputs and targets provided by users. Users may pass data as a list of arrays, dictionary of arrays, @@ -130,6 +131,11 @@ def standardize_input_data(data, shape = shape[1:] for dim, ref_dim in zip(data_shape, shape): if ref_dim != dim and ref_dim: + # ignore shape differencew in last layer only if loss is + # multi_hot_sparse_categorical_crossentropy, + # last layer can only be dense or activation layer + if not check_last_layer_shape and names[i].startswith(("dense", "activation")): + continue raise ValueError( 'Error when checking ' + exception_prefix + ': expected ' + names[i] + ' to have shape ' + diff --git a/keras/losses.py b/keras/losses.py index 1e9c1ccd042..e93f643d241 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -74,6 +74,13 @@ def sparse_categorical_crossentropy(y_true, y_pred): return K.sparse_categorical_crossentropy(y_true, y_pred) +def multi_hot_sparse_categorical_crossentropy(y_true, y_pred): + if K.backend() != 'mxnet': + raise NotImplementedError('multi_hot_sparse_categorical_crossentropy ' + 'is only available in MXNet backend') + return K.multi_hot_sparse_categorical_crossentropy(y_true, y_pred) + + def binary_crossentropy(y_true, y_pred): return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) diff --git a/keras/metrics.py b/keras/metrics.py index 3d5df23b9ec..cc07a2399b2 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -15,6 +15,7 @@ from .losses import squared_hinge from .losses import categorical_crossentropy from .losses import sparse_categorical_crossentropy +from .losses import multi_hot_sparse_categorical_crossentropy from .losses import binary_crossentropy from .losses import kullback_leibler_divergence from .losses import poisson @@ -39,6 +40,12 @@ def sparse_categorical_accuracy(y_true, y_pred): K.floatx()) +def multi_hot_sparse_categorical_accuracy(y_true, y_pred): + return K.cast(K.equal(K.max(y_true, axis=-1), + K.cast(K.argmax(y_pred, axis=-1), K.floatx())), + K.floatx()) + + def top_k_categorical_accuracy(y_true, y_pred, k=5): return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1) diff --git a/tests/keras/losses_test.py b/tests/keras/losses_test.py index 04669804119..80d4aa816bf 100644 --- a/tests/keras/losses_test.py +++ b/tests/keras/losses_test.py @@ -109,6 +109,22 @@ def test_sparse_categorical_crossentropy_4d(): assert np.isclose(expected_loss, np.mean(loss)) +def test_multi_hot_sparse_categorical_crossentropy(): + y_true_np = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]]) + y_pred_np = np.array([[0.1, 0.4, 0.5], + [0.4, 0.2, 0.4], + [0.7, 0.2, 0.1]]) + y_true_np2 = np.array([[1, 2], [0, 2], [0]]) + loss = K.eval(losses.categorical_crossentropy(K.variable(y_true_np), K.variable(y_pred_np))) + + # pad labels to have the same size, use -1 to differentiate from normal class labels + y_true_np2 = keras.preprocessing.sequence.pad_sequences(y_true_np2, value=-1) + y_pred2 = K.variable(y_pred_np) + y_true2 = K.variable(y_true_np2) + loss2 = K.eval(losses.multi_hot_sparse_categorical_crossentropy(y_true2, y_pred2)) + assert np.allclose(loss, loss2) + + class MSE_MAE_loss: """Loss function with internal state, for testing serialization code.""" def __init__(self, mse_fraction): diff --git a/tests/keras/metrics_test.py b/tests/keras/metrics_test.py index de955877f48..217a8a2c3f6 100644 --- a/tests/keras/metrics_test.py +++ b/tests/keras/metrics_test.py @@ -39,6 +39,7 @@ all_sparse_metrics = [ metrics.sparse_categorical_accuracy, metrics.sparse_categorical_crossentropy, + metrics.multi_hot_sparse_categorical_accuracy ]