Skip to content

Commit

Permalink
add clamp value for coordinate weights
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 15, 2021
1 parent 878dd8e commit 1153bb0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_nearest_neighbors = 8
num_nearest_neighbors = 8,
coor_weights_clamp_value = 2. # absolute clampd value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
Expand Down
9 changes: 8 additions & 1 deletion egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __init__(
only_sparse_neighbors = False,
valid_radius = float('inf'),
m_pool_method = 'sum',
soft_edges = False
soft_edges = False,
coor_weights_clamp_value = None
):
super().__init__()
assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean'
Expand Down Expand Up @@ -162,6 +163,8 @@ def __init__(
self.only_sparse_neighbors = only_sparse_neighbors
self.valid_radius = valid_radius

self.coor_weights_clamp_value = coor_weights_clamp_value

self.init_eps = init_eps
self.apply(self.init_)

Expand Down Expand Up @@ -257,6 +260,10 @@ def forward(self, feats, coors, edges = None, mask = None, adj_mat = None):
if exists(mask):
coor_weights.masked_fill_(~mask, 0.)

if exists(self.coor_weights_clamp_value):
clamp_value = self.coor_weights_clamp_value
coor_weights.clamp_(min = -clamp_value, max = clamp_value)

coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors
else:
coors_out = coors
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.1.7',
version = '0.1.8',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 1153bb0

Please sign in to comment.