From 9111a8059b699344313f21a4314562d9405ec991 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 24 Sep 2024 18:07:01 +0200 Subject: [PATCH] Make weights_load_device not change EngineArgs.create_load_config() (#336) Some backends rely on calling EngineArgs.create_load_config() directly, for which we've altered the API. We don't need to alter it to enable weight load device functionality. This PR fixes it. --- vllm/config.py | 3 ++- vllm/engine/arg_utils.py | 9 ++++++--- vllm/model_executor/model_loader/loader.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index b8ec23e030ac..011563038e6b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -751,7 +751,8 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - device: Device on which weights are loaded. + device: Device to which model weights will be loaded, default to + device_config.device """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ffe12d4cd5fb..84529b267ce0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -268,8 +268,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--weights-load-device", type=str, default=EngineArgs.weights_load_device, - choices=["cuda", "neuron", "hpu", "cpu"], - help='Device on which weights are loaded.') + choices=DEVICE_OPTIONS, + help=('Device to which model weights ' + 'will be loaded.')) parser.add_argument( '--config-format', default=EngineArgs.config_format, @@ -843,7 +844,9 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, ) - def create_load_config(self, load_device) -> LoadConfig: + def create_load_config(self, load_device=None) -> LoadConfig: + if load_device is None: + load_device = DeviceConfig(device=self.device).device return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b3274b6d9511..fcff39f79056 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -363,7 +363,7 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - logger.info("Loading weights on %s ...", self.load_config.device) + logger.info("Loading weights on %s...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision,