Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix guided sampling with outlines #226

Conversation

tae-su-kim
Copy link

This is a rebase of PR #153 to habana_main due to the deprecation of habana_next.

Current habana_main includes guided decoding related code from vllm, and the feature is already there in the openAI api endpoint. However, guided decoding currently fails to run with following error:

...
 File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1585, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/workspace/codes/vllm/model_executor/layers/sampler.py", line 138, in forward
    sample_results, maybe_sampled_tokens_tensor = _sample(
  File "/workspace/codes/vllm/model_executor/layers/sampler.py", line 711, in _sample
    return _sample_with_torch(
  File "/workspace/codes/vllm/model_executor/layers/sampler.py", line 592, in _sample_with_torch
    sample_results = _greedy_sample(seq_groups, greedy_samples)
  File "/workspace/codes/vllm/model_executor/layers/sampler.py", line 336, in _greedy_sample
    samples_lst = samples.tolist()
RuntimeError: synNodeCreateWithId failed for node: strided_insert with synStatus 1 [Invalid argument]. .

This PR suggests to use masked_fill rather than _add for the masking process of guided decode. With this PR, openai endpoint supports guided decoding. For example,

Input:

payload = {
        "model": "/models/Meta-Llama-3-8B-Instruct",
        "messages": [
            {"role": "user", "content": "reply negatively."}
        ],
        "best_of": best_of,
        "use_beam_search": use_beam_search,
        "temperature": 0.0,
        "top_p": 1.0,
        "guided_regex": "[Pp]ositive format |[Nn]egative format",
}

Output:

{'id': 'cmpl-f3e792eb0197492a8d7eec4bb9916936', 'object': 'chat.completion', 'created': 1722847036, 'model': '/models/Meta-Llama-3-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'Negative format'}, 'logprobs': None, 'finish_reason': 'stop', 'stop_reason': None}], 'usage': {'prompt_tokens': 14, 'total_tokens': 21, 'completion_tokens': 7}}

@tae-su-kim
Copy link
Author

tae-su-kim commented Sep 9, 2024

We fixed abnormal latency overhead with commit 6d57c18.

Rough benchmark is as follows:

Version Prefill Latency Decode Latency
w/o guided_decode 53.1s 962.6s
commit 4bea4e3 156.9s 1322.5s
commit 6d57c18 53.2s 1149.5s
Setup: llama-3-8b, greedy sampling, max_num_seqs 256, 1k request, QPS -1
prefill_latency: latency(1k random input tokens / 1 output token)
decode_latency: latency(1k random input tokens / 1k output tokens without eos) - prefill_latency

@@ -127,7 +127,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
class CFGLogitsProcessor(BaseLogitsProcessor):

@classmethod
@cache()
@lru_cache(maxsize=32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruff static analysis found that "cache" imported at the top of file is now not used, please remove the import

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am benchmarking the performance of .add versus .masked_fill to determine which has better throughput. I'll resolve the conflict based on the benchmark and fix the ruff issue before merging. I'll let you know once it's completed. Apologies for the delayed response!

@michalkuligowski michalkuligowski added the external Issues or PRs submitted by external users label Sep 20, 2024
@tae-su-kim
Copy link
Author

@michalkuligowski I removed the import and resolved the conflict based on benchmark results. Both .add and .masked_fill based implementations significantly degrade decode throughput, so I opted to stick with the current code. If any optimizations are possible, I will open another PR.

@michalkuligowski
Copy link

Ruff fails on unsorted imports. BTW is outlines version bump required now after those changes?

@tae-su-kim
Copy link
Author

There are some marginal throughput difference, but I think most of the updates in this PR is already here. I will close this PR for now. Thank you!

@tae-su-kim tae-su-kim closed this Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
external Issues or PRs submitted by external users
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants