Skip to content

Commit

Permalink
Fix ControlNet batch (#2909)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored May 19, 2024
1 parent cbc7ef3 commit a5c0da5
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def __init__(self) -> None:
self.post_processors = []
self.noise_modifier = None
self.ui_batch_option_state = [BatchOption.DEFAULT.value, False]
# `Script` instance is created twice, once for img2img and once for txt2img.
# However, A1111 does not pass is_img2img to script constructor, so this field
# is initialized in `ui` method.
self.is_img2img = False
batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process)
batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each)
batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each)
Expand Down Expand Up @@ -366,6 +370,8 @@ def ui(self, is_img2img):
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
self.is_img2img = is_img2img

infotext = Infotext()
ui_groups = []
controls = []
Expand Down Expand Up @@ -1296,18 +1302,30 @@ def postprocess(self, p, processed, *args):
tracemalloc.stop()

def batch_tab_process(self, p, batches, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

self.enabled_units = Script.get_enabled_units(p)
for unit_i, unit in enumerate(self.enabled_units):
unit.batch_images = iter([batch[unit_i] for batch in batches])

def batch_tab_process_each(self, p, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

for unit in self.enabled_units:
if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0:
continue

unit.image = next(unit.batch_images)

def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

for unit_i, unit in enumerate(self.enabled_units):
if getattr(unit, 'loopback', False):
output_images = getattr(processed, 'images', [])[processed.index_of_first_image:]
Expand All @@ -1317,6 +1335,10 @@ def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. Using control map from last batch iteration instead')

def batch_tab_postprocess(self, p, *args, **kwargs):
is_img2img = isinstance(p, StableDiffusionProcessingImg2Img)
if is_img2img != self.is_img2img:
return

self.enabled_units.clear()
self.input_image = None
if self.latest_network is None:
Expand Down

0 comments on commit a5c0da5

Please sign in to comment.