-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Masked NLL interface #466
base: master
Are you sure you want to change the base?
Conversation
@ilkerkesen I had a different implementation for masked-nll over here. It is a bit hacky but the advantage is that it doesn't require another mask array, and probably it is slightly faster because it has no extra indexing operation. The hack works by forcing user to put 0 into the answer array where the answer will be masked. Then, |
I liked yours more. I'm just questioning that do we ever need positive mask/pad indices in the output array just for ease of use? |
function nll(y, a::AbstractArray{<:Integer}, mask; dims=1, average=true) =
a[mask].=0 # No AD overhead in the masking since `a` is not `Param` object
return nll(y,a;dims=dims, average=average)
end For example, user may want to mask different tokens in each iteration.
|
|
What do you suggest then? Should I make another PR? This should work: function nll(y, a::AbstractArray{<:Integer}, mask; dims=1, average=true) =
# where mask is boolean or integer array where 0s or falses are masked
return nll(y,a .* mask;dims=dims, average=average)
end |
The least intrusive (backward compatible) change is to implement Ekin's suggestion of skipping zeros in answer key, implemented in: 2f33926 This does not break any existing code because zero is currently not a valid answer. Users will most likely want to use a positive number for padding rather than zero, otherwise the decoder input will be invalid in a s2s model for example. Before calling nll/accuracy the convention is to prepare an answer key where padding is done via zeros. I believe this is the method used in Keras. Eventually it would be best to implement complete MaskedSequence, PaddedSequence, PackedSequence types like in pytorch. The current solution should be seen as a low level solution. Check out the commit and close PR if satisfactory. P.S. I had to change the behavior of average=false to return a pair (total,count) rather than just total, as this was needed everywhere. Please check and fix if average=false is used in your code. |
This PR adds Masked NLL functionality with two different interfaces. Currently, tests fail because of some multiple dispatch pattern matching issue.