From a93fd4b426987610611056709dc30b35b236108f Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Mon, 2 Oct 2023 14:09:02 +0100 Subject: [PATCH] Clean & clear the testing pipeline --- grace/run.py | 34 +++++++++++++++------------------- grace/training/config.py | 30 +++++++++++++++++------------- tests/conftest.py | 5 ++--- tests/test_config.py | 2 +- tests/test_dataset.py | 2 +- tests/test_model.py | 23 ++++++++++++++--------- tests/test_train.py | 7 ++++--- 7 files changed, 54 insertions(+), 49 deletions(-) diff --git a/grace/run.py b/grace/run.py index 63479fdc..b54da0ba 100644 --- a/grace/run.py +++ b/grace/run.py @@ -49,6 +49,7 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None: # Define where you'll save the outputs: current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + # run_dir = config.log_dir / current_time / "model" run_dir = config.log_dir / current_time setattr(config, "run_dir", run_dir) @@ -174,31 +175,26 @@ def prepare_dataset( ) # Save the trained model: - model_save_fn = run_dir / "model" / "classifier.pt" + config.run_dir = run_dir / "model" + model_save_fn = config.run_dir / "classifier.pt" torch.save(classifier, model_save_fn) + + # Save the training hyperparameters: write_config_file(config, filetype="json") write_config_file(config, filetype="yaml") # Archive the model architecture: model_architecture = ModelArchiver(classifier).architecture - import json - import yaml - - architecture_fn = run_dir / "model" / "summary_architecture.json" - with open(architecture_fn, "w") as outfile: - json.dump(model_architecture, outfile, indent=4) - - # write_file_with_suffix(model_architecture, architecture_fn) - architecture_fn = run_dir / "model" / "summary_architecture.yaml" - with open(architecture_fn, "w") as outfile: - yaml.dump( - model_architecture, - outfile, - default_flow_style=False, - allow_unicode=True, - ) - - # write_file_with_suffix(model_architecture, architecture_fn) + write_file_with_suffix( + model_architecture, + config.run_dir / "summary_architecture.json", + convert_types=False, + ) + write_file_with_suffix( + model_architecture, + config.run_dir / "summary_architecture.yaml", + convert_types=False, + ) # Project the TSNE manifold: if config.visualise_tsne_manifold is True: diff --git a/grace/training/config.py b/grace/training/config.py index d399620b..656c740b 100644 --- a/grace/training/config.py +++ b/grace/training/config.py @@ -221,20 +221,23 @@ def write_config_file(config: Config, filetype: str = "json") -> None: if isinstance(config.run_dir, str): setattr(config, "run_dir", Path(config.run_dir)) - fn = config.run_dir / "model" / f"config_hyperparams.{filetype}" + fn = config.run_dir / f"config_hyperparams.{filetype}" write_file_with_suffix(params, fn) def write_file_with_suffix( - parameters_dict: dict[Any], filename: str | Path + parameters_dict: dict[Any], + filename: str | Path, + convert_types: bool = True, ) -> None: if isinstance(filename, str): filename = Path(filename) if filename.suffix == ".json": - # Convert all params to strings: - for attr, param in parameters_dict.items(): - parameters_dict[attr] = str(param) + if convert_types is True: + # Convert all params to strings: + for attr, param in parameters_dict.items(): + parameters_dict[attr] = str(param) # Write the file out: with open(filename, "w") as outfile: json.dump( @@ -244,14 +247,15 @@ def write_file_with_suffix( ) elif filename.suffix == ".yaml": - # Convert all params to yaml-parsable types: - for attr, param in parameters_dict.items(): - if isinstance(param, Path): - parameters_dict[attr] = str(param) - elif isinstance(param, tuple): - parameters_dict[attr] = list(param) - else: - parameters_dict[attr] = param + if convert_types is True: + # Convert all params to yaml-parsable types: + for attr, param in parameters_dict.items(): + if isinstance(param, Path): + parameters_dict[attr] = str(param) + elif isinstance(param, tuple): + parameters_dict[attr] = list(param) + else: + parameters_dict[attr] = param # Write the file out in human-readable form: with open(filename, "w") as outfile: yaml.dump( diff --git a/tests/conftest.py b/tests/conftest.py index 7f8ab669..bf6187b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,6 @@ import numpy as np import torch -import torch.nn as nn from grace.base import GraphAttrs, graph_from_dataframe from pathlib import Path @@ -72,11 +71,11 @@ def mrc_image_and_annotations_dir(tmp_path_factory, default_rng) -> Path: return tmp_data_dir -class SimpleExtractor(nn.Module): +class SimpleExtractor(torch.nn.Module): def forward(self, x): return torch.rand(x.size(0), 2) @pytest.fixture(scope="session") -def simple_extractor() -> nn.Module: +def simple_extractor() -> torch.nn.Module: return SimpleExtractor() diff --git a/tests/test_config.py b/tests/test_config.py index 389cf023..58e1db46 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,7 +15,7 @@ def test_load_config_file(tmp_path): config = Config() setattr(config, "feature_dim", 251) setattr(config, "run_dir", tmp_path) - write_config_file(config) + write_config_file(config, "json") loaded_config = load_config_params(tmp_path / "config_hyperparams.json") diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 7b34b7de..4753861d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,6 +1,6 @@ import pytest -from grace.io.core import GraphAttrs, Annotation +from grace.base import GraphAttrs, Annotation from grace.io.image_dataset import ImageGraphDataset from grace.models.datasets import dataset_from_graph diff --git a/tests/test_model.py b/tests/test_model.py index 14372cf3..6392f33f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,4 @@ -from grace.models.classifier import GCN +from grace.models.classifier import GCNModel from grace.models.feature_extractor import FeatureExtractor import math @@ -14,7 +14,8 @@ @pytest.mark.parametrize("input_channels", [1, 2]) -@pytest.mark.parametrize("hidden_channels", [[16, 4], [32, 8]]) +@pytest.mark.parametrize("hidden_graph_channels", [[16, 4], [32, 8]]) +@pytest.mark.parametrize("hidden_dense_channels", [[16, 4], [32, 8]]) @pytest.mark.parametrize("node_output_classes", [2, 4]) @pytest.mark.parametrize("edge_output_classes", [2, 4]) class TestGCN: @@ -22,13 +23,15 @@ class TestGCN: def gcn( self, input_channels, - hidden_channels, + hidden_graph_channels, + hidden_dense_channels, node_output_classes, edge_output_classes, ): - return GCN( + return GCNModel( input_channels=input_channels, - hidden_channels=hidden_channels, + hidden_graph_channels=hidden_graph_channels, + hidden_dense_channels=hidden_dense_channels, node_output_classes=node_output_classes, edge_output_classes=edge_output_classes, ) @@ -36,7 +39,8 @@ def gcn( def test_model_building( self, input_channels, - hidden_channels, + hidden_graph_channels, + hidden_dense_channels, node_output_classes, edge_output_classes, gcn, @@ -45,17 +49,18 @@ def test_model_building( assert gcn.conv_layer_list[0].in_channels == input_channels - assert gcn.node_classifier.in_features == hidden_channels[-1] + assert gcn.node_classifier.in_features == hidden_dense_channels[-1] assert gcn.node_classifier.out_features == node_output_classes - assert gcn.edge_classifier.in_features == hidden_channels[-1] * 2 + assert gcn.edge_classifier.in_features == hidden_dense_channels[-1] * 2 assert gcn.edge_classifier.out_features == edge_output_classes @pytest.mark.parametrize("num_nodes", [4, 5]) def test_output_sizes( self, input_channels, - hidden_channels, + hidden_graph_channels, + hidden_dense_channels, node_output_classes, edge_output_classes, gcn, diff --git a/tests/test_train.py b/tests/test_train.py index 8b7a85d7..d49e8c3e 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,7 +7,7 @@ ) from grace.training.train import train_model from grace.models.datasets import dataset_from_graph -from grace.models.classifier import GCN +from grace.models.classifier import GCNModel from _utils import random_image_and_graph @@ -15,9 +15,10 @@ class TestTraining: @pytest.fixture def data_and_model(self, default_rng): - model = GCN( + model = GCNModel( input_channels=2, - hidden_channels=[16, 4], + hidden_graph_channels=[16, 8], + hidden_dense_channels=[4, 2], ) _, graph = random_image_and_graph(