-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_adapter.py
80 lines (69 loc) · 3.14 KB
/
test_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import cv2
import torch
from basicsr.utils import tensor2img
from pytorch_lightning import seed_everything
from torch import autocast
from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
from ldm.modules.extra_condition import api
from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
torch.set_grad_enabled(False)
def main():
supported_cond = [e.name for e in ExtraCondition]
parser = get_base_argument_parser()
parser.add_argument(
'--which_cond',
type=str,
required=True,
choices=supported_cond,
help='which condition modality you want to test',
)
opt = parser.parse_args()
which_cond = opt.which_cond
if opt.outdir is None:
opt.outdir = f'outputs/test-{which_cond}'
os.makedirs(opt.outdir, exist_ok=True)
if opt.resize_short_edge is None:
print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# support two test mode: single image test, and batch test (through a txt file)
if opt.prompt.endswith('.txt'):#从txt中读取prompt
assert opt.prompt.endswith('.txt')
image_paths = []
prompts = []
with open(opt.prompt, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
image_paths.append(line.split('; ')[0])
prompts.append(line.split('; ')[1])
else:
image_paths = [opt.cond_path]
prompts = [opt.prompt]
print(image_paths)
# prepare models
sd_model, sampler = get_sd_models(opt)
adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
cond_model = None
if opt.cond_inp_type == 'image':
cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))#选取cond——model
process_cond_module = getattr(api, f'get_cond_{which_cond}')#选取api中对应的get_cond函数
# inference
with torch.inference_mode(), \
sd_model.ema_scope(), \
autocast('cuda'):
for test_idx, (cond_path, prompt) in enumerate(zip(image_paths, prompts)):
seed_everything(opt.seed)
for v_idx in range(opt.n_samples):
# seed_everything(opt.seed+v_idx+test_idx)
cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)#即用对应的get_cond处理image
# 保存生成的condition
base_count = len(os.listdir(opt.outdir)) // 2
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
#
adapter_features, append_to_context = get_adapter_feature(cond, adapter)
opt.prompt = prompt
result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)#推理
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))
if __name__ == '__main__':
main()