Skip to content

Commit

Permalink
Move enum (#2845)
Browse files Browse the repository at this point in the history
* Move enums to enums.py

* Add missing import

* Remove unused import
  • Loading branch information
huchenlei authored May 4, 2024
1 parent 2a28f08 commit ef65fec
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 61 deletions.
37 changes: 1 addition & 36 deletions internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import io
from dataclasses import dataclass
from enum import Enum
from copy import copy
from typing import List, Any, Optional, Union, Tuple, Dict
import torch
Expand All @@ -11,7 +10,7 @@
from modules.safe import unsafe_torch_load
from scripts import global_state
from scripts.logging import logger
from scripts.enums import HiResFixOption, PuLIDMode
from scripts.enums import HiResFixOption, PuLIDMode, ControlMode, ResizeMode
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter

from modules.api import api
Expand All @@ -21,40 +20,6 @@ def get_api_version() -> int:
return 2


class ControlMode(Enum):
"""
The improved guess mode.
"""

BALANCED = "Balanced"
PROMPT = "My prompt is more important"
CONTROL = "ControlNet is more important"


class BatchOption(Enum):
DEFAULT = "All ControlNet units for all images in a batch"
SEPARATE = "Each ControlNet unit for each image in a batch"


class ResizeMode(Enum):
"""
Resize modes for ControlNet input images.
"""

RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"

def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"


resize_mode_aliases = {
"Inner Fit (Scale to Fit)": "Crop and Resize",
"Outer Fit (Shrink to Fit)": "Resize and Fill",
Expand Down
31 changes: 19 additions & 12 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from internal_controlnet.external_code import ControlMode
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode
from scripts.enums import (
ControlModelType,
StableDiffusionVersion,
HiResFixOption,
PuLIDMode,
ControlMode,
BatchOption,
ResizeMode,
)
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
Expand Down Expand Up @@ -239,7 +246,7 @@ def get_control(
else: # Following operations are only for single input image.
input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
if unit.module == 'inpaint_only+lama' and resize_mode == ResizeMode.OUTER_FIT:
# inpaint_only+lama is special and required outpaint fix
_, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
input_images = [input_image]
Expand Down Expand Up @@ -335,7 +342,7 @@ def __init__(self) -> None:
self.detected_map = []
self.post_processors = []
self.noise_modifier = None
self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False]
self.ui_batch_option_state = [BatchOption.DEFAULT.value, False]
batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process)
batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each)
batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each)
Expand Down Expand Up @@ -366,8 +373,8 @@ def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea

