-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix coverage penalty (wu) #119
Conversation
It seems that the return_attn path is also broken. So I've made sure I don't go through this path when applying penalty coverage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave these few comments to try and explain the changes I have made
if self._cov_pen: # coverage penalty | ||
self._prev_penalty = torch.zeros_like(self.topk_log_probs) | ||
self._coverage = current_attn | ||
else: | ||
self._coverage = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self_coverage
is constructed sequentially by concatenation as a tensor of size (beam_size x current_batch_size, T+ 1, N)
where T
is the number of decoding steps and N
is the length of the source. In this way, for a given decoding step t
, the slice self._coverage[k, t+1, :]
represents the vector of the attentions granted by the target t-th token to the source tokens.
In the first decoding step, the coverage is initialized with a vector of zeros. When calculating the penalty, some zeros will be added to the sum of attentions and it will have no impact on the final result.
# shape: (batch_size x beam_size, 1) | ||
self._coverage = torch.cat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then in the following decoding, steps, the attentions granted by the current target tokens to the sources are retrieved with attn[:, :, : self._coverage.size(-1)]
, as self._coverage.size(-1)
is in fact equal to N
(only the N
first attentions are kept for the calculation of the coverage). It appears that the attn tensor is “naturally” pruned when the decoding of one of the source sequences in the batch is complete. However, to ensure consistent coverage, only self.select_indices
are retained on the first dimension.
cov_penalty = self.global_scorer.cov_penalty( | ||
self._coverage, beta=self.global_scorer.beta | ||
) | ||
self.topk_scores -= cov_penalty.view(_B, self.beam_size).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The coverage penalty is then calculated. As it is negative by construction, it is added to the hypothesis probabilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best hypotheses will be re-ranked by taking into account the penalty.
@@ -65,7 +65,7 @@ def coverage_wu(self, cov, beta=0.0): | |||
then the ``seq_len`` axis probably sums to (almost) 1. | |||
""" | |||
|
|||
penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) | |||
penalty = torch.min(cov.sum(1), cov.clone().sum(1).fill_(1.0)).log().sum(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First the attentions granted by the target tokens to each source token are summed up with cov.sum(1)
, the sums are compared to 1, and the log of the minimum is calculated. We get a tensor of size (beam_size x current_batch_size, N)
Then the logs logs are summed up for each source token with sum(1)
. We get a 1-dim tensor of size (beam_size x current_batch_size)
Finally it is multiplied by beta to get the penalty of each hypothesis.
@@ -357,30 +356,34 @@ def advance(self, log_probs, attn): | |||
self.maybe_update_forbidden_tokens() | |||
|
|||
if self.return_attention or self._cov_pen: | |||
current_attn = attn[self.select_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current_attn
is not used anymore for the calculation of the coverage.
if step == 1: | ||
self.alive_attn = current_attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It leads to an error so it is taken off. The alive attn is only used for the return attn path, so it is separated inside the coverage path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems weird. Isn't it needed in update_finished
and remove_finished_batches
?
self._coverage + attn, self.global_scorer.beta | ||
).view(_B, self.beam_size) | ||
cov_penalty = self.global_scorer.cov_penalty(attn, self.global_scorer.beta) | ||
self.topk_log_probs -= cov_penalty.view(_B, self.beam_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have clearly understood this part. What do _stepwise_cov_pen
and _prev_penalty
mean ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is what I gathered quickly looking at the code:
-
_stepwise_cov_pen
is roughly thestepwise_penalty
condition, which applies penalty at each step
Lines 44 to 47 in 4a3d0dd
stepwise_penalty: bool = Field( default=False, description="Apply coverage penalty at every decoding step. Helpful for summary penalty.", ) -
in that context,
_prev_penalty
is just the state at the previous step, so that we accumulate along
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to understand the underlying issue here. So not trivial to grasp the fixes either.
Can you elaborate ? (Error traces, weird behaviours encountered, etc.)
Maybe in top PR comment, to clarify the context, the issue faced and how this PR intends to fix it.
self._coverage + attn, self.global_scorer.beta | ||
).view(_B, self.beam_size) | ||
cov_penalty = self.global_scorer.cov_penalty(attn, self.global_scorer.beta) | ||
self.topk_log_probs -= cov_penalty.view(_B, self.beam_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is what I gathered quickly looking at the code:
-
_stepwise_cov_pen
is roughly thestepwise_penalty
condition, which applies penalty at each step
Lines 44 to 47 in 4a3d0dd
stepwise_penalty: bool = Field( default=False, description="Apply coverage penalty at every decoding step. Helpful for summary penalty.", ) -
in that context,
_prev_penalty
is just the state at the previous step, so that we accumulate along
if step == 1: | ||
self.alive_attn = current_attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems weird. Isn't it needed in update_finished
and remove_finished_batches
?
Thanks for your review ! |
When I use
In my decoding config, I have the following error
I have tried to fix it but it revamps a bit the penalty calculation.
I calculate it "from scratch" at each decoding step, using the attentions.