Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a disable_mmap option to the from_single_file loader to improve load performance on network mounts #10305

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fe205a6
Add no_mmap arg.
danhipke Dec 17, 2024
a6b4d8f
Fix arg parsing.
danhipke Dec 17, 2024
3cf01bf
Update another method to force no mmap.
danhipke Dec 18, 2024
2e08242
logging
danhipke Dec 18, 2024
bcca53b
logging2
danhipke Dec 18, 2024
c895d86
propagate no_mmap
danhipke Dec 18, 2024
c081e0b
logging3
danhipke Dec 18, 2024
7231c28
propagate no_mmap
danhipke Dec 18, 2024
0c472b2
logging4
danhipke Dec 18, 2024
c4d4d60
fix open call
danhipke Dec 18, 2024
4f84222
clean up logging
danhipke Dec 19, 2024
5fab6d1
cleanup
danhipke Dec 19, 2024
1d8cf69
fix missing arg
danhipke Dec 19, 2024
5ef288f
update logging and comments
danhipke Dec 19, 2024
fec5753
fix merge conflict
danhipke Dec 19, 2024
3cc50f0
Merge branch 'main' into no-mmap
hlky Dec 20, 2024
f80644d
Rename to disable_mmap and update other references.
danhipke Dec 20, 2024
ffe5aba
[Docs] Update ltx_video.md to remove generator from `from_pretrained(…
sayakpaul Dec 20, 2024
3fc4a42
docs: fix a mistake in docstring (#10319)
Leojc Dec 20, 2024
9e887b4
[BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() Ty…
syntaxticsugr Dec 20, 2024
dbbcd0f
[docs] Fix quantization links (#10323)
stevhliu Dec 20, 2024
dfebda2
[Sana]add 2K related model for Sana (#10322)
lawrence-cj Dec 20, 2024
a6e3745
Merge branch 'main' into no-mmap
danhipke Dec 20, 2024
6720c51
Update src/diffusers/loaders/single_file_model.py
danhipke Dec 23, 2024
2926158
Update src/diffusers/loaders/single_file.py
danhipke Dec 23, 2024
5fdd062
Merge branch 'main' into no-mmap
sayakpaul Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def load_single_file_sub_model(
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
Expand Down Expand Up @@ -106,6 +107,7 @@ def load_single_file_sub_model(
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)

Expand Down Expand Up @@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)

is_legacy_loading = False

Expand Down Expand Up @@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)

if config is None:
Expand Down Expand Up @@ -504,6 +511,7 @@ def load_module(name, value):
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
Expand Down Expand Up @@ -229,6 +232,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
Expand All @@ -241,6 +245,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
Expand Down Expand Up @@ -362,7 +367,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
_, unexpected_keys = model.load_state_dict(
diffusers_format_checkpoint, strict=False
)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def load_single_file_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
disable_mmap=False,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
Expand All @@ -394,7 +395,7 @@ def load_single_file_checkpoint(
revision=revision,
)

checkpoint = load_state_dict(pretrained_model_link_or_path)
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)

# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
Expand All @@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.

<Tip>

Expand Down Expand Up @@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
disable_mmap = kwargs.pop("disable_mmap", False)

allow_pickle = False
if use_safetensors is None:
Expand Down Expand Up @@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)

# move the params from meta device to cpu
Expand Down Expand Up @@ -983,7 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
model._convert_deprecated_attention_blocks(state_dict)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
Expand Down