def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str):
batch_option = gr.Radio(
choices=[e.value for e in external_code.BatchOption],
value=external_code.BatchOption.DEFAULT.value,
choices=[e.value for e in BatchOption],
value=BatchOption.DEFAULT.value,
label="Batch Option",
elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio",
elem_classes="controlnet_batch_option_radio",
Expand Down Expand Up @@ -616,7 +623,7 @@ def high_quality_resize(x, size):

return y

if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == ResizeMode.RESIZE:
detected_map = high_quality_resize(detected_map, (w, h))
detected_map = safe_numpy(detected_map)
return get_pytorch_control(detected_map), detected_map
Expand All @@ -629,7 +636,7 @@ def high_quality_resize(x, size):

safeint = lambda x: int(np.round(x))

if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == ResizeMode.OUTER_FIT:
k = min(k0, k1)
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
Expand Down Expand Up @@ -683,7 +690,7 @@ def choose_input_image(
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
idx: int
) -> Tuple[np.ndarray, external_code.ResizeMode]:
) -> Tuple[np.ndarray, ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
Expand Down Expand Up @@ -805,7 +812,7 @@ def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
resize_mode: ResizeMode,
) -> np.ndarray:
"""
Crop ControlNet input image based on A1111 inpaint mask given.
Expand Down Expand Up @@ -847,7 +854,7 @@ def try_crop_image_with_a1111_mask(

input_image = [x.crop(crop_region) for x in input_image]
input_image = [
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
images.resize_image(ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
for x in input_image
]

Expand Down Expand Up @@ -922,7 +929,7 @@ def controlnet_main_entry(self, p):
if not batch_hijack.instance.is_batch:
self.enabled_units = Script.get_enabled_units(p)

batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value
batch_option_uint_separate = self.ui_batch_option_state[0] == BatchOption.SEPARATE.value
batch_option_style_align = self.ui_batch_option_state[1]

if len(self.enabled_units) == 0 and not batch_option_style_align:
Expand Down
14 changes: 10 additions & 4 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.photopea import Photopea
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
from scripts.enums import InputMode, PuLIDMode
from scripts.enums import (
InputMode,
HiResFixOption,
PuLIDMode,
ControlMode,
ResizeMode,
)
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton

Expand Down Expand Up @@ -602,15 +608,15 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
)

self.control_mode = gr.Radio(
choices=[e.value for e in external_code.ControlMode],
choices=[e.value for e in ControlMode],
value=self.default_unit.control_mode.value,
label="Control Mode",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio",
elem_classes="controlnet_control_mode_radio",
)

self.resize_mode = gr.Radio(
choices=[e.value for e in external_code.ResizeMode],
choices=[e.value for e in ResizeMode],
value=self.default_unit.resize_mode.value,
label="Resize Mode",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio",
Expand All @@ -619,7 +625,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
)

self.hr_option = gr.Radio(
choices=[e.value for e in external_code.HiResFixOption],
choices=[e.value for e in HiResFixOption],
value=self.default_unit.hr_option.value,
label="Hires-Fix Option",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
Expand Down
34 changes: 34 additions & 0 deletions scripts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,37 @@ class InputMode(Enum):
class PuLIDMode(Enum):
FIDELITY = "Fidelity"
STYLE = "Extremely style"


class ControlMode(Enum):
"""
The improved guess mode.
"""

BALANCED = "Balanced"
PROMPT = "My prompt is more important"
CONTROL = "ControlNet is more important"


class BatchOption(Enum):
DEFAULT = "All ControlNet units for all images in a batch"
SEPARATE = "Each ControlNet unit for each image in a batch"


class ResizeMode(Enum):
"""
Resize modes for ControlNet input images.
"""

RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"

def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"
13 changes: 7 additions & 6 deletions tests/cn_script/cn_script_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


from scripts import external_code
from scripts.enums import ResizeMode
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from modules import processing

Expand Down Expand Up @@ -134,30 +135,30 @@ def test_choose_input_image(self):
_, resize_mode = Script.choose_input_image(
p=MockImg2ImgProcessing(
init_images=[TestScript.sample_np_image],
resize_mode=external_code.ResizeMode.OUTER_FIT,
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
image=TestScript.sample_base64_image,
module="none",
resize_mode=external_code.ResizeMode.INNER_FIT,
resize_mode=ResizeMode.INNER_FIT,
),
idx=0,
)
self.assertEqual(resize_mode, external_code.ResizeMode.INNER_FIT)
self.assertEqual(resize_mode, ResizeMode.INNER_FIT)

with self.subTest(name="A1111 input"):
_, resize_mode = Script.choose_input_image(
p=MockImg2ImgProcessing(
init_images=[TestScript.sample_np_image],
resize_mode=external_code.ResizeMode.OUTER_FIT,
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
module="none",
resize_mode=external_code.ResizeMode.INNER_FIT,
resize_mode=ResizeMode.INNER_FIT,
),
idx=0,
)
self.assertEqual(resize_mode, external_code.ResizeMode.OUTER_FIT)
self.assertEqual(resize_mode, ResizeMode.OUTER_FIT)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions tests/external_code_api/external_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import copy
from scripts import external_code
from scripts import controlnet
from scripts.enums import ResizeMode
from modules import scripts, ui, shared


Expand Down Expand Up @@ -120,15 +121,15 @@ class TestPixelPerfectResolution(unittest.TestCase):
def test_outer_fit(self):
image = np.zeros((100, 100, 3))
target_H, target_W = 50, 100
resize_mode = external_code.ResizeMode.OUTER_FIT
resize_mode = ResizeMode.OUTER_FIT
result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode)
expected = 50 # manually computed expected result
self.assertEqual(result, expected)

def test_inner_fit(self):
image = np.zeros((100, 100, 3))
target_H, target_W = 50, 100
resize_mode = external_code.ResizeMode.INNER_FIT
resize_mode = ResizeMode.INNER_FIT
result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode)
expected = 100 # manually computed expected result
self.assertEqual(result, expected)
Expand Down
3 changes: 2 additions & 1 deletion tests/external_code_api/script_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


from scripts import external_code
from scripts.enums import ControlMode


class TestGetAllUnitsFrom(unittest.TestCase):
Expand All @@ -15,7 +16,7 @@ def setUp(self):
"resize_mode": 1,
"low_vram": False,
"processor_res": 64,
"control_mode": external_code.ControlMode.BALANCED.value,
"control_mode": ControlMode.BALANCED.value,
}
self.object_unit = external_code.ControlNetUnit(**self.control_unit)

Expand Down

0 comments on commit ef65fec

Please sign in to comment.