-
Notifications
You must be signed in to change notification settings - Fork 3
/
scale_attention_turbo.py
371 lines (314 loc) · 15.7 KB
/
scale_attention_turbo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from einops import rearrange
import random
def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
x_coord = torch.arange(kernel_size)
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
return kernel
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
channels = latents.shape[1]
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
return blurred_latents
def get_views(height, width, h_window_size=64, w_window_size=64, scale_factor=8):
height = int(height)
width = int(width)
h_window_stride = h_window_size // 2
w_window_stride = w_window_size // 2
h_window_size = int(h_window_size / scale_factor)
w_window_size = int(w_window_size / scale_factor)
h_window_stride = int(h_window_stride / scale_factor)
w_window_stride = int(w_window_stride / scale_factor)
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * h_window_stride)
h_end = h_start + h_window_size
w_start = int((i % num_blocks_width) * w_window_stride)
w_end = w_start + w_window_size
if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0
random_jitter = True
if random_jitter:
h_jitter_range = h_window_size // 8
w_jitter_range = w_window_size // 8
h_jitter = 0
w_jitter = 0
if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, w_jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, h_jitter_range)
h_start += (h_jitter + h_jitter_range)
h_end += (h_jitter + h_jitter_range)
w_start += (w_jitter + w_jitter_range)
w_end += (w_jitter + w_jitter_range)
views.append((h_start, h_end, w_start, w_end))
return views
def scale_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
if self.current_hw:
current_scale_num_h, current_scale_num_w = self.current_hw[0] // 512, self.current_hw[1] // 512
else:
current_scale_num_h, current_scale_num_w = 1, 1
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
ratio_hw = current_scale_num_h / current_scale_num_w
latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
latent_w = int(latent_h / ratio_hw)
scale_factor = 64 * current_scale_num_h / latent_h
if ratio_hw > 1:
sub_h = 64
sub_w = int(64 / ratio_hw)
else:
sub_h = int(64 * ratio_hw)
sub_w = 64
h_jitter_range = int(sub_h / scale_factor // 8)
w_jitter_range = int(sub_w / scale_factor // 8)
views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
current_scale_num = max(current_scale_num_h, current_scale_num_w)
global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
if self.fast_mode:
four_window = False
fourg_window = True
else:
four_window = True
fourg_window = False
if four_window:
norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
value = torch.zeros_like(norm_hidden_states_)
count = torch.zeros_like(norm_hidden_states_)
for index, view in enumerate(views):
h_start, h_end, w_start, w_end = view
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
local_output = self.attn1(
local_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
count[:, h_start:h_end, w_start:w_end, :] += 1
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
attn_output = torch.where(count>0, value/count, value)
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output_global = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output = gaussian_local + (attn_output_global - gaussian_global)
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
elif fourg_window:
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
value = torch.zeros_like(norm_hidden_states_)
count = torch.zeros_like(norm_hidden_states_)
for index, view in enumerate(views):
h_start, h_end, w_start, w_end = view
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
local_output = self.attn1(
local_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
count[:, h_start:h_end, w_start:w_end, :] += 1
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
attn_output = torch.where(count>0, value/count, value)
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
value = torch.zeros_like(norm_hidden_states)
count = torch.zeros_like(norm_hidden_states)
for index, global_view in enumerate(global_views):
h, w = global_view
global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
global_output = self.attn1(
global_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
attn_output_global = torch.where(count>0, value/count, value)
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output = gaussian_local + (attn_output_global - gaussian_global)
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
def ori_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states