Skip to content

Commit

Permalink
Improve memory efficiency of graph convolution operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsilter committed Apr 25, 2024
1 parent 3995b24 commit 2c867df
Showing 1 changed file with 137 additions and 12 deletions.
149 changes: 137 additions & 12 deletions models/tensor_layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import logging
import os
from typing import Tuple, Union, List

import numpy as np
import torch
import torch.nn as nn
Expand All @@ -9,6 +13,7 @@

from models.layers import FCBlock


def get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars):
if use_second_order_repr:
irrep_seq = [
Expand Down Expand Up @@ -117,6 +122,115 @@ def forward(self, in_, sh, weight):
return torch.cat(out, dim=-1)


def tp_scatter_simple(tp, fc_layer, node_attr, edge_index, edge_attr, edge_sh,
out_nodes=None, reduce='mean', edge_weight=1.0):
"""
Perform TensorProduct + scatter operation, aka graph convolution.
This function is only for edge_groups == 1. For multiple edge groups, and for larger graphs,
use tp_scatter_multigroup instead.
"""

assert isinstance(edge_attr, torch.Tensor), \
"This function is only for a single edge group, so edge_attr must be a tensor and not a list."

_device = node_attr.device
_dtype = node_attr.dtype
edge_src, edge_dst = edge_index
out_irreps = fc_layer(edge_attr).to(_device).to(_dtype)
out_irreps.mul_(edge_weight)
tp = tp(node_attr[edge_dst], edge_sh, out_irreps)
out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
return out


def tp_scatter_multigroup(tp: o3.TensorProduct, fc_layer: Union[nn.Module, nn.ModuleList],
node_attr: torch.Tensor, edge_index: torch.Tensor,
edge_attr_groups: List[torch.Tensor], edge_sh: torch.Tensor,
out_nodes=None, reduce='mean', edge_weight=1.0):
"""
Perform TensorProduct + scatter operation, aka graph convolution.
To keep the peak memory usage reasonably low, this function does not concatenate the edge_attr_groups.
Rather, we sum the output of the tensor product for each edge group, and then divide by the number of edges
Parameters
----------
tp: o3.TensorProduct
fc_layer: nn.Module, or nn.ModuleList
If a list, must be the same length as edge_attr_groups
node_attr: torch.Tensor
edge_index: torch.Tensor of shape (2, num_edges)
Indicates the source and destination nodes of each edge
edge_attr_groups: List[torch.Tensor]
List of tensors, with shape (X_i, num_edge_attributes). Each tensor is a different group of edge attributes
X may be different for each tensor, although sum(X_i) must be equal to edge_index.shape[1]
edge_sh: torch.Tensor
Spherical harmonics for the edges (see o3.spherical_harmonics)
out_nodes:
Number of output nodes
reduce: str
'mean' or 'sum'. Reduce function for scatter.
edge_weight : float or torch.Tensor
Edge weights. If a tensor, must be the same shape as `edge_index`
Returns
-------
torch.Tensor
Result of the graph convolution
"""

assert isinstance(edge_attr_groups, list), "This function is only for a list of edge groups"
assert reduce in {"mean", "sum"}, "Only 'mean' and 'sum' are supported for reduce"
# It would be possible to support mul/min/max but that would require more work and more code,
# so only going to do it if it's needed.

_device = node_attr.device
_dtype = node_attr.dtype
edge_src, edge_dst = edge_index
edge_attr_lengths = [_edge_attr.shape[0] for _edge_attr in edge_attr_groups]
total_rows = sum(edge_attr_lengths)
assert total_rows == edge_index.shape[1], "Sum of edge_attr_groups must be equal to edge_index.shape[1]"
num_edge_groups = len(edge_attr_groups)
edge_weight_is_indexable = hasattr(edge_weight, '__getitem__')

out_nodes = out_nodes or node_attr.shape[0]
total_output_dim = sum([x.dim for x in tp.irreps_out])
final_out = torch.zeros((out_nodes, total_output_dim), device=_device, dtype=_dtype)
div_factors = torch.zeros(out_nodes, device=_device, dtype=_dtype)

cur_start = 0
for ii in range(num_edge_groups):
cur_length = edge_attr_lengths[ii]
cur_end = cur_start + cur_length
cur_edge_range = slice(cur_start, cur_end)
cur_edge_src, cur_edge_dst = edge_src[cur_edge_range], edge_dst[cur_edge_range]

cur_fc = fc_layer[ii] if isinstance(fc_layer, nn.ModuleList) else fc_layer
cur_out_irreps = cur_fc(edge_attr_groups[ii])
if edge_weight_is_indexable:
cur_out_irreps.mul_(edge_weight[cur_edge_range])
else:
cur_out_irreps.mul_(edge_weight)

summand = tp(node_attr[cur_edge_dst, :], edge_sh[cur_edge_range, :], cur_out_irreps)
# We take a simple sum, and then add up the count of edges which contribute,
# so that we can take the mean later.
final_out += scatter(summand, cur_edge_src, dim=0, dim_size=out_nodes, reduce="sum")
div_factors += torch.bincount(cur_edge_src, minlength=out_nodes)

cur_start = cur_end

del cur_out_irreps, summand

if reduce == 'mean':
div_factors = torch.clamp(div_factors, torch.finfo(_dtype).eps)
final_out = final_out / div_factors[:, None]

return final_out


class TensorProductConvLayer(torch.nn.Module):
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
hidden_features=None, faster=False, edge_groups=1, tp_weights_layers=2, activation='relu', depthwise=False):
Expand Down Expand Up @@ -193,17 +307,19 @@ def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=T
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None

def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
if edge_index.shape[1] == 0 and node_attr.shape[0] == 0:
raise ValueError("No edges and no nodes")

_dtype = node_attr.dtype
if edge_index.shape[1] == 0:
out = torch.zeros((node_attr.shape[0], self.out_size), dtype=node_attr.dtype, device=node_attr.device)
out = torch.zeros((node_attr.shape[0], self.out_size), dtype=_dtype, device=node_attr.device)
else:
edge_src, edge_dst = edge_index
edge_attr_ = self.fc(edge_attr) if self.edge_groups == 1 else torch.cat(
[self.fc[i](edge_attr[i]) for i in range(self.edge_groups)], dim=0).to(node_attr.device)
tp = self.tp(node_attr[edge_dst], edge_sh, edge_attr_ * edge_weight)

out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
if self.edge_groups == 1:
out = tp_scatter_simple(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh,
out_nodes, reduce, edge_weight)
else:
out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh,
out_nodes, reduce, edge_weight)

if self.depthwise:
out = self.linear_2(out)
Expand All @@ -214,6 +330,8 @@ def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, red
if self.residual:
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
out = out + padded

out = out.to(_dtype)
return out


Expand All @@ -240,16 +358,23 @@ def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=T

def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):

edge_src, edge_dst = edge_index
tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr) * edge_weight)
# Break up the edge_attr into chunks to limit the maximum memory usage
edge_chunk_size = 100_000
num_edges = edge_attr.shape[0]
num_chunks = (num_edges // edge_chunk_size) if num_edges % edge_chunk_size == 0 \
else (num_edges // edge_chunk_size) + 1
edge_ranges = np.array_split(np.arange(num_edges), num_chunks)
edge_attr_groups = [edge_attr[cur_range] for cur_range in edge_ranges]

out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr_groups, edge_sh,
out_nodes, reduce, edge_weight)

if self.residual:
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
out = out + padded

if self.batch_norm:
out = self.batch_norm(out)

out = out.to(node_attr.dtype)
return out

0 comments on commit 2c867df

Please sign in to comment.