Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to make repetition penalty faster #442

Open
wants to merge 5 commits into
base: habana_main
Choose a base branch
from

Conversation

ccrhx4
Copy link

@ccrhx4 ccrhx4 commented Oct 29, 2024

This PR is to fix very slow sampling process when repetition penalty is set.

The fix includes:

  1. Enable pin_memory on HPU
  2. Padding prompt tokens and output_tokens to avoid recompile
  3. Replace slow ops

Before the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%|5/5 [03:24<00:00, 40.99s/it]
Avg latency: 40.98862759781768 seconds
10% percentile latency: 11.699748958216514 seconds
25% percentile latency: 11.73845003999304 seconds
50% percentile latency: 11.801458386995364 seconds
75% percentile latency: 11.861465670051984 seconds
90% percentile latency: 99.46527566103033 seconds
99% percentile latency: 152.02756165561732 seconds

After the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%| 5/5 [00:57<00:00, 11.59s/it]
Avg latency: 11.58703240059549 seconds
10% percentile latency: 11.444069900200702 seconds
25% percentile latency: 11.511425047006924 seconds
50% percentile latency: 11.525146245025098 seconds
75% percentile latency: 11.556680046953261 seconds
90% percentile latency: 11.788318535778672 seconds
99% percentile latency: 11.927301629073918 seconds

Testing code is by: https://github.com/ccrhx4/huanxing.vllm-fork/blob/slow_repetition_penalty/benchmarks/reproduce.sh

vllm/worker/cache_engine.py Outdated Show resolved Hide resolved
vllm/worker/cache_engine.py Outdated Show resolved Hide resolved
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not needed, since this method is called for hpu

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HI Michal, I removed the device check logic here and kept the pin_memory check. In this way, this method behavior is exactly the same to the un-aligned version.

Comment on lines +857 to +860
if not current_platform.is_hpu():
tensor = tensor.pin_memory()
else:
tensor = tensor.pin_memory("hpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed, as it wont be called now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Michal, this method make_tensor_with_pad is still called from different places. It is replaced by make_tensor_with_pad_align in the repetition penaly. So I think we still need the check here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants