Skip to content

Python module to download and extract the MNIST database for training and testing deep learning neural networks in computer vision.

Notifications You must be signed in to change notification settings

JulianChia/mnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mnist.py

Title This python module provides a simple to use function to download and extract the MNIST database of handwritten digits that is provided by http://yann.lecun.com/exdb/mnist/.

Function:

load_MNIST(path=None, normalise=True, flatten=True, onehot=True)

kwarg:

path - str: MNIST datasets directory. Default to current directory/MNIST.
            Create if nonexistant. Download any missing MNIST files.
normalise - boolean: yes -> pixel RGB values [0,255] divided by 255.
                     no  -> pixel RGB values [0,255].
flatten   - boolean: yes -> pixels of all images stored as 2D numpy array.
                     no  -> pixels of all images stored as 3D numpy array.
onehot    - boolean: yes -> labels stored as one-hot encoded numpy array.
                     no  -> labels values used.

Returns a nested dictionary:

 {'train': {'images': train_images, 'labels': train_labels},
  'test': {'images': test_images, 'labels': test_labels}}
 where,
  train_images = MNISTimages(magic_number=2051, nimages=60000, nrows=28,
                             ncols=28, pixels=np.array())
                 if normalise, pixels dtype='float32'
                 else,         pixels dtype='uint8'
                 if flatten,   pixels.shape = (60000, 784)
                 else,         pixels.shape = (60000, 28, 28)
  train_labels = MNISTlabels(magic_number=2049, nlabels=60000,
                             labels=np.array() dtype='uint8')
                 if onehot,   labels.shape = (60000, 10)
                 else,        labels.shape = (60000,)
  test_images = MNISTimages(magic_number=2051, nimages=10000, nrows=28,
                            ncols=28, pixels=np.array())
                if normalise, pixelsdtype='float32'
                else,         pixels dtype='uint8'
                if flatten,   pixels.shape = (10000, 784)
                else,         pixels.shape = (10000, 28, 28)
  test_labels = MNISTlabels(magic_number=2049, nlabels=10000,
                            labels=np.array() dtype='uint8')
                if onehot,   labels.shape = (10000, 10)
                else,        labels.shape = (10000,)

Remarks:

MNISTimages() and MNISTlabels() are dataklass objects. On my system, they performed ~25x faster than python3 built-in dataclass objects and 5x faster than namedtuple.

How to use?

from mnist import load_MNIST           # Import function from module
mdb = load_MNIST()                     # Get MNIST database using default settings
train_images = mdb['train']['pixels']  # A 60000x784 numpy array with float32 values    
train_labels = mdb['train']['labels']  # A 60000x10 numpy array with uint8 values
test_images = mdb['test']['pixels']   # A 10000x784 numpy array with float32 values    
test_labels = mdb['test']['labels']   # A 10000x10 numpy array with uint8 values

About

Python module to download and extract the MNIST database for training and testing deep learning neural networks in computer vision.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages