Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ControlNet batch #2909

Merged
merged 1 commit into from
May 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def __init__(self) -> None:
self.post_processors = []
self.noise_modifier = None
self.ui_batch_option_state = [BatchOption.DEFAULT.value, False]
# `Script` instance is created twice, once for img2img and once for txt2img.
# However, A1111 does not pass is_img2img to script constructor, so this field
# is initialized in `ui` method.
self.is_img2img = False
batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process)
batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each)
batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each)
Expand Down Expand Up @@ -366,6 +370,8 @@ def ui(self, is_img2img):
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
self.is_img2img = is_img2img

infotext = Infotext()
ui_groups = []
controls = []
Expand Down Expand Up @@ -1296,18 +1302,30 @@ def postprocess(self, p, processed, *args):
tracemalloc.stop()

def batch_tab_process(self, p, batches, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

self.enabled_units = Script.get_enabled_units(p)
for unit_i, unit in enumerate(self.enabled_units):
unit.batch_images = iter([batch[unit_i] for batch in batches])

def batch_tab_process_each(self, p, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

for unit in self.enabled_units:
if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0:
continue

unit.image = next(unit.batch_images)

def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

for unit_i, unit in enumerate(self.enabled_units):
if getattr(unit, 'loopback', False):
output_images = getattr(processed, 'images', [])[processed.index_of_first_image:]
Expand All @@ -1317,6 +1335,10 @@ def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. Using control map from last batch iteration instead')

def batch_tab_postprocess(self, p, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

self.enabled_units.clear()
self.input_image = None
if self.latest_network is None:
Expand Down
Loading