Skip to content

Commit

Permalink
Fix batch hijack test
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 6, 2024
1 parent 596f35f commit f09d790
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
3 changes: 3 additions & 0 deletions scripts/batch_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from modules import img2img, processing, shared, script_callbacks
from scripts import external_code
from scripts.enums import InputMode
from scripts.logging import logger

class BatchHijack:
def __init__(self):
Expand Down Expand Up @@ -221,6 +222,8 @@ def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[
else:
batches[i].append(unit.image)

if any_unit_is_batch:
logger.info(f"Batch enabled ({len(batches)})")
return any_unit_is_batch, batches, output_dir, input_file_names


Expand Down
61 changes: 37 additions & 24 deletions tests/cn_script/batch_hijack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
original_process_images_inner = processing.process_images_inner


def create_unit(**kwargs) -> ControlNetUnit:
return ControlNetUnit(enabled=True, **kwargs)


class TestBatchHijack(unittest.TestCase):
@unittest.mock.patch('modules.script_callbacks.on_script_unloaded')
def setUp(self, on_script_unloaded_mock):
Expand Down Expand Up @@ -60,9 +64,18 @@ def assert_get_cn_batches_works(self, batch_images_list):
is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p)
batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir)

batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH]
batch_units = [
unit
for unit in self.p.script_args
if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH
]
# Convert iterator to list to avoid double eval of iterator exhausting
# the iterator in following checks.
for unit in batch_units:
unit.batch_images = list(unit.batch_images)

if batch_units:
self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches))
self.assertEqual(min(len(unit.batch_images) for unit in batch_units), len(batches))
else:
self.assertEqual(1, len(batches))

Expand All @@ -75,15 +88,15 @@ def test_get_cn_batches__empty(self):
self.assertEqual(is_batch, False)

def test_get_cn_batches__1_simple(self):
self.p.script_args.append(ControlNetUnit(image=get_dummy_image()))
self.p.script_args.append(create_unit(image=get_dummy_image()))
self.assert_get_cn_batches_works([
[get_dummy_image()],
])

def test_get_cn_batches__2_simples(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(1)),
create_unit(image=get_dummy_image(0)),
create_unit(image=get_dummy_image(1)),
])
self.assert_get_cn_batches_works([
[get_dummy_image(0)],
Expand All @@ -92,7 +105,7 @@ def test_get_cn_batches__2_simples(self):

def test_get_cn_batches__1_batch(self):
self.p.script_args.extend([
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
Expand All @@ -109,14 +122,14 @@ def test_get_cn_batches__1_batch(self):

def test_get_cn_batches__2_batches(self):
self.p.script_args.extend([
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
get_dummy_image(1),
],
),
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(2),
Expand All @@ -137,8 +150,8 @@ def test_get_cn_batches__2_batches(self):

def test_get_cn_batches__2_mixed(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(
create_unit(image=get_dummy_image(0)),
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(1),
Expand All @@ -159,16 +172,16 @@ def test_get_cn_batches__2_mixed(self):

def test_get_cn_batches__3_mixed(self):
self.p.script_args.extend([
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(
create_unit(image=get_dummy_image(0)),
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(1),
get_dummy_image(2),
get_dummy_image(3),
],
),
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(4),
Expand Down Expand Up @@ -244,14 +257,14 @@ def test_process_images_no_units_forwards(self):

def test_process_images__only_simple_units__forwards(self):
self.p.script_args = [
ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
create_unit(image=get_dummy_image()),
create_unit(image=get_dummy_image()),
]
self.assert_process_images_hijack_called(batch_count=0)

def test_process_images__1_batch_1_unit__runs_1_batch(self):
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(),
Expand All @@ -262,7 +275,7 @@ def test_process_images__1_batch_1_unit__runs_1_batch(self):

def test_process_images__2_batches_1_unit__runs_2_batches(self):
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
Expand All @@ -275,7 +288,7 @@ def test_process_images__2_batches_1_unit__runs_2_batches(self):
def test_process_images__8_batches_1_unit__runs_8_batches(self):
batch_count = 8
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(i) for i in range(batch_count)]
),
Expand All @@ -284,11 +297,11 @@ def test_process_images__8_batches_1_unit__runs_8_batches(self):

def test_process_images__1_batch_2_units__runs_1_batch(self):
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(0)]
),
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[get_dummy_image(1)]
),
Expand All @@ -297,14 +310,14 @@ def test_process_images__1_batch_2_units__runs_1_batch(self):

def test_process_images__2_batches_2_units__runs_2_batches(self):
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
get_dummy_image(1),
],
),
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(2),
Expand All @@ -316,15 +329,15 @@ def test_process_images__2_batches_2_units__runs_2_batches(self):

def test_process_images__3_batches_2_mixed_units__runs_3_batches(self):
self.p.script_args = [
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
get_dummy_image(0),
get_dummy_image(1),
get_dummy_image(2),
],
),
ControlNetUnit(
create_unit(
input_mode=batch_hijack.InputMode.SIMPLE,
image=get_dummy_image(3),
),
Expand Down

0 comments on commit f09d790

Please sign in to comment.