Skip to content

Commit

Permalink
Cinpp convs (#107)
Browse files Browse the repository at this point in the history
* added classes for peptides dataset

* loading peptides dataset

* parametrized down adj as kwargs from cli args in zinc, peptides, ogb datasets

* include down_adj tu

* Include AP & include_down_adj

* Added support for CIN++ models

* added support for cin++ experiments

* modified data/dataset/__init__.py

* modified data/data_loading.py for pept graphs

* added down_adj to sh files for cin++

* changed include_down_adj to bool

* if embed_dim is None:
            embed_dim = hidden

* if embed_dim is None:
            embed_dim = hidden

* remove dwn_msg from super(CINppCochainConv, self)

* init lower_nns before calling super

* remove override reset_params
  • Loading branch information
lrnzgiusti authored Jun 13, 2023
1 parent 387afe8 commit c4ddd24
Show file tree
Hide file tree
Showing 21 changed files with 1,138 additions and 22 deletions.
35 changes: 27 additions & 8 deletions data/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
from data.complex import Cochain, CochainBatch, Complex, ComplexBatch
from data.datasets import (
load_sr_graph_dataset, load_tu_graph_dataset, load_zinc_graph_dataset, load_ogb_graph_dataset,
load_ring_transfer_dataset, load_ring_lookup_dataset)
load_ring_transfer_dataset, load_ring_lookup_dataset, load_pep_f_graph_dataset, load_pep_s_graph_dataset)
from data.datasets import (
SRDataset, ClusterDataset, TUDataset, ComplexDataset, FlowDataset,
OceanDataset, ZincDataset, CSLDataset, OGBDataset, RingTransferDataset, RingLookupDataset,
DummyDataset, DummyMolecularDataset)
DummyDataset, DummyMolecularDataset, PeptidesFunctionalDataset, PeptidesStructuralDataset)


class Collater(object):
Expand Down Expand Up @@ -133,19 +133,24 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold=
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'PROTEINS':
dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2,
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'],
init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'NCI1':
dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2,
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'],
init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'NCI109':
dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2,
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'],
init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'PTC':
dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2,
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'],
init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'MUTAG':
dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2,
fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'],
init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None))
elif name == 'FLOW':
dataset = FlowDataset(os.path.join(root, name), name, num_points=kwargs['flow_points'],
train_samples=1000, val_samples=200, train_orient=kwargs['train_orient'],
Expand All @@ -159,9 +164,11 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold=
dataset = RingLookupDataset(os.path.join(root, name), nodes=kwargs['max_ring_size'])
elif name == 'ZINC':
dataset = ZincDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'],
include_down_adj=kwargs['include_down_adj'],
use_edge_features=kwargs['use_edge_features'], n_jobs=n_jobs)
elif name == 'ZINC-FULL':
dataset = ZincDataset(os.path.join(root, name), subset=False, max_ring_size=kwargs['max_ring_size'],
include_down_adj=kwargs['include_down_adj'],
use_edge_features=kwargs['use_edge_features'], n_jobs=n_jobs)
elif name == 'CSL':
dataset = CSLDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'],
Expand All @@ -172,11 +179,17 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold=
official_name = 'ogbg-'+name.lower()
dataset = OGBDataset(os.path.join(root, name), official_name, max_ring_size=kwargs['max_ring_size'],
use_edge_features=kwargs['use_edge_features'], simple=kwargs['simple_features'],
init_method=init_method, n_jobs=n_jobs)
include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs)
elif name == 'DUMMY':
dataset = DummyDataset(os.path.join(root, name))
elif name == 'DUMMYM':
dataset = DummyMolecularDataset(os.path.join(root, name))
elif name == 'PEPTIDES-F':
dataset = PeptidesFunctionalDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'],
include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs)
elif name == 'PEPTIDES-S':
dataset = PeptidesStructuralDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'],
include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs)
else:
raise NotImplementedError(name)
return dataset
Expand Down Expand Up @@ -217,6 +230,12 @@ def load_graph_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), fold=0, **
elif name == 'ZINC':
graph_list, train_ids, val_ids, test_ids = load_zinc_graph_dataset(root=root)
data = (graph_list, train_ids, val_ids, test_ids, 1)
elif name == 'PEPTIDES-F':
graph_list, train_ids, val_ids, test_ids = load_pep_f_graph_dataset(root=root)
data = (graph_list, train_ids, val_ids, test_ids, 2)
elif name == 'PEPTIDES-S':
graph_list, train_ids, val_ids, test_ids = load_pep_s_graph_dataset(root=root)
data = (graph_list, train_ids, val_ids, test_ids, 2)
elif name == 'ZINC-FULL':
graph_list, train_ids, val_ids, test_ids = load_zinc_graph_dataset(root=root, subset=False)
data = (graph_list, train_ids, val_ids, test_ids, 1)
Expand Down
2 changes: 2 additions & 0 deletions data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from data.datasets.dummy import DummyDataset, DummyMolecularDataset
from data.datasets.csl import CSLDataset
from data.datasets.ogb import OGBDataset, load_ogb_graph_dataset
from data.datasets.peptides_functional import PeptidesFunctionalDataset, load_pep_f_graph_dataset
from data.datasets.peptides_structural import PeptidesStructuralDataset, load_pep_s_graph_dataset
from data.datasets.ringtransfer import RingTransferDataset, load_ring_transfer_dataset
from data.datasets.ringlookup import RingLookupDataset, load_ring_lookup_dataset

6 changes: 4 additions & 2 deletions data/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ class OGBDataset(InMemoryComplexDataset):
"""This is OGB graph-property prediction. This are graph-wise classification tasks."""

def __init__(self, root, name, max_ring_size, use_edge_features=False, transform=None,
pre_transform=None, pre_filter=None, init_method='sum', simple=False, n_jobs=2):
pre_transform=None, pre_filter=None, init_method='sum',
include_down_adj=False, simple=False, n_jobs=2):
self.name = name
self._max_ring_size = max_ring_size
self._use_edge_features = use_edge_features
self._simple = simple
self._n_jobs = n_jobs
super(OGBDataset, self).__init__(root, transform, pre_transform, pre_filter,
max_dim=2, init_method=init_method, cellular=True)
max_dim=2, init_method=init_method,
include_down_adj=include_down_adj, cellular=True)
self.data, self.slices, idx, self.num_tasks = self.load_dataset()
self.train_ids = idx['train']
self.val_ids = idx['valid']
Expand Down
257 changes: 257 additions & 0 deletions data/datasets/peptides_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 2 21:37:42 2023
@author: renz
"""

import hashlib
import os.path as osp
import os
import pickle
import shutil

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import decide_download
from torch_geometric.data import Data, download_url
from torch_geometric.data import InMemoryDataset
from data.utils import convert_graph_dataset_with_rings
from data.datasets import InMemoryComplexDataset
from tqdm import tqdm


class PeptidesFunctionalDataset(InMemoryComplexDataset):
"""
PyG dataset of 15,535 peptides represented as their molecular graph
(SMILES) with 10-way multi-task binary classification of their
functional classes.
The goal is use the molecular representation of peptides instead
of amino acid sequence representation ('peptide_seq' field in the file,
provided for possible baseline benchmarking but not used here) to test
GNNs' representation capability.
The 10 classes represent the following functional classes (in order):
['antifungal', 'cell_cell_communication', 'anticancer',
'drug_delivery_vehicle', 'antimicrobial', 'antiviral',
'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']
Args:
root (string): Root directory where the dataset should be saved.
smiles2graph (callable): A callable function that converts a SMILES
string into a graph object. We use the OGB featurization.
* The default smiles2graph requires rdkit to be installed *
"""
def __init__(self, root, max_ring_size, smiles2graph=smiles2graph,
transform=None, pre_transform=None, pre_filter=None,
include_down_adj=False, init_method='sum', n_jobs=2):
self.original_root = root
self.smiles2graph = smiles2graph
self.folder = osp.join(root, 'peptides-functional')

self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1'
self.version = '701eb743e899f4d793f0e13c8fa5a1b4' # MD5 hash of the intended dataset file
self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1'
self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061'

# Check version and update if necessary.
release_tag = osp.join(self.folder, self.version)
if osp.isdir(self.folder) and (not osp.exists(release_tag)):
print(f"{self.__class__.__name__} has been updated.")
if input("Will you update the dataset now? (y/N)\n").lower() == 'y':
shutil.rmtree(self.folder)

self.name = 'peptides_functional'
self._max_ring_size = max_ring_size
self._use_edge_features = True
self._n_jobs = n_jobs
super(PeptidesFunctionalDataset, self).__init__(root, transform, pre_transform, pre_filter,
max_dim=2, init_method=init_method, include_down_adj=include_down_adj,
cellular=True, num_classes=1)

self.data, self.slices, idx, self.num_tasks = self.load_dataset()
self.train_ids = idx['train']
self.val_ids = idx['val']
self.test_ids = idx['test']

self.num_node_type = 9
self.num_edge_type = 3

@property
def raw_file_names(self):
return 'peptide_multi_class_dataset.csv.gz'

@property
def processed_file_names(self):
return [f'{self.name}_complex.pt', f'{self.name}_idx.pt', f'{self.name}_tasks.pt']


@property
def processed_dir(self):
"""Overwrite to change name based on edge and simple feats"""
directory = super(PeptidesFunctionalDataset, self).processed_dir
suffix1 = f"_{self._max_ring_size}rings" if self._cellular else ""
suffix2 = "-E" if self._use_edge_features else ""
return directory + suffix1 + suffix2


def _md5sum(self, path):
hash_md5 = hashlib.md5()
with open(path, 'rb') as f:
buffer = f.read()
hash_md5.update(buffer)
return hash_md5.hexdigest()

def download(self):
if decide_download(self.url):
path = download_url(self.url, self.raw_dir)
# Save to disk the MD5 hash of the downloaded file.
hash = self._md5sum(path)
if hash != self.version:
raise ValueError("Unexpected MD5 hash of the downloaded file")
open(osp.join(self.root, hash), 'w').close()
# Download train/val/test splits.
path_split1 = download_url(self.url_stratified_split, self.root)
assert self._md5sum(path_split1) == self.md5sum_stratified_split
old_df_name = osp.join(self.raw_dir,
'peptide_multi_class_dataset.csv.gz?dl=1')
new_df_name = osp.join(self.raw_dir,
'peptide_multi_class_dataset.csv.gz')


old_split_file = osp.join(self.root,
"splits_random_stratified_peptide.pickle?dl=1")
new_split_file = osp.join(self.root,
"splits_random_stratified_peptide.pickle")
os.rename(old_df_name, new_df_name)
os.rename(old_split_file, new_split_file)

else:
print('Stop download.')
exit(-1)

def load_dataset(self):
"""Load the dataset from here and process it if it doesn't exist"""
print("Loading dataset from disk...")
data, slices = torch.load(self.processed_paths[0])
idx = torch.load(self.processed_paths[1])
tasks = torch.load(self.processed_paths[2])
return data, slices, idx, tasks

