Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Support Cross encoder models #10400

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

maxdebayser
Copy link
Contributor

@maxdebayser maxdebayser commented Nov 17, 2024

This PR contains a proof of concept to add support for Cross Encoder models reusing most of what was done to support embedding models, as well as the chat embeddings API. It's a bit hacky, but I think it helps to have something concrete to iterate and refine the design.

curl -X 'POST'   'http://localhost:8000/v1/embeddings'   -H 'accept: application/json'   -H 'Content-Type: application/json'   -d '{
  "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
  "encoding_format": "float",
  "messages": [{
    "role": "user",
    "content": "What is the capital of France?"
  },
  {
    "role": "user",
    "content": "The capital of France is Paris."
  }]
}'

Response:

{
  "id": "embd-49858b4f279b4f19939ac832e85de528",
  "object": "list",
  "created": 110449,
  "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
  "data": [
    {
      "index": 0,
      "object": "embedding",
      "embedding": [
        9.265625
      ]
    }
  ],
  "usage": {
    "prompt_tokens": 17,
    "total_tokens": 17,
    "completion_tokens": 0,
    "prompt_tokens_details": null
  }
}

I've added a new model type and a new task type. From the point of view of the models and the model runner, this is not strictly necessary at this level because the inputs and outputs are compatible with the embedding models. However at the serving level we need to know that the model has a different task to be able to know for sure that we have to use the text_pair parameter of the tokenizer instead trying to apply a chat template.

What's still missing:

  • Support for cross encoding in the LLM class
  • Tests comparing with sentence-transformers
  • CPU support
  • Roberta cross encoder models

cc: @DarkLight1337 @flaviabeo

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Nov 17, 2024
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an example script to show how to use this API?

QQ: Why do we have to define a separate "cross_encoding" task for this? I think we can keep using "embedding" task if we just override the pooler method instead of defining a new classification_output method.

vllm/model_executor/models/bert.py Outdated Show resolved Hide resolved
@maxdebayser
Copy link
Contributor Author

Can you add an example script to show how to use this API?

Yes, I've added one in the PR description but it's a good idea to add it to the documentation.

QQ: Why do we have to define a separate "cross_encoding" task for this? I think we can keep using "embedding" task if we just override the pooler method instead of defining a new classification_output method.

I thought about this, and the only reason I kept it that way was that in the serving layer I need to know what task is being done because I need to call the tokenizer with tokenizer(text=text1, text_pair=text2). If instead of reusing the chat embeddings API we had a new endpoint just for cross encoding, this wouldn't be necessary. Or perhaps we can add an attribute to some of the config classes to tell the although the task is "embedding", the model is actually a "BertModelForSequenceClassification"

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 17, 2024

I think it may be simpler to make this a separate flag, similar to how we have a flag for multimodal models, rather than creating a new task for it. That way, we won't have to change our internals at all.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 17, 2024

I think having a separate API for this would be cleaner as well - perhaps a Scoring API where we output a single score?

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
@maxdebayser
Copy link
Contributor Author

I've removed the "cross_encoding" task and added a is_cross_encoder property to the ModelConfig class. I've also added support for Roberta models.

Pending TODOs:

  • Add Scoring API
  • Support for cross encoding in the LLM class
  • Tests comparing with sentence-transformers
  • CPU support

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
@maxdebayser
Copy link
Contributor Author

I've added a score() method to the LLM class and added tests for it. I've also fixed the CPU support.

Pending TODOs:

  • Add Scoring API
  • Test Scoring API comparing with sentence-transformers

Comment on lines 809 to 810
texts: Union[SingletonPrompt, Sequence[SingletonPrompt]],
text_pairs: Union[SingletonPrompt, Sequence[SingletonPrompt]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe text_1 and text_2? text_pairs is a bit confusing to me as it suggests that I should pass in a list of tuples.

Comment on lines +475 to +481
if (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
self.default_activation_function = import_from_string(
config.sbert_ce_default_activation_function)()
else:
self.default_activation_function = \
nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factor out this code to transformer_utils?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants