Skip to content

Commit

Permalink
Implement Spatial Transformer used by LDM UNet (#438)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #438

**TL;DR**
Implements the Spatial Transformer and the Cross-Attention layers used within the LDM Unet.

**Summary of Changes**
1. Implement `TransformerCrossAttentionLayer` similar to LDM;s implementation. This module uses `GEGLU` activation instead of `GLU`
1. Implements `SpatialTransformer` that operates on images.
1. Implement `GEGLU` activation
1. Implement `zero_module` util
1. Add tests for each component

Reviewed By: pbontrager

Differential Revision: D47709486

fbshipit-source-id: 3eb7164a25a1726911cba65b19528cc9753a740d
  • Loading branch information
Abhinav Arora authored and facebook-github-bot committed Jul 27, 2023
1 parent 82c1dc2 commit 7d7e9cd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/modules/layers/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
import torch

from tests.test_utils import assert_expected
from torchmultimodal.modules.layers.activation import SiLU
from torchmultimodal.modules.layers.activation import GEGLU, SiLU


def test_sigmoid_linear_unit():
silu = SiLU()
actual = silu(torch.ones(3))
expected = torch.tensor([0.8458, 0.8458, 0.8458])
assert_expected(actual, expected)


def test_geglu():
geglu = GEGLU()
actual = geglu(torch.ones(10))
expected = torch.tensor([0.8413, 0.8413, 0.8413, 0.8413, 0.8413])
assert_expected(actual, expected, atol=1e-4, rtol=1e-5)
8 changes: 8 additions & 0 deletions tests/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.utils.checkpoint import checkpoint
from torchmultimodal.utils.common import (
checkpoint_wrapper,
init_module_parameters_to_zero,
shift_dim,
tensor_slice,
to_tuple_tuple,
Expand All @@ -30,6 +31,13 @@ def test_shift_dim():
assert_expected(actual, expected)


def test_init_module_parameters_to_zero():
module = nn.Conv2d(10, 10, kernel_size=1)
init_module_parameters_to_zero(module)
for p in module.parameters():
assert_expected(p, torch.zeros_like(p))


class TestTensorSlice:
@pytest.fixture(scope="class")
def test_input(self):
Expand Down
20 changes: 20 additions & 0 deletions torchmultimodal/modules/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F
from torch import nn, Tensor


Expand All @@ -22,3 +23,22 @@ class SiLU(nn.Module):

def forward(self, x: Tensor) -> Tensor:
return torch.sigmoid(1.702 * x) * x


class GEGLU(nn.Module):
"""Gated Linear Unit with GELU activation function

.. math:: \text{GEGLU}(a,b) = a * \text{GELU}(b)

where :math:`a` is the first half of the input matrices and :math:`b` is
the second half, as descibed in the paper:
`"GLU Variants Improve Transformer"<https://arxiv.org/pdf/2002.05202.pdf>`.
"""

def __init__(self, dim: int = -1):
super().__init__()
self.split_dim = dim

def forward(self, x: Tensor) -> Tensor:
x, gate = x.chunk(2, dim=self.split_dim)
return x * F.gelu(gate)
9 changes: 9 additions & 0 deletions torchmultimodal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,12 @@ def custom_forward(*inputs: Any) -> Callable:

def get_clones(module: nn.Module, n: int) -> nn.ModuleList:
return nn.ModuleList([deepcopy(module) for i in range(n)])


def init_module_parameters_to_zero(module: nn.Module) -> None:
"""
Sets the parameters of a module to zero. This is a commonly used trick
from Fixup initialization, to stabilize training of residual networks.
"""
for p in module.parameters():
nn.init.zeros_(p)

0 comments on commit 7d7e9cd

Please sign in to comment.