Skip to content

Commit

Permalink
Cifar-10 dataset (#78)
Browse files Browse the repository at this point in the history
* Added cifar10

* custom mlp model with scaled array params

* copied mnist classifier and edited to use cifar as a test

* removed trial files

---------

Co-authored-by: samho <>
  • Loading branch information
samhosegood authored Jan 12, 2024
1 parent 293c1c8 commit ebfc951
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions experiments/mnist/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import array
import gzip
import os
import pickle
import struct
import tarfile
import urllib.request
from os import path

Expand Down Expand Up @@ -47,6 +49,19 @@ def _one_hot(x, k, dtype=np.float32):
return np.array(x[:, None] == np.arange(k), dtype)


def _unzip(file):
file = tarfile.open(file)
file.extractall(_DATA)
file.close()
return


def _unpickle(file):
with open(file, "rb") as fo:
dict = pickle.load(fo, encoding="bytes")
return dict


def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
Expand Down Expand Up @@ -93,3 +108,46 @@ def mnist(permute_train=False):
train_labels = train_labels[perm]

return train_images, train_labels, test_images, test_labels


def cifar_raw():
"""Download, unzip and parse the raw cifar dataset."""

filename = "cifar-10-python.tar.gz"
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
_download(url, filename)
_unzip(path.join(_DATA, filename))

data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]
data = []
labels = []
for batch in data_batches:
tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch))
data.append(tmp_dict[b"data"])
labels.append(tmp_dict[b"labels"])
train_images = np.concatenate(data)
train_labels = np.concatenate(labels)

test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch"))
test_images = test_dict[b"data"]
test_labels = np.array(test_dict[b"labels"])

return train_images, train_labels, test_images, test_labels


def cifar(permute_train=False):
"""Download, parse and process cifar data to unit scale and one-hot labels."""

train_images, train_labels, test_images, test_labels = cifar_raw()

train_images = train_images / np.float32(255.0)
test_images = test_images / np.float32(255.0)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

if permute_train:
perm = np.random.RandomState(0).permutation(train_images.shape[0])
train_images = train_images[perm]
train_labels = train_labels[perm]

return train_images, train_labels, test_images, test_labels

0 comments on commit ebfc951

Please sign in to comment.