Skip to content

Commit

Permalink
Fix Qwen2 OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
shepark committed Aug 30, 2024
1 parent 17cd625 commit 58777a3
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput

Expand Down Expand Up @@ -259,6 +260,11 @@ def forward(
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)

if current_platform.is_hpu():
import habana_frameworks.torch as htorch
htorch.core.mark_step()

residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand All @@ -269,6 +275,9 @@ def forward(
attn_metadata,
residual,
)
if current_platform.is_hpu():
htorch.core.mark_step()

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down

0 comments on commit 58777a3

Please sign in to comment.