Skip to content

Commit

Permalink
Remove own RGCN implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
GlycanConnector committed Oct 22, 2024
1 parent 2fe83c6 commit 7277ba9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 168 deletions.
108 changes: 55 additions & 53 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
seed: 42
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs
# root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
# logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs
root_dir: /home/daniel/Data1/roman/GIFFLAR/data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/logs
datasets:
- name: Immunogenicity
task: classification
- name: Glycosylation
task: classification
#- name: Immunogenicity
# task: classification
#- name: Glycosylation
# task: classification
- name: Taxonomy_Domain
task: multilabel
- name: Taxonomy_Kingdom
Expand Down Expand Up @@ -36,14 +38,14 @@ pre-transforms:
dim: 20
individual: False
model:
- name: rf
n_estimators: 500
n_jobs: -1
random_state: 42
- name: svm
random_state: 42
- name: xgb
random_state: 42
#- name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
#- name: svm
# random_state: 42
#- name: xgb
# random_state: 42
- name: mlp
feat_dim: 1024
hidden_dim: 1024
Expand All @@ -53,42 +55,42 @@ model:
patience: 30
learning_rate: 0.001
optimizer: Adam
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
batch_size: 512
num_layers: 16
epochs: 100
patience: 30
learning_rate: 0.001
optimizer: Adam
suffix:
- name: gnngly
feat_dim: 128
hidden_dim: 1024
batch_size: 512
num_layers: 8
epochs: 100
patience: 30
learning_rate: 0.001
optimizer: Adam
suffix:
- name: pyrgcn
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
suffix:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 100
learning_rate: 0.001
optimizer: Adam
pooling: global_pool
suffix: _128_8_gp
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 512
# num_layers: 16
# epochs: 100
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# suffix:
#- name: gnngly
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 512
# num_layers: 8
# epochs: 100
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# suffix:
#- name: rgcn
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 100
# learning_rate: 0.001
# optimizer: Adam
# suffix:
#- name: gifflar
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 100
# learning_rate: 0.001
# optimizer: Adam
# pooling: global_mean
# suffix: _128_8_gp
3 changes: 2 additions & 1 deletion gifflar/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_taxonomic_level(
"""
if not (p := (root / Path(f"taxonomy_{level}.tsv"))).exists():
# Chop to taxonomic level of interest and remove invalid rows
tax = get_taxonomy()[["glycan", level]]
tax = get_taxonomy(root)[["glycan", level]]
tax.rename(columns={"glycan": "IUPAC"}, inplace=True)
tax[tax[level] == "undetermined"] = np.nan
tax.dropna(inplace=True)
Expand Down Expand Up @@ -148,6 +148,7 @@ def get_dataset(data_config: dict, root: Path | str) -> dict:
Returns:
The configuration of the dataset with the filepath added and made sure the dataset is preprocessed
"""
Path(root).mkdir(exist_ok=True, parents=True)
name_fracs = data_config["name"].split("_")
match name_fracs[0]:
case "Taxonomy":
Expand Down
1 change: 0 additions & 1 deletion gifflar/data/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]

# For each baseline, collate its node features and edge indices as well
for b in baselines:
break
kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0)
edges = []
batch = []
Expand Down
85 changes: 2 additions & 83 deletions gifflar/model/baselines/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,7 @@
from gifflar.model.downstream import DownstreamGGIN


class HeteroPReLU(nn.Module):
def __init__(self, prelus: dict[str, nn.PReLU]):
"""
A module that applies a PReLU activation function to each input.
Args:
prelus: The PReLU activations to apply to each input.
"""
super(HeteroPReLU, self).__init__()
for name, prelu in prelus.items():
setattr(self, name, prelu)

def forward(self, input_: dict) -> dict:
"""
Apply the PReLU activation to the input.
Args:
input_: The input to apply the activation to.
Returns:
The input with the PReLU activation applied.
"""
for key, value in input_.items():
input_[key] = getattr(self, key).forward(value)
return input_


class PyTorchRGCN(DownstreamGGIN):
class RGCN(DownstreamGGIN):
def __init__(
self,
feat_dim: int,
Expand All @@ -47,7 +20,7 @@ def __init__(
pre_transform_args: Optional[dict] = None,
**kwargs: Any,
):
super(PyTorchRGCN, self).__init__(feat_dim, hidden_dim, output_dim, task, num_layers, batch_size,
super(RGCN, self).__init__(feat_dim, hidden_dim, output_dim, task, num_layers, batch_size,
pre_transform_args, **kwargs)

dims = [feat_dim]
Expand Down Expand Up @@ -91,57 +64,3 @@ def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"preds": pred,
}


class RGCN(DownstreamGGIN):
def __init__(
self,
feat_dim: int,
hidden_dim: int,
output_dim: int,
task: Literal["regression", "classification", "multilabel"],
num_layers: int = 3,
batch_size: int = 32,
pre_transform_args: Optional[dict] = None,
**kwargs: Any
):
"""
Implementation of the relational GCN model.
Args:
feat_dim: The feature dimension of the model.
hidden_dim: The dimensionality of the hidden layers.
output_dim: The output dimension of the model
task: The type of task to perform.
num_layers: The number of layers in the network.
batch_size: The batch size to use
pre_transform_args: The arguments for the pre-transforms.
kwargs: Additional arguments
"""
super(RGCN, self).__init__(feat_dim, hidden_dim, output_dim, task, num_layers, batch_size, pre_transform_args,
**kwargs)

self.convs = torch.nn.ModuleList()
dims = [feat_dim, hidden_dim // 2] + [hidden_dim] * (num_layers - 1)
for i in range(num_layers):
convs = {
# Set the inner layers to be a single weight without using the nodes embedding (therefore, e=-1)
key: GINConv(nn.Sequential(nn.Linear(dims[i], dims[i + 1])), eps=-1) for key in [
("atoms", "coboundary", "atoms"),
("atoms", "to", "bonds"),
("bonds", "to", "monosacchs"),
("bonds", "boundary", "bonds"),
("monosacchs", "boundary", "monosacchs")
]
}
self_loop_weight = nn.Sequential(nn.Linear(dims[i], dims[i + 1]))
convs.update({
("atoms", "self", "atoms"): GINConv(self_loop_weight, eps=-1),
("bonds", "self", "bonds"): GINConv(self_loop_weight, eps=-1),
("monosacchs", "self", "monosacchs"): GINConv(self_loop_weight, eps=-1),
})
self.convs.append(HeteroConv(convs))
self.convs.append(HeteroPReLU({
"atoms": nn.PReLU(),
"bonds": nn.PReLU(),
"monosacchs": nn.PReLU(),
}))
28 changes: 0 additions & 28 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,32 +136,6 @@ def __call__(self, data: HeteroData) -> HeteroData:


class RGCNTransform(RootTransform):
def __call__(self, data: HeteroData) -> HeteroData:
"""
Add self-loops to the graph for the RGCN model.
Args:
data: The input data to be transformed.
Returns:
The transformed data.
"""
data["atoms", "self", "atoms"].edge_index = torch.stack([
torch.arange(data["atoms"]["num_nodes"]),
torch.arange(data["atoms"]["num_nodes"])
])
data["bonds", "self", "bonds"].edge_index = torch.stack([
torch.arange(data["bonds"]["num_nodes"]),
torch.arange(data["bonds"]["num_nodes"])
])
data["monosacchs", "self", "monosacchs"].edge_index = torch.stack([
torch.arange(data["monosacchs"]["num_nodes"]),
torch.arange(data["monosacchs"]["num_nodes"])
])
return data


class PyTorchRGCNTransform(RootTransform):
def __call__(self, data: HeteroData) -> HeteroData:
data["rgcn_x"] = torch.cat([
data["atoms"].x,
Expand Down Expand Up @@ -557,8 +531,6 @@ def get_pretransforms(dataset_name: str = "", **pre_transform_args: dict[str, di
pre_transforms.append(SweetNetTransform(**args or {}))
case "RGCNTransform":
pre_transforms.append(RGCNTransform(**args or {}))
case "PyTorchRGCNTransform":
pre_transforms.append(PyTorchRGCNTransform(**args or {}))
case "LaplacianPE":
pre_transforms.append(LaplacianPE(**args or {}))
case "RandomWalkPE":
Expand Down
6 changes: 4 additions & 2 deletions gifflar/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from pathlib import Path
from typing import Any
import time
Expand All @@ -13,7 +16,7 @@
from gifflar.data.modules import DownsteamGDM, PretrainGDM
from gifflar.model.baselines.gnngly import GNNGLY
from gifflar.model.baselines.mlp import MLP
from gifflar.model.baselines.rgcn import RGCN, PyTorchRGCN
from gifflar.model.baselines.rgcn import RGCN
from gifflar.model.baselines.sweetnet import SweetNetLightning
from gifflar.benchmarks import get_dataset
from gifflar.model.downstream import DownstreamGGIN
Expand All @@ -29,7 +32,6 @@
"gnngly": GNNGLY,
"mlp": MLP,
"rgcn": RGCN,
"pyrgcn": PyTorchRGCN,
"sweetnet": SweetNetLightning,
}

Expand Down

0 comments on commit 7277ba9

Please sign in to comment.