Skip to content

Commit

Permalink
release vector gating from https://arxiv.org/abs/2106.03843
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 8, 2021
1 parent 6dea598 commit d709615
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 14 deletions.
32 changes: 22 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ $ pip install geometric-vector-perceptron
```

### Functionality

* `GVP`: Implementing the basic geometric vector perceptron.
* `GVPDropout`: Adapted dropout for GVP in MPNN context
* `GVPLayerNorm`: Adapted LayerNorm for GVP in MPNN context
Expand All @@ -27,7 +28,8 @@ model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512
dim_feats_out = 512,
vector_gating = True # use the vector gating as proposed in https://arxiv.org/abs/2106.03843
)

feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))
Expand All @@ -46,7 +48,8 @@ model = GVP(
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_vectors_out = 256,
dim_feats_out = 512
dim_feats_out = 512,
vector_gating = True
)

dropout = GVPDropout(0.2)
Expand All @@ -66,13 +69,22 @@ The original implementation in TF by the paper authors can be found here: https:
## Citations

```bibtex
@inproceedings{
anonymous2021learning,
title={Learning from Protein Structure with Geometric Vector Perceptrons},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=1YLJDvSx6J4},
note={under review}
@inproceedings{anonymous2021learning,
title = {Learning from Protein Structure with Geometric Vector Perceptrons},
author = {Anonymous},
booktitle = {Submitted to International Conference on Learning Representations},
year = {2021},
url = {https://openreview.net/forum?id=1YLJDvSx6J4}
}
```

```bibtex
@misc{jing2021equivariant,
title = {Equivariant Graph Neural Networks for 3D Macromolecular Structure},
author = {Bowen Jing and Stephan Eismann and Pratham N. Soni and Ron O. Dror},
year = {2021},
eprint = {2106.03843},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
25 changes: 22 additions & 3 deletions geometric_vector_perceptron/geometric_vector_perceptron.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import torch
from torch import nn, einsum
from torch_geometric.nn import MessagePassing

# types

from typing import Optional, List, Union
from torch_geometric.typing import OptPairTensor, Adj, Size, OptTensor, Tensor

# helper functions

def exists(val):
return val is not None

# classes

class GVP(nn.Module):
def __init__(
self,
Expand All @@ -14,7 +23,8 @@ def __init__(
dim_feats_in,
dim_feats_out,
feats_activation = nn.Sigmoid(),
vectors_activation = nn.Sigmoid()
vectors_activation = nn.Sigmoid(),
vector_gating = False
):
super().__init__()
self.dim_vectors_in = dim_vectors_in
Expand All @@ -33,6 +43,10 @@ def __init__(
feats_activation
)

# branching logic to use old GVP, or GVP with vector gating

self.scalar_to_vector_gates = nn.Linear(dim_feats_out, dim_vectors_out) if vector_gating else None

def forward(self, data):
feats, vectors = data
b, n, _, v, c = *feats.shape, *vectors.shape
Expand All @@ -44,13 +58,18 @@ def forward(self, data):
Vu = einsum('b h c, h u -> b u c', Vh, self.Wu)

sh = torch.norm(Vh, p = 2, dim = -1)
vu = torch.norm(Vu, p = 2, dim = -1, keepdim = True)

s = torch.cat((feats, sh), dim = 1)

feats_out = self.to_feats_out(s)
vectors_out = self.vectors_activation(vu) * Vu

if exists(self.scalar_to_vector_gates):
gating = self.scalar_to_vector_gates(feats_out)
gating = gating.unsqueeze(dim = -1)
else:
gating = torch.norm(Vu, p = 2, dim = -1, keepdim = True)

vectors_out = self.vectors_activation(gating) * Vu
return (feats_out, vectors_out)

class GVPDropout(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'geometric-vector-perceptron',
packages = find_packages(),
version = '0.0.12',
version = '0.0.14',
license='MIT',
description = 'Geometric Vector Perceptron - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit d709615

Please sign in to comment.