Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip connection option added #165

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,44 @@
from hydragnn.utils.distributed import get_device


class Conv(torch.nn.Module):
pzhanggit marked this conversation as resolved.
Show resolved Hide resolved
"""conv module"""

def __init__(self, conv, batch_norm, act=torch.nn.ReLU()):
super().__init__()
self.conv = conv
self.batch_norm = batch_norm
self.act = act

def forward(self, **args):
c = self.conv(**args)
return self.act(self.batch_norm(c))


class SkipConv(torch.nn.Module):
"""conv module with skip connection"""

def __init__(self, conv, batch_norm, act=torch.nn.ReLU()):
super().__init__()
self.conv = conv
self.batch_norm = batch_norm
self.act = act

def forward(self, **args):
identity = args["x"]
c = self.conv(**args)
return self.act(self.batch_norm(c)) + identity


class ConvSequential(torch.nn.Sequential):
"""Extention to use graph inputs"""

def forward(self, x=None, **args):
for module in self:
x = module(x=x, **args)
return x


class Base(Module):
def __init__(
self,
Expand All @@ -36,6 +74,7 @@ def __init__(
dropout: float = 0.25,
num_conv_layers: int = 16,
num_nodes: int = None,
skip_connection: bool = False,
):
super().__init__()
self.device = get_device()
Expand All @@ -58,6 +97,7 @@ def __init__(
self.batch_norms_node_hidden = ModuleList()
self.convs_node_output = ModuleList()
self.batch_norms_node_output = ModuleList()
self.skip_connection = skip_connection

self.loss_function = loss_function_selection(loss_function_type)
self.ilossweights_nll = ilossweights_nll
Expand Down Expand Up @@ -100,6 +140,15 @@ def __init__(
if self.initial_bias is not None:
self._set_bias()

self.layers = ModuleList()
for conv, batch_norm in zip(self.convs, self.batch_norms):
if self.skip_connection:
self.layers.append(SkipConv(conv, batch_norm))
else:
self.layers.append(Conv(conv, batch_norm))

self.conv_shared = ConvSequential(*self.layers)
allaffa marked this conversation as resolved.
Show resolved Hide resolved

def _init_conv(self):
self.convs.append(self.get_conv(self.input_dim, self.hidden_dim))
self.batch_norms.append(BatchNorm(self.hidden_dim))
Expand Down Expand Up @@ -246,9 +295,7 @@ def forward(self, data):

### encoder part ####
conv_args = self._conv_args(data)
for conv, batch_norm in zip(self.convs, self.batch_norms):
c = conv(x=x, **conv_args)
x = F.relu(batch_norm(c))
x = self.conv_shared(x=x, **conv_args)

#### multi-head decoder part####
# shared dense layers for graph level output
Expand Down
8 changes: 8 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def create_model_config(
config["Architecture"]["num_gaussians"],
config["Architecture"]["num_filters"],
config["Architecture"]["radius"],
config["Architecture"]["skip_connection"],
verbosity,
use_gpu,
)
Expand All @@ -75,6 +76,7 @@ def create_model(
num_gaussians: int = None,
num_filters: int = None,
radius: float = None,
skip_connection: bool = False,
verbosity: int = 0,
use_gpu: bool = True,
):
Expand All @@ -98,6 +100,7 @@ def create_model(
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

elif model_type == "PNA":
Expand All @@ -116,6 +119,7 @@ def create_model(
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

elif model_type == "GAT":
Expand Down Expand Up @@ -153,6 +157,7 @@ def create_model(
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

elif model_type == "CGCNN":
Expand All @@ -168,6 +173,7 @@ def create_model(
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

elif model_type == "SAGE":
Expand All @@ -182,6 +188,7 @@ def create_model(
freeze_conv=freeze_conv,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

elif model_type == "SchNet":
pzhanggit marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -204,6 +211,7 @@ def create_model(
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
skip_connection=skip_connection,
)

else:
Expand Down
3 changes: 3 additions & 0 deletions hydragnn/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def update_config(config, train_loader, val_loader, test_loader):
config["NeuralNetwork"]["Architecture"]
)

if "skip_connection" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["skip_connection"] = False

if "freeze_conv_layers" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["freeze_conv_layers"] = False
if "initial_bias" not in config["NeuralNetwork"]["Architecture"]:
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/ci.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"NeuralNetwork": {
"Architecture": {
"model_type": "PNA",
"skip_connection": false,
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/ci_multihead.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"NeuralNetwork": {
"Architecture": {
"model_type": "PNA",
"skip_connection": false,
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/ci_vectoroutput.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"NeuralNetwork": {
"Architecture": {
"model_type": "PNA",
"skip_connection": false,
"radius": 2.0,
"max_neighbours": 100,
"periodic_boundary_conditions": false,
Expand Down
25 changes: 18 additions & 7 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@


# Main unit test function called by pytest wrappers.
def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False):
def unittest_train_model(
model_type, skip_connection, ci_input, use_lengths, overwrite_data=False
):

world_size, rank = hydragnn.utils.get_comm_size_and_rank()

os.environ["SERIALIZED_DATA_PATH"] = os.getcwd()
Expand All @@ -31,6 +34,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
with open(config_file, "r") as f:
config = json.load(f)
config["NeuralNetwork"]["Architecture"]["model_type"] = model_type
config["NeuralNetwork"]["Architecture"]["skip_connection"] = skip_connection
"""
to test this locally, set ci.json as
"Dataset": {
Expand Down Expand Up @@ -137,6 +141,8 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
thresholds["PNA"] = [0.10, 0.10]
if use_lengths and "vector" in ci_input:
thresholds["PNA"] = [0.2, 0.15]
if skip_connection:
thresholds["PNA"] = [0.8, 0.7]
verbosity = 2

for ihead in range(len(true_values)):
Expand Down Expand Up @@ -175,18 +181,23 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
@pytest.mark.parametrize(
"model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet"]
)
@pytest.mark.parametrize("skip_connection", [False, True])
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
def pytest_train_model(model_type, ci_input, overwrite_data=False):
unittest_train_model(model_type, ci_input, False, overwrite_data)
def pytest_train_model(model_type, skip_connection, ci_input, overwrite_data=False):
unittest_train_model(model_type, skip_connection, ci_input, False, overwrite_data)


# Test only models
@pytest.mark.parametrize("model_type", ["PNA", "CGCNN", "SchNet"])
def pytest_train_model_lengths(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci.json", True, overwrite_data)
@pytest.mark.parametrize("skip_connection", [False, True])
def pytest_train_model_lengths(model_type, skip_connection, overwrite_data=False):
unittest_train_model(model_type, skip_connection, "ci.json", True, overwrite_data)


# Test vector output
@pytest.mark.parametrize("model_type", ["PNA"])
def pytest_train_model_vectoroutput(model_type, overwrite_data=False):
unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data)
@pytest.mark.parametrize("skip_connection", [False, True])
def pytest_train_model_vectoroutput(model_type, skip_connection, overwrite_data=False):
unittest_train_model(
model_type, skip_connection, "ci_vectoroutput.json", True, overwrite_data
)