From 97264a215cc5755b95f34e9cba41f6b38545798c Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Wed, 26 Jul 2023 22:05:55 -0700 Subject: [PATCH] Implement Spatial Transformer used by LDM UNet (#438) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/438 **TL;DR** Implements the Spatial Transformer and the Cross-Attention layers used within the LDM Unet. **Summary of Changes** 1. Implement `TransformerCrossAttentionLayer` similar to LDM;s implementation. This module uses `GEGLU` activation instead of `GLU` 1. Implements `SpatialTransformer` that operates on images. 1. Implement `GEGLU` activation 1. Implement `zero_module` util 1. Add tests for each component Reviewed By: pbontrager Differential Revision: D47709486 fbshipit-source-id: 9ccb91afb43f530d73dcbad37db4655ab7dfeb51 --- tests/modules/layers/test_activation.py | 9 ++++++++- tests/utils/test_common.py | 8 ++++++++ torchmultimodal/modules/layers/activation.py | 20 ++++++++++++++++++++ torchmultimodal/utils/common.py | 9 +++++++++ 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/modules/layers/test_activation.py b/tests/modules/layers/test_activation.py index 4e8a7010..8a0b42aa 100644 --- a/tests/modules/layers/test_activation.py +++ b/tests/modules/layers/test_activation.py @@ -7,7 +7,7 @@ import torch from tests.test_utils import assert_expected -from torchmultimodal.modules.layers.activation import SiLU +from torchmultimodal.modules.layers.activation import GEGLU, SiLU def test_sigmoid_linear_unit(): @@ -15,3 +15,10 @@ def test_sigmoid_linear_unit(): actual = silu(torch.ones(3)) expected = torch.tensor([0.8458, 0.8458, 0.8458]) assert_expected(actual, expected) + + +def test_geglu(): + geglu = GEGLU() + actual = geglu(torch.ones(10)) + expected = torch.tensor([0.8413, 0.8413, 0.8413, 0.8413, 0.8413]) + assert_expected(actual, expected, atol=1e-4, rtol=1e-5) diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index fda1918f..c5340947 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -13,6 +13,7 @@ from torch.utils.checkpoint import checkpoint from torchmultimodal.utils.common import ( checkpoint_wrapper, + init_module_parameters_to_zero, shift_dim, tensor_slice, to_tuple_tuple, @@ -30,6 +31,13 @@ def test_shift_dim(): assert_expected(actual, expected) +def test_init_module_parameters_to_zero(): + module = nn.Conv2d(10, 10, kernel_size=1) + init_module_parameters_to_zero(module) + for p in module.parameters(): + assert_expected(p, torch.zeros_like(p)) + + class TestTensorSlice: @pytest.fixture(scope="class") def test_input(self): diff --git a/torchmultimodal/modules/layers/activation.py b/torchmultimodal/modules/layers/activation.py index e3f4c374..3ca3615b 100644 --- a/torchmultimodal/modules/layers/activation.py +++ b/torchmultimodal/modules/layers/activation.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn, Tensor @@ -22,3 +23,22 @@ class SiLU(nn.Module): def forward(self, x: Tensor) -> Tensor: return torch.sigmoid(1.702 * x) * x + + +class GEGLU(nn.Module): + """Gated Linear Unit with GELU activation function + + .. math:: \text{GEGLU}(a,b) = a * \text{GELU}(b) + + where :math:`a` is the first half of the input matrices and :math:`b` is + the second half, as descibed in the paper: + `"GLU Variants Improve Transformer"`. + """ + + def __init__(self, dim: int = -1): + super().__init__() + self.split_dim = dim + + def forward(self, x: Tensor) -> Tensor: + x, gate = x.chunk(2, dim=self.split_dim) + return x * F.gelu(gate) diff --git a/torchmultimodal/utils/common.py b/torchmultimodal/utils/common.py index 7cff4fb6..5d4fc4c3 100644 --- a/torchmultimodal/utils/common.py +++ b/torchmultimodal/utils/common.py @@ -187,3 +187,12 @@ def custom_forward(*inputs: Any) -> Callable: def get_clones(module: nn.Module, n: int) -> nn.ModuleList: return nn.ModuleList([deepcopy(module) for i in range(n)]) + + +def init_module_parameters_to_zero(module: nn.Module) -> None: + """ + Sets the parameters of a module to zero. This is a commonly used trick + from Fixup initialization, to stabilize training of residual networks. + """ + for p in module.parameters(): + nn.init.zeros_(p)