Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Roberta embedding models #9387

Merged
merged 26 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f7e23fb
support head size 32
maxdebayser Oct 22, 2024
10ebc9e
add support for Roberta models
maxdebayser Oct 15, 2024
b457cc5
fix after refactoring
maxdebayser Nov 11, 2024
3fe28f6
Review suggestions
flaviabeo Nov 12, 2024
5b75f4a
Merge branch 'upstream_main' into roberta
flaviabeo Nov 12, 2024
971acea
Fixes conflicts with new upstream changes
flaviabeo Nov 12, 2024
18a2d58
Merge changes fixes
flaviabeo Nov 12, 2024
40ac579
More fixed related to the upstream merge
flaviabeo Nov 12, 2024
e171896
Adds test for roberta model executor
flaviabeo Nov 12, 2024
55912f9
Asserts for Roberta models instance
flaviabeo Nov 12, 2024
6f06a76
Fix space for linting
flaviabeo Nov 12, 2024
d4c8849
Fix space for linting
flaviabeo Nov 12, 2024
b9e64b1
Modifies test for multilingual-e5-large
flaviabeo Nov 12, 2024
366a992
Fix linting in test
flaviabeo Nov 13, 2024
aed1216
Merge branch 'upstream_main' into roberta
flaviabeo Nov 13, 2024
aae474e
trigger ci
flaviabeo Nov 13, 2024
07c931c
finish generalizing the Bert classes
maxdebayser Nov 13, 2024
4495a50
Skips test for ROCm unsupported platform
flaviabeo Nov 13, 2024
49e8381
fix roberta position_ids
maxdebayser Nov 14, 2024
1267bba
add assert to verify assumption
maxdebayser Nov 14, 2024
49cc57b
improve assert
maxdebayser Nov 14, 2024
0f334ae
add model to embedding test
maxdebayser Nov 14, 2024
f27aae1
Remove encoder embedding model for compile test
maxdebayser Nov 14, 2024
44a9d22
trigger ci
maxdebayser Nov 14, 2024
9f31bd5
trigger ci
maxdebayser Nov 14, 2024
80ead23
trigger ci
maxdebayser Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
Expand Down Expand Up @@ -903,6 +906,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down
4 changes: 2 additions & 2 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"BAAI/bge-multilingual-gemma2"
]

ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5",
"BAAI/bge-base-en-v1.5"
]


Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()
self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
Expand Down Expand Up @@ -422,3 +423,9 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def _build_model(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
return BertModel(config, cache_config, quant_config, BertEmbedding)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
Expand Down
85 changes: 85 additions & 0 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Optional

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import CacheConfig, PoolerConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import (BertEmbedding, BertEmbeddingModel,
BertEncoder, BertModel)


class RobertaModel(BertModel):

def __init__(
self,
config: RobertaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
# Skip BertModel.__init__()
nn.Module.__init__(self)
self.embeddings = RobertaEmbedding(config)
self.encoder = BertEncoder(config, cache_config, quant_config)


class RobertaEmbedding(BertEmbedding):

def __init__(self, config: RobertaConfig):
# Skip BertEmbedding.__init__()
nn.Module.__init__(self)
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)

self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )

self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")


class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the RobertaModel and provides an interface for
embedding operations and customized pooling functions.

Attributes:
model: An instance of RobertaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self,
config: RobertaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
nn.Module.__init__(self)
self.model = RobertaModel(config, cache_config, quant_config)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

def _build_model(self,
config: RobertaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
return BertModel(config, cache_config, quant_config, RobertaEmbedding)