Skip to content

Commit

Permalink
Clean & clear the testing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Oct 2, 2023
1 parent ff69006 commit a93fd4b
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 49 deletions.
34 changes: 15 additions & 19 deletions grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None:

# Define where you'll save the outputs:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# run_dir = config.log_dir / current_time / "model"
run_dir = config.log_dir / current_time
setattr(config, "run_dir", run_dir)

Expand Down Expand Up @@ -174,31 +175,26 @@ def prepare_dataset(
)

# Save the trained model:
model_save_fn = run_dir / "model" / "classifier.pt"
config.run_dir = run_dir / "model"
model_save_fn = config.run_dir / "classifier.pt"
torch.save(classifier, model_save_fn)

# Save the training hyperparameters:
write_config_file(config, filetype="json")
write_config_file(config, filetype="yaml")

# Archive the model architecture:
model_architecture = ModelArchiver(classifier).architecture
import json
import yaml

architecture_fn = run_dir / "model" / "summary_architecture.json"
with open(architecture_fn, "w") as outfile:
json.dump(model_architecture, outfile, indent=4)

# write_file_with_suffix(model_architecture, architecture_fn)
architecture_fn = run_dir / "model" / "summary_architecture.yaml"
with open(architecture_fn, "w") as outfile:
yaml.dump(
model_architecture,
outfile,
default_flow_style=False,
allow_unicode=True,
)

# write_file_with_suffix(model_architecture, architecture_fn)
write_file_with_suffix(
model_architecture,
config.run_dir / "summary_architecture.json",
convert_types=False,
)
write_file_with_suffix(
model_architecture,
config.run_dir / "summary_architecture.yaml",
convert_types=False,
)

# Project the TSNE manifold:
if config.visualise_tsne_manifold is True:
Expand Down
30 changes: 17 additions & 13 deletions grace/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,23 @@ def write_config_file(config: Config, filetype: str = "json") -> None:
if isinstance(config.run_dir, str):
setattr(config, "run_dir", Path(config.run_dir))

fn = config.run_dir / "model" / f"config_hyperparams.{filetype}"
fn = config.run_dir / f"config_hyperparams.{filetype}"
write_file_with_suffix(params, fn)


def write_file_with_suffix(
parameters_dict: dict[Any], filename: str | Path
parameters_dict: dict[Any],
filename: str | Path,
convert_types: bool = True,
) -> None:
if isinstance(filename, str):
filename = Path(filename)

if filename.suffix == ".json":
# Convert all params to strings:
for attr, param in parameters_dict.items():
parameters_dict[attr] = str(param)
if convert_types is True:
# Convert all params to strings:
for attr, param in parameters_dict.items():
parameters_dict[attr] = str(param)
# Write the file out:
with open(filename, "w") as outfile:
json.dump(
Expand All @@ -244,14 +247,15 @@ def write_file_with_suffix(
)

elif filename.suffix == ".yaml":
# Convert all params to yaml-parsable types:
for attr, param in parameters_dict.items():
if isinstance(param, Path):
parameters_dict[attr] = str(param)
elif isinstance(param, tuple):
parameters_dict[attr] = list(param)
else:
parameters_dict[attr] = param
if convert_types is True:
# Convert all params to yaml-parsable types:
for attr, param in parameters_dict.items():
if isinstance(param, Path):
parameters_dict[attr] = str(param)
elif isinstance(param, tuple):
parameters_dict[attr] = list(param)
else:
parameters_dict[attr] = param
# Write the file out in human-readable form:
with open(filename, "w") as outfile:
yaml.dump(
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np

import torch
import torch.nn as nn

from grace.base import GraphAttrs, graph_from_dataframe
from pathlib import Path
Expand Down Expand Up @@ -72,11 +71,11 @@ def mrc_image_and_annotations_dir(tmp_path_factory, default_rng) -> Path:
return tmp_data_dir


class SimpleExtractor(nn.Module):
class SimpleExtractor(torch.nn.Module):
def forward(self, x):
return torch.rand(x.size(0), 2)


@pytest.fixture(scope="session")
def simple_extractor() -> nn.Module:
def simple_extractor() -> torch.nn.Module:
return SimpleExtractor()
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_load_config_file(tmp_path):
config = Config()
setattr(config, "feature_dim", 251)
setattr(config, "run_dir", tmp_path)
write_config_file(config)
write_config_file(config, "json")

loaded_config = load_config_params(tmp_path / "config_hyperparams.json")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from grace.io.core import GraphAttrs, Annotation
from grace.base import GraphAttrs, Annotation
from grace.io.image_dataset import ImageGraphDataset
from grace.models.datasets import dataset_from_graph

Expand Down
23 changes: 14 additions & 9 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from grace.models.classifier import GCN
from grace.models.classifier import GCNModel
from grace.models.feature_extractor import FeatureExtractor

import math
Expand All @@ -14,29 +14,33 @@


@pytest.mark.parametrize("input_channels", [1, 2])
@pytest.mark.parametrize("hidden_channels", [[16, 4], [32, 8]])
@pytest.mark.parametrize("hidden_graph_channels", [[16, 4], [32, 8]])
@pytest.mark.parametrize("hidden_dense_channels", [[16, 4], [32, 8]])
@pytest.mark.parametrize("node_output_classes", [2, 4])
@pytest.mark.parametrize("edge_output_classes", [2, 4])
class TestGCN:
@pytest.fixture
def gcn(
self,
input_channels,
hidden_channels,
hidden_graph_channels,
hidden_dense_channels,
node_output_classes,
edge_output_classes,
):
return GCN(
return GCNModel(
input_channels=input_channels,
hidden_channels=hidden_channels,
hidden_graph_channels=hidden_graph_channels,
hidden_dense_channels=hidden_dense_channels,
node_output_classes=node_output_classes,
edge_output_classes=edge_output_classes,
)

def test_model_building(
self,
input_channels,
hidden_channels,
hidden_graph_channels,
hidden_dense_channels,
node_output_classes,
edge_output_classes,
gcn,
Expand All @@ -45,17 +49,18 @@ def test_model_building(

assert gcn.conv_layer_list[0].in_channels == input_channels

assert gcn.node_classifier.in_features == hidden_channels[-1]
assert gcn.node_classifier.in_features == hidden_dense_channels[-1]
assert gcn.node_classifier.out_features == node_output_classes

assert gcn.edge_classifier.in_features == hidden_channels[-1] * 2
assert gcn.edge_classifier.in_features == hidden_dense_channels[-1] * 2
assert gcn.edge_classifier.out_features == edge_output_classes

@pytest.mark.parametrize("num_nodes", [4, 5])
def test_output_sizes(
self,
input_channels,
hidden_channels,
hidden_graph_channels,
hidden_dense_channels,
node_output_classes,
edge_output_classes,
gcn,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
)
from grace.training.train import train_model
from grace.models.datasets import dataset_from_graph
from grace.models.classifier import GCN
from grace.models.classifier import GCNModel

from _utils import random_image_and_graph


class TestTraining:
@pytest.fixture
def data_and_model(self, default_rng):
model = GCN(
model = GCNModel(
input_channels=2,
hidden_channels=[16, 4],
hidden_graph_channels=[16, 8],
hidden_dense_channels=[4, 2],
)

_, graph = random_image_and_graph(
Expand Down

0 comments on commit a93fd4b

Please sign in to comment.