Skip to content

Commit

Permalink
make torch geometric optional by making import of se3gnn explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 15, 2021
1 parent 706f45f commit 7756773
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
1 change: 0 additions & 1 deletion egnn_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from egnn_pytorch.egnn_pytorch import EGNN, EGNN_sparse, EGNN_Network
from egnn_pytorch.se3gnn import get_sparse_adj_paths, SE3GNN_sparse
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'egnn-pytorch',
packages = find_packages(),
version = '0.1.9',
version = '0.1.10',
license='MIT',
description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down
10 changes: 10 additions & 0 deletions tests/test_equivariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def test_egnn_equivariance():
assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'
assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'

def test_higher_dimension():
layer = EGNN(dim=512, edge_dim=4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 5)
edges = torch.randn(1, 16, 16, 4)
mask = torch.ones(1, 16).bool()

feats, coors = layer(feats, coors, edges, mask = mask)
assert True

def test_egnn_equivariance_with_nearest_neighbors():
layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch

from egnn_pytorch.utils import rot
from egnn_pytorch import EGNN, EGNN_sparse, get_sparse_adj_paths, SE3GNN_sparse
from egnn_pytorch import EGNN, EGNN_sparse
from egnn_pytorch.se3gnn import get_sparse_adj_paths, SE3GNN_sparse


def test_geom_angles():
Expand Down

0 comments on commit 7756773

Please sign in to comment.