From 784b6d01a7a115016d98cceac0bb153b1aa47082 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sat, 4 May 2024 12:25:00 -0400 Subject: [PATCH] Support PuLID (#2838) * Add preprocessors * Fix resolution param * Fix various issues * Add PuLID attn * remove unused import * Resize img before passing to facexlib * safe unload --- internal_controlnet/external_code.py | 7 +- requirements.txt | 1 + scripts/api.py | 2 +- scripts/controlnet.py | 22 ++- scripts/controlnet_ui/controlnet_ui_group.py | 24 ++- scripts/enums.py | 5 + scripts/ipadapter/image_proj_models.py | 62 +++++++ scripts/ipadapter/ipadapter_model.py | 45 ++++- scripts/ipadapter/plugable_ipadapter.py | 105 ++++++++--- scripts/ipadapter/presets.py | 6 + scripts/ipadapter/pulid_attn.py | 94 ++++++++++ scripts/preprocessor/__init__.py | 1 + scripts/preprocessor/inpaint.py | 29 +-- scripts/preprocessor/lama_inpaint.py | 16 +- .../legacy/legacy_preprocessors.py | 2 +- scripts/preprocessor/pulid.py | 169 ++++++++++++++++++ scripts/supported_preprocessor.py | 20 ++- scripts/utils.py | 41 ++++- 18 files changed, 571 insertions(+), 80 deletions(-) create mode 100644 scripts/ipadapter/pulid_attn.py create mode 100644 scripts/preprocessor/pulid.py diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index 016142e5a..b471cac5d 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -11,7 +11,7 @@ from modules.safe import unsafe_torch_load from scripts import global_state from scripts.logging import logger -from scripts.enums import HiResFixOption +from scripts.enums import HiResFixOption, PuLIDMode from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter from modules.api import api @@ -207,6 +207,10 @@ class ControlNetUnit: # The effective region mask that unit's effect should be restricted to. effective_region_mask: Optional[np.ndarray] = None + # The weight mode for PuLID. + # https://github.com/ToTheBeginning/PuLID + pulid_mode: PuLIDMode = PuLIDMode.FIDELITY + # The tensor input for ipadapter. When this field is set in the API, # the base64string will be interpret by torch.load to reconstruct ipadapter # preprocessor output. @@ -243,6 +247,7 @@ def infotext_excluded_fields() -> List[str]: # provide much information when restoring the unit. "inpaint_crop_input_image", "effective_region_mask", + "pulid_mode", ] @property diff --git a/requirements.txt b/requirements.txt index 5fb5e2e2b..f0072d277 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ addict yapf albumentations==1.4.3 matplotlib +facexlib diff --git a/scripts/api.py b/scripts/api.py index c2d348a3a..2e251d836 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -179,7 +179,7 @@ def accept(self, json_dict: dict) -> None: low_vram=low_vram, ) if preprocessor.returns_image: - images.append(encode_to_base64(result.display_image)) + images.append(encode_to_base64(result.display_images[0])) else: tensors.append(encode_tensor_to_base64(result.value)) diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 3845d537e..9bb333174 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -17,12 +17,14 @@ import scripts.preprocessor as preprocessor_init # noqa from annotator.util import HWC3 from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils +from internal_controlnet.external_code import ControlMode from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lllite import clear_all_lllite from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter +from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent from scripts.hook import ControlParams, UnetHook, HackedImageRNG -from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption +from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from scripts.controlnet_ui.photopea import Photopea from scripts.logging import logger @@ -279,6 +281,7 @@ def preprocess_input_image(input_image: np.ndarray): ) detected_map = result.value is_image = preprocessor.returns_image + # TODO: Refactor img control detection logic. if high_res_fix: if is_image: hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) @@ -293,7 +296,8 @@ def preprocess_input_image(input_image: np.ndarray): store_detected_map(detected_map, unit.module) else: control = detected_map - store_detected_map(input_image, unit.module) + for image in result.display_images: + store_detected_map(image, unit.module) if control_model_type == ControlModelType.T2I_StyleAdapter: control = control['last_hidden_state'] @@ -1092,8 +1096,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe global_average_pooling=global_average_pooling, hr_hint_cond=hr_control, hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH, - soft_injection=control_mode != external_code.ControlMode.BALANCED, - cfg_injection=control_mode == external_code.ControlMode.CONTROL, + soft_injection=control_mode != ControlMode.BALANCED, + cfg_injection=control_mode == ControlMode.CONTROL, effective_region_mask=( get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :] if unit.effective_region_mask is not None @@ -1190,7 +1194,7 @@ def recolor_intensity_post_processing(x, i): is_low_vram = any(unit.low_vram for unit in self.enabled_units) - for i, param in enumerate(forward_params): + for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)): if param.control_model_type == ControlModelType.IPAdapter: if param.advanced_weighting is not None: logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}") @@ -1205,6 +1209,13 @@ def recolor_intensity_post_processing(x, i): weight = param.weight h, w, hr_y, hr_x = Script.get_target_dimensions(p) + pulid_mode = PuLIDMode(unit.pulid_mode) + if pulid_mode == PuLIDMode.STYLE: + pulid_attn_setting = PULID_SETTING_STYLE + else: + assert pulid_mode == PuLIDMode.FIDELITY + pulid_attn_setting = PULID_SETTING_FIDELITY + param.control_model.hook( model=unet, preprocessor_outputs=param.hint_cond, @@ -1215,6 +1226,7 @@ def recolor_intensity_post_processing(x, i): latent_width=w // 8, latent_height=h // 8, effective_region_mask=param.effective_region_mask, + pulid_attn_setting=pulid_attn_setting, ) if param.control_model_type == ControlModelType.Controlllite: param.control_model.hook( diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index a155a9dbd..d16cb8e23 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -18,7 +18,7 @@ from scripts.controlnet_ui.preset import ControlNetPresetUI from scripts.controlnet_ui.photopea import Photopea from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl -from scripts.enums import InputMode +from scripts.enums import InputMode, PuLIDMode from modules import shared from modules.ui_components import FormRow, FormHTML, ToolButton @@ -287,6 +287,7 @@ def __init__( self.batch_image_dir_state = None self.output_dir_state = None self.advanced_weighting = gr.State(None) + self.pulid_mode = None # API-only fields self.ipadapter_input = gr.State(None) @@ -626,6 +627,15 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=False, ) + self.pulid_mode = gr.Radio( + choices=[e.value for e in PuLIDMode], + value=self.default_unit.pulid_mode.value, + label="PuLID Mode", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio", + elem_classes="controlnet_pulid_mode_radio", + visible=False, + ) + self.loopback = gr.Checkbox( label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", value=self.default_unit.loopback, @@ -673,6 +683,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.save_detected_map, self.advanced_weighting, self.effective_region_mask, + self.pulid_mode, ) unit = gr.State(self.default_unit) @@ -947,7 +958,7 @@ def is_openpose(module: str): return ( # Update to `generated_image` - gr.update(value=result.display_image, visible=True, interactive=False), + gr.update(value=result.display_images[0], visible=True, interactive=False), # preprocessor_preview gr.update(value=True), # openpose editor @@ -1118,6 +1129,14 @@ def register_shift_upload_mask(self): show_progress=False, ) + def register_shift_pulid_mode(self): + self.model.change( + fn=lambda model: gr.update(visible="pulid" in model.lower()), + inputs=[self.model], + outputs=[self.pulid_mode], + show_progress=False, + ) + def register_sync_batch_dir(self): def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir): if batch_dir: @@ -1220,6 +1239,7 @@ def register_core_callbacks(self): self.register_build_sliders() self.register_shift_preview() self.register_shift_upload_mask() + self.register_shift_pulid_mode() self.register_create_canvas() self.register_clear_preview() self.register_multi_images_upload() diff --git a/scripts/enums.py b/scripts/enums.py index 327f36431..4daaca9b6 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -247,3 +247,8 @@ class InputMode(Enum): # Input is a directory. 1 generation. Each generation takes N input image # from the directory. MERGE = "merge" + + +class PuLIDMode(Enum): + FIDELITY = "Fidelity" + STYLE = "Extremely style" diff --git a/scripts/ipadapter/image_proj_models.py b/scripts/ipadapter/image_proj_models.py index 8594ac99b..d8dd12157 100644 --- a/scripts/ipadapter/image_proj_models.py +++ b/scripts/ipadapter/image_proj_models.py @@ -269,3 +269,65 @@ def forward(self, x): latents = self.proj_out(latents) return self.norm_out(latents) + + +class PuLIDEncoder(nn.Module): + def __init__(self, width=1280, context_dim=2048, num_token=5): + super().__init__() + self.num_token = num_token + self.context_dim = context_dim + h1 = min((context_dim * num_token) // 4, 1024) + h2 = min((context_dim * num_token) // 2, 1024) + self.body = nn.Sequential( + nn.Linear(width, h1), + nn.LayerNorm(h1), + nn.LeakyReLU(), + nn.Linear(h1, h2), + nn.LayerNorm(h2), + nn.LeakyReLU(), + nn.Linear(h2, context_dim * num_token), + ) + + for i in range(5): + setattr( + self, + f"mapping_{i}", + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + setattr( + self, + f"mapping_patch_{i}", + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + def forward(self, x, y): + # x shape [N, C] + x = self.body(x) + x = x.reshape(-1, self.num_token, self.context_dim) + + hidden_states = () + for i, emb in enumerate(y): + hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr( + self, f"mapping_patch_{i}" + )(emb[:, 1:]).mean(dim=1, keepdim=True) + hidden_states += (hidden_state,) + hidden_states = torch.cat(hidden_states, dim=1) + + return torch.cat([x, hidden_states], dim=1) diff --git a/scripts/ipadapter/ipadapter_model.py b/scripts/ipadapter/ipadapter_model.py index 7314c9b2d..16d9ac4c5 100644 --- a/scripts/ipadapter/ipadapter_model.py +++ b/scripts/ipadapter/ipadapter_model.py @@ -12,6 +12,7 @@ MLPProjModel, MLPProjModelFaceId, ProjModelFaceIdPlus, + PuLIDEncoder, ) @@ -71,6 +72,7 @@ def __init__( is_faceid: bool, is_portrait: bool, is_instantid: bool, + is_pulid: bool, is_v2: bool, ): super().__init__() @@ -85,9 +87,12 @@ def __init__( self.is_v2 = is_v2 self.is_faceid = is_faceid self.is_instantid = is_instantid + self.is_pulid = is_pulid self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4 - if is_instantid: + if self.is_pulid: + self.image_proj_model = PuLIDEncoder() + elif self.is_instantid: self.image_proj_model = self.init_proj_instantid() elif is_faceid: self.image_proj_model = self.init_proj_faceid() @@ -235,6 +240,34 @@ def _get_image_embeds_instantid( self.image_proj_model(torch.zeros_like(prompt_image_emb)), ) + def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed: + """Get image embeds for pulid.""" + id_cond = torch.cat( + [ + pulid_proj_input.id_ante_embedding.to( + device=self.device, dtype=torch.float32 + ), + pulid_proj_input.id_cond_vit.to( + device=self.device, dtype=torch.float32 + ), + ], + dim=-1, + ) + id_vit_hidden = [ + t.to(device=self.device, dtype=torch.float32) + for t in pulid_proj_input.id_vit_hidden + ] + return ImageEmbed( + self.image_proj_model( + id_cond, + id_vit_hidden, + ), + self.image_proj_model( + torch.zeros_like(id_cond), + [torch.zeros_like(t) for t in id_vit_hidden], + ), + ) + @staticmethod def load(state_dict: dict, model_name: str) -> IPAdapterModel: """ @@ -245,6 +278,7 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_v2 = "v2" in model_name is_faceid = "faceid" in model_name is_instantid = "instant_id" in model_name + is_pulid = "pulid" in model_name.lower() is_portrait = "portrait" in model_name is_full = "proj.3.weight" in state_dict["image_proj"] is_plus = ( @@ -256,8 +290,8 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: sdxl = cross_attention_dim == 2048 sdxl_plus = sdxl and is_plus - if is_instantid: - # InstantID does not use clip embedding. + if is_instantid or is_pulid: + # InstantID/PuLID does not use clip embedding. clip_embeddings_dim = None elif is_faceid: if is_plus: @@ -291,10 +325,13 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_portrait=is_portrait, is_instantid=is_instantid, is_v2=is_v2, + is_pulid=is_pulid, ) def get_image_emb(self, preprocessor_output) -> ImageEmbed: - if self.is_instantid: + if self.is_pulid: + return self._get_image_embeds_pulid(preprocessor_output) + elif self.is_instantid: return self._get_image_embeds_instantid(preprocessor_output) elif self.is_faceid and self.is_plus: # Note: FaceID plus uses both face_embed and clip_embed. diff --git a/scripts/ipadapter/plugable_ipadapter.py b/scripts/ipadapter/plugable_ipadapter.py index b56522489..72c0e6652 100644 --- a/scripts/ipadapter/plugable_ipadapter.py +++ b/scripts/ipadapter/plugable_ipadapter.py @@ -1,8 +1,9 @@ import itertools import torch import math -from typing import Union, Dict, Optional +from typing import Union, Dict, Optional, Callable +from .pulid_attn import PuLIDAttnSetting from .ipadapter_model import ImageEmbed, IPAdapterModel from ..enums import StableDiffusionVersion, TransformerID @@ -93,7 +94,7 @@ def clear_all_ip_adapter(): class PlugableIPAdapter(torch.nn.Module): def __init__(self, ipadapter: IPAdapterModel): super().__init__() - self.ipadapter = ipadapter + self.ipadapter: IPAdapterModel = ipadapter self.disable_memory_management = True self.dtype = None self.weight: Union[float, Dict[int, float]] = 1.0 @@ -103,6 +104,7 @@ def __init__(self, ipadapter: IPAdapterModel): self.latent_width: int = 0 self.latent_height: int = 0 self.effective_region_mask = None + self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None def reset(self): self.cache = {} @@ -118,6 +120,7 @@ def hook( latent_width: int, latent_height: int, effective_region_mask: Optional[torch.Tensor], + pulid_attn_setting: Optional[PuLIDAttnSetting] = None, dtype=torch.float32, ): global current_model @@ -128,6 +131,7 @@ def hook( self.latent_width = latent_width self.latent_height = latent_height self.effective_region_mask = effective_region_mask + self.pulid_attn_setting = pulid_attn_setting self.cache = {} @@ -186,7 +190,9 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: # sequence_length = (latent_height * factor) * (latent_height * factor) # sequence_length = (latent_height * latent_height) * factor ^ 2 factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height)) - assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" + assert ( + factor > 0 + ), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" mask_h = int(self.latent_height * factor) mask_w = int(self.latent_width * factor) @@ -199,6 +205,71 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) return out * mask + def attn_eval( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + cond_uncond_image_emb: torch.Tensor, + attn_heads: int, + head_dim: int, + emb_to_k: Callable[[torch.Tensor], torch.Tensor], + emb_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + if self.ipadapter.is_pulid: + assert self.pulid_attn_setting is not None + return self.pulid_attn_setting.eval( + hidden_states, + query, + cond_uncond_image_emb, + attn_heads, + head_dim, + emb_to_k, + emb_to_v, + ) + else: + return self._attn_eval_ipadapter( + hidden_states, + query, + cond_uncond_image_emb, + attn_heads, + head_dim, + emb_to_k, + emb_to_v, + ) + + def _attn_eval_ipadapter( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + cond_uncond_image_emb: torch.Tensor, + attn_heads: int, + head_dim: int, + emb_to_k: Callable[[torch.Tensor], torch.Tensor], + emb_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + assert hidden_states.ndim == 3 + batch_size, sequence_length, inner_dim = hidden_states.shape + ip_k = emb_to_k(cond_uncond_image_emb) + ip_v = emb_to_v(cond_uncond_image_emb) + + ip_k, ip_v = map( + lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2), + (ip_k, ip_v), + ) + assert ip_k.dtype == ip_v.dtype + + # On MacOS, q can be float16 instead of float32. + # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 + if query.dtype != ip_k.dtype: + ip_k = ip_k.to(dtype=query.dtype) + ip_v = ip_v.to(dtype=query.dtype) + + ip_out = torch.nn.functional.scaled_dot_product_attention( + query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False + ) + ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) + return ip_out + @torch.no_grad() def patch_forward(self, number: int, transformer_index: int): @torch.no_grad() @@ -220,27 +291,15 @@ def forward(attn_blk, x, q): k_key = f"{number * 2 + 1}_to_k_ip" v_key = f"{number * 2 + 1}_to_v_ip" - cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark) - ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) - ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) - - ip_k, ip_v = map( - lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), - (ip_k, ip_v), + ip_out = self.attn_eval( + hidden_states=x, + query=q, + cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark), + attn_heads=h, + head_dim=head_dim, + emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device), + emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device), ) - assert ip_k.dtype == ip_v.dtype - - # On MacOS, q can be float16 instead of float32. - # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 - if q.dtype != ip_k.dtype: - ip_k = ip_k.to(dtype=q.dtype) - ip_v = ip_v.to(dtype=q.dtype) - - ip_out = torch.nn.functional.scaled_dot_product_attention( - q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False - ) - ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) - return self.apply_effective_region_mask(ip_out * weight) return forward diff --git a/scripts/ipadapter/presets.py b/scripts/ipadapter/presets.py index 275f70be5..764c83c98 100644 --- a/scripts/ipadapter/presets.py +++ b/scripts/ipadapter/presets.py @@ -166,6 +166,12 @@ def match_model(model_name: str) -> IPAdapterPreset: model="ip-adapter-faceid-portrait_sdxl", sd_version=StableDiffusionVersion.SDXL, ), + IPAdapterPreset( + name="pulid", + module="ip-adapter_pulid", + model="ip-adapter_pulid_sdxl_fp16", + sd_version=StableDiffusionVersion.SDXL, + ), ] _preset_by_model = {p.model: p for p in ipadapter_presets} diff --git a/scripts/ipadapter/pulid_attn.py b/scripts/ipadapter/pulid_attn.py new file mode 100644 index 000000000..e2823470c --- /dev/null +++ b/scripts/ipadapter/pulid_attn.py @@ -0,0 +1,94 @@ +import torch +import torch.nn.functional as F +from dataclasses import dataclass +from typing import Callable + + +@dataclass +class PuLIDAttnSetting: + num_zero: int = 0 + ortho: bool = False + ortho_v2: bool = False + + def eval( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + id_embedding: torch.Tensor, + attn_heads: int, + head_dim: int, + id_to_k: Callable[[torch.Tensor], torch.Tensor], + id_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + assert hidden_states.ndim == 3 + batch_size, sequence_length, inner_dim = hidden_states.shape + + if self.num_zero == 0: + id_key = id_to_k(id_embedding).to(query.dtype) + id_value = id_to_v(id_embedding).to(query.dtype) + else: + zero_tensor = torch.zeros( + (id_embedding.size(0), self.num_zero, id_embedding.size(-1)), + dtype=id_embedding.dtype, + device=id_embedding.device, + ) + id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to( + query.dtype + ) + id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to( + query.dtype + ) + + id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + id_hidden_states = F.scaled_dot_product_attention( + query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + id_hidden_states = id_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn_heads * head_dim + ) + id_hidden_states = id_hidden_states.to(query.dtype) + + if not self.ortho and not self.ortho_v2: + return id_hidden_states + elif self.ortho_v2: + orig_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + id_hidden_states = id_hidden_states.to(torch.float32) + attn_map = query @ id_key.transpose(-2, -1) + attn_mean = attn_map.softmax(dim=-1).mean(dim=1) + attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) + projection = ( + torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) + / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) + * hidden_states + ) + orthogonal = id_hidden_states + (attn_mean - 1) * projection + return orthogonal.to(orig_dtype) + else: + orig_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + id_hidden_states = id_hidden_states.to(torch.float32) + projection = ( + torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) + / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) + * hidden_states + ) + orthogonal = id_hidden_states - projection + return orthogonal.to(orig_dtype) + + +PULID_SETTING_FIDELITY = PuLIDAttnSetting( + num_zero=8, + ortho=False, + ortho_v2=True, +) + +PULID_SETTING_STYLE = PuLIDAttnSetting( + num_zero=16, + ortho=True, + ortho_v2=False, +) diff --git a/scripts/preprocessor/__init__.py b/scripts/preprocessor/__init__.py index 6bbcb762f..b330e73ce 100644 --- a/scripts/preprocessor/__init__.py +++ b/scripts/preprocessor/__init__.py @@ -1,3 +1,4 @@ +from .pulid import * from .inpaint import * from .lama_inpaint import * from .ip_adapter_auto import * diff --git a/scripts/preprocessor/inpaint.py b/scripts/preprocessor/inpaint.py index 4605dc252..25874196a 100644 --- a/scripts/preprocessor/inpaint.py +++ b/scripts/preprocessor/inpaint.py @@ -1,18 +1,7 @@ -import numpy as np - +from scripts.utils import visualize_inpaint_mask from ..supported_preprocessor import Preprocessor, PreprocessorParameter -def visualize_inpaint_mask(img): - if img.ndim == 3 and img.shape[2] == 4: - result = img.copy() - mask = result[:, :, 3] - mask = 255 - mask // 2 - result[:, :, 3] = mask - return np.ascontiguousarray(result.copy()) - return img - - class PreprocessorInpaint(Preprocessor): def __init__(self): super().__init__(name="inpaint") @@ -23,9 +12,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result): - return visualize_inpaint_mask(result) - def __call__( self, input_image, @@ -35,7 +21,10 @@ def __call__( slider_3=None, **kwargs ): - return input_image + return Preprocessor.Result( + value=input_image, + display_images=visualize_inpaint_mask(input_image)[None, :, :, :], + ) class PreprocessorInpaintOnly(Preprocessor): @@ -47,9 +36,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result): - return visualize_inpaint_mask(result) - def __call__( self, input_image, @@ -59,7 +45,10 @@ def __call__( slider_3=None, **kwargs ): - return input_image + return Preprocessor.Result( + value=input_image, + display_images=visualize_inpaint_mask(input_image)[None, :, :, :], + ) Preprocessor.add_supported_preprocessor(PreprocessorInpaint()) diff --git a/scripts/preprocessor/lama_inpaint.py b/scripts/preprocessor/lama_inpaint.py index 33aff60bf..1cd1c521c 100644 --- a/scripts/preprocessor/lama_inpaint.py +++ b/scripts/preprocessor/lama_inpaint.py @@ -2,7 +2,7 @@ import numpy as np from ..supported_preprocessor import Preprocessor, PreprocessorParameter -from ..utils import resize_image_with_pad +from ..utils import resize_image_with_pad, visualize_inpaint_mask class PreprocessorLamaInpaint(Preprocessor): @@ -15,12 +15,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result: np.ndarray): - """For lama inpaint, display image should not contain mask.""" - assert result.ndim == 3 - assert result.shape[2] == 4 - return result[:, :, :3] - def __call__( self, input_image, @@ -56,7 +50,13 @@ def __call__( fin_color = fin_color.clip(0, 255).astype(np.uint8) result = np.concatenate([fin_color, raw_mask], axis=2) - return result + return Preprocessor.Result( + value=result, + display_images=[ + result[:, :, :3], + visualize_inpaint_mask(result), + ], + ) Preprocessor.add_supported_preprocessor(PreprocessorLamaInpaint()) diff --git a/scripts/preprocessor/legacy/legacy_preprocessors.py b/scripts/preprocessor/legacy/legacy_preprocessors.py index 902e6c9d0..7c5e1c873 100644 --- a/scripts/preprocessor/legacy/legacy_preprocessors.py +++ b/scripts/preprocessor/legacy/legacy_preprocessors.py @@ -93,7 +93,7 @@ def unload(self): def __call__( self, input_image, - resolution, + resolution=512, slider_1=None, slider_2=None, slider_3=None, diff --git a/scripts/preprocessor/pulid.py b/scripts/preprocessor/pulid.py new file mode 100644 index 000000000..a46f91290 --- /dev/null +++ b/scripts/preprocessor/pulid.py @@ -0,0 +1,169 @@ +# https://github.com/ToTheBeginning/PuLID + +import torch +import cv2 +import numpy as np +from typing import Optional, List +from dataclasses import dataclass +from facexlib.parsing import init_parsing_model +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from ..supported_preprocessor import Preprocessor, PreprocessorParameter +from scripts.utils import npimg2tensor, tensor2npimg, resize_image_with_pad + + +def to_gray(img): + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +class PreprocessorFaceXLib(Preprocessor): + def __init__(self): + super().__init__(name="facexlib") + self.tags = [] + self.slider_resolution = PreprocessorParameter(visible=False) + self.model: Optional[FaceRestoreHelper] = None + + def load_model(self): + if self.model is None: + self.model = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=self.device, + ) + self.model.face_parse = init_parsing_model( + model_name="bisenet", device=self.device + ) + self.model.face_parse.to(device=self.device) + self.model.face_det.to(device=self.device) + return self.model + + def unload(self) -> bool: + """@Override""" + if self.model is not None: + self.model.face_parse.to(device="cpu") + self.model.face_det.to(device="cpu") + return True + return False + + def __call__( + self, + input_image, + resolution=512, + slider_1=None, + slider_2=None, + slider_3=None, + input_mask=None, + return_tensor=False, + **kwargs + ): + """ + @Override + Returns black and white face features image with background removed. + """ + self.load_model() + self.model.clean_all() + input_image, _ = resize_image_with_pad(input_image, resolution) + # using facexlib to detect and align face + image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) + self.model.read_image(image_bgr) + self.model.get_face_landmarks_5(only_center_face=True) + self.model.align_warp_face() + if len(self.model.cropped_faces) == 0: + raise RuntimeError("facexlib align face fail") + align_face = self.model.cropped_faces[0] + align_face_rgb = cv2.cvtColor(align_face, cv2.COLOR_BGR2RGB) + input = npimg2tensor(align_face_rgb) + input = input.to(self.device) + parsing_out = self.model.face_parse( + normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + )[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) + # only keep the face features + face_features_image = torch.where(bg, white_image, to_gray(input)) + if return_tensor: + return face_features_image + else: + return tensor2npimg(face_features_image) + + +@dataclass +class PuLIDProjInput: + id_ante_embedding: torch.Tensor + id_cond_vit: torch.Tensor + id_vit_hidden: List[torch.Tensor] + + +class PreprocessorPuLID(Preprocessor): + """PuLID preprocessor.""" + + def __init__(self): + super().__init__(name="ip-adapter_pulid") + self.tags = ["IP-Adapter"] + self.slider_resolution = PreprocessorParameter(visible=False) + self.returns_image = False + self.preprocessors_deps = [ + "facexlib", + "instant_id_face_embedding", + "EVA02-CLIP-L-14-336", + ] + + def facexlib_detect(self, input_image: np.ndarray) -> torch.Tensor: + facexlib_preprocessor = Preprocessor.get_preprocessor("facexlib") + return facexlib_preprocessor(input_image, return_tensor=True) + + def insightface_antelopev2_detect(self, input_image: np.ndarray) -> torch.Tensor: + antelopev2_preprocessor = Preprocessor.get_preprocessor( + "instant_id_face_embedding" + ) + return antelopev2_preprocessor(input_image) + + def unload(self) -> bool: + unloaded = False + for p_name in self.preprocessors_deps: + p = Preprocessor.get_preprocessor(p_name) + if p is not None: + unloaded = unloaded or p.unload() + return unloaded + + def __call__( + self, + input_image, + resolution, + slider_1=None, + slider_2=None, + slider_3=None, + input_mask=None, + **kwargs + ) -> Preprocessor.Result: + id_ante_embedding = self.insightface_antelopev2_detect(input_image) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) + + face_features_image = self.facexlib_detect(input_image) + evaclip_preprocessor = Preprocessor.get_preprocessor("EVA02-CLIP-L-14-336") + assert ( + evaclip_preprocessor is not None + ), "EVA02-CLIP-L-14-336 preprocessor not found! Please install sd-webui-controlnet-evaclip" + r = evaclip_preprocessor(face_features_image) + + return Preprocessor.Result( + value=PuLIDProjInput( + id_ante_embedding=id_ante_embedding, + id_cond_vit=r.id_cond_vit, + id_vit_hidden=r.id_vit_hidden, + ), + display_images=[tensor2npimg(face_features_image)], + ) + + +Preprocessor.add_supported_preprocessor(PreprocessorFaceXLib()) +Preprocessor.add_supported_preprocessor(PreprocessorPuLID()) diff --git a/scripts/supported_preprocessor.py b/scripts/supported_preprocessor.py index caf5a6a78..473d6203c 100644 --- a/scripts/supported_preprocessor.py +++ b/scripts/supported_preprocessor.py @@ -4,7 +4,7 @@ import numpy as np import torch -from modules import shared +from modules import shared, devices from scripts.logging import logger from scripts.utils import ndarray_lru_cache @@ -101,6 +101,7 @@ class Preprocessor(ABC): use_soft_projection_in_hr_fix = False expand_mask_when_resize_and_fill = False model: Optional[torch.nn.Module] = None + device = devices.get_device_for("controlnet") all_processors: ClassVar[Dict[str, "Preprocessor"]] = {} all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {} @@ -183,18 +184,19 @@ def unload_unused(cls, active_processors: Set["Preprocessor"]): class Result(NamedTuple): value: Any - # The display image shown on UI. - display_image: np.ndarray - - def get_display_image(self, input_image: np.ndarray, result): - return result if self.returns_image else input_image + # The display images shown on UI. + display_images: List[np.ndarray] def cached_call(self, input_image, *args, **kwargs) -> "Preprocessor.Result": """The function exposed that also returns an image for display.""" result = self._cached_call(input_image, *args, **kwargs) - return Preprocessor.Result( - value=result, display_image=self.get_display_image(input_image, result) - ) + if isinstance(result, Preprocessor.Result): + return result + else: + return Preprocessor.Result( + value=result, + display_images=[result if self.returns_image else input_image], + ) @ndarray_lru_cache(max_size=CACHE_SIZE) def _cached_call(self, *args, **kwargs): diff --git a/scripts/utils.py b/scripts/utils.py index c26750f14..e660279a9 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,3 +1,4 @@ +from einops import rearrange import torch import os import functools @@ -105,8 +106,9 @@ def wrapper(*args, **kwargs): class TimeMeta(type): - """ Metaclass to record execution time on all methods of the - child class. """ + """Metaclass to record execution time on all methods of the + child class.""" + def __new__(cls, name, bases, attrs): for attr_name, attr_value in attrs.items(): if callable(attr_value): @@ -161,7 +163,9 @@ def read_image(img_path: str) -> str: return encoded_image -def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: +def read_image_dir( + img_dir: str, suffixes=(".png", ".jpg", ".jpeg", ".webp") +) -> List[str]: """Try read all images in given img_dir.""" images = [] for filename in os.listdir(img_dir): @@ -175,7 +179,7 @@ def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> def align_dim_latent(x: int) -> int: - """ Align the pixel dimension (w/h) to latent dimension. + """Align the pixel dimension (w/h) to latent dimension. Stable diffusion 1:8 ratio for latent/pixel, i.e., 1 latent unit == 8 pixel unit.""" return (x // 8) * 8 @@ -203,9 +207,34 @@ def resize_image_with_pad(img: np.ndarray, resolution: int): W_target = int(np.round(float(W_raw) * k)) img = cv2.resize(img, (W_target, H_target), interpolation=interpolation) H_pad, W_pad = pad64(H_target), pad64(W_target) - img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge') + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode="edge") def remove_pad(x): return safer_memory(x[:H_target, :W_target]) - return safer_memory(img_padded), remove_pad \ No newline at end of file + return safer_memory(img_padded), remove_pad + + +def npimg2tensor(img: np.ndarray) -> torch.Tensor: + """Convert numpy img ([H, W, C]) to tensor ([1, C, H, W])""" + return rearrange(torch.from_numpy(img).float() / 255.0, "h w c -> 1 c h w") + + +def tensor2npimg(t: torch.Tensor) -> np.ndarray: + """Convert tensor ([1, C, H, W]) to numpy RGB img ([H, W, C])""" + return ( + (rearrange(t, "1 c h w -> h w c") * 255.0) + .to(dtype=torch.uint8) + .cpu() + .numpy() + ) + + +def visualize_inpaint_mask(img): + if img.ndim == 3 and img.shape[2] == 4: + result = img.copy() + mask = result[:, :, 3] + mask = 255 - mask // 2 + result[:, :, 3] = mask + return np.ascontiguousarray(result.copy()) + return img