Skip to content

Commit

Permalink
add rcfmnist
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Sep 11, 2023
1 parent bf74d3f commit 26f3e0e
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 2 deletions.
46 changes: 46 additions & 0 deletions config/usb_cv/fullysupervised/fullysupervised_rcfmnist_600_0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
algorithm: fullysupervised
save_dir: ./saved_models/usb_cv/
save_name: fullysupervised_rcfmnist_600_0
resume: False
load_path: ./saved_models/usb_cv//fullysupervised_rcfmnist_600_0/latest_model.pth
overwrite: True
use_tensorboard: True
use_wandb: False
epoch: 200
num_train_iter: 204800
num_log_iter: 256
num_eval_iter: 2048
batch_size: 32
eval_batch_size: 64
num_warmup_iter: 5120
num_labels: 600
uratio: 1
ema_m: 0.0
img_size: 32
crop_ratio: 0.875
optim: AdamW
lr: 0.0005
layer_decay: 0.5
momentum: 0.9
weight_decay: 0.0005
amp: False
clip: 0.0
use_cat: True
net: vit_small_patch2_32
net_from_name: False
data_dir: ./data/
dataset: rcfmnist
train_sampler: RandomSampler
num_classes: 1
loss_type: 'l1_loss'
num_workers: 4
seed: 0
world_size: 1
rank: 0
multiprocessing_distributed: True
dist_url: tcp://127.0.0.1:10025
dist_backend: nccl
gpu: None
use_pretrain: True
pretrain_path: https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_small_patch2_32_mlp_im_1k_32.pth
find_unused_parameters: False
6 changes: 5 additions & 1 deletion semilearn/core/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d
data_dir: data folder
include_lb_to_ulb: flag of including labeled data into unlabeled data
"""
from semilearn.datasets import get_agedb, get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet, get_imdb_wiki, get_json_dset, get_pkl_dset
from semilearn.datasets import (get_agedb, get_eurosat, get_medmnist, get_semi_aves, get_cifar, get_svhn, get_stl10, get_imagenet,
get_imdb_wiki, get_rcfmnist, get_json_dset, get_pkl_dset)

if dataset == "eurosat":
lb_dset, ulb_dset, eval_dset = get_eurosat(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb)
Expand Down Expand Up @@ -102,6 +103,9 @@ def get_dataset(args, algorithm, dataset, num_labels, num_classes, data_dir='./d
elif dataset == "imdb_wiki":
lb_dset, ulb_dset, eval_dset = get_imdb_wiki(args, algorithm, dataset, num_labels, data_dir=data_dir)
test_dset = None
elif dataset == "rcfmnist":
lb_dset, ulb_dset, eval_dset = get_rcfmnist(args, algorithm, dataset, num_labels, data_dir=data_dir)
test_dset = None
# speech dataset
elif dataset in ['esc50', 'fsdnoisy', 'gtzan', 'superbks', 'superbsi', 'urbansound8k']:
lb_dset, ulb_dset, eval_dset, test_dset = get_pkl_dset(args, algorithm, dataset, num_labels, num_classes, data_dir=data_dir, include_lb_to_ulb=include_lb_to_ulb)
Expand Down
2 changes: 1 addition & 1 deletion semilearn/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

from semilearn.datasets.utils import split_ssl_data, get_collactor
from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_agedb, get_imdb_wiki,
from semilearn.datasets.cv_datasets import (get_cifar, get_eurosat, get_imagenet, get_agedb, get_imdb_wiki, get_rcfmnist,
get_medmnist, get_semi_aves, get_stl10, get_svhn, get_food101)
from semilearn.datasets.nlp_datasets import get_json_dset
from semilearn.datasets.audio_datasets import get_pkl_dset
Expand Down
1 change: 1 addition & 0 deletions semilearn/datasets/cv_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .imagenet import get_imagenet
from .imdb_wiki import get_imdb_wiki
from .medmnist import get_medmnist
from .rcf_mnist import get_rcfmnist
from .stl10 import get_stl10
from .svhn import get_svhn
from .food101 import get_food101
271 changes: 271 additions & 0 deletions semilearn/datasets/cv_datasets/rcf_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import random
import numpy as np
import math
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from semilearn.datasets.cv_datasets.datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data


class RCFMNISTDataset(BasicDataset):

def __init__(self,
alg,
data,
targets=None,
num_classes=None,
transform=None,
is_ulb=False,
strong_transform=None,
onehot=False,
*args,
**kwargs):
super(RCFMNISTDataset, self).__init__(alg=alg, data=data, targets=targets, num_classes=num_classes,
transform=transform, is_ulb=is_ulb, strong_transform=strong_transform, onehot=onehot, *args, **kwargs)

def __sample__(self, idx):
img = Image.fromarray(self.data[idx, ...]).convert('RGB')
label = np.asarray([self.targets[idx]]).astype('float32')
return img, label


def img_torch2numpy(img): # N C H W --> N H W C
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if len(img.shape) == 4:
return np.transpose(npimg, (0, 2, 3, 1))
elif len(img.shape) == 3:
return np.transpose(npimg, (1, 2, 0))

def img_numpy2torch(img):# N H W C --> N C H W
if len(img.shape) == 4:
tmp_img = np.transpose(img, (0, 3, 1 ,2))
else:
tmp_img = np.transpose(img, (2, 0 ,1))
tcimg = torch.tensor(tmp_img)
tcimg = (tcimg - 0.5) * 2 # normalize
return tcimg


def get_all_batches(loader):
data_iter = iter(loader)
steplen = len(loader)
img_list = []
label_list = []

for step in range(steplen):
images, labels = data_iter.next()
img_list.append(images)
label_list.append(labels)
img_tensor = torch.cat(img_list, dim=0)
label_tensor = torch.cat(label_list, dim=0)

return img_tensor, label_tensor


def rotate_img(img, degree=None):
rotate_class = [(360.0 / 60) * i for i in range(60)]
# rotate a image by PIL
img = img / 2 + 0.5 # unnormalize
pil_img = transforms.ToPILImage()(img)
if degree is None:
degree = random.sample(rotate_class, 1)[0]
r_img = pil_img.rotate(degree)
r_img = transforms.ToTensor()(r_img) # read float
r_img = (r_img - 0.5) * 2.0 # normalize
return r_img, degree


def get_rotate_imgs(imgs):
# rotate set of images
r_img_list, degree_list = [], []
for i in range(imgs.shape[0]):
r_img, degree = rotate_img(imgs[i])
r_img_list.append(r_img.unsqueeze(0))
degree_list.append(torch.Tensor([int(degree)]))

r_img_list = torch.cat(r_img_list, dim=0)
degree_list = torch.cat(degree_list, dim=0)
r_img_np = img_torch2numpy(r_img_list)
degree_np = degree_list.numpy()

return r_img_np, degree_list


def copydim(set:np.array, num=3):
return np.expand_dims(set, -1).repeat(num, axis=-1)


def linear_red_blue_preparation(x_train, x_test, y_train, y_test, spurious_ratio=0.9):
# # calculate the pearson between ratio and y
x_train, ratio_reshape, y_reshape, idx2assay_train, assay2idx_train, _ = color_linear_red_blue(
x_train, y_train, spurious_ratio, use_spurious=True, inv=False)
x_test, ratio_reshape, y_reshape, _, _, test_assay2idx_list = color_linear_red_blue(
x_test, y_test, spurious_ratio, use_spurious=True, inv=False)

return x_train, x_test, assay2idx_train, test_assay2idx_list


def color_linear_red_blue(x_set:np.array, y_set:np.array, spurious_ratio=0.9, use_spurious=1, inv=False):
y_reshape = y_set
x_tmp = x_set
ratio_reshape = np.zeros_like(y_reshape)
print(f'x_tmp.shape = {x_tmp.shape}, y_reshape.shape = {y_reshape.shape}')

num = int(y_reshape.shape[0])
idx = np.arange(num)
idx2assay = np.zeros(num)

if use_spurious: # spurious
mixtype_normal_idx = np.random.choice(idx, size=int(num * spurious_ratio), replace=False)
mixtype_inverse_idx = np.setdiff1d(idx, mixtype_normal_idx)
ratio_matric = copydim(copydim(y_reshape, 1), 1)

if inv == False: # normal spurious correlation
x_tmp[mixtype_normal_idx] = red_blue_linear_map(x_tmp[mixtype_normal_idx], ratio_matric[mixtype_normal_idx])
x_tmp[mixtype_inverse_idx] = red_blue_linear_map(x_tmp[mixtype_inverse_idx], 1.0 - ratio_matric[mixtype_inverse_idx])
ratio_reshape[mixtype_normal_idx] = y_reshape[mixtype_normal_idx]
ratio_reshape[mixtype_inverse_idx] = 1.0 - y_reshape[mixtype_inverse_idx]
else: # inverse spurious correlation
x_tmp[mixtype_normal_idx] = red_blue_linear_map(x_tmp[mixtype_normal_idx], 1.0 - ratio_matric[mixtype_normal_idx])
x_tmp[mixtype_inverse_idx] = red_blue_linear_map(x_tmp[mixtype_inverse_idx], ratio_matric[mixtype_inverse_idx])
ratio_reshape[mixtype_normal_idx] = 1.0 - y_reshape[mixtype_normal_idx]
ratio_reshape[mixtype_inverse_idx] = y_reshape[mixtype_inverse_idx]

idx2assay[mixtype_normal_idx] = 0 # class 0
idx2assay[mixtype_inverse_idx] = 1 # class 1

else: # test random
ratio_reshape = np.random.rand(num)
ratio = copydim(copydim(ratio_reshape, 1), 1)
x_tmp = red_blue_linear_map(x_tmp, ratio)
# all class 0
x_set = x_tmp

# ood
assay2idx_list = [torch.nonzero(torch.tensor(idx2assay == loc)).squeeze(-1)
for loc in np.unique(idx2assay)] # ok
assay2idx = {loc:torch.nonzero(torch.tensor(idx2assay == loc)).squeeze(-1)
for loc in np.unique(idx2assay)}
return x_set, ratio_reshape, y_reshape, idx2assay, assay2idx, assay2idx_list


def red_blue_linear_map(imgt:np.array, red_ratio):
color_lower_bound = 5 / 255 # 60 / 255
# R outside background -> red * ratio
imgt[...,0] = np.where(imgt[...,0] > color_lower_bound , imgt[...,0] * red_ratio, imgt[...,0])
# G outside background -> 0
imgt[...,1] = np.where(imgt[...,1] > color_lower_bound , 0, imgt[...,1])
# B outside background -> blue * (1 - ratio)
imgt[...,2] = np.where(imgt[...,2] > color_lower_bound , imgt[...,2] * (1 - red_ratio), imgt[...,2])
return imgt


def test_img(img, x, degree, name, iscolor = True):
for i in range(3):
img_save(img[i], f'{name}_{i}')
if iscolor:
img_save(x[i], f'{name}_{i}_color_r_{float(degree[i])}')
else:
img_save(x[i], f'{name}_{i}_no_color_r_{float(degree[i])}')


def img_save(img, file_name):

plt.switch_backend('agg')
plt.axis('off')
plt.xticks([])
plt.yticks([])

img = img / 2 + 0.5 # unnormalize
if not isinstance(img, np.ndarray):
npimg = img.numpy()
else:
npimg = img
plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.savefig(f'./{file_name}.jpg')


def get_rcfmnist(args, alg='', name=None, num_labels=1000, num_classes=1, data_dir='./data', include_lb_to_ulb=True):

data_dir = os.path.join(data_dir, 'rcf_mnist')
basic_data_transforms = transforms.Compose([
transforms.Resize(32),
transforms.Grayscale(3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
train_set = FashionMNIST(root=data_dir, train=True, download=True, transform=basic_data_transforms)
test_set = FashionMNIST(root=data_dir, train=False, download=True, transform=basic_data_transforms)
train_loader = DataLoader(train_set, batch_size=1000, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)
train_data_raw, _ = get_all_batches(train_loader)
test_data_raw, _ = get_all_batches(test_loader)

train_data, train_labels = get_rotate_imgs(train_data_raw)
test_data, test_labels = get_rotate_imgs(test_data_raw)
train_data = np.uint8(train_data) # prepare for numpy to PIL
test_data = np.uint8(test_data)
train_labels, test_labels = train_labels.cpu().numpy(), test_labels.cpu().numpy()

img_size = args.img_size
crop_ratio = args.crop_ratio

transform_weak = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomCrop(img_size, padding=2),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

transform_strong = transforms.Compose([
transforms.Resize(int(math.floor(img_size / crop_ratio))),
RandomResizedCropAndInterpolation((img_size, img_size), scale=(0.8, 1.)),
transforms.RandomApply(torch.nn.ModuleList([
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
]), p=0.8),
transforms.GaussianBlur(kernel_size=11),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

transform_val = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, train_data, train_labels, num_classes=1,
lb_num_labels=num_labels,
ulb_num_labels=args.ulb_num_labels,
lb_imbalance_ratio=args.lb_imb_ratio,
ulb_imbalance_ratio=args.ulb_imb_ratio,
include_lb_to_ulb=include_lb_to_ulb)

if alg == 'fullysupervised':
lb_data = train_data
lb_targets = train_labels

lb_dset = RCFMNISTDataset(alg, lb_data, lb_targets, num_classes,
transform_weak, False, None, False)
ulb_dset = RCFMNISTDataset(alg, ulb_data, ulb_targets, num_classes,
transform_weak, True, transform_strong, False)
eval_dset = RCFMNISTDataset(alg, test_data, test_labels, num_classes,
transform_val, False, None, False)

return lb_dset, ulb_dset, eval_dset


if __name__ == '__main__':
get_rcfmnist(args=None)

0 comments on commit 26f3e0e

Please sign in to comment.