-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
to make repetition penalty faster #442
base: habana_main
Are you sure you want to change the base?
Conversation
masked_fill instead of boolean index; third add paddings to prompt tokens and output tokens to reduce re-compile.
655018e
to
ec38d67
Compare
max_len_align=max_len_align) | ||
|
||
tensor = torch.from_numpy(padded_x).to(device) | ||
if pin_memory: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not needed, since this method is called for hpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HI Michal, I removed the device check logic here and kept the pin_memory check. In this way, this method behavior is exactly the same to the un-aligned version.
if not current_platform.is_hpu(): | ||
tensor = tensor.pin_memory() | ||
else: | ||
tensor = tensor.pin_memory("hpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed, as it wont be called now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Michal, this method make_tensor_with_pad is still called from different places. It is replaced by make_tensor_with_pad_align in the repetition penaly. So I think we still need the check here.
This PR is to fix very slow sampling process when repetition penalty is set.
The fix includes:
Before the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%|5/5 [03:24<00:00, 40.99s/it]
Avg latency: 40.98862759781768 seconds
10% percentile latency: 11.699748958216514 seconds
25% percentile latency: 11.73845003999304 seconds
50% percentile latency: 11.801458386995364 seconds
75% percentile latency: 11.861465670051984 seconds
90% percentile latency: 99.46527566103033 seconds
99% percentile latency: 152.02756165561732 seconds
After the fix:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), guided_decoding=None
Warming up...
Profiling iterations: 100%| 5/5 [00:57<00:00, 11.59s/it]
Avg latency: 11.58703240059549 seconds
10% percentile latency: 11.444069900200702 seconds
25% percentile latency: 11.511425047006924 seconds
50% percentile latency: 11.525146245025098 seconds
75% percentile latency: 11.556680046953261 seconds
90% percentile latency: 11.788318535778672 seconds
99% percentile latency: 11.927301629073918 seconds
Testing code is by: https://github.com/ccrhx4/huanxing.vllm-fork/blob/slow_repetition_penalty/benchmarks/reproduce.sh