Skip to content

Commit

Permalink
add gating after edges MLP - section 3.3 of the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 4, 2021
1 parent ab220e1 commit 65fcb3e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch_geometric.typing import Adj, Size, OptTensor, Tensor
PYG_AVAILABLE = True
except:
MessagePassing = object
Tensor = OptTensor = Adj = MessagePassing = Size = object
PYG_AVAILABLE = False

# to stop throwing errors from type suggestions
Expand Down Expand Up @@ -114,7 +114,8 @@ def __init__(
update_coors = True,
only_sparse_neighbors = False,
valid_radius = float('inf'),
m_pool_method = 'sum'
m_pool_method = 'sum',
soft_edges = False
):
super().__init__()
assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean'
Expand All @@ -133,6 +134,11 @@ def __init__(
SiLU()
)

self.edge_gate = nn.Sequential(
nn.Linear(m_dim, 1),
nn.Sigmoid()
) if soft_edges else None

self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
self.coors_norm = CoorsNorm() if norm_coors else nn.Identity()

Expand Down Expand Up @@ -229,6 +235,9 @@ def forward(self, feats, coors, edges = None, mask = None, adj_mat = None):

m_ij = self.edge_mlp(edge_input)

if exists(self.edge_gate):
m_ij = m_ij * self.edge_gate(m_ij)

if exists(mask):
mask_i = rearrange(mask, 'b i -> b i ()')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.1.4',
version = '0.1.5',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 65fcb3e

Please sign in to comment.