-
Notifications
You must be signed in to change notification settings - Fork 65
Add multi hot sparse categorical crossentropy #163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work!
keras/backend/mxnet_backend.py
Outdated
keepdims=True)) | ||
mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) | ||
mx_output = mx.sym.concat(mx.sym.full((target.shape[0],1), 0.5), mx_output) | ||
from mxnet.symbol.contrib import foreach |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move it.
keras/backend/mxnet_backend.py
Outdated
mx_output = mx.sym.broadcast_div(mx_output, mx.sym.sum(mx_output, | ||
axis=axis, | ||
keepdims=True)) | ||
mx_output = mx.sym.clip(mx_output, a_min=epsilon(), a_max=1.0 - epsilon()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work. Please document these steps so it will be easier later for maintenance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, added a few comments inline
break | ||
sparse_label_j = np.where(dense_labels[i] == 1)[0] | ||
sparse_labels.append(sparse_label_j) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[minor] remove blank line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
metrics=['accuracy']) | ||
|
||
# pad sparse labels into shape length with value -1 to differentiate from normal labels | ||
y_train_pad = pad_sequences(sparse_labels, value=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not add this padding logic inside multi_hot_sparse_categorical_crossentropy calculating method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has to be done at numpy array level, before feeding to the model. Keras can't take a input(y_true) in the form of a list of list, and the list length may vary. (e.g. [[1,2],[0],[0,1,2]])
# using control flow ops to iterate output and take target (true label) | ||
_step = lambda data, _: (mx.sym.take(data[0], data[1]), []) | ||
data = [mx_output, target.symbol] | ||
outputs, _ = mx.symbol.contrib.foreach(_step, data, []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a MXNet contrib API do we have an issue to track that this needs to be updated when the stable version is released for foreach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we have CI to test MXNet API changes (deprecated or breaking change). It's the same for contrib ops and other normal ops. If it's moved to mx.sym, it will give a deprecated warning. we should be able to catch it. created #173
@@ -109,6 +109,22 @@ def test_sparse_categorical_crossentropy_4d(): | |||
assert np.isclose(expected_loss, np.mean(loss)) | |||
|
|||
|
|||
def test_multi_hot_sparse_categorical_crossentropy(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add test for multi_hot_sparse_categorical_accuracy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added, see changes at tests/keras/metrics_test.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions. LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for your contribution @roywei
keras/backend/mxnet_backend.py
Outdated
@@ -3017,6 +3017,11 @@ def multi_hot_sparse_categorical_crossentropy(target, output, from_logits=False, | |||
>>>y_true_np2 = np.array([[1, 2], [0, 2],[0]]) | |||
``` | |||
""" | |||
# TODO: remove version check after mxnet 1.3.1 stable release |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please create a tracking issue to remove this check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tracked at #175
Summary
This PR adds a new feature to calculate categorical cross-entropy on multi hot sparse labels
Inputs are softmax predictions and true labels.
It should return same loss as categorical cross-entropy.
Example:
Changes
X3 faster for calculating 3000 classes multi labeled sparse labels. (5 labels at most for each data)
PR Overview