From da3f5c927e3e047b13e31975da05b64ebd29f168 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sun, 18 Aug 2024 23:59:31 -0700 Subject: [PATCH] split gpt_bigcode forward to small graph to avoid gc issue --- vllm/model_executor/models/gpt_bigcode.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fc4e13bbb0e68..3ae3c8c8f712c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,6 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -224,9 +225,14 @@ def forward( position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + if current_platform.is_hpu(): + htorch.core.mark_step() hidden_states = self.ln_f(hidden_states) return hidden_states