Skip to content

Commit

Permalink
Merge pull request #100 from benedekrozemberczki/evolvegcno
Browse files Browse the repository at this point in the history
importing instead of repeating from PyG functions
  • Loading branch information
benedekrozemberczki committed Sep 10, 2021
2 parents f05d692 + 38caa4e commit ffbc9d6
Showing 1 changed file with 4 additions and 50 deletions.
54 changes: 4 additions & 50 deletions torch_geometric_temporal/nn/recurrent/evolvegcno.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,13 @@
import math
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
from torch.nn import GRU
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor
from torch_sparse import SparseTensor
from torch_geometric.nn.inits import glorot
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes


@torch.jit._overload
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
# type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa
pass


@torch.jit._overload
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
# type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa
pass


def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):

fill_value = 2. if improved else 1.

num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
from torch_geometric.nn.conv.gcn_conv import gcn_norm


class GCNConv_Fixed_W(MessagePassing):
Expand Down Expand Up @@ -143,10 +101,6 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j


def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)

class EvolveGCNO(torch.nn.Module):
r"""An implementation of the Evolving Graph Convolutional without Hidden Layer.
Expand Down

0 comments on commit ffbc9d6

Please sign in to comment.