diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 8b99f0843aaf6..741cd0c82dc89 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -98,6 +98,9 @@ void paged_attention_v1_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V1(64); break; diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 3a7a9dee916aa..6de8d0bdd5b8d 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -104,6 +104,9 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V2(64); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index e3953c7c45719..e73eca1b345fd 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 7e5e2780d3916..ed321ba9f00c1 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -4,12 +4,17 @@ from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.models.bert import BertEmbeddingModel +from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform MAX_MODEL_LEN = 128 MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") +MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", + "intfloat/multilingual-e5-large") +REVISION_ROBERTA = os.environ.get("REVISION", "main") + @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") @@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner): assert model._pooler.normalize # assert output assert output + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_roberta_model_loading_with_params(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ + with vllm_runner(model_name=MODEL_NAME_ROBERTA, + revision=REVISION_ROBERTA, + dtype="float16", + max_model_len=MAX_MODEL_LEN) as model: + output = model.encode("Write a short story about a robot that" + " dreams for the first time.\n") + + model_config = model.model.llm_engine.model_config + + model_tokenizer = model.model.llm_engine.tokenizer + + # asserts on the bert model config file + assert model_config.encoder_config["max_seq_length"] == 512 + assert not model_config.encoder_config["do_lower_case"] + + # asserts on the pooling config files + assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name + assert model_config.pooler_config.pooling_norm + + # asserts on the tokenizer loaded + assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large" + assert not model_tokenizer.tokenizer_config["do_lower_case"] + + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert isinstance(model, RobertaEmbeddingModel) + assert model._pooler.pooling_type == PoolingType.MEAN + assert model._pooler.normalize + + # assert output + assert output diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index cd920aec6502e..fcdd684168d04 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -13,10 +13,12 @@ "intfloat/e5-mistral-7b-instruct", "BAAI/bge-base-en-v1.5", "BAAI/bge-multilingual-gemma2", + "intfloat/multilingual-e5-large", ] ENCODER_ONLY = [ "BAAI/bge-base-en-v1.5", + "intfloat/multilingual-e5-large", ] diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 6b270ffd5bc00..8df6d4ced9dc6 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -10,7 +10,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [32, 64, 80, 96, 112, 128, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..076f151ffcb61 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -34,7 +34,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 192, 256] + return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 7dbc7fa0aaba4..42dd6119e76f1 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -5,7 +5,7 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -305,14 +305,16 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type = BertEmbedding): super().__init__() - config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - - self.embeddings = BertEmbedding(config) + self.embeddings = embedding_class(config) self.encoder = BertEncoder(config, cache_config, quant_config, @@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() pooler_config = vllm_config.model_config.pooler_config - self.model = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) + self.model = self._build_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self._pooler = self._build_pooler(pooler_config) def forward( self, @@ -415,3 +413,16 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.model.load_weights(weights) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=BertEmbedding) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + return Pooler.from_config_with_defaults(pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f172c06c4a26a..f22d1b04ebf09 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -94,6 +94,8 @@ _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), + "RobertaModel": ("roberta", "RobertaEmbeddingModel"), + "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "LlamaModel": ("llama", "LlamaEmbeddingModel"), diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py new file mode 100644 index 0000000000000..c1dcdd36ec3de --- /dev/null +++ b/vllm/model_executor/models/roberta.py @@ -0,0 +1,117 @@ +from typing import List, Optional + +import torch +from torch import nn +from transformers import RobertaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.sequence import IntermediateTensors + + +class RobertaEmbedding(nn.Module): + + def __init__(self, config: RobertaConfig): + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx) + + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + + # Input embeddings. + inputs_embeds = self.word_embeddings(input_ids) + + # TODO: figure out if there is a better way + # to make to make position ids start at padding_idx + 1 + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + position_ids += self.padding_idx + 1 + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + + # Token type embeddings. (TODO: move off hotpath?) + token_type_embeddings = self.token_type_embeddings( + torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device)) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class RobertaEmbeddingModel(BertEmbeddingModel): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=RobertaEmbedding) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Verify assumption that position are always a sequence from + # 0 to N. (Actually here we just check 0 and N to simplify). + # This is important to fix the position which are assumed to + # start from padding_idx + 1 instead of 0 in the Roberta models. + assert hasattr(attn_metadata, "seq_lens_tensor") + cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0) + start_pos = torch.cat( + (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device), + cumulative[:-1])) + assert len(torch.nonzero(positions[start_pos])) == 0 + end_pos = cumulative - 1 + last_tokens = attn_metadata.seq_lens_tensor - 1 + assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0 + + return super().forward(input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds)