Skip to content

Commit

Permalink
add absolute positional embedding for EGNN_network, which is automati…
Browse files Browse the repository at this point in the history
…cally summed with feature representations when num_positions is set to the maximum sequence length
  • Loading branch information
lucidrains committed May 15, 2021
1 parent 1153bb0 commit f579883
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ from egnn_pytorch.egnn_pytorch import EGNN_Network

net = EGNN_Network(
num_tokens = 21,
num_positions = 1024, # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32,
depth = 3,
num_nearest_neighbors = 8,
Expand Down
11 changes: 10 additions & 1 deletion egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,18 @@ def __init__(
dim,
num_tokens = None,
num_edge_tokens = None,
num_positions = None,
edge_dim = 0,
num_adj_degrees = None,
adj_dim = 0,
**kwargs
):
super().__init__()
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
self.num_positions = num_positions

self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
self.has_edges = edge_dim > 0

Expand All @@ -323,11 +326,17 @@ def __init__(
self.layers.append(EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs))

def forward(self, feats, coors, adj_mat = None, edges = None, mask = None):
b = feats.shape[0]
b, device = feats.shape[0], feats.device

if exists(self.token_emb):
feats = self.token_emb(feats)

if exists(self.pos_emb):
n = feats.shape[1]
assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init'
pos_emb = self.pos_emb(torch.arange(n, device = device))
feats += rearrange(pos_emb, 'n d -> () n d')

if exists(edges) and exists(self.edge_emb):
edges = self.edge_emb(edges)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.1.8',
version = '0.1.9',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit f579883

Please sign in to comment.