def process(self):
data_df = pd.read_csv(osp.join(self.raw_dir,
'peptide_multi_class_dataset.csv.gz'))
smiles_list = data_df['smiles']

print('Converting SMILES strings into graphs...')
data_list = []
for i in tqdm(range(len(smiles_list))):
data = Data()

smiles = smiles_list[i]
graph = self.smiles2graph(smiles)

assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert (len(graph['node_feat']) == graph['num_nodes'])

data.__num_nodes__ = int(graph['num_nodes'])
data.edge_index = torch.from_numpy(graph['edge_index']).to(
torch.int64)
data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
torch.int64)
data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])

data_list.append(data)

if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]

split_idx = self.get_idx_split()

# NB: the init method would basically have no effect if
# we use edge features and do not initialize rings.
print(f"Converting the {self.name} dataset to a cell complex...")
complexes, _, _ = convert_graph_dataset_with_rings(
data_list,
max_ring_size=self._max_ring_size,
include_down_adj=self.include_down_adj,
init_method=self._init_method,
init_edges=self._use_edge_features,
init_rings=False,
n_jobs=self._n_jobs)

print(f'Saving processed dataset in {self.processed_paths[0]}...')
torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0])

print(f'Saving idx in {self.processed_paths[1]}...')
torch.save(split_idx, self.processed_paths[1])

print(f'Saving num_tasks in {self.processed_paths[2]}...')
torch.save(10, self.processed_paths[2])

def get_idx_split(self):
""" Get dataset splits.
Returns:
Dict with 'train', 'val', 'test', splits indices.
"""
split_file = osp.join(self.root,
"splits_random_stratified_peptide.pickle")
with open(split_file, 'rb') as f:
splits = pickle.load(f)
split_dict = replace_numpy_with_torchtensor(splits)
split_dict['valid'] = split_dict['val']
return split_dict


def load_pep_f_graph_dataset(root):
raw_dir = osp.join(root, 'raw')
data_df = pd.read_csv(osp.join(raw_dir,
'peptide_multi_class_dataset.csv.gz'))
smiles_list = data_df['smiles']
target_names = ['Inertia_mass_a', 'Inertia_mass_b', 'Inertia_mass_c',
'Inertia_valence_a', 'Inertia_valence_b',
'Inertia_valence_c', 'length_a', 'length_b', 'length_c',
'Spherocity', 'Plane_best_fit']
# Normalize to zero mean and unit standard deviation.
data_df.loc[:, target_names] = data_df.loc[:, target_names].apply(
lambda x: (x - x.mean()) / x.std(), axis=0)

print('Converting SMILES strings into graphs...')
data_list = []
for i in tqdm(range(len(smiles_list))):
data = Data()

smiles = smiles_list[i]
y = data_df.iloc[i][target_names]
graph = smiles2graph(smiles)

assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert (len(graph['node_feat']) == graph['num_nodes'])

data.__num_nodes__ = int(graph['num_nodes'])
data.edge_index = torch.from_numpy(graph['edge_index']).to(
torch.int64)
data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
torch.int64)
data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
data.y = torch.Tensor([y])

data_list.append(data)

dataset = InMemoryDataset.collate(data_list)

#get split file
split_file = osp.join(root,
"splits_random_stratified_peptide.pickle")
with open(split_file, 'rb') as f:
splits = pickle.load(f)
split_dict = replace_numpy_with_torchtensor(splits)
split_dict['valid'] = split_dict['val']

return dataset, split_dict['train'], split_dict['valid'], split_dict['test']
Loading

0 comments on commit c4ddd24

Please sign in to comment.