Skip to content

Commit

Permalink
Merge pull request #785 from 50h100a/loadfix
Browse files Browse the repository at this point in the history
Stream models rather than load them completely into RAM.
  • Loading branch information
50h100a authored Oct 21, 2024
2 parents 4d3d819 + 9022c6d commit be4fa5b
Show file tree
Hide file tree
Showing 54 changed files with 91 additions and 218 deletions.
25 changes: 24 additions & 1 deletion aphrodite/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import threading
import uuid
import warnings
import math
from asyncio import FIRST_COMPLETED, ensure_future
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
Type, TypeVar, Union, overload, Iterable)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -1117,3 +1118,25 @@ def progress_bar(iterable, desc="Processing"):
progress.update(task, advance=1)
else:
yield from iterable

def tensor_progress_bar(iterable:Iterable[Tuple[str, torch.Tensor]],
final_bytes:int, desc="Processing"):
show_progress = get_tensor_model_parallel_rank() == 0
units = 1024 ** (int(math.log2(final_bytes)) // 10)

if show_progress:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
) as progress:
task = progress.add_task(f"[cyan]{desc}", total=final_bytes/units)
for item in iterable:
steps = item[1].element_size() * item[1].nelement() / units
yield item
progress.update(task, advance=steps)
else:
yield from iterable
16 changes: 10 additions & 6 deletions aphrodite/modeling/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.utils import is_pin_memory_available
from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
from aphrodite.modeling.model_loader.tensorizer import (
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
serialize_aphrodite_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -303,10 +303,12 @@ def _prepare_weights(self, model_name_or_path: str,
def _get_weights_iterator(
self, model_name_or_path: str, revision: Optional[str],
fall_back_to_pt: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]:
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], int]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision, fall_back_to_pt)
est_weight_bytes = sum(os.path.getsize(f)
for f in hf_weights_files)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
Expand All @@ -329,7 +331,7 @@ def _xla_weights_iterator(iterator: Generator):
xm.mark_step()

weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
return weights_iterator, est_weight_bytes

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
Expand All @@ -343,13 +345,15 @@ def load_model(self, *, model_config: ModelConfig,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
weights, wgt_bytes = self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
True))
model.load_weights(tensor_progress_bar(weights, wgt_bytes,
"Loading modules..."))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -491,9 +490,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"It will take ~10 minutes loading from the 16-bit weights. "
"Alternatively, use the prequantized 8-bit weights of arctic "
"and set load-format to `sharded_state` will accelerate loading.")
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.layers.activation import SiluAndMul
Expand Down Expand Up @@ -368,9 +367,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
Expand Down
4 changes: 1 addition & 3 deletions aphrodite/modeling/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,12 +930,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())

weights_tuple_list = list(weights)

shared_embedding_weight = None
shared_embedding_shard_id = None

for name, loaded_weight in weights_tuple_list:
for name, loaded_weight in weights:

name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
Expand Down
6 changes: 1 addition & 5 deletions aphrodite/modeling/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from aphrodite.common.config import CacheConfig, MultiModalConfig
from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
SequenceData)
from aphrodite.common.utils import progress_bar
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from aphrodite.modeling.layers.activation import get_act_fn
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -682,10 +681,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())

weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if "rotary_emb.inv_freq" in name:
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from aphrodite.modeling.layers.activation import get_act_fn
Expand Down Expand Up @@ -311,9 +310,7 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
Expand Down
6 changes: 2 additions & 4 deletions aphrodite/modeling/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aphrodite.common.config import CacheConfig, MultiModalConfig
from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
SequenceData)
from aphrodite.common.utils import print_warning_once, progress_bar
from aphrodite.common.utils import print_warning_once
from aphrodite.distributed import get_tensor_model_parallel_world_size
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from aphrodite.modeling.layers.activation import SiluAndMul
Expand Down Expand Up @@ -1005,9 +1005,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import get_tensor_model_parallel_world_size
from aphrodite.modeling.layers.activation import SiluAndMul
from aphrodite.modeling.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -389,9 +388,7 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import get_tensor_model_parallel_world_size
from aphrodite.modeling.layers.activation import SiluAndMul
from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
Expand Down Expand Up @@ -381,9 +380,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
params_dict = dict(self.named_parameters())
loaded_params = set()
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -411,9 +410,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from transformers import LlamaConfig

from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.utils import progress_bar
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
from aphrodite.modeling.models.llama import LlamaForCausalLM
from aphrodite.quantization.base_config import QuantizationConfig
Expand Down Expand Up @@ -77,9 +76,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -423,9 +422,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from aphrodite.modeling.layers.activation import SiluAndMul
Expand Down Expand Up @@ -489,9 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts=self.config.n_routed_experts)

params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
Expand Down
6 changes: 2 additions & 4 deletions aphrodite/modeling/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import is_hip, progress_bar
from aphrodite.common.utils import is_hip
from aphrodite.distributed import (get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -534,9 +534,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -422,9 +421,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from aphrodite.common.config import CacheConfig, MultiModalConfig
from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
SequenceData)
from aphrodite.common.utils import progress_bar
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from aphrodite.modeling.layers.linear import ColumnParallelLinear
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -310,9 +309,7 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down
5 changes: 1 addition & 4 deletions aphrodite/modeling/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from aphrodite.attention import Attention, AttentionMetadata
from aphrodite.common.config import CacheConfig, LoRAConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import progress_bar
from aphrodite.distributed import get_tensor_model_parallel_world_size
from aphrodite.modeling.layers.activation import GeluAndMul
from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
Expand Down Expand Up @@ -378,9 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
weights_list = list(weights)
for name, loaded_weight in progress_bar(weights_list,
desc="Loading modules..."):
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
Expand Down
Loading

0 comments on commit be4fa5b

Please sign in to comment.