Skip to content

Commit

Permalink
Make weights_load_device not change EngineArgs.create_load_config() (H…
Browse files Browse the repository at this point in the history
…abanaAI#336)

Some backends rely on calling EngineArgs.create_load_config() directly,
for which we've altered the API. We don't need to alter it to enable
weight load device functionality. This PR fixes it.
  • Loading branch information
kzawora-intel committed Sep 24, 2024
1 parent 73f4b48 commit 9111a80
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,8 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
device: Device on which weights are loaded.
device: Device to which model weights will be loaded, default to
device_config.device
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
Expand Down
9 changes: 6 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--weights-load-device",
type=str,
default=EngineArgs.weights_load_device,
choices=["cuda", "neuron", "hpu", "cpu"],
help='Device on which weights are loaded.')
choices=DEVICE_OPTIONS,
help=('Device to which model weights '
'will be loaded.'))
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
Expand Down Expand Up @@ -843,7 +844,9 @@ def create_model_config(self) -> ModelConfig:
mm_processor_kwargs=self.mm_processor_kwargs,
)

def create_load_config(self, load_device) -> LoadConfig:
def create_load_config(self, load_device=None) -> LoadConfig:
if load_device is None:
load_device = DeviceConfig(device=self.device).device
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def load_model(self, *, model_config: ModelConfig,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
logger.info("Loading weights on %s ...", self.load_config.device)
logger.info("Loading weights on %s...", self.load_config.device)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
Expand Down

0 comments on commit 9111a80

Please sign in to comment.