diff --git a/vllm/config.py b/vllm/config.py index b8ec23e030ac..011563038e6b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -751,7 +751,8 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - device: Device on which weights are loaded. + device: Device to which model weights will be loaded, default to + device_config.device """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ffe12d4cd5fb..84529b267ce0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -268,8 +268,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--weights-load-device", type=str, default=EngineArgs.weights_load_device, - choices=["cuda", "neuron", "hpu", "cpu"], - help='Device on which weights are loaded.') + choices=DEVICE_OPTIONS, + help=('Device to which model weights ' + 'will be loaded.')) parser.add_argument( '--config-format', default=EngineArgs.config_format, @@ -843,7 +844,9 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, ) - def create_load_config(self, load_device) -> LoadConfig: + def create_load_config(self, load_device=None) -> LoadConfig: + if load_device is None: + load_device = DeviceConfig(device=self.device).device return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b3274b6d9511..fcff39f79056 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -363,7 +363,7 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - logger.info("Loading weights on %s ...", self.load_config.device) + logger.info("Loading weights on %s...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision,