-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_subset_generation.py
55 lines (43 loc) · 2.46 KB
/
run_subset_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torchvision import transforms
import numpy as np
import os
from utils import (MNISTWithIdx, CIFAR10WithIdx, load_subset_indices)
def get_subset_indices(dataset, num_per_class):
"""Randomly chooses num_per_class indices from dataset to make a balanced subset and returns a list of indices."""
# Get the indices of each class in the dataset
indices = {}
for i in range(len(dataset)):
_, label,_ = dataset[i]
if label not in indices:
indices[label] = []
indices[label].append(i)
# Select a balanced subset of the dataset
subset_indices = []
for label in indices:
subset_indices += np.random.choice(indices[label], num_per_class, replace=False).tolist()
return subset_indices
def main():
# Download the data from torchvision
trainset_mnist = MNISTWithIdx(root='./data', train=True, transform=transforms.ToTensor(), download=True)
testset_mnist = MNISTWithIdx(root='./data', train=False, transform=transforms.ToTensor(), download=True)
transform_cifar = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset_cifar = CIFAR10WithIdx(root='./data', train=True, transform=transform_cifar, download=True)
testset_cifar = CIFAR10WithIdx(root='./data', train=False, transform=transform_cifar, download=True)
for task, datasets in zip(['mnist3', 'cifar10'], [(trainset_mnist, testset_mnist), (trainset_cifar, testset_cifar)]):
trainset, testset = datasets
for num_per_class in [10, 20, 50]:
# Create the balanced subset dataset (loading same indices as used in our study)
train_indices = load_subset_indices(f'{os.getcwd()}/data/{task}/train_subset_{num_per_class}pc.txt')
test_indices = load_subset_indices(f'{os.getcwd()}/data/{task}/test_subset.txt')
# !! If you would like to define new subsets, uncomment this:
# train_indices = get_subset_indices(trainset, num_per_class)
# test_indices = get_subset_indices(testset, num_per_class)
train_subset = torch.utils.data.Subset(trainset, train_indices)
test_subset = torch.utils.data.Subset(testset, test_indices)
# Save
torch.save(train_subset, f'{os.getcwd()}/data/{task}/train_subset_{num_per_class}pc.pt')
torch.save(test_subset, f'{os.getcwd()}/data/{task}/test_subset.pt')
if __name__=='__main__':
main()