Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HPU: offload logits processing to CPU #358

Merged
merged 4 commits into from
Oct 29, 2024

Conversation

madamczykhabana
Copy link

Due to high dynamicity on logits processing it's better to offload it completely to CPU instead of computing it on HPU.

@madamczykhabana
Copy link
Author

@tae-su-kim , @huijjj
Please check on your end if that PR helps with structured output performance on HPU. In my local testing I saw performance increase from ~60 tok/sec to ~1100 tok/sec on 1k/1k scenario.

Note that when cherry-picking it to v0.8 use vllm/model_executor/guided_decoding/outlines_logits_processors.py from this PR and discard any previous modifications as git by default will leave with_mark_steps and other changes (and calling it for every sample in a batch hurts performance)

@madamczykhabana madamczykhabana added the habana Issues or PRs submitted by Habana Labs label Oct 4, 2024
@tae-su-kim
Copy link

tae-su-kim commented Oct 7, 2024

@madamczykhabana Hi, sorry for the delayed response. We are trying to reproduce the number in both v0.8 branch of ours and habana_main, and current PR seems to fail in habana_main. It collides with commit 7c7714d where MQLLMEngine was introduced.

  File "/workspace/codes/vllm/engine/multiprocessing/client.py", line 552, in _process_request
    lp_bytes = cloudpickle.dumps(logits_processors)
  File "/usr/local/lib/python3.10/dist-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/usr/local/lib/python3.10/dist-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
_pickle.PicklingError: Can't pickle <functools._lru_cache_wrapper object at 0x7fbc567d5010>: it's not the same object as vllm.model_executor.guided_decoding.outlines_logits_processors.BaseLogitsProcessor._get_mask_tensor

It seems like new multiprocessing engine pickles logit processor information using cloudpickle, but cloudpickle is unable to pickle functions decorated with lru_cache (cloudpipe/cloudpickle#178).

I can confirm that this PR works in v0.8 branch of ours (based on vllm-fork v0.5.3) and improves e2e throghput as follows:
(3.1-8B-Instruct, fixed 1K/1K benchmark) 917 tokens/sec to 1233 tokens/sec
(3.1-8B-Instruct, dynamic 1K/1K benchmark) 59 tokens/sec to 1027 tokens/sec

@madamczykhabana madamczykhabana force-pushed the dev/madamczyk/offload_logits branch 2 times, most recently from e0c5216 to b02d483 Compare October 8, 2024 13:07
@madamczykhabana
Copy link
Author

@tae-su-kim thanks for the info. I've just pushed a workaround for the pickling error. Please check if it helps on your end.
Also, I'm seeing some weird behavior with structured output on habana_main, even without this PR. I'm still trying to figure out why I see a a big discrepancy in number of generated tokens.

@tae-su-kim
Copy link

@madamczykhabana Thanks for the prompt update! I will benchmark with multiprocessing again. Can you elaborate more on the discrepancy in number of generated tokens issue? From our side, we observe similar number of tokens: e.g. 952,171 tokens from A100, 959,993 tokens from Gaudi-2 on 1K/1K dynamic benchmark with json guide. We can share our setup if you want.

@madamczykhabana madamczykhabana marked this pull request as ready for review October 28, 2024 10:29
@madamczykhabana
Copy link
Author

@tae-su-kim Sorry for the delay. The token discrepancy was when I was comparing guided-json to non-guided-json. The PR has been rebased and it's ready for review.

@@ -30,11 +30,48 @@
from transformers import PreTrainedTokenizerBase


# Unfortunately we cannot use lru_cache as it breaks pickling
# so we use a simpler implementation
def _cached(fn):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many masks are expected at maximu? Since _create_mask_tensor is limited and the _cached fun is not, won't we get into a situation where we will get None result for least recently used (128+)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore my last comment, the lrucache maxsize does not have effect here, since _cached cache keeps objects.
If you reverse the condition in line 39 the else can be replaced with return cache[args]
if args not in cache:
cache[args] = fn(*args)
return cache[args]

Copy link
Author

@madamczykhabana madamczykhabana Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but at the same time we'd need to calculate the hash and access the map one additional time.

  1. if args in cache
  2. cache[args] = fn(*args)
  3. return cache[args]

In current version you access it twice:

  1. args in cache
  2. cache[args] = result

To be honest I didn't measure how big of a perf impact this might be, but I'd like to be on the safe side.

@madamczykhabana madamczykhabana merged commit 3203bd9 into habana_main Oct 29, 2024
19 checks passed
@madamczykhabana madamczykhabana deleted the dev/madamczyk/offload_logits branch October 29, 2024 07:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
habana Issues or PRs submitted by Habana Labs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants