Skip to content

Commit

Permalink
Merge pull request #99 from benedekrozemberczki/evolvegcno
Browse files Browse the repository at this point in the history
further improve test coverage
  • Loading branch information
benedekrozemberczki authored Sep 6, 2021
2 parents 5f63c35 + 7a0c892 commit f05d692
Showing 1 changed file with 22 additions and 59 deletions.
81 changes: 22 additions & 59 deletions torch_geometric_temporal/nn/recurrent/evolvegcno.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,23 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,

fill_value = 2. if improved else 1.

if isinstance(edge_index, SparseTensor):
adj_t = edge_index
if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)
if add_self_loops:
adj_t = fill_diag(adj_t, fill_value)
deg = sparsesum(adj_t, dim=1)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t

else:
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]


class GCNConv_Fixed_W(MessagePassing):
Expand Down Expand Up @@ -139,27 +125,11 @@ def forward(self, W: torch.FloatTensor, x: Tensor, edge_index: Adj,
""""""

if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]

elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)

x = torch.matmul(x, W)

Expand All @@ -172,13 +142,6 @@ def forward(self, W: torch.FloatTensor, x: Tensor, edge_index: Adj,
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)


def glorot(tensor):
if tensor is not None:
Expand Down

0 comments on commit f05d692

Please sign in to comment.