diff --git a/torch_geometric_temporal/nn/recurrent/evolvegcno.py b/torch_geometric_temporal/nn/recurrent/evolvegcno.py index 0161dfd5..bef2b881 100644 --- a/torch_geometric_temporal/nn/recurrent/evolvegcno.py +++ b/torch_geometric_temporal/nn/recurrent/evolvegcno.py @@ -1,55 +1,13 @@ -import math from typing import Optional, Tuple import torch from torch import Tensor -from torch.nn import Parameter from torch.nn import GRU -from torch_geometric.typing import Adj, OptTensor, PairTensor -from torch_scatter import scatter_add -from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul -from torch_geometric.nn.inits import zeros +from torch_geometric.typing import Adj, OptTensor +from torch_sparse import SparseTensor +from torch_geometric.nn.inits import glorot from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import add_remaining_self_loops -from torch_geometric.utils.num_nodes import maybe_num_nodes - - -@torch.jit._overload -def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, - add_self_loops=True, dtype=None): - # type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa - pass - - -@torch.jit._overload -def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, - add_self_loops=True, dtype=None): - # type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa - pass - - -def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, - add_self_loops=True, dtype=None): - - fill_value = 2. if improved else 1. - - num_nodes = maybe_num_nodes(edge_index, num_nodes) - - if edge_weight is None: - edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, - device=edge_index.device) - - if add_self_loops: - edge_index, tmp_edge_weight = add_remaining_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - assert tmp_edge_weight is not None - edge_weight = tmp_edge_weight - - row, col = edge_index[0], edge_index[1] - deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow_(-0.5) - deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) - return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] +from torch_geometric.nn.conv.gcn_conv import gcn_norm class GCNConv_Fixed_W(MessagePassing): @@ -143,10 +101,6 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j -def glorot(tensor): - if tensor is not None: - stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) - tensor.data.uniform_(-stdv, stdv) class EvolveGCNO(torch.nn.Module): r"""An implementation of the Evolving Graph Convolutional without Hidden Layer.