Skip to content

Commit

Permalink
scripts/vsmlrt.py: add support for SwinIR models
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 23, 2024
1 parent e1826de commit ce239ac
Showing 1 changed file with 124 additions and 1 deletion.
125 changes: 124 additions & 1 deletion scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.20.9"
__version__ = "3.20.10"

__all__ = [
"Backend", "BackendV2",
Expand All @@ -10,6 +10,7 @@
"RIFE", "RIFEModel", "RIFEMerge",
"SAFA", "SAFAModel", "SAFAAdaptiveMode",
"SCUNet", "SCUNetModel",
"SwinIR", "SwinIRModel",
"inference"
]

Expand Down Expand Up @@ -1498,6 +1499,128 @@ def SCUNet(
return clip


@enum.unique
class SwinIRModel(enum.IntEnum):
lightweightSR_DIV2K_s64w8_SwinIR_S_x2 = 0
lightweightSR_DIV2K_s64w8_SwinIR_S_x3 = 1
lightweightSR_DIV2K_s64w8_SwinIR_S_x4 = 2
realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_x4_GAN = 3
# unused
realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_x4_PSNR = 5
classicalSR_DF2K_s64w8_SwinIR_M_x2 = 6
classicalSR_DF2K_s64w8_SwinIR_M_x3 = 7
classicalSR_DF2K_s64w8_SwinIR_M_x4 = 8
classicalSR_DF2K_s64w8_SwinIR_M_x8 = 9
realSR_BSRGAN_DFO_s64w8_SwinIR_M_x2_GAN = 10
realSR_BSRGAN_DFO_s64w8_SwinIR_M_x2_PSNR = 11
realSR_BSRGAN_DFO_s64w8_SwinIR_M_x4_GAN = 12
realSR_BSRGAN_DFO_s64w8_SwinIR_M_x4_PSNR = 13
grayDN_DFWB_s128w8_SwinIR_M_noise15 = 14
grayDN_DFWB_s128w8_SwinIR_M_noise25 = 15
grayDN_DFWB_s128w8_SwinIR_M_noise50 = 16
colorDN_DFWB_s128w8_SwinIR_M_noise15 = 17
colorDN_DFWB_s128w8_SwinIR_M_noise25 = 18
colorDN_DFWB_s128w8_SwinIR_M_noise50 = 19
CAR_DFWB_s126w7_SwinIR_M_jpeg10 = 20
CAR_DFWB_s126w7_SwinIR_M_jpeg20 = 21
CAR_DFWB_s126w7_SwinIR_M_jpeg30 = 22
CAR_DFWB_s126w7_SwinIR_M_jpeg40 = 23
colorCAR_DFWB_s126w7_SwinIR_M_jpeg10 = 24
colorCAR_DFWB_s126w7_SwinIR_M_jpeg20 = 25
colorCAR_DFWB_s126w7_SwinIR_M_jpeg30 = 26
colorCAR_DFWB_s126w7_SwinIR_M_jpeg40 = 27


def SwinIR(
clip: vs.VideoNode,
tiles: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
model: SwinIRModel = SwinIRModel.lightweightSR_DIV2K_s64w8_SwinIR_S_x2,
backend: backendT = Backend.OV_CPU()
) -> vs.VideoNode:
""" SwinIR: Image Restoration Using Swin Transformer """

func_name = "vsmlrt.SwinIR"

if not isinstance(clip, vs.VideoNode):
raise TypeError(f'{func_name}: "clip" must be a clip!')

if clip.format.sample_type != vs.FLOAT or clip.format.bits_per_sample not in [16, 32]:
raise ValueError(f"{func_name}: only constant format 16/32 bit float input supported")

if not isinstance(model, int) or model not in SwinIRModel.__members__.values():
raise ValueError(f'{func_name}: invalid "model"')

if model in range(14, 17) or model in range(20, 24):
if clip.format.color_family != vs.GRAY:
raise ValueError(f'{func_name}: "clip" must be of GRAY color family')
elif clip.format.color_family != vs.RGB:
raise ValueError(f'{func_name}: "clip" must be of RGB color family')

if overlap is None:
overlap_w = overlap_h = 16
elif isinstance(overlap, int):
overlap_w = overlap_h = overlap
else:
overlap_w, overlap_h = overlap

multiple = 1

(tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize(
tiles=tiles, tilesize=tilesize,
width=clip.width, height=clip.height,
multiple=multiple,
overlap_w=overlap_w, overlap_h=overlap_h
)

if tile_w % multiple != 0 or tile_h % multiple != 0:
raise ValueError(
f'{func_name}: tile size must be divisible by {multiple} ({tile_w}, {tile_h})'
)

backend = init_backend(
backend=backend,
trt_opt_shapes=(tile_w, tile_h)
)

if model < 4:
model_name = tuple(SwinIRModel.__members__)[model]
else:
model_name = tuple(SwinIRModel.__members__)[model - 1]

model_name = model_name.replace("SwinIR_", "SwinIR-")

if model in range(3):
model_name = f"002_{model_name}"
elif model in (3, 5):
model_name = f"003_{model_name}"
elif model in range(6, 10):
model_name = f"001_{model_name}"
elif model in range(10, 14):
model_name = f"003_{model_name}"
elif model in range(14, 17):
model_name = f"004_{model_name}"
elif model in range(17, 20):
model_name = f"005_{model_name}"
elif model in range(20, 28):
model_name = f"006_{model_name}"

network_path = os.path.join(
models_path,
"swinir",
f"{model_name}.onnx"
)

clip = inference_with_fallback(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)

return clip


def get_engine_path(
network_path: str,
min_shapes: typing.Tuple[int, int],
Expand Down

0 comments on commit ce239ac

Please sign in to comment.