fix overwrite bug when adding symbol to dictionary #5329
+65
−22
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Before submitting
What does this PR do?
Fixes #3064.
Fixes #3705.
Fixes #1309.
TLDR; This PR fixes the bug that duplicates the symbols that were meant to be overwritten in the vocabulary file. See detailed explanation in this blog post.
Expected behavior:
A Dictionary object has an
indices
dict and two lists (symbols
andcounts
). By default, when loading a vocabulary from a file, a Dictionary instance is first created by adding 4 special tokens (<s>
,<pad>
,</s>
and<unk>
in that order). Then, all the entries from the file are appended to the Dictionary. If the vocabulary file already has some of the special tokens, their file entry should contain#fairseq:overwrite
, otherwise a "duplicate" error will be raised at runtime. Furthermore, during preprocessing, the saved dictionary should not contain any of the special symbols.Current behavior:
The
add_symbol
function is responsible for adding the symbols to the Dictionary. It has anoverwrite
argument that is set toTrue
when the corresponding line in the file has#fairseq:overwrite
. Rather than testingif word in self.indices and overwrite
, it is currently testingif word in self.indices and not overwrite
, which makes it ignore the case where the symbol should actually be overwritten. Hence, the symbol is appended to thesymbols
list, and its index is changed in theindices
dict. This results in duplicate symbols and incorrect indices. Generally, only the special symbols will be affected. However, because the number of special tokens is set during initialization, it remains correct.For example, a dictionary with 50K tokens that already has
<s>
,<pad>
,</s>
and<unk>
with the#fairseq:overwrite
tag will end up having 50004 tokens when loaded. This will also propagate to the subsequent model which will have an embedding dimension of 50004 instead of 50K. Also, withfairseq-preprocess
, the resulting dictionary will skip the first 4 special symbols but will still contain the duplicate ones.Domino effects and backward compatibility:
By fixing this bug, dictionary files will be loaded properly. However, this fix might cause problems in pipelines that use existing architectures and pretrained models because of the mismatch in sentencepiece encoding and/or embedding dimension.
For the sake of backward compatibility, a
#fairseq:duplicate
flag is introduced to ensure that duplicates are kept in the dictionary just like the bug. When used withfairseq-preprocess
, the produced dict.txt file will also write#fairseq:duplicate
next to the same symbols.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Yes, I did 🙃