Skip to content

Commit

Permalink
Add support for charged molecules in TensorNet (total charge) (#238)
Browse files Browse the repository at this point in the history
* Update tensornet.py for support of total charge q

* Update tensornet.py for q support

* fix

* fix

* fix comment

* initialize zero charge as tensor, move charge broadcasting to interaction module

* add clarification comment

* trying fix

* try fix
  • Loading branch information
guillemsimeon authored Nov 3, 2023
1 parent e5fc011 commit af51c58
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,16 @@ def forward(
edge_vec is not None
), "Distance module did not return directional information"
# Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom
# Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q
if q is None:
q = torch.zeros_like(z, device=z.device, dtype=z.dtype)
else:
q = q[batch]
zp = z
if self.static_shapes:
mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index)
zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0)
q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0)
# I trick the model into thinking that the masked edges pertain to the extra atom
# WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs
edge_index = edge_index.masked_fill(mask, z.shape[0])
Expand All @@ -228,7 +234,7 @@ def forward(
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr)
X = layer(X, edge_index, edge_weight, edge_attr, q)
I, A, S = decompose_tensor(X)
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
Expand Down Expand Up @@ -379,7 +385,7 @@ def reset_parameters(self):
linear.reset_parameters()

def forward(
self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor
self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, q: Tensor
) -> Tensor:

C = self.cutoff(edge_weight)
Expand All @@ -401,7 +407,7 @@ def forward(
if self.equivariance_invariance_group == "O(3)":
A = torch.matmul(msg, Y)
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(A + B)
I, A, S = decompose_tensor((1 + 0.1*q[...,None,None,None])*(A + B))
if self.equivariance_invariance_group == "SO(3)":
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(2 * B)
Expand All @@ -411,5 +417,5 @@ def forward(
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
X = X + dX + torch.matrix_power(dX, 2)
X = X + dX + (1 + 0.1*q[...,None,None,None]) * torch.matrix_power(dX, 2)
return X

0 comments on commit af51c58

Please sign in to comment.