Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Roberta embedding models #9387

Merged
merged 26 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f7e23fb
support head size 32
maxdebayser Oct 22, 2024
10ebc9e
add support for Roberta models
maxdebayser Oct 15, 2024
b457cc5
fix after refactoring
maxdebayser Nov 11, 2024
3fe28f6
Review suggestions
flaviabeo Nov 12, 2024
5b75f4a
Merge branch 'upstream_main' into roberta
flaviabeo Nov 12, 2024
971acea
Fixes conflicts with new upstream changes
flaviabeo Nov 12, 2024
18a2d58
Merge changes fixes
flaviabeo Nov 12, 2024
40ac579
More fixed related to the upstream merge
flaviabeo Nov 12, 2024
e171896
Adds test for roberta model executor
flaviabeo Nov 12, 2024
55912f9
Asserts for Roberta models instance
flaviabeo Nov 12, 2024
6f06a76
Fix space for linting
flaviabeo Nov 12, 2024
d4c8849
Fix space for linting
flaviabeo Nov 12, 2024
b9e64b1
Modifies test for multilingual-e5-large
flaviabeo Nov 12, 2024
366a992
Fix linting in test
flaviabeo Nov 13, 2024
aed1216
Merge branch 'upstream_main' into roberta
flaviabeo Nov 13, 2024
aae474e
trigger ci
flaviabeo Nov 13, 2024
07c931c
finish generalizing the Bert classes
maxdebayser Nov 13, 2024
4495a50
Skips test for ROCm unsupported platform
flaviabeo Nov 13, 2024
49e8381
fix roberta position_ids
maxdebayser Nov 14, 2024
1267bba
add assert to verify assumption
maxdebayser Nov 14, 2024
49cc57b
improve assert
maxdebayser Nov 14, 2024
0f334ae
add model to embedding test
maxdebayser Nov 14, 2024
f27aae1
Remove encoder embedding model for compile test
maxdebayser Nov 14, 2024
44a9d22
trigger ci
maxdebayser Nov 14, 2024
9f31bd5
trigger ci
maxdebayser Nov 14, 2024
80ead23
trigger ci
maxdebayser Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down
9 changes: 9 additions & 0 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class TestSetting:
method="encode",
fullgraph=True,
),
TestSetting(
model="intfloat/multilingual-e5-large",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
Expand Down
42 changes: 42 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")

MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
"intfloat/multilingual-e5-large")
REVISION_ROBERTA = os.environ.get("REVISION", "main")


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
Expand Down Expand Up @@ -48,3 +53,40 @@ def test_model_loading_with_params(vllm_runner):
assert model._pooler.normalize
# assert output
assert output


def test_roberta_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert not model_config.encoder_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize

# assert output
assert output
2 changes: 1 addition & 1 deletion vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,16 @@ def forward(self, hidden_states: torch.Tensor,

class BertModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
Expand Down Expand Up @@ -415,3 +417,7 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def _build_model(self, vllm_config: VllmConfig):
return BertModel(vllm_config=vllm_config,
embedding_class=BertEmbedding)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
Expand Down
73 changes: 73 additions & 0 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import (BertEmbedding, BertEmbeddingModel,
BertEncoder, BertModel)


class RobertaModel(BertModel):

def __init__(self, vllm_config: VllmConfig):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = RobertaEmbedding(config)
self.encoder = BertEncoder(config, cache_config, quant_config)


class RobertaEmbedding(BertEmbedding):

def __init__(self, config: RobertaConfig):
# Skip BertEmbedding.__init__()
nn.Module.__init__(self)
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)

self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )

self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")


class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the RobertaModel and provides an interface for
embedding operations and customized pooling functions.

Attributes:
model: An instance of RobertaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig) -> None:
nn.Module.__init__(self)
pooler_config = vllm_config.model_config.pooler_config
self.model = RobertaModel(vllm_config=vllm_config)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

def _build_model(self, vllm_config: VllmConfig):
return BertModel(vllm_config=vllm_config,
embedding_class=RobertaEmbedding)