diff --git a/scripts/controlnet.py b/scripts/controlnet.py index dab171f7a..d63ec151d 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -307,6 +307,10 @@ def __init__(self) -> None: self.post_processors = [] self.noise_modifier = None self.ui_batch_option_state = [BatchOption.DEFAULT.value, False] + # `Script` instance is created twice, once for img2img and once for txt2img. + # However, A1111 does not pass is_img2img to script constructor, so this field + # is initialized in `ui` method. + self.is_img2img = False batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process) batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each) batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each) @@ -366,6 +370,8 @@ def ui(self, is_img2img): The return value should be an array of all components that are used in processing. Values of those returned components will be passed to run() and process() functions. """ + self.is_img2img = is_img2img + infotext = Infotext() ui_groups = [] controls = [] @@ -1296,11 +1302,19 @@ def postprocess(self, p, processed, *args): tracemalloc.stop() def batch_tab_process(self, p, batches, *args, **kwargs): + is_img2img = isinstance(p, StableDiffusionProcessingImg2Img) + if is_img2img != self.is_img2img: + return + self.enabled_units = Script.get_enabled_units(p) for unit_i, unit in enumerate(self.enabled_units): unit.batch_images = iter([batch[unit_i] for batch in batches]) def batch_tab_process_each(self, p, *args, **kwargs): + is_img2img = isinstance(p, StableDiffusionProcessingImg2Img) + if is_img2img != self.is_img2img: + return + for unit in self.enabled_units: if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0: continue @@ -1308,6 +1322,10 @@ def batch_tab_process_each(self, p, *args, **kwargs): unit.image = next(unit.batch_images) def batch_tab_postprocess_each(self, p, processed, *args, **kwargs): + is_img2img = isinstance(p, StableDiffusionProcessingImg2Img) + if is_img2img != self.is_img2img: + return + for unit_i, unit in enumerate(self.enabled_units): if getattr(unit, 'loopback', False): output_images = getattr(processed, 'images', [])[processed.index_of_first_image:] @@ -1317,6 +1335,10 @@ def batch_tab_postprocess_each(self, p, processed, *args, **kwargs): logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. Using control map from last batch iteration instead') def batch_tab_postprocess(self, p, *args, **kwargs): + is_img2img = isinstance(p, StableDiffusionProcessingImg2Img) + if is_img2img != self.is_img2img: + return + self.enabled_units.clear() self.input_image = None if self.latest_network is None: