Skip to content

Commit

Permalink
Fix formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiefen-boop committed Aug 29, 2024
1 parent f2710c9 commit 53cdd9b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
14 changes: 10 additions & 4 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,19 @@ def forward(self, state):
class StaticFusedMOE(torch.nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts


def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
Expand All @@ -278,7 +282,9 @@ def forward(self, hidden_states, w1, w2, score, topk):

for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
w_output = self.w13_list[expert_idx].calc(hidden_states, expert_idx, w1)
w_output = self.w13_list[expert_idx].calc(hidden_states,
expert_idx,
w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
final_hidden_states += w_output * padded_weight
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hpu


logger = init_logger(__name__)


Expand Down Expand Up @@ -76,7 +75,8 @@ def apply(
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group, layer)
use_grouped_topk, num_expert_group, topk_group,
layer)

def forward_cuda(
self,
Expand Down Expand Up @@ -106,11 +106,13 @@ def forward_cuda(
def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool, num_expert_group: Optional[int],
topk_group: Optional[int], layer: Optional[torch.nn.Module],):
topk_group: Optional[int], layer: Optional[torch.nn.Module]
):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
assert layer is not None, 'layer has to be provided on HP'
return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k)

def forward_cpu(self, *args, **kwargs):
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -88,7 +87,8 @@ class INCLinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""

def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False):
def __init__(self, quant_config: INCConfig,
separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.quant_config = quant_config

Expand Down

0 comments on commit 53cdd9b

Please sign in to comment.