Skip to content

Commit

Permalink
changed dataset_from_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-soelistyo committed Jul 26, 2023
1 parent cfb907f commit 1ec9afa
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 102 deletions.
18 changes: 2 additions & 16 deletions grace/models/classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Optional
from typing import List, Tuple, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -31,7 +31,7 @@ class GCN(torch.nn.Module):
def __init__(
self,
input_channels: int,
hidden_channels: list[int],
hidden_channels: List[int],
*,
node_output_classes: int = 2,
edge_output_classes: int = 2,
Expand Down Expand Up @@ -82,17 +82,3 @@ def forward(
edge_x = self.edge_classifier(edge_features)

return node_x, edge_x

def predict(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
):
# predict the labels of the subgraph, no matter the annotations (node, edge)
pass


class Classifier:
def __init__(self, model_type: str = "gcn", layer_list: list[int] = []):
pass
140 changes: 63 additions & 77 deletions grace/models/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,141 +3,127 @@
import networkx as nx
import numpy as np
import torch
import torch_geometric

from torch_geometric.data import Data
from grace.base import GraphAttrs, Annotation


def dataset_from_subgraphs(
def dataset_from_graph(
graph: nx.Graph,
*,
mode: str = "full",
n_hop: int = 1,
in_train_mode: bool = True,
) -> List[torch_geometric.data.Data]:
) -> List[Data]:
"""Create a pytorch geometric dataset from a given networkx graph.
Parameters
----------
graph : graph
A networkx graph.
mode : str
"sub" or "full".
n_hop : int
The number of hops from the central node when creating the subgraphs.
in_train_mode : bool
Traverses & checks sub-graphs to generate training dataset. Default = True
Returns
-------
dataset : list
A list of pytorch geometric data objects representing the extracted
subgraphs.
dataset : List[Data] or Data
A (list of) pytorch geometric data object(s) representing the extracted
subgraphs or full graph.
TODO:
- currently doesn't work on 'corner' nodes i.e. nodes which have
patches cropped at the boundary of the image - need to pad the image beforehand
"""

dataset = []
assert mode in ["sub", "full"]

for node, values in graph.nodes(data=True):
# Define a subgraph - n_hop subgraph at train time, whole graph otherwise:
sub_graph = nx.ego_graph(graph, node, radius=n_hop)
if mode == "sub":
dataset = []

# Constraint: exclusion of unknown nodes at the centre of subgraph:
if in_train_mode is True:
if values[GraphAttrs.NODE_GROUND_TRUTH] is Annotation.UNKNOWN:
for node, values in graph.nodes(data=True):
if (
in_train_mode
and values[GraphAttrs.NODE_GROUND_TRUTH] is Annotation.UNKNOWN
):
continue

edge_label = [
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in sub_graph.edges(data=True)
]
sub_graph = nx.ego_graph(graph, node, radius=n_hop)
edge_label = [
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in sub_graph.edges(data=True)
]

# Constraint: exclusion of all unknown edges forming the subgraph:
if in_train_mode is True:
if all([e == Annotation.UNKNOWN for e in edge_label]):
if in_train_mode and all(
[e == Annotation.UNKNOWN for e in edge_label]
):
continue

x = np.stack(
[
node[GraphAttrs.NODE_FEATURES]
for _, node in sub_graph.nodes(data=True)
],
axis=0,
)
pos = np.stack(
[
(node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y])
for _, node in graph.nodes(data=True)
],
axis=0,
)
central_node = np.array(
[values[GraphAttrs.NODE_X], values[GraphAttrs.NODE_Y]]
)
edge_attr = pos - central_node

data = _info_from_graph(
sub_graph,
pos=torch.Tensor(pos),
edge_attr=torch.Tensor(edge_attr),
y=torch.as_tensor([values[GraphAttrs.NODE_GROUND_TRUTH]]),
)

dataset.append(data)

elif mode == "full":
edge_label = [
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in graph.edges(data=True)
]

pos = np.stack(
[
(node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y])
for _, node in sub_graph.nodes(data=True)
for _, node in graph.nodes(data=True)
],
axis=0,
)

# TODO: edge attributes
central_node = np.array(
[values[GraphAttrs.NODE_X], values[GraphAttrs.NODE_Y]]
)
edge_attr = pos - central_node

item = nx.convert_node_labels_to_integers(sub_graph)
edges = list(item.edges)
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

