From 972f3bc8f0a1a11ab84a0edc59bc9e009e29d003 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 29 Aug 2024 00:42:19 +0000 Subject: [PATCH 1/4] remove arctic gpu hardcode Signed-off-by: Chendi.Xue --- vllm/model_executor/models/arctic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 49e57a847e847..6d92e7597eabf 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -131,14 +131,14 @@ def __init__(self, torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", - dtype=self.params_dtype)) + dtype=self.params_dtype), + , requires_grad=False) self.w2s = nn.Parameter( torch.empty(self.num_experts, self.hidden_size, self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) + dtype=self.params_dtype), + requires_grad=False) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, }) From 778d7e64dcaf2728e9688b1c8d18bed600dab243 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 29 Aug 2024 00:42:50 +0000 Subject: [PATCH 2/4] remove dbrx gpu hardcode Signed-off-by: Chendi.Xue --- vllm/model_executor/models/dbrx.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index d758333b22388..463003d0bba7b 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -86,17 +86,15 @@ def __init__( self.num_total_experts, 2 * self.intermediate_size, self.d_model, - device="cuda", dtype=self.params_dtype, - )) + ), requires_grad=False) self.w2s = nn.Parameter( torch.empty( self.num_total_experts, self.d_model, self.intermediate_size, - device="cuda", dtype=self.params_dtype, - )) + ), requires_grad=False) set_weight_attrs( self.ws, From fb98cad144e9654abcc698c4b56d793d1d56cce7 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 3 Sep 2024 16:30:17 +0000 Subject: [PATCH 3/4] Remove requires_grad=False Signed-off-by: Chendi.Xue --- vllm/model_executor/models/arctic.py | 6 ++---- vllm/model_executor/models/dbrx.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 6d92e7597eabf..603579d41946e 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -131,14 +131,12 @@ def __init__(self, torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_size, - dtype=self.params_dtype), - , requires_grad=False) + dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_experts, self.hidden_size, self.intermediate_size, - dtype=self.params_dtype), - requires_grad=False) + dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, }) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 463003d0bba7b..e3a45b26d909b 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -86,15 +86,13 @@ def __init__( self.num_total_experts, 2 * self.intermediate_size, self.d_model, - dtype=self.params_dtype, - ), requires_grad=False) + dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty( self.num_total_experts, self.d_model, self.intermediate_size, - dtype=self.params_dtype, - ), requires_grad=False) + dtype=self.params_dtype)) set_weight_attrs( self.ws, From 046cb25a4a549f985105152cb3dec2c25279252e Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 12 Sep 2024 15:23:51 +0000 Subject: [PATCH 4/4] Fix yapf detected format issue Signed-off-by: Chendi.Xue --- vllm/model_executor/models/dbrx.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e3a45b26d909b..71362299a9fcf 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -82,17 +82,15 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - dtype=self.params_dtype)) + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.d_model, + dtype=self.params_dtype)) self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - dtype=self.params_dtype)) + torch.empty(self.num_total_experts, + self.d_model, + self.intermediate_size, + dtype=self.params_dtype)) set_weight_attrs( self.ws,