Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coverage penalty (wu) #119

Closed
wants to merge 4 commits into from

Conversation

l-k-11235
Copy link
Contributor

@l-k-11235 l-k-11235 commented Sep 25, 2024

When I use

beam_size: 2
coverage_penalty: 'wu'
beta: 1

In my decoding config, I have the following error

Traceback (most recent call last):
  File "/usr/local/bin/eole", line 33, in <module>
    sys.exit(load_entry_point('EOLE', 'console_scripts', 'eole')())
  File "/workdir/eole/eole/bin/main.py", line 39, in main
    bin_cls.run(args)
  File "/workdir/eole/eole/bin/run/predict.py", line 42, in run
    predict(config)
  File "/workdir/eole/eole/bin/run/predict.py", line 18, in predict
    _, _, _ = engine.infer_file()
  File "/workdir/eole/eole/inference_engine.py", line 38, in infer_file
    scores, estims, preds = self._predict(infer_iter)
  File "/workdir/eole/eole/inference_engine.py", line 170, in _predict
    scores, estims, preds = self.predictor._predict(
  File "/workdir/eole/eole/predict/inference.py", line 475, in _predict
    batch_data = self.predict_batch(batch, attn_debug)
  File "/workdir/eole/eole/predict/generator.py", line 71, in predict_batch
    return self._predict_batch_with_strategy(batch, decode_strategy)
  File "/workdir/eole/eole/predict/generator.py", line 149, in _predict_batch_with_strategy
    decode_strategy.advance(log_probs, attn)
  File "/workdir/eole/eole/predict/beam_search.py", line 437, in advance
    super(BeamSearchLM, self).advance(log_probs, attn)
  File "/workdir/eole/eole/predict/beam_search.py", line 383, in advance
    self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()
RuntimeError: shape '[1, 2]' is invalid for input of size 518

I have tried to fix it but it revamps a bit the penalty calculation.
I calculate it "from scratch" at each decoding step, using the attentions.

@l-k-11235 l-k-11235 changed the title fixed coverage_wu Fix coverage penalty (wu) Sep 25, 2024
@l-k-11235
Copy link
Contributor Author

It seems that the return_attn path is also broken. So I've made sure I don't go through this path when applying penalty coverage.

Copy link
Contributor Author

@l-k-11235 l-k-11235 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave these few comments to try and explain the changes I have made

if self._cov_pen: # coverage penalty
self._prev_penalty = torch.zeros_like(self.topk_log_probs)
self._coverage = current_attn
else:
self._coverage = torch.zeros(
Copy link
Contributor Author

@l-k-11235 l-k-11235 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self_coverage is constructed sequentially by concatenation as a tensor of size (beam_size x current_batch_size, T+ 1, N) where T is the number of decoding steps and N is the length of the source. In this way, for a given decoding step t, the slice self._coverage[k, t+1, :] represents the vector of the attentions granted by the target t-th token to the source tokens.
In the first decoding step, the coverage is initialized with a vector of zeros. When calculating the penalty, some zeros will be added to the sum of attentions and it will have no impact on the final result.

# shape: (batch_size x beam_size, 1)
self._coverage = torch.cat(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then in the following decoding, steps, the attentions granted by the current target tokens to the sources are retrieved with attn[:, :, : self._coverage.size(-1)], as self._coverage.size(-1) is in fact equal to N (only the N first attentions are kept for the calculation of the coverage). It appears that the attn tensor is “naturally” pruned when the decoding of one of the source sequences in the batch is complete. However, to ensure consistent coverage, only self.select_indices are retained on the first dimension.

cov_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta
)
self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The coverage penalty is then calculated. As it is negative by construction, it is added to the hypothesis probabilities.

Copy link
Contributor Author

@l-k-11235 l-k-11235 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best hypotheses will be re-ranked by taking into account the penalty.

@@ -65,7 +65,7 @@ def coverage_wu(self, cov, beta=0.0):
then the ``seq_len`` axis probably sums to (almost) 1.
"""

penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1)
penalty = torch.min(cov.sum(1), cov.clone().sum(1).fill_(1.0)).log().sum(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First the attentions granted by the target tokens to each source token are summed up with cov.sum(1) , the sums are compared to 1, and the log of the minimum is calculated. We get a tensor of size (beam_size x current_batch_size, N)
Then the logs logs are summed up for each source token with sum(1). We get a 1-dim tensor of size (beam_size x current_batch_size)
Finally it is multiplied by beta to get the penalty of each hypothesis.

@@ -357,30 +356,34 @@ def advance(self, log_probs, attn):
self.maybe_update_forbidden_tokens()

if self.return_attention or self._cov_pen:
current_attn = attn[self.select_indices]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current_attn is not used anymore for the calculation of the coverage.

if step == 1:
self.alive_attn = current_attn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It leads to an error so it is taken off. The alive attn is only used for the return attn path, so it is separated inside the coverage path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems weird. Isn't it needed in update_finished and remove_finished_batches?

self._coverage + attn, self.global_scorer.beta
).view(_B, self.beam_size)
cov_penalty = self.global_scorer.cov_penalty(attn, self.global_scorer.beta)
self.topk_log_probs -= cov_penalty.view(_B, self.beam_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have clearly understood this part. What do _stepwise_cov_pen and _prev_penalty mean ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I gathered quickly looking at the code:

  • _stepwise_cov_pen is roughly the stepwise_penalty condition, which applies penalty at each step

    stepwise_penalty: bool = Field(
    default=False,
    description="Apply coverage penalty at every decoding step. Helpful for summary penalty.",
    )

  • in that context, _prev_penalty is just the state at the previous step, so that we accumulate along

Copy link
Contributor

@francoishernandez francoishernandez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to understand the underlying issue here. So not trivial to grasp the fixes either.
Can you elaborate ? (Error traces, weird behaviours encountered, etc.)
Maybe in top PR comment, to clarify the context, the issue faced and how this PR intends to fix it.

self._coverage + attn, self.global_scorer.beta
).view(_B, self.beam_size)
cov_penalty = self.global_scorer.cov_penalty(attn, self.global_scorer.beta)
self.topk_log_probs -= cov_penalty.view(_B, self.beam_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I gathered quickly looking at the code:

  • _stepwise_cov_pen is roughly the stepwise_penalty condition, which applies penalty at each step

    stepwise_penalty: bool = Field(
    default=False,
    description="Apply coverage penalty at every decoding step. Helpful for summary penalty.",
    )

  • in that context, _prev_penalty is just the state at the previous step, so that we accumulate along

if step == 1:
self.alive_attn = current_attn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems weird. Isn't it needed in update_finished and remove_finished_batches?

@l-k-11235
Copy link
Contributor Author

l-k-11235 commented Oct 4, 2024

Thanks for your review !
I hadn't understood the use of accumulated coverage, which consists of keeping a step-by-step table of the sum of attentions of each hypothesis in the source tokens, and which changes the penalty formula from that of the article... in the correction I use the table of attentions of each hypothesis token on the source, and I recalculate the coverage “from zero” at each step (using the formula of the article), which is not optimal. So I open another PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants