diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 49e57a847e847..603579d41946e 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -131,13 +131,11 @@ def __init__(self, torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_experts, self.hidden_size, self.intermediate_size, - device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index d758333b22388..71362299a9fcf 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -82,21 +82,15 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - device="cuda", - dtype=self.params_dtype, - )) + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.d_model, + dtype=self.params_dtype)) self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) + torch.empty(self.num_total_experts, + self.d_model, + self.intermediate_size, + dtype=self.params_dtype)) set_weight_attrs( self.ws,