Skip to content

Commit

Permalink
Merge pull request #121 from alan-turing-institute/120-node-classifer…
Browse files Browse the repository at this point in the history
…-takes-node-embedding

removed global mean in classifier
  • Loading branch information
chris-soelistyo authored Aug 2, 2023
2 parents 69f3ab6 + 32bab83 commit 7af42d6
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The `grace` workflow consists of the following steps:
4. Cropping of image patches (at various scales) from each bounding box detected in the image
5. Latent feature extraction from image patches (_e.g._ pre-trained neural network, such as [_ResNet-152_](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet152.html))
6. *'Human-in-the-loop'* annotation of the desired pattern in the image data (see the [napari plugin](#development) below)
7. Classification of each 'nodeness' and 'edgeness' confidence via deep neural network classifiers (_e.g._ using immediate [1-hop neighbourhood](https://arxiv.org/pdf/1907.06051.pdf))
7. Classification of each 'nodeness' and 'edgeness' confidence via deep neural network classifiers. The neural network can be applied to a full graph, or subgraphs around each node (_e.g._ using immediate [1-hop neighbourhood](https://arxiv.org/pdf/1907.06051.pdf)).
8. Combinatorial optimisation via integer linear programming (ILP) to connect the candidate object nodes via edges (see the [expected outcomes](#outcomes) below)
9. Quantitative evaluation of the filament detection performance
10. Ta-da! 🥳
Expand Down
23 changes: 11 additions & 12 deletions grace/models/classifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Tuple, Optional
from typing import List, Tuple

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
Expand All @@ -13,8 +13,10 @@ class GCN(torch.nn.Module):
----------
input_channels : int
The dimension of the input; i.e., length of node feature vectors
embedding_channels : int
hidden_channels : int
The dimension of the hidden embeddings.
dropout: float
Dropout to apply to the embeddings.
node_output_classes : int
The dimension of the node output. This is typically the number of classes in
the classification task.
Expand All @@ -33,6 +35,7 @@ def __init__(
input_channels: int,
hidden_channels: List[int],
*,
dropout: float = 0.5,
node_output_classes: int = 2,
edge_output_classes: int = 2,
):
Expand All @@ -51,12 +54,12 @@ def __init__(
self.edge_classifier = Linear(
hidden_channels_list[-1] * 2, edge_output_classes
)
self.dropout = dropout

def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor]:
for layer in range(len(self.conv_layer_list)):
x = self.conv_layer_list[layer](x, edge_index)
Expand All @@ -65,14 +68,10 @@ def forward(
else:
embeddings = x

# Extract the node embeddings for feature classif:
x = global_mean_pool(
embeddings, batch
) # [batch_size, hidden_channels]

# TODO: set dropout probability as config hyperparam:
x = F.dropout(x, p=0.5, training=self.training)
node_x = self.node_classifier(x)
embeddings = F.dropout(
embeddings, p=self.dropout, training=self.training
)
node_x = self.node_classifier(embeddings)

src, dst = edge_index
edge_features = torch.cat(
Expand Down
13 changes: 12 additions & 1 deletion grace/models/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def dataset_from_graph(

data = Data(
x=_x(sub_graph),
y=_y(sub_graph),
pos=pos,
edge_attr=edge_attr,
edge_index=_edge_index(sub_graph),
edge_label=edge_label,
y=torch.as_tensor([values[GraphAttrs.NODE_GROUND_TRUTH]]),
)

dataset.append(data)
Expand Down Expand Up @@ -119,6 +119,17 @@ def _x(graph: nx.Graph):
return torch.Tensor(x)


def _y(graph: nx.Graph):
y = np.stack(
[
node[GraphAttrs.NODE_GROUND_TRUTH]
for _, node in graph.nodes(data=True)
],
axis=0,
)
return torch.Tensor(y).long()


def _edge_index(graph: nx.Graph):
item = nx.convert_node_labels_to_integers(graph)
edges = list(item.edges)
Expand Down
4 changes: 2 additions & 2 deletions grace/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(loader):
model.train()

for data in loader:
node_x, edge_x = model(data.x, data.edge_index, data.batch)
node_x, edge_x = model(data.x, data.edge_index)

loss_node = node_criterion(node_x, data.y)
loss_edge = edge_criterion(edge_x, data.edge_label)
Expand All @@ -97,7 +97,7 @@ def test(loader):
edge_true = []

for data in loader:
node_x, edge_x = model(data.x, data.edge_index, data.batch)
node_x, edge_x = model(data.x, data.edge_index)

node_pred.extend(node_x)
edge_pred.extend(edge_x)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_output_sizes(
num_edges = subgraph.number_of_edges()
node_x, edge_x = gcn(x=data.x, edge_index=data.edge_index)

assert node_x.size() == (1, node_output_classes)
assert node_x.size() == (num_nodes, node_output_classes)
assert edge_x.size() == (num_edges, edge_output_classes)


Expand Down

0 comments on commit 7af42d6

Please sign in to comment.