diff --git a/tests/modules/layers/test_activation.py b/tests/modules/layers/test_activation.py index 4e8a7010..8a0b42aa 100644 --- a/tests/modules/layers/test_activation.py +++ b/tests/modules/layers/test_activation.py @@ -7,7 +7,7 @@ import torch from tests.test_utils import assert_expected -from torchmultimodal.modules.layers.activation import SiLU +from torchmultimodal.modules.layers.activation import GEGLU, SiLU def test_sigmoid_linear_unit(): @@ -15,3 +15,10 @@ def test_sigmoid_linear_unit(): actual = silu(torch.ones(3)) expected = torch.tensor([0.8458, 0.8458, 0.8458]) assert_expected(actual, expected) + + +def test_geglu(): + geglu = GEGLU() + actual = geglu(torch.ones(10)) + expected = torch.tensor([0.8413, 0.8413, 0.8413, 0.8413, 0.8413]) + assert_expected(actual, expected, atol=1e-4, rtol=1e-5) diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index fda1918f..f8d2bdf4 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -16,6 +16,7 @@ shift_dim, tensor_slice, to_tuple_tuple, + zero_module, ) @@ -30,6 +31,13 @@ def test_shift_dim(): assert_expected(actual, expected) +def test_zero_module(): + module = nn.Conv2d(10, 10, kernel_size=1) + zero_module(module) + for p in module.parameters(): + assert_expected(p, torch.zeros_like(p)) + + class TestTensorSlice: @pytest.fixture(scope="class") def test_input(self): diff --git a/torchmultimodal/modules/layers/activation.py b/torchmultimodal/modules/layers/activation.py index e3f4c374..3ca3615b 100644 --- a/torchmultimodal/modules/layers/activation.py +++ b/torchmultimodal/modules/layers/activation.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn, Tensor @@ -22,3 +23,22 @@ class SiLU(nn.Module): def forward(self, x: Tensor) -> Tensor: return torch.sigmoid(1.702 * x) * x + + +class GEGLU(nn.Module): + """Gated Linear Unit with GELU activation function + + .. math:: \text{GEGLU}(a,b) = a * \text{GELU}(b) + + where :math:`a` is the first half of the input matrices and :math:`b` is + the second half, as descibed in the paper: + `"GLU Variants Improve Transformer"`. + """ + + def __init__(self, dim: int = -1): + super().__init__() + self.split_dim = dim + + def forward(self, x: Tensor) -> Tensor: + x, gate = x.chunk(2, dim=self.split_dim) + return x * F.gelu(gate) diff --git a/torchmultimodal/utils/common.py b/torchmultimodal/utils/common.py index 7cff4fb6..61a8f740 100644 --- a/torchmultimodal/utils/common.py +++ b/torchmultimodal/utils/common.py @@ -187,3 +187,11 @@ def custom_forward(*inputs: Any) -> Callable: def get_clones(module: nn.Module, n: int) -> nn.ModuleList: return nn.ModuleList([deepcopy(module) for i in range(n)]) + + +def zero_module(module: nn.Module) -> None: + """ + Zero out the parameters of a module. + """ + for p in module.parameters(): + nn.init.zeros_(p)