Python script that uses Machine Learning (Neural Networks) to classify handwritten digits from the MNIST dataset.
This project was done as a part of the CSE 574 course on Machine Learning at SUNY Buffalo.
Some details about the project:
1. The MNIST dataset consists of two parts:
i. Training dataset of 60000 handwritten digits
ii. Test dataset of 10000 handwritten digits
We divide the original training dataset of 60000 handwritten digits into two parts:
i. An actual training dataset of 50000 handwritten digits
ii. Validation dataset of 10000 handwritten digits
Thus, we will be working with a total of three datasets:
i. Training dataset of 50000 handwritten digits
ii. Validation dataset of 10000 handwritten digits
iii. Test dataset of 10000 handwritten digits.
2. The neural network consists of the following broad steps:
i. Initialize the weights to some random numbers
ii. Feedforward Propagation: In this phase, we have the parameters of the Neural Network and an input feature vector. We compute the probability that this feature vector belongs to a particular digit.
iii. Backpropagation: In this phase, we calculate the error in prediction by comparing our prediction with the actual labels. Then we transmit this error backwards and update the weights accordingly.
3. The neural network uses a perceptron as its building block. This perceptrom uses a sigmoid function as its activation function. The neural network consists of one hidden layer with 50 nodes. Also, it uses a regularization parameter to avoid the problem of overfitting. The number of nodes in the hidden layer and the regularization parameter can be varied to observe the effects on training time and accuracy.
4. The script consists of a number of functions for activities like preprocessing, prediction, and other calculations. Appropriate comments in the script explain the various fuctions, as well as their respective inputs and outputs.
After making any desired modifications, the above script can simply be executed from the command line as shown below.
python nnScript.py