Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][SpecDecode] apply sampling parameters to target probabilities for consistency in rejection sampling. #10198

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

jeongin601
Copy link

@jeongin601 jeongin601 commented Nov 10, 2024

FIX #9834 (link existing issues this PR will resolve)

Problem

The current BatchExpansionTop1Scorer implements a speculative scoring mechanism that uses batch expansion to estimate the probabilities of speculative tokens based on the scoring model. However, in the existing setup, SequenceGroupMetadata applies default sampling parameters (top_p=1.0, temperature=1.0, repetition_penalty=1.0) when generating target probabilities. According to comments in the code, this choice seems to be made since the sampled tokens are not used directly.

Modification

Although we do not directly sample tokens from the target model while scoring, I believe applying consistent sampling parameters to both draft and target probabilities is essential for accurate rejection sampling. The current implementation uses draft probabilities influenced by sampling (filtered by top_p), while target probabilities are not, leading to a mismatch that could affect scoring accuracy. Because the unsampled target probabilities don’t represent actual usage probabilities, I modified the code to apply the same sampling parameters to both draft and target probabilities for consistency in rejection sampling.

In my experiment, this change resulted in a significant difference in the acceptance rate, as shown in the figures below.

Experiment

Setting

  • Target Model / Draft Model: llama3-70B / llama3-8B
  • TP: 4
  • Devices: A100 * 4
  • Total number of requests: 500
  • input length / output length: 1024 / 128
  • sampling parameter: repetition_penalty=1.0, temperature=0.6, top_p=0.9, top_k=-1
  • dataset: c4
  • batch size: 1
  • K: # of speculative tokens)

As-Is

K acceptance rate system efficiency
1 65.1 82.5
2 63.3 68.3
3 62.4 57.4

To-be (applied in this PR)

K acceptance rate system efficiency
1 81.9 91.0
2 80.5 82.3
3 80.2 75.4

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@llsj14
Copy link
Contributor

llsj14 commented Nov 11, 2024

@sroy745 @LiuXiaoxuanPKU @njhill
Would you please check this PR related to the sampling process?

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeongin601 this looks like a very nice finding!

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
@llsj14
Copy link
Contributor

llsj14 commented Nov 12, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.

Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeong_in.bae <jeong_in.bae@navercorp.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
@njhill
Copy link
Member

njhill commented Nov 12, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.

@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@sroy745 sroy745 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 15, 2024
@sroy745
Copy link
Collaborator

sroy745 commented Nov 15, 2024

Adding /ready to kick off the tests and verify nothing else fails from this

@llsj14
Copy link
Contributor

llsj14 commented Nov 15, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.
@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@njhill
Yeah, I also needed to double-check. I think in this part, we might need to use seeds, but I haven't examined seeded_seqs in detail yet.

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

@sroy745
Copy link
Collaborator

sroy745 commented Nov 18, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.
@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@njhill Yeah, I also needed to double-check. I think in this part, we might need to use seeds, but I haven't examined seeded_seqs in detail yet.

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

Hi,
I think this change should not impact the per request seed handling logic in the RejectionSampler. The per request seeds are set here which remains unchanged, hence I am wondering if this should be fine.

cc: @tdoublep who made the change for respecting per request seed in spec-decode worker. @tdoublep can you PTAL and see if this change impacts the per request seeding logic or not.

@jeongin601 there is one test failure in the spec_decoding tests (test_many_k[1-32-2-test_llm_kwargs3-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]). I ran the test locally and it passes. Also from the failure logs it seems transient. Can you please trigger the tests once to see if it passes or not?

@llsj14
Copy link
Contributor

llsj14 commented Nov 18, 2024

Thank you @sroy745, I was able to check correctly after your comments.

I found out that this PR also corrects the seed for 'non_spec_token_ids'. Although I haven't used 'non_spec_token_ids' while utilizing spec decode, if it is used, the seed should be set to match that of 'seq_group_metadata'.

I also confirmed that this section remains unchanged by this PR and is already using the correct sampling parameters. This PR cannot affect the 'seq_group_metadata'(which has per request sampling parameters) as the 'target_seq_group_metadata_list' is simply generated from 'seq_group_metadata.'

Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
@llsj14
Copy link
Contributor

llsj14 commented Nov 19, 2024

test_many_k passed, but test_mlp_e2e_seeded_correctness failed(it didn't raise assertion). I think there shouldn't be any issue with the seed, but we need to check. @jeongin601 will rerun the test first.

What I suspect is that the number of attempts to sample with the same seed may have changed due to this PR. This could affect the output because it causes the generator to use a different part of the random values. If that's the case, I believe the outcome is not incorrect, but we need to verify it.

Signed-off-by: jeongin601 <0200angela@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Sampling parameter fixed issue while doing speculative sampling verification step
4 participants