-
Notifications
You must be signed in to change notification settings - Fork 2
/
get_data_loaders_tuad.py
72 lines (53 loc) · 2.06 KB
/
get_data_loaders_tuad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import os.path as osp
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Constant
from torch_geometric.datasets import TUDataset
from sklearn.model_selection import StratifiedKFold
def get_ad_split_TU(args, fold=5):
DS = args.dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', DS)
dataset = TUDataset(path, name=DS)
data_list = []
label_list = []
for data in dataset:
data_list.append(data)
label_list.append(data.y.item())
kfd = StratifiedKFold(n_splits=fold, random_state=0, shuffle=True)
splits = []
for k, (train_index, test_index) in enumerate(kfd.split(data_list, label_list)):
splits.append((train_index, test_index))
return splits
def get_data_loaders_TU(args, split):
DS = args.dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', DS)
if DS in ['IMDB-BINARY', 'REDDIT-BINARY', 'COLLAB']:
dataset = TUDataset(path, name=DS, transform=(Constant(1, cat=False)))
else:
dataset = TUDataset(path, name=DS)
dataset_num_features = dataset.num_node_features
data_list = []
label_list = []
for data in dataset:
data.edge_attr = None
data_list.append(data)
label_list.append(data.y.item())
(train_index, test_index) = split
data_train_ = [data_list[i] for i in train_index]
data_test = [data_list[i] for i in test_index]
data_train = []
for data in data_train_:
if data.y != 0:
data_train.append(data)
idx = 0
for data in data_train:
data.y = 0
data['idx'] = idx
idx += 1
for data in data_test:
data.y = 1 if data.y == 0 else 0
dataloader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=args.batch_size_test, shuffle=True)
meta = {'num_feat':dataset_num_features, 'num_train':len(data_train), 'num_edge_feat':0}
loader_dict = {'train': dataloader, 'test': dataloader_test}
return loader_dict, meta