Skip to content

Commit

Permalink
improve: stable diffusion flash attention monkey patch
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Oct 14, 2022
1 parent b86df14 commit 3847c96
Show file tree
Hide file tree
Showing 5 changed files with 473 additions and 19 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@ Modular image generation library.

## Install

```
```bash
poetry add perceptor
```

Or, for the old timers:

```
```bash
pip install perceptor
```

### CUDA 11.3 (Support RTX cards)

```bash
poetry run pip uninstall torch torchvision -y && poetry run pip install torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu113
```

_Stable Diffusion Flash Attention optimization with xformers_.

```bash
poetry run pip uninstall triton xformers -y && poetry run pip install triton xformers
```

## Interface

Shortlist of available features. See the [API reference](https://perceptor.readthedocs.io/en/latest/) for more information.
Expand Down
348 changes: 348 additions & 0 deletions perceptor/models/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
import math
from inspect import isfunction
from typing import Any, Optional

import torch
import torch.nn.functional as F
from torch import nn

import xformers
import xformers.ops


def exists(val):
return val is not None


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels

self.num_heads = (
channels // num_head_channels if num_head_channels is not None else 1
)
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(
num_channels=channels, num_groups=num_groups, eps=eps, affine=True
)

# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)

self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)

def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection

def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(
1, 2
)

# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)

# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))

attention_scores = torch.matmul(
query_states * scale, key_states.transpose(-1, -2) * scale
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(
attention_scores.dtype
)

# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)

hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch, channel, height, width
)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states


class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""

def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)

self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)

self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
)
for d in range(depth)
]
)

self.proj_out = nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)

def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * weight, channel
)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(
0, 3, 1, 2
)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual


class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""

def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
AttentionBuilder = CrossAttention
self.attn1 = AttentionBuilder(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = AttentionBuilder(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint

def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def forward(self, hidden_states, context=None):
hidden_states = (
hidden_states.contiguous()
if hidden_states.device.type == "mps"
else hidden_states
)
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = (
self.attn2(self.norm2(hidden_states), context=context) + hidden_states
)
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.heads = heads
self.dim_head = dim_head

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None

def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)

# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)

# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)


class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
glu: bool = False,
dropout: float = 0.0,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)

self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)

def forward(self, hidden_states):
return self.net(hidden_states)


# feedforward
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)

def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
Loading

0 comments on commit 3847c96

Please sign in to comment.