-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_cifar10.py
66 lines (60 loc) · 1.89 KB
/
build_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import cv2
#import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
#import tensorflow as tf
import json
import tensorflow_datasets as tfds
from tqdm import tqdm
def write_json(label_dict, json_path):
json_object = json.dumps(label_dict, indent=4)
with open(json_path, 'w') as f:
f.write(json_object)
label_map = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck"}
base_dir = "data"
dataset_name = "cifar10"
train_class = {}
test_class = {}
dataset_dir = os.path.join(base_dir, dataset_name)
os.makedirs(dataset_dir, exist_ok=True)
write_json(label_map, os.path.join(dataset_dir,"label2name.json"))
ds = tfds.load(dataset_name)
for traintest in list(ds.keys()):
cur_ds = ds[traintest]
# train/test folder dir
traintest_dir = os.path.join(dataset_dir, traintest)
os.makedirs(traintest_dir, exist_ok=True)
if traintest == "train":
class_dict = train_class
elif traintest == "test":
class_dict = test_class
else:
raise ValueError("unknown traintest" + traintest)
for data in tqdm(cur_ds):
image_id = data['id'].numpy().decode()
image_label = data['label'].numpy()
label_name = label_map[image_label]
image_rgb = data['image'].numpy()
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
# save_dir
save_dir = os.path.join(traintest_dir, label_name)
# class stat
if label_name in class_dict.keys():
class_dict[label_name]+=1
else:
class_dict[label_name] = 1
os.makedirs(save_dir, exist_ok=True)
image_name = f"{image_id}_{image_label}.png"
cv2.imwrite(os.path.join(save_dir, image_name), image_bgr)
write_json(class_dict, os.path.join(base_dir, dataset_name, f"{traintest}_stat.json"))