-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] enable LoRA for Mochi-1 #9943
Open
sayakpaul
wants to merge
8
commits into
main
Choose a base branch
from
mochi-1-lor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+522
−5
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
576cd12
feat: add lora support to Mochi-1.
sayakpaul eab77ed
fix copies
sayakpaul 6068621
skip mochi-1 lora tests on MPS.
sayakpaul 65ee3a3
remove print.
sayakpaul 83bb655
remove space.
sayakpaul 36276e9
Merge branch 'main' into mochi-1-lor
sayakpaul e42edd9
Merge branch 'main' into mochi-1-lor
sayakpaul 103da43
Merge branch 'main' into mochi-1-lor
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2362,7 +2362,7 @@ def save_lora_weights( | |
|
||
class CogVideoXLoraLoaderMixin(LoraBaseMixin): | ||
r""" | ||
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`]. | ||
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`]. | ||
""" | ||
|
||
_lora_loadable_modules = ["transformer"] | ||
|
@@ -2667,6 +2667,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * | |
super().unfuse_lora(components=components) | ||
|
||
|
||
class Mochi1LoraLoaderMixin(LoraBaseMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A copy-paste of the Cog LoRA loader classes, indicated by the "Copied from ..." comments. |
||
r""" | ||
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`]. | ||
""" | ||
|
||
_lora_loadable_modules = ["transformer"] | ||
transformer_name = TRANSFORMER_NAME | ||
|
||
@classmethod | ||
@validate_hf_hub_args | ||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict | ||
def lora_state_dict( | ||
cls, | ||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
**kwargs, | ||
): | ||
r""" | ||
Return state dict for lora weights and the network alphas. | ||
<Tip warning={true}> | ||
We support loading A1111 formatted LoRA checkpoints in a limited capacity. | ||
This function is experimental and might change in the future. | ||
</Tip> | ||
Parameters: | ||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | ||
Can be either: | ||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | ||
the Hub. | ||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | ||
with [`ModelMixin.save_pretrained`]. | ||
- A [torch state | ||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | ||
cache_dir (`Union[str, os.PathLike]`, *optional*): | ||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | ||
is not used. | ||
force_download (`bool`, *optional*, defaults to `False`): | ||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
cached versions if they exist. | ||
proxies (`Dict[str, str]`, *optional*): | ||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | ||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
local_files_only (`bool`, *optional*, defaults to `False`): | ||
Whether to only load local model weights and configuration files or not. If set to `True`, the model | ||
won't be downloaded from the Hub. | ||
token (`str` or *bool*, *optional*): | ||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | ||
`diffusers-cli login` (stored in `~/.huggingface`) is used. | ||
revision (`str`, *optional*, defaults to `"main"`): | ||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | ||
allowed by Git. | ||
subfolder (`str`, *optional*, defaults to `""`): | ||
The subfolder location of a model file within a larger model repository on the Hub or locally. | ||
""" | ||
# Load the main state dict first which has the LoRA layers for either of | ||
# transformer and text encoder or both. | ||
cache_dir = kwargs.pop("cache_dir", None) | ||
force_download = kwargs.pop("force_download", False) | ||
proxies = kwargs.pop("proxies", None) | ||
local_files_only = kwargs.pop("local_files_only", None) | ||
token = kwargs.pop("token", None) | ||
revision = kwargs.pop("revision", None) | ||
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
|
||
allow_pickle = False | ||
if use_safetensors is None: | ||
use_safetensors = True | ||
allow_pickle = True | ||
|
||
user_agent = { | ||
"file_type": "attn_procs_weights", | ||
"framework": "pytorch", | ||
} | ||
|
||
state_dict = _fetch_state_dict( | ||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | ||
weight_name=weight_name, | ||
use_safetensors=use_safetensors, | ||
local_files_only=local_files_only, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
token=token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
user_agent=user_agent, | ||
allow_pickle=allow_pickle, | ||
) | ||
|
||
is_dora_scale_present = any("dora_scale" in k for k in state_dict) | ||
if is_dora_scale_present: | ||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." | ||
logger.warning(warn_msg) | ||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | ||
|
||
return state_dict | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
def load_lora_weights( | ||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | ||
): | ||
""" | ||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and | ||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See | ||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. | ||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state | ||
dict is loaded into `self.transformer`. | ||
Parameters: | ||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | ||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | ||
adapter_name (`str`, *optional*): | ||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
`default_{i}` where i is the total number of adapters being loaded. | ||
low_cpu_mem_usage (`bool`, *optional*): | ||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | ||
weights. | ||
kwargs (`dict`, *optional*): | ||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | ||
""" | ||
if not USE_PEFT_BACKEND: | ||
raise ValueError("PEFT backend is required for this method.") | ||
|
||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | ||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | ||
raise ValueError( | ||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||
) | ||
|
||
# if a dict is passed, copy it instead of modifying it inplace | ||
if isinstance(pretrained_model_name_or_path_or_dict, dict): | ||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | ||
|
||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
||
is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
if not is_correct_format: | ||
raise ValueError("Invalid LoRA checkpoint.") | ||
|
||
self.load_lora_into_transformer( | ||
state_dict, | ||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel | ||
def load_lora_into_transformer( | ||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False | ||
): | ||
""" | ||
This will load the LoRA layers specified in `state_dict` into `transformer`. | ||
Parameters: | ||
state_dict (`dict`): | ||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly | ||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text | ||
encoder lora layers. | ||
transformer (`CogVideoXTransformer3DModel`): | ||
The Transformer model to load the LoRA layers into. | ||
adapter_name (`str`, *optional*): | ||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
`default_{i}` where i is the total number of adapters being loaded. | ||
low_cpu_mem_usage (`bool`, *optional*): | ||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | ||
weights. | ||
""" | ||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | ||
raise ValueError( | ||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||
) | ||
|
||
# Load the layers corresponding to transformer. | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=None, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights | ||
def save_lora_weights( | ||
cls, | ||
save_directory: Union[str, os.PathLike], | ||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
is_main_process: bool = True, | ||
weight_name: str = None, | ||
save_function: Callable = None, | ||
safe_serialization: bool = True, | ||
): | ||
r""" | ||
Save the LoRA parameters corresponding to the UNet and text encoder. | ||
Arguments: | ||
save_directory (`str` or `os.PathLike`): | ||
Directory to save LoRA parameters to. Will be created if it doesn't exist. | ||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): | ||
State dict of the LoRA layers corresponding to the `transformer`. | ||
is_main_process (`bool`, *optional*, defaults to `True`): | ||
Whether the process calling this is the main process or not. Useful during distributed training and you | ||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main | ||
process to avoid race conditions. | ||
save_function (`Callable`): | ||
The function to use to save the state dictionary. Useful during distributed training when you need to | ||
replace `torch.save` with another method. Can be configured with the environment variable | ||
`DIFFUSERS_SAVE_MODE`. | ||
safe_serialization (`bool`, *optional*, defaults to `True`): | ||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | ||
""" | ||
state_dict = {} | ||
|
||
if not transformer_lora_layers: | ||
raise ValueError("You must pass `transformer_lora_layers`.") | ||
|
||
if transformer_lora_layers: | ||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) | ||
|
||
# Save the model | ||
cls.write_lora_layers( | ||
state_dict=state_dict, | ||
save_directory=save_directory, | ||
is_main_process=is_main_process, | ||
weight_name=weight_name, | ||
save_function=save_function, | ||
safe_serialization=safe_serialization, | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer | ||
def fuse_lora( | ||
self, | ||
components: List[str] = ["transformer", "text_encoder"], | ||
lora_scale: float = 1.0, | ||
safe_fusing: bool = False, | ||
adapter_names: Optional[List[str]] = None, | ||
**kwargs, | ||
): | ||
r""" | ||
Fuses the LoRA parameters into the original parameters of the corresponding blocks. | ||
<Tip warning={true}> | ||
This is an experimental API. | ||
</Tip> | ||
Args: | ||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. | ||
lora_scale (`float`, defaults to 1.0): | ||
Controls how much to influence the outputs with the LoRA parameters. | ||
safe_fusing (`bool`, defaults to `False`): | ||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. | ||
adapter_names (`List[str]`, *optional*): | ||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. | ||
Example: | ||
```py | ||
from diffusers import DiffusionPipeline | ||
import torch | ||
pipeline = DiffusionPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | ||
).to("cuda") | ||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | ||
pipeline.fuse_lora(lora_scale=0.7) | ||
``` | ||
""" | ||
super().fuse_lora( | ||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer | ||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): | ||
r""" | ||
Reverses the effect of | ||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). | ||
<Tip warning={true}> | ||
This is an experimental API. | ||
</Tip> | ||
Args: | ||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. | ||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. | ||
unfuse_text_encoder (`bool`, defaults to `True`): | ||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the | ||
LoRA parameters then it won't have any effect. | ||
""" | ||
super().unfuse_lora(components=components) | ||
|
||
|
||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): | ||
def __init__(self, *args, **kwargs): | ||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated change but doesn't hurt I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks okay to fix here