diff --git a/vllm/hpu/rotary_embed.py b/vllm/hpu/rotary_embed.py index 30a88d68a24af..4b91dd79bda02 100644 --- a/vllm/hpu/rotary_embed.py +++ b/vllm/hpu/rotary_embed.py @@ -107,6 +107,11 @@ def forward(self, positions: torch.Tensor, query: torch.Tensor, else: cos = cos[positions].unsqueeze(2) sin = sin[positions].unsqueeze(2) + if self.dim != self.head_size: + assert (self.head_size % self.dim) == 0 + num = self.head_size // self.dim + sin = sin.repeat(1,1,1,num) + cos = cos.repeat(1,1,1,num) query, key = FusedRoPE.apply(query, cos, sin, 0), FusedRoPE.apply(key, cos, sin, 0) return query.reshape(