-
Notifications
You must be signed in to change notification settings - Fork 65
Add multi hot sparse categorical crossentropy #163
Changes from 5 commits
e8fd9e1
871c46f
9127f07
a1e69ba
b2458ba
f066dbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
'''Trains a simple convnet on multi label classification using multi_hot_sparse_categorical_crossentropy | ||
|
||
This example demonstrate | ||
1) how to do multi label classification using normal categorical crossentropy | ||
2) when labels are sparse, how to improve performance using multi_hot_sparse_categorical_crossentropy | ||
''' | ||
|
||
import time | ||
import random | ||
import numpy as np | ||
import keras | ||
from keras.models import Sequential | ||
from keras.layers import Dense, Dropout, Flatten | ||
from keras.layers import Conv2D, MaxPooling2D | ||
from keras.preprocessing.sequence import pad_sequences | ||
from keras import backend as K | ||
""" | ||
Input: | ||
input data is random images of size (32, 32) in channels first data format | ||
|
||
Labels: | ||
Tradition labels are in the shape of (num_samples, num_class), for example: | ||
labels = [[0, 1, 1, ..., 0], | ||
[1, 1, 0, ..., 0], | ||
... | ||
[0, 0, 0, ..., 1]] | ||
where len(labels) = num_samples, len(labels[0]) = num_classes | ||
|
||
However, when num_classes are very large and labels are very sparse, | ||
we can represent them differently, for example: | ||
|
||
There are total 1000 classes, so there will be 10000 different labels. | ||
Each image can belong to at most 5 labels at the same time. | ||
labels = [[1, 2], | ||
[0, 1], | ||
... | ||
[999]] | ||
where labels is a list of list | ||
|
||
Special Note: | ||
To deal with different length of sparse labels, we pad them with negative values, | ||
so we can differentiate padding values with normal labels. It will become: | ||
padded_labels = pad_sequeences(labels, value=-1) | ||
padded_labels = [[-1, -1, -1, 1, 2], | ||
[-1, -1, -1, 0, 1], | ||
... | ||
[-1, -1, -1, -1, 999]] | ||
It will have shape (num_samples, 5) which still save space compare to dense labels. | ||
""" | ||
|
||
# input image dimensions | ||
img_rows, img_cols = 28, 28 | ||
epoch = 5 | ||
num_gpus = 4 | ||
num_classes = 3000 | ||
num_samples = 50000 | ||
input_shape = (3, img_rows, img_cols) | ||
batch_size = num_gpus * 32 if num_gpus > 1 else 32 | ||
# creating random images of size (28, 28) as training data | ||
x_train = np.random.randint(0, 256, (num_samples, 3, img_rows, img_cols)) | ||
|
||
# creating dense labels and sparse labels | ||
sparse_labels = [] | ||
dense_labels = np.zeros((num_samples, num_classes)) | ||
for i in range(0, num_samples): | ||
# each data have at most 5 labels | ||
for j in range(0, 5): | ||
label = random.randint(0, num_classes - 1) | ||
dense_labels[i][label] = 1 | ||
# making the number of labels for each data unequal | ||
if random.randint(0, 5) == 1: | ||
break | ||
sparse_label_j = np.where(dense_labels[i] == 1)[0] | ||
sparse_labels.append(sparse_label_j) | ||
|
||
|
||
# construct a simple CNN model | ||
model = Sequential() | ||
model.add(Conv2D(32, kernel_size=(3, 3), | ||
activation='relu', | ||
input_shape=input_shape)) | ||
model.add(Conv2D(64, (3, 3), activation='relu')) | ||
model.add(MaxPooling2D(pool_size=(2, 2))) | ||
model.add(Dropout(0.25)) | ||
model.add(Flatten()) | ||
model.add(Dense(128, activation='relu')) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(num_classes, activation='softmax')) | ||
|
||
# use multi gpu | ||
if num_gpus > 1: | ||
model = keras.utils.multi_gpu_model(model, num_gpus) | ||
|
||
# use normal categorical crossentropy | ||
model.compile(loss=keras.losses.categorical_crossentropy, | ||
optimizer=keras.optimizers.Adadelta(), | ||
metrics=['accuracy']) | ||
model.fit(x_train, dense_labels, | ||
batch_size=batch_size, | ||
epochs=epoch) | ||
|
||
|
||
# use normal multi_hot_sparse_categorical_crossentropy | ||
model.compile(loss=keras.losses.multi_hot_sparse_categorical_crossentropy, | ||
optimizer=keras.optimizers.Adadelta(), | ||
metrics=['accuracy']) | ||
|
||
# pad sparse labels into shape length with value -1 to differentiate from normal labels | ||
y_train_pad = pad_sequences(sparse_labels, value=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not add this padding logic inside multi_hot_sparse_categorical_crossentropy calculating method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has to be done at numpy array level, before feeding to the model. Keras can't take a input(y_true) in the form of a list of list, and the list length may vary. (e.g. [[1,2],[0],[0,1,2]]) |
||
model.fit(x_train, y_train_pad, | ||
batch_size=batch_size, | ||
epochs=epoch) | ||
|
||
# speed reference on two losses | ||
outputs = model.predict(x_train) | ||
start = time.time() | ||
loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(dense_labels)) | ||
print("categorical crossentropy loss time per epoch:", time.time() - start) | ||
outputs = model.predict(x_train) | ||
start = time.time() | ||
loss = keras.losses.categorical_crossentropy(K.variable(outputs), K.variable(y_train_pad)) | ||
print("multi hot sparse categorical crossentropy loss time per epoch:", time.time() - start) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2988,6 +2988,66 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): | |
return reshape(KerasSymbol(mx_output), target.shape) | ||
|
||
|
||
@keras_mxnet_symbol | ||
def multi_hot_sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): | ||
"""Calculate Categorical crossentropy with a list of integer targets (multi-labels) | ||
|
||
# Arguments | ||
target: An integer tensor. | ||
output: A tensor resulting from a softmax | ||
(unless `from_logits` is True, in which | ||
case `output` is expected to be the logits). | ||
from_logits: Boolean, whether `output` is the | ||
result of a softmax, or is a tensor of logits. | ||
|
||
# Returns | ||
Output tensor. | ||
|
||
# Example: | ||
``` | ||
# refer to examples/multi_hot_sparse_categorical_crossentropy.py | ||
# for a multi-label classification problem with 3 classes | ||
# target with multi labels in normal categorical crossentropy | ||
>>>target_dense = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]]) | ||
# output from softmax remains the same format | ||
>>>output = np.array([[0.1, 0.4, 0.5], | ||
>>> [0.4, 0.2, 0.4], | ||
>>> [0.7, 0.2, 0.1]]) | ||
# target with multi labels in multi_hot categorical crossentropy | ||
>>>y_true_np2 = np.array([[1, 2], [0, 2],[0]]) | ||
``` | ||
""" | ||
output_dimensions = list(range(ndim(output))) | ||
if axis != -1 and axis not in output_dimensions: | ||
raise ValueError( | ||
'{}{}{}'.format( | ||
'Unexpected channels axis {}. '.format(axis), | ||
'Expected to be -1 or one of the axes of `output`, ', | ||
'which has {} dimensions.'.format(len(int_shape(output))))) | ||
|
||
mx_output = output.symbol | ||
# scale predictions so that the class probabilities of each sample sum to 1 | ||
if from_logits: | ||
mx_output = mx.sym.softmax(mx_output, axis=axis) | ||
else: | ||
mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output, | ||
axis=axis, | ||
keepdims=True)) | ||
# clip to prevent NaN's and Inf's | ||
mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice work. Please document these steps so it will be easier later for maintenance. |
||
|
||
# using control flow ops to iterate output and take target (true label) | ||
_step = lambda data, _: (mx.sym.take(data[0], data[1]), []) | ||
data = [mx_output, target.symbol] | ||
outputs, _ = mx.symbol.contrib.foreach(_step, data, []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a MXNet contrib API do we have an issue to track that this needs to be updated when the stable version is released for foreach? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently we have CI to test MXNet API changes (deprecated or breaking change). It's the same for contrib ops and other normal ops. If it's moved to mx.sym, it will give a deprecated warning. we should be able to catch it. created #173 |
||
|
||
# calculate loss | ||
# check if target is larger than 0, remove padded labels (-1) | ||
outputs = - mx.sym.sum(mx.sym.broadcast_greater_equal(target.symbol, mx.sym.zeros((1, 1))) * | ||
mx.sym.log(outputs), axis=axis) | ||
return KerasSymbol(outputs) | ||
|
||
|
||
@keras_mxnet_symbol | ||
def binary_crossentropy(target, output, from_logits=False): | ||
"""Binary crossentropy between an output tensor and a target tensor. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,22 @@ def test_sparse_categorical_crossentropy_4d(): | |
assert np.isclose(expected_loss, np.mean(loss)) | ||
|
||
|
||
def test_multi_hot_sparse_categorical_crossentropy(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add test for multi_hot_sparse_categorical_accuracy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added, see changes at tests/keras/metrics_test.py |
||
y_true_np = np.array([[0, 1, 1], [1, 0, 1], [1, 0, 0]]) | ||
y_pred_np = np.array([[0.1, 0.4, 0.5], | ||
[0.4, 0.2, 0.4], | ||
[0.7, 0.2, 0.1]]) | ||
y_true_np2 = np.array([[1, 2], [0, 2], [0]]) | ||
loss = K.eval(losses.categorical_crossentropy(K.variable(y_true_np), K.variable(y_pred_np))) | ||
|
||
# pad labels to have the same size, use -1 to differentiate from normal class labels | ||
y_true_np2 = keras.preprocessing.sequence.pad_sequences(y_true_np2, value=-1) | ||
y_pred2 = K.variable(y_pred_np) | ||
y_true2 = K.variable(y_true_np2) | ||
loss2 = K.eval(losses.multi_hot_sparse_categorical_crossentropy(y_true2, y_pred2)) | ||
assert np.allclose(loss, loss2) | ||
|
||
|
||
class MSE_MAE_loss: | ||
"""Loss function with internal state, for testing serialization code.""" | ||
def __init__(self, mse_fraction): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[minor] remove blank line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed