Skip to content

Commit

Permalink
Add PuLID attn
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 4, 2024
1 parent 9509e77 commit 67e2369
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 29 deletions.
7 changes: 6 additions & 1 deletion internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modules.safe import unsafe_torch_load
from scripts import global_state
from scripts.logging import logger
from scripts.enums import HiResFixOption
from scripts.enums import HiResFixOption, PuLIDMode
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter

from modules.api import api
Expand Down Expand Up @@ -207,6 +207,10 @@ class ControlNetUnit:
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None

# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY

# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
Expand Down Expand Up @@ -243,6 +247,7 @@ def infotext_excluded_fields() -> List[str]:
# provide much information when restoring the unit.
"inpaint_crop_input_image",
"effective_region_mask",
"pulid_mode",
]

@property
Expand Down
18 changes: 14 additions & 4 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from internal_controlnet.external_code import ControlMode
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
Expand Down Expand Up @@ -1094,8 +1096,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
soft_injection=control_mode != ControlMode.BALANCED,
cfg_injection=control_mode == ControlMode.CONTROL,
effective_region_mask=(
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
if unit.effective_region_mask is not None
Expand Down Expand Up @@ -1192,7 +1194,7 @@ def recolor_intensity_post_processing(x, i):

is_low_vram = any(unit.low_vram for unit in self.enabled_units)

for i, param in enumerate(forward_params):
for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)):
if param.control_model_type == ControlModelType.IPAdapter:
if param.advanced_weighting is not None:
logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}")
Expand All @@ -1207,6 +1209,13 @@ def recolor_intensity_post_processing(x, i):
weight = param.weight

h, w, hr_y, hr_x = Script.get_target_dimensions(p)
pulid_mode = PuLIDMode(unit.pulid_mode)
if pulid_mode == PuLIDMode.STYLE:
pulid_attn_setting = PULID_SETTING_STYLE
else:
assert pulid_mode == PuLIDMode.FIDELITY
pulid_attn_setting = PULID_SETTING_FIDELITY

param.control_model.hook(
model=unet,
preprocessor_outputs=param.hint_cond,
Expand All @@ -1217,6 +1226,7 @@ def recolor_intensity_post_processing(x, i):
latent_width=w // 8,
latent_height=h // 8,
effective_region_mask=param.effective_region_mask,
pulid_attn_setting=pulid_attn_setting,
)
if param.control_model_type == ControlModelType.Controlllite:
param.control_model.hook(
Expand Down
22 changes: 21 additions & 1 deletion scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.photopea import Photopea
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
from scripts.enums import InputMode
from scripts.enums import InputMode, PuLIDMode
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton

Expand Down Expand Up @@ -287,6 +287,7 @@ def __init__(
self.batch_image_dir_state = None
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
self.pulid_mode = None

# API-only fields
self.ipadapter_input = gr.State(None)
Expand Down Expand Up @@ -626,6 +627,15 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
visible=False,
)

self.pulid_mode = gr.Radio(
choices=[e.value for e in PuLIDMode],
value=self.default_unit.pulid_mode.value,
label="PuLID Mode",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio",
elem_classes="controlnet_pulid_mode_radio",
visible=False,
)

self.loopback = gr.Checkbox(
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
value=self.default_unit.loopback,
Expand Down Expand Up @@ -673,6 +683,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.save_detected_map,
self.advanced_weighting,
self.effective_region_mask,
self.pulid_mode,
)

unit = gr.State(self.default_unit)
Expand Down Expand Up @@ -1118,6 +1129,14 @@ def register_shift_upload_mask(self):
show_progress=False,
)

def register_shift_pulid_mode(self):
self.model.change(
fn=lambda model: gr.update(visible="pulid" in model.lower()),
inputs=[self.model],
outputs=[self.pulid_mode],
show_progress=False,
)

def register_sync_batch_dir(self):
def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
if batch_dir:
Expand Down Expand Up @@ -1220,6 +1239,7 @@ def register_core_callbacks(self):
self.register_build_sliders()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_shift_pulid_mode()
self.register_create_canvas()
self.register_clear_preview()
self.register_multi_images_upload()
Expand Down
7 changes: 7 additions & 0 deletions scripts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, List, NamedTuple
from functools import lru_cache

from colorama import Style


class UnetBlockType(Enum):
INPUT = "input"
Expand Down Expand Up @@ -247,3 +249,8 @@ class InputMode(Enum):
# Input is a directory. 1 generation. Each generation takes N input image
# from the directory.
MERGE = "merge"


class PuLIDMode(Enum):
FIDELITY = "Fidelity"
STYLE = "Extremely style"
105 changes: 82 additions & 23 deletions scripts/ipadapter/plugable_ipadapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import torch
import math
from typing import Union, Dict, Optional
from typing import Union, Dict, Optional, Callable

from .pulid_attn import PuLIDAttnSetting
from .ipadapter_model import ImageEmbed, IPAdapterModel
from ..enums import StableDiffusionVersion, TransformerID

Expand Down Expand Up @@ -93,7 +94,7 @@ def clear_all_ip_adapter():
class PlugableIPAdapter(torch.nn.Module):
def __init__(self, ipadapter: IPAdapterModel):
super().__init__()
self.ipadapter = ipadapter
self.ipadapter: IPAdapterModel = ipadapter
self.disable_memory_management = True
self.dtype = None
self.weight: Union[float, Dict[int, float]] = 1.0
Expand All @@ -103,6 +104,7 @@ def __init__(self, ipadapter: IPAdapterModel):
self.latent_width: int = 0
self.latent_height: int = 0
self.effective_region_mask = None
self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None

def reset(self):
self.cache = {}
Expand All @@ -118,6 +120,7 @@ def hook(
latent_width: int,
latent_height: int,
effective_region_mask: Optional[torch.Tensor],
pulid_attn_setting: Optional[PuLIDAttnSetting] = None,
dtype=torch.float32,
):
global current_model
Expand All @@ -128,6 +131,7 @@ def hook(
self.latent_width = latent_width
self.latent_height = latent_height
self.effective_region_mask = effective_region_mask
self.pulid_attn_setting = pulid_attn_setting

self.cache = {}

Expand Down Expand Up @@ -186,7 +190,9 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor:
# sequence_length = (latent_height * factor) * (latent_height * factor)
# sequence_length = (latent_height * latent_height) * factor ^ 2
factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height))
assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
assert (
factor > 0
), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
mask_h = int(self.latent_height * factor)
mask_w = int(self.latent_width * factor)

Expand All @@ -199,6 +205,71 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor:
mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2])
return out * mask

def attn_eval(
self,
hidden_states: torch.Tensor,
query: torch.Tensor,
cond_uncond_image_emb: torch.Tensor,
attn_heads: int,
head_dim: int,
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
):
if self.ipadapter.is_pulid:
assert self.pulid_attn_setting is not None
return self.pulid_attn_setting.eval(
hidden_states,
query,
cond_uncond_image_emb,
attn_heads,
head_dim,
emb_to_k,
emb_to_v,
)
else:
return self._attn_eval_ipadapter(
hidden_states,
query,
cond_uncond_image_emb,
attn_heads,
head_dim,
emb_to_k,
emb_to_v,
)

def _attn_eval_ipadapter(
self,
hidden_states: torch.Tensor,
query: torch.Tensor,
cond_uncond_image_emb: torch.Tensor,
attn_heads: int,
head_dim: int,
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
):
assert hidden_states.ndim == 3
batch_size, sequence_length, inner_dim = hidden_states.shape
ip_k = emb_to_k(cond_uncond_image_emb)
ip_v = emb_to_v(cond_uncond_image_emb)

ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2),
(ip_k, ip_v),
)
assert ip_k.dtype == ip_v.dtype

# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if query.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=query.dtype)
ip_v = ip_v.to(dtype=query.dtype)

ip_out = torch.nn.functional.scaled_dot_product_attention(
query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
return ip_out

@torch.no_grad()
def patch_forward(self, number: int, transformer_index: int):
@torch.no_grad()
Expand All @@ -220,27 +291,15 @@ def forward(attn_blk, x, q):

k_key = f"{number * 2 + 1}_to_k_ip"
v_key = f"{number * 2 + 1}_to_v_ip"
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark)
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)

ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(ip_k, ip_v),
ip_out = self.attn_eval(
hidden_states=x,
query=q,
cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark),
attn_heads=h,
head_dim=head_dim,
emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device),
emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device),
)
assert ip_k.dtype == ip_v.dtype

# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if q.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)

ip_out = torch.nn.functional.scaled_dot_product_attention(
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)

return self.apply_effective_region_mask(ip_out * weight)

return forward
Loading

0 comments on commit 67e2369

Please sign in to comment.