Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
  • Loading branch information
jikunshang committed Nov 13, 2024
1 parent 4cd3598 commit 7e1e8a0
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,13 @@ def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):

def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
if module.__class__.__name__.endswith(suffix):
module.original_forward = module.forward

def new_forward(self, *args, **kwargs):
ret = self.original_forward(*args, **kwargs)
def forward_hook(module, args, output):
htorch.core.mark_step()
return ret
return output

module.register_forward_hook(forward_hook)

module.forward = new_forward.__get__(module)
for child_name, child_module in module.named_children():
modify_decoder_layer(child_module)

Expand Down

0 comments on commit 7e1e8a0

Please sign in to comment.