forked from 2kpr/ComfyUI-PMRF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nodes.py
executable file
·296 lines (265 loc) · 12.8 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import cv2
from tqdm import tqdm
import torch
import torchvision
import numpy as np
from PIL import Image
import folder_paths
from comfy.utils import ProgressBar
from comfy.model_management import get_torch_device, text_encoder_offload_device, soft_empty_cache, should_use_fp16, should_use_bf16
device = get_torch_device()
script_directory = os.path.dirname(os.path.abspath(__file__))
class PMRF:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("PMRF_Model",),
"facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
"images": ("IMAGE",),
"scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 40.0, "step": 0.1}),
"num_steps": ("INT", {"default": 25, "min": 1, "max": 400, "step": 1}),
"seed": ("INT", {"default": 88888888, "min": 0, "max": 2**32, "step": 1}),
"interpolation": (["lanczos4", "nearest", "linear", "cubic", "area", "linear_exact", "nearest_exact"],),
"keep_model_loaded": ("BOOLEAN", {"default": True, "label_on": "yes", "label_off": "no", "tooltip": "Warning: do not delete model unless this node no longer needed, it will try release device_memory and ram. if checked and want to continue node generation, use ComfyUI-Manager `Free model and node cache` to reset node state or change parameter in Loader node to activate.\n注意:仅在这个节点不再需要时删除模型,将尽量释放系统内存和设备专用内存。如果删除后想继续使用此节点,使用ComfyUI-Manager插件的`Free model and node cache`重置节点状态或者更换模型加载节点的参数来激活。"}),
"keep_model_device": ("BOOLEAN", {"default": True, "label_on": "comfy", "label_off": "device", "tooltip": "Keep model in comfy_auto_unet_offload_device (HIGH_VRAM: device, Others: cpu) or device_memory after generation. \n生图完成后,模型转移到comfy自动选择设备(HIGH_VRAM: device, 其他: cpu)或者保留在设备专用内存上。"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images", )
FUNCTION = "pmrf"
CATEGORY = "PMRF"
DESCRIPTION = "Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration\nPMRF is a novel photo-realistic image restoration algorithm. It (provably) approximates the optimal estimator that minimizes the Mean Squared Error (MSE) under a perfect perceptual quality constraint.\nIt's a Photo-Realistic Image Restoration project, not designed for super resolution upscale!\n这是写实照片图像恢复项目,不是超分放大!"
def pmrf(self, images, scale, num_steps, seed, interpolation, model, facedetection, keep_model_loaded, keep_model_device):
if 'cpu' in model['pmrf'].device.type:
model['pmrf'].to(device)
result = inference(images, scale, num_steps, seed, interpolation, model, facedetection)
if keep_model_loaded:
if keep_model_device:
model['pmrf'].to(text_encoder_offload_device())
soft_empty_cache(True)
else:
del model['upscaler_model']
del model['upscaler'].model
del model['upscaler']
del model['pmrf']
soft_empty_cache(True)
return result
class PMRF_Loader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("upscale_models"), {'default': 'RealESRGAN_x2plus.pth', 'tooltip': 'Only `RealESRGAN_x2plus.pth` works now.'}),
"weight_dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto", 'tooltip': '......'}),
"Auto_Download_Path": ("BOOLEAN", {"default": True, "label_on": "node_dir本地", "label_off": ".cache缓存", "tooltip": "Download or load PMRF_model from:`ComfyUI\custom_nodes\ComfyUI-PMRF\checkpoint` or huggingface_cache_dir.\n下载或者从路径加载PMRF模型:`ComfyUI\custom_nodes\ComfyUI-PMRF\checkpoint`或者huggingface缓存路径。"}),
},
}
RETURN_TYPES = ("PMRF_Model",)
RETURN_NAMES = ("model", )
FUNCTION = "pmrf_loader"
CATEGORY = "PMRF"
def pmrf_loader(self, model_name, weight_dtype, Auto_Download_Path):
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils.realesrgan_utils import RealESRGANer
from .lightning_models.mmse_rectified_flow import MMSERectifiedFlow
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
weight_dtype = get_dtype_by_name(weight_dtype)
if 'RealESRGAN_x2plus' in model_name:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2,)
elif 'RealESRGAN_x4plus' in model_name:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4,)
elif 'RealESRGAN_x4plus_anime_6B' in model_name:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4,)
else:
raise ValueError('Model support not added.')
upsampler = RealESRGANer(
scale=2,
model_path=model_path,
model=model,
tile=400,
tile_pad=40,
pre_pad=0,
weight_dtype=weight_dtype,
device=device,
)
pmrf_path = os.path.join(script_directory, "checkpoint")
if not os.path.exists(os.path.join(pmrf_path, 'model.safetensors')):
if Auto_Download_Path:
from huggingface_hub import hf_hub_download
hf_hub_download('ohayonguy/PMRF_blind_face_image_restoration', filename='model.safetensors', local_dir=pmrf_path)
else:
pmrf_path = 'ohayonguy/PMRF_blind_face_image_restoration'
pmrf = MMSERectifiedFlow.from_pretrained(pmrf_path)
pmrf_models = {
'upscaler': upsampler,
'upscaler_model': upsampler.model,
'pmrf': pmrf.to(dtype=weight_dtype),
}
return (pmrf_models, )
NODE_CLASS_MAPPINGS = {
"PMRF": PMRF,
"PMRF_Loader": PMRF_Loader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PMRF": "PMRF",
"PMRF_Loader": "PMRF_Loader",
}
def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
source_dist_samples = pmrf_model.create_source_distribution_samples(
x, y, non_noisy_z0
)
dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
x_t_next = source_dist_samples.clone()
t_one = torch.ones(x.shape[0], device=device, dtype=pmrf_model.dtype)
ComfyProgressBar = ProgressBar(num_flow_steps)
for i in tqdm(range(num_flow_steps)):
num_t = (i / num_flow_steps) * (
1.0 - pmrf_model.hparams.eps
) + pmrf_model.hparams.eps
v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
x_t_next = x_t_next.clone() + v_t_next * dt
ComfyProgressBar.update(1)
return x_t_next.clip(0, 1)
def resize(img, size, interpolation):
# From https://github.com/sczhou/CodeFormer/blob/master/facelib/utils/face_restoration_helper.py
h, w = img.shape[0:2]
scale = float(size) / float(min(h, w))
h, w = int(round(h * scale)), int(round(w * scale))
return cv2.resize(img, (w, h), interpolation=interpolation)
@torch.inference_mode()
def enhance_face(img, face_helper, num_flow_steps, scale=2, interpolation=cv2.INTER_LANCZOS4, model=None):
from basicsr.utils import img2tensor, tensor2img
face_helper.clean_all()
face_helper.read_image(img)
face_helper.input_img = resize(face_helper.input_img, 640, interpolation)
face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=5)
face_helper.align_warp_face()
upsampler = model['upscaler']
pmrf = model['pmrf']
# face restoration
for i, cropped_face in tqdm(enumerate(face_helper.cropped_faces)):
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
output = generate_reconstructions(
pmrf,
torch.zeros_like(cropped_face_t).to(dtype=pmrf.dtype),
cropped_face_t.to(dtype=pmrf.dtype),
None,
num_flow_steps,
device,
)
restored_face = tensor2img(
output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1)
)
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
# upsample the background
# Now only support RealESRGAN for upsampling background
bg_img = upsampler.enhance(img, outscale=scale)[0]
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return face_helper.cropped_faces, face_helper.restored_faces, restored_img
@torch.inference_mode()
def inference(
imgs,
scale,
num_flow_steps,
seed,
interpolation,
model,
facedetection,
):
if not os.path.exists(os.path.join(folder_paths.models_dir, 'facedetection', 'detection_Resnet50_Final.pth')):
try:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper as UL_FaceRestoreHelper
except ImportError as e:
from .site_package.UL_facexlib.utils.face_restoration_helper import FaceRestoreHelper as UL_FaceRestoreHelper
else:
from .site_package.UL_facexlib.utils.face_restoration_helper import FaceRestoreHelper as UL_FaceRestoreHelper
torch.manual_seed(seed)
if interpolation == "lanczos4":
interpolation = cv2.INTER_LANCZOS4
elif interpolation == "nearest":
interpolation = cv2.INTER_NEAREST
elif interpolation == "linear":
interpolation = cv2.INTER_LINEAR
elif interpolation == "cubic":
interpolation = cv2.INTER_CUBIC
elif interpolation == "area":
interpolation = cv2.INTER_AREA
elif interpolation == "linear_exact":
interpolation = cv2.INTER_LINEAR_EXACT
elif interpolation == "nearest_exact":
interpolation = cv2.INTER_NEAREST_EXACT
imgs_output = []
for img in imgs:
img = img.permute(2, 0, 1)
img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)).convert("RGB")
img = np.array(img)
img = img[:, :, ::-1].copy()
h, w = img.shape[0:2]
size = min(h, w)
face_scale = scale*(size/640)
face_scale = face_scale if face_scale < scale else scale
face_scale = face_scale if face_scale > 1.0 else 1.0
face_helper = UL_FaceRestoreHelper(
face_scale,
face_size=512,
crop_ratio=(1, 1),
# det_model="retinaface_resnet50",
det_model=facedetection,
save_ext="png",
use_parse=True,
device=device,
# model_rootpath=None,
)
cropped_face, restored_faces, restored_img = enhance_face(
img,
face_helper,
num_flow_steps=num_flow_steps,
scale=face_scale,
interpolation=interpolation,
model=model,
)
output = restored_img
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
output = resize(output, size*scale, interpolation)
torch.cuda.empty_cache()
output = torchvision.transforms.functional.pil_to_tensor(Image.fromarray(output)).to(torch.float32) / 255.0
output = output.permute(1, 2, 0)
imgs_output.append(output[None,])
return (torch.cat(tuple(imgs_output), dim=0),)
def get_dtype_by_name(dtype, debug: bool=False):
"""
"dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto"}),返回模型精度选择。
"""
if dtype == 'auto':
try:
if should_use_fp16():
dtype = torch.float16
elif should_use_bf16():
dtype = torch.bfloat16
else:
dtype = torch.float32
except:
raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.")
elif dtype== "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp32":
dtype = torch.float32
elif dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
elif dtype == "fp8_e4m3fnuz":
dtype = torch.float8_e4m3fnuz
elif dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
elif dtype == "fp8_e5m2fnuz":
dtype = torch.float8_e5m2fnuz
if debug:
print("\033[93mModel Precision(模型精度):", dtype, "\033[0m")
return dtype