diff --git a/models/tensor_layers.py b/models/tensor_layers.py index 505b6f42c..b1618c10d 100644 --- a/models/tensor_layers.py +++ b/models/tensor_layers.py @@ -1,3 +1,7 @@ +import logging +import os +from typing import Tuple, Union, List + import numpy as np import torch import torch.nn as nn @@ -9,6 +13,7 @@ from models.layers import FCBlock + def get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars): if use_second_order_repr: irrep_seq = [ @@ -117,6 +122,115 @@ def forward(self, in_, sh, weight): return torch.cat(out, dim=-1) +def tp_scatter_simple(tp, fc_layer, node_attr, edge_index, edge_attr, edge_sh, + out_nodes=None, reduce='mean', edge_weight=1.0): + """ + Perform TensorProduct + scatter operation, aka graph convolution. + + This function is only for edge_groups == 1. For multiple edge groups, and for larger graphs, + use tp_scatter_multigroup instead. + """ + + assert isinstance(edge_attr, torch.Tensor), \ + "This function is only for a single edge group, so edge_attr must be a tensor and not a list." + + _device = node_attr.device + _dtype = node_attr.dtype + edge_src, edge_dst = edge_index + out_irreps = fc_layer(edge_attr).to(_device).to(_dtype) + out_irreps.mul_(edge_weight) + tp = tp(node_attr[edge_dst], edge_sh, out_irreps) + out_nodes = out_nodes or node_attr.shape[0] + out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce) + return out + + +def tp_scatter_multigroup(tp: o3.TensorProduct, fc_layer: Union[nn.Module, nn.ModuleList], + node_attr: torch.Tensor, edge_index: torch.Tensor, + edge_attr_groups: List[torch.Tensor], edge_sh: torch.Tensor, + out_nodes=None, reduce='mean', edge_weight=1.0): + """ + Perform TensorProduct + scatter operation, aka graph convolution. + + To keep the peak memory usage reasonably low, this function does not concatenate the edge_attr_groups. + Rather, we sum the output of the tensor product for each edge group, and then divide by the number of edges + + Parameters + ---------- + tp: o3.TensorProduct + fc_layer: nn.Module, or nn.ModuleList + If a list, must be the same length as edge_attr_groups + node_attr: torch.Tensor + edge_index: torch.Tensor of shape (2, num_edges) + Indicates the source and destination nodes of each edge + edge_attr_groups: List[torch.Tensor] + List of tensors, with shape (X_i, num_edge_attributes). Each tensor is a different group of edge attributes + X may be different for each tensor, although sum(X_i) must be equal to edge_index.shape[1] + edge_sh: torch.Tensor + Spherical harmonics for the edges (see o3.spherical_harmonics) + out_nodes: + Number of output nodes + reduce: str + 'mean' or 'sum'. Reduce function for scatter. + edge_weight : float or torch.Tensor + Edge weights. If a tensor, must be the same shape as `edge_index` + + Returns + ------- + torch.Tensor + Result of the graph convolution + """ + + assert isinstance(edge_attr_groups, list), "This function is only for a list of edge groups" + assert reduce in {"mean", "sum"}, "Only 'mean' and 'sum' are supported for reduce" + # It would be possible to support mul/min/max but that would require more work and more code, + # so only going to do it if it's needed. + + _device = node_attr.device + _dtype = node_attr.dtype + edge_src, edge_dst = edge_index + edge_attr_lengths = [_edge_attr.shape[0] for _edge_attr in edge_attr_groups] + total_rows = sum(edge_attr_lengths) + assert total_rows == edge_index.shape[1], "Sum of edge_attr_groups must be equal to edge_index.shape[1]" + num_edge_groups = len(edge_attr_groups) + edge_weight_is_indexable = hasattr(edge_weight, '__getitem__') + + out_nodes = out_nodes or node_attr.shape[0] + total_output_dim = sum([x.dim for x in tp.irreps_out]) + final_out = torch.zeros((out_nodes, total_output_dim), device=_device, dtype=_dtype) + div_factors = torch.zeros(out_nodes, device=_device, dtype=_dtype) + + cur_start = 0 + for ii in range(num_edge_groups): + cur_length = edge_attr_lengths[ii] + cur_end = cur_start + cur_length + cur_edge_range = slice(cur_start, cur_end) + cur_edge_src, cur_edge_dst = edge_src[cur_edge_range], edge_dst[cur_edge_range] + + cur_fc = fc_layer[ii] if isinstance(fc_layer, nn.ModuleList) else fc_layer + cur_out_irreps = cur_fc(edge_attr_groups[ii]) + if edge_weight_is_indexable: + cur_out_irreps.mul_(edge_weight[cur_edge_range]) + else: + cur_out_irreps.mul_(edge_weight) + + summand = tp(node_attr[cur_edge_dst, :], edge_sh[cur_edge_range, :], cur_out_irreps) + # We take a simple sum, and then add up the count of edges which contribute, + # so that we can take the mean later. + final_out += scatter(summand, cur_edge_src, dim=0, dim_size=out_nodes, reduce="sum") + div_factors += torch.bincount(cur_edge_src, minlength=out_nodes) + + cur_start = cur_end + + del cur_out_irreps, summand + + if reduce == 'mean': + div_factors = torch.clamp(div_factors, torch.finfo(_dtype).eps) + final_out = final_out / div_factors[:, None] + + return final_out + + class TensorProductConvLayer(torch.nn.Module): def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0, hidden_features=None, faster=False, edge_groups=1, tp_weights_layers=2, activation='relu', depthwise=False): @@ -193,17 +307,19 @@ def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=T self.batch_norm = BatchNorm(out_irreps) if batch_norm else None def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0): + if edge_index.shape[1] == 0 and node_attr.shape[0] == 0: + raise ValueError("No edges and no nodes") + _dtype = node_attr.dtype if edge_index.shape[1] == 0: - out = torch.zeros((node_attr.shape[0], self.out_size), dtype=node_attr.dtype, device=node_attr.device) + out = torch.zeros((node_attr.shape[0], self.out_size), dtype=_dtype, device=node_attr.device) else: - edge_src, edge_dst = edge_index - edge_attr_ = self.fc(edge_attr) if self.edge_groups == 1 else torch.cat( - [self.fc[i](edge_attr[i]) for i in range(self.edge_groups)], dim=0).to(node_attr.device) - tp = self.tp(node_attr[edge_dst], edge_sh, edge_attr_ * edge_weight) - - out_nodes = out_nodes or node_attr.shape[0] - out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce) + if self.edge_groups == 1: + out = tp_scatter_simple(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh, + out_nodes, reduce, edge_weight) + else: + out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh, + out_nodes, reduce, edge_weight) if self.depthwise: out = self.linear_2(out) @@ -214,6 +330,8 @@ def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, red if self.residual: padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1])) out = out + padded + + out = out.to(_dtype) return out @@ -240,11 +358,16 @@ def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=T def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0): - edge_src, edge_dst = edge_index - tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr) * edge_weight) + # Break up the edge_attr into chunks to limit the maximum memory usage + edge_chunk_size = 100_000 + num_edges = edge_attr.shape[0] + num_chunks = (num_edges // edge_chunk_size) if num_edges % edge_chunk_size == 0 \ + else (num_edges // edge_chunk_size) + 1 + edge_ranges = np.array_split(np.arange(num_edges), num_chunks) + edge_attr_groups = [edge_attr[cur_range] for cur_range in edge_ranges] - out_nodes = out_nodes or node_attr.shape[0] - out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce) + out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr_groups, edge_sh, + out_nodes, reduce, edge_weight) if self.residual: padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1])) @@ -252,4 +375,6 @@ def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, red if self.batch_norm: out = self.batch_norm(out) + + out = out.to(node_attr.dtype) return out