data = torch_geometric.data.Data(
x=torch.Tensor(x),
edge_index=edge_index,
edge_attr=torch.Tensor(edge_attr),
edge_label=torch.Tensor(edge_label).long(),
data = _info_from_graph(
sub_graph,
pos=torch.Tensor(pos),
edge_attr=torch.Tensor(edge_attr),
y=torch.as_tensor([values[GraphAttrs.NODE_GROUND_TRUTH]]),
edge_label=torch.Tensor(edge_label).long(),
)

dataset.append(data)

return dataset

return data

def dataset_from_whole_graph(graph: nx.Graph) -> torch_geometric.data.Data:
"""Create a single pytorch geometric dataset from an entire give networkx graph.
Parameters
----------
graph : graph
A networkx graph.
Returns
-------
dataset : list
A single pytorch geometric data objects representing the extracted graph.
"""

edge_label = [
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in graph.edges(data=True)
]

def _info_from_graph(
graph: nx.Graph,
**kwargs,
) -> Data:
x = np.stack(
[node[GraphAttrs.NODE_FEATURES] for _, node in graph.nodes(data=True)],
axis=0,
)

pos = np.stack(
[
(node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y])
for _, node in graph.nodes(data=True)
],
axis=0,
)

item = nx.convert_node_labels_to_integers(graph)
edges = list(item.edges)
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

data = torch_geometric.data.Data(
data = Data(
x=torch.Tensor(x),
edge_index=edge_index,
edge_label=torch.Tensor(edge_label).long(),
pos=torch.Tensor(pos),
**kwargs,
)

return data
4 changes: 2 additions & 2 deletions grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from grace.config import write_config_file, load_config_params
from grace.io.image_dataset import ImageGraphDataset
from grace.models.train import train_model
from grace.models.datasets import dataset_from_subgraphs
from grace.models.datasets import dataset_from_graph
from grace.models.classifier import GCN
from grace.models.feature_extractor import FeatureExtractor
from grace.utils.transforms import get_transforms
Expand Down Expand Up @@ -58,7 +58,7 @@ def transform(img, grph):
input_data, desc="Extracting patch features from training data... "
):
print(target["metadata"]["image_filename"])
dataset.extend(dataset_from_subgraphs(target["graph"]))
dataset.extend(dataset_from_graph(target["graph"], mode="sub"))

classifier = GCN(
input_channels=config.feature_dim,
Expand Down
9 changes: 6 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from grace.io.core import GraphAttrs, Annotation
from grace.io.image_dataset import ImageGraphDataset
from grace.models.datasets import dataset_from_subgraphs
from grace.models.datasets import dataset_from_graph

from _utils import random_image_and_graph

Expand Down Expand Up @@ -43,7 +43,7 @@ def test_dataset_ignores_subgraph_if_all_edges_unknown(default_rng):
# this action is not currently required since edges are by default UNKNOWN;
# however it enables testing of this condition should the default label be changed

assert dataset_from_subgraphs(graph) == []
assert dataset_from_graph(graph, mode="sub") == []


@pytest.mark.parametrize("num_unknown", [7, 17])
Expand Down Expand Up @@ -72,7 +72,10 @@ def test_dataset_ignores_subgraph_if_central_node_unknown(
]
graph.update(nodes=node_update)

assert len(dataset_from_subgraphs(graph)) == num_nodes_total - num_unknown
assert (
len(dataset_from_graph(graph, mode="sub"))
== num_nodes_total - num_unknown
)


def test_dataset_only_takes_common_filenames(tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from _utils import random_image_and_graph

from grace.base import GraphAttrs, Annotation
from grace.models.datasets import dataset_from_subgraphs
from grace.models.datasets import dataset_from_graph


@pytest.mark.parametrize("input_channels", [1, 2])
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_output_sizes(
for src, dst in graph.edges
]
)
data = dataset_from_subgraphs(graph)[0]
data = dataset_from_graph(graph, mode="sub")[0]

subgraph = nx.ego_graph(graph, 0)
num_nodes = subgraph.number_of_nodes()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from grace.base import GraphAttrs, Annotation
from grace.utils.metrics import accuracy_metric, confusion_matrix_metric
from grace.models.train import train_model
from grace.models.datasets import dataset_from_subgraphs
from grace.models.datasets import dataset_from_graph
from grace.models.classifier import GCN

from _utils import random_image_and_graph
Expand All @@ -31,7 +31,7 @@ def data_and_model(self, default_rng):
for src, dst in graph.edges
]
)
dataset = dataset_from_subgraphs(graph)
dataset = dataset_from_graph(graph, mode="sub")

return dataset, model

Expand Down

0 comments on commit 1ec9afa

Please sign in to comment.