-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improve: stable diffusion flash attention monkey patch
- Loading branch information
Showing
5 changed files
with
473 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
import math | ||
from inspect import isfunction | ||
from typing import Any, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
import xformers | ||
import xformers.ops | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def default(val, d): | ||
if exists(val): | ||
return val | ||
return d() if isfunction(d) else d | ||
|
||
|
||
class AttentionBlock(nn.Module): | ||
""" | ||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted | ||
to the N-d case. | ||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. | ||
Uses three q, k, v linear layers to compute attention. | ||
Parameters: | ||
channels (:obj:`int`): The number of channels in the input and output. | ||
num_head_channels (:obj:`int`, *optional*): | ||
The number of channels in each head. If None, then `num_heads` = 1. | ||
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. | ||
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. | ||
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
channels: int, | ||
num_head_channels: Optional[int] = None, | ||
num_groups: int = 32, | ||
rescale_output_factor: float = 1.0, | ||
eps: float = 1e-5, | ||
): | ||
super().__init__() | ||
self.channels = channels | ||
|
||
self.num_heads = ( | ||
channels // num_head_channels if num_head_channels is not None else 1 | ||
) | ||
self.num_head_size = num_head_channels | ||
self.group_norm = nn.GroupNorm( | ||
num_channels=channels, num_groups=num_groups, eps=eps, affine=True | ||
) | ||
|
||
# define q,k,v as linear layers | ||
self.query = nn.Linear(channels, channels) | ||
self.key = nn.Linear(channels, channels) | ||
self.value = nn.Linear(channels, channels) | ||
|
||
self.rescale_output_factor = rescale_output_factor | ||
self.proj_attn = nn.Linear(channels, channels, 1) | ||
|
||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: | ||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) | ||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) | ||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) | ||
return new_projection | ||
|
||
def forward(self, hidden_states): | ||
residual = hidden_states | ||
batch, channel, height, width = hidden_states.shape | ||
|
||
# norm | ||
hidden_states = self.group_norm(hidden_states) | ||
|
||
hidden_states = hidden_states.view(batch, channel, height * width).transpose( | ||
1, 2 | ||
) | ||
|
||
# proj to q, k, v | ||
query_proj = self.query(hidden_states) | ||
key_proj = self.key(hidden_states) | ||
value_proj = self.value(hidden_states) | ||
|
||
# transpose | ||
query_states = self.transpose_for_scores(query_proj) | ||
key_states = self.transpose_for_scores(key_proj) | ||
value_states = self.transpose_for_scores(value_proj) | ||
|
||
# get scores | ||
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) | ||
|
||
attention_scores = torch.matmul( | ||
query_states * scale, key_states.transpose(-1, -2) * scale | ||
) | ||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type( | ||
attention_scores.dtype | ||
) | ||
|
||
# compute attention output | ||
hidden_states = torch.matmul(attention_probs, value_states) | ||
|
||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() | ||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) | ||
hidden_states = hidden_states.view(new_hidden_states_shape) | ||
|
||
# compute next hidden_states | ||
hidden_states = self.proj_attn(hidden_states) | ||
hidden_states = hidden_states.transpose(-1, -2).reshape( | ||
batch, channel, height, width | ||
) | ||
|
||
# res connect and rescale | ||
hidden_states = (hidden_states + residual) / self.rescale_output_factor | ||
return hidden_states | ||
|
||
|
||
class SpatialTransformer(nn.Module): | ||
""" | ||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply | ||
standard transformer action. Finally, reshape to image. | ||
Parameters: | ||
in_channels (:obj:`int`): The number of channels in the input and output. | ||
n_heads (:obj:`int`): The number of heads to use for multi-head attention. | ||
d_head (:obj:`int`): The number of channels in each head. | ||
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | ||
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. | ||
context_dim (:obj:`int`, *optional*): The number of context dimensions to use. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
n_heads: int, | ||
d_head: int, | ||
depth: int = 1, | ||
dropout: float = 0.0, | ||
num_groups: int = 32, | ||
context_dim: Optional[int] = None, | ||
): | ||
super().__init__() | ||
self.n_heads = n_heads | ||
self.d_head = d_head | ||
self.in_channels = in_channels | ||
inner_dim = n_heads * d_head | ||
self.norm = torch.nn.GroupNorm( | ||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True | ||
) | ||
|
||
self.proj_in = nn.Conv2d( | ||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0 | ||
) | ||
|
||
self.transformer_blocks = nn.ModuleList( | ||
[ | ||
BasicTransformerBlock( | ||
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim | ||
) | ||
for d in range(depth) | ||
] | ||
) | ||
|
||
self.proj_out = nn.Conv2d( | ||
inner_dim, in_channels, kernel_size=1, stride=1, padding=0 | ||
) | ||
|
||
def _set_attention_slice(self, slice_size): | ||
for block in self.transformer_blocks: | ||
block._set_attention_slice(slice_size) | ||
|
||
def forward(self, hidden_states, context=None): | ||
# note: if no context is given, cross-attention defaults to self-attention | ||
batch, channel, height, weight = hidden_states.shape | ||
residual = hidden_states | ||
hidden_states = self.norm(hidden_states) | ||
hidden_states = self.proj_in(hidden_states) | ||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( | ||
batch, height * weight, channel | ||
) | ||
for block in self.transformer_blocks: | ||
hidden_states = block(hidden_states, context=context) | ||
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute( | ||
0, 3, 1, 2 | ||
) | ||
hidden_states = self.proj_out(hidden_states) | ||
return hidden_states + residual | ||
|
||
|
||
class BasicTransformerBlock(nn.Module): | ||
r""" | ||
A basic Transformer block. | ||
Parameters: | ||
dim (:obj:`int`): The number of channels in the input and output. | ||
n_heads (:obj:`int`): The number of heads to use for multi-head attention. | ||
d_head (:obj:`int`): The number of channels in each head. | ||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. | ||
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. | ||
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
n_heads: int, | ||
d_head: int, | ||
dropout=0.0, | ||
context_dim: Optional[int] = None, | ||
gated_ff: bool = True, | ||
checkpoint: bool = True, | ||
): | ||
super().__init__() | ||
AttentionBuilder = CrossAttention | ||
self.attn1 = AttentionBuilder( | ||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout | ||
) # is a self-attention | ||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) | ||
self.attn2 = AttentionBuilder( | ||
query_dim=dim, | ||
context_dim=context_dim, | ||
heads=n_heads, | ||
dim_head=d_head, | ||
dropout=dropout, | ||
) # is self-attn if context is none | ||
self.norm1 = nn.LayerNorm(dim) | ||
self.norm2 = nn.LayerNorm(dim) | ||
self.norm3 = nn.LayerNorm(dim) | ||
self.checkpoint = checkpoint | ||
|
||
def _set_attention_slice(self, slice_size): | ||
self.attn1._slice_size = slice_size | ||
self.attn2._slice_size = slice_size | ||
|
||
def forward(self, hidden_states, context=None): | ||
hidden_states = ( | ||
hidden_states.contiguous() | ||
if hidden_states.device.type == "mps" | ||
else hidden_states | ||
) | ||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states | ||
hidden_states = ( | ||
self.attn2(self.norm2(hidden_states), context=context) + hidden_states | ||
) | ||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states | ||
return hidden_states | ||
|
||
|
||
class CrossAttention(nn.Module): | ||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
context_dim = default(context_dim, query_dim) | ||
|
||
self.heads = heads | ||
self.dim_head = dim_head | ||
|
||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | ||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | ||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) | ||
) | ||
self.attention_op: Optional[Any] = None | ||
|
||
def forward(self, x, context=None, mask=None): | ||
q = self.to_q(x) | ||
context = default(context, x) | ||
k = self.to_k(context) | ||
v = self.to_v(context) | ||
|
||
b, _, _ = q.shape | ||
q, k, v = map( | ||
lambda t: t.unsqueeze(3) | ||
.reshape(b, t.shape[1], self.heads, self.dim_head) | ||
.permute(0, 2, 1, 3) | ||
.reshape(b * self.heads, t.shape[1], self.dim_head) | ||
.contiguous(), | ||
(q, k, v), | ||
) | ||
|
||
# actually compute the attention, what we cannot get enough of | ||
out = xformers.ops.memory_efficient_attention( | ||
q, k, v, attn_bias=None, op=self.attention_op | ||
) | ||
|
||
# TODO: Use this directly in the attention operation, as a bias | ||
if exists(mask): | ||
raise NotImplementedError | ||
out = ( | ||
out.unsqueeze(0) | ||
.reshape(b, self.heads, out.shape[1], self.dim_head) | ||
.permute(0, 2, 1, 3) | ||
.reshape(b, out.shape[1], self.heads * self.dim_head) | ||
) | ||
return self.to_out(out) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
r""" | ||
A feed-forward layer. | ||
Parameters: | ||
dim (:obj:`int`): The number of channels in the input. | ||
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. | ||
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. | ||
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. | ||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
dim_out: Optional[int] = None, | ||
mult: int = 4, | ||
glu: bool = False, | ||
dropout: float = 0.0, | ||
): | ||
super().__init__() | ||
inner_dim = int(dim * mult) | ||
dim_out = dim_out if dim_out is not None else dim | ||
project_in = GEGLU(dim, inner_dim) | ||
|
||
self.net = nn.Sequential( | ||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) | ||
) | ||
|
||
def forward(self, hidden_states): | ||
return self.net(hidden_states) | ||
|
||
|
||
# feedforward | ||
class GEGLU(nn.Module): | ||
r""" | ||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. | ||
Parameters: | ||
dim_in (:obj:`int`): The number of channels in the input. | ||
dim_out (:obj:`int`): The number of channels in the output. | ||
""" | ||
|
||
def __init__(self, dim_in: int, dim_out: int): | ||
super().__init__() | ||
self.proj = nn.Linear(dim_in, dim_out * 2) | ||
|
||
def forward(self, hidden_states): | ||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) | ||
return hidden_states * F.gelu(gate) |
Oops, something went wrong.