-
Notifications
You must be signed in to change notification settings - Fork 3
/
lora.py
135 lines (113 loc) · 4.22 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from safetensors.torch import load_file
def __load_lora(
pipeline
,lora_path
,lora_weight=0.5
):
state_dict = load_file(lora_path)
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
alpha = lora_weight
visited = []
# directly update weight in diffusers model
for key in state_dict:
# as we have set the alpha beforehand, so just skip
if '.alpha' in key or key in visited:
continue
if 'text' in key:
layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER+'_')[-1].split('_')
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET+'_')[-1].split('_')
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += '_'+layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
# org_forward(x) + lora_up(lora_down(x)) * multiplier
pair_keys = []
if 'lora_down' in key:
pair_keys.append(key.replace('lora_down', 'lora_up'))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace('lora_up', 'lora_down'))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
return pipeline
from flask import Flask, request, send_file, render_template
from diffusers import (
StableDiffusionPipeline,
EulerDiscreteScheduler,
StableDiffusionImg2ImgPipeline,
)
import torch
from PIL import Image
from io import BytesIO
import requests
from flask_cors import CORS, cross_origin
app = Flask(__name__, template_folder="frontend", static_folder="frontend")
CORS(app, support_credentials=True)
@app.route("/")
def index():
return render_template("index.html")
@app.get("/health")
def health_check():
return "Healthy", 200
@app.post("/txt2img")
def text_to_img():
data = request.json
model_id = "XpucT/Deliberate"
output = "output_txt2img.png"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
subfolder="scheduler",
custom_pipeline = "lpw_stable_diffusion" ,
torch_dtype = torch.float16,
)
lora = (r"./moxin.safetensors",0.8)
pipe = __load_lora(pipeline=pipe,lora_path=lora[0],lora_weight=lora[1])
pipe = pipe.to("cuda")
#pipe.enable_xformers_memory_efficient_attention()
image = pipe(data["prompt"], height=data["height"], width=data["width"]).images[0]
image.save(output)
return send_file(output), 200
@app.post("/img2img")
def img_to_img():
data = request.json
model_id = "runwayml/stable-diffusion-v1-5"
output = "output_img2img.png"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
response = requests.get(data["url"])
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))
images = pipe(
prompt=data["prompt"], image=init_image, strength=0.75, guidance_scale=7.5
).images
images[0].save(output)
return send_file(output), 200
app.run(host='0.0.0.0', port=5000)