MNIST 数据集 是一个手写数字组成的数据集,现在被当作一个机器学习算法评测的基准数据集。
这是一个下载并解压数据的脚本:
In [1]:
%%file download_mnist.py
import os
import os.path
import urllib
import gzip
import shutil
if not os.path.exists('mnist'):
os.mkdir('mnist')
def download_and_gzip(name):
if not os.path.exists(name + '.gz'):
urllib.urlretrieve('http://yann.lecun.com/exdb/' + name + '.gz', name + '.gz')
if not os.path.exists(name):
with gzip.open(name + '.gz', 'rb') as f_in, open(name, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
download_and_gzip('mnist/train-images-idx3-ubyte')
download_and_gzip('mnist/train-labels-idx1-ubyte')
download_and_gzip('mnist/t10k-images-idx3-ubyte')
download_and_gzip('mnist/t10k-labels-idx1-ubyte')
Overwriting download_mnist.py
可以运行这个脚本来下载和解压数据:
In [2]:
%run download_mnist.py
使用如下的脚本来导入 MNIST 数据,源码地址:
https://github.com/Newmu/Theano-Tutorials/blob/master/load.py
In [3]:
%%file load.py
import numpy as np
import os
datasets_dir = './'
def one_hot(x,n):
if type(x) == list:
x = np.array(x)
x = x.flatten()
o_h = np.zeros((len(x),n))
o_h[np.arange(len(x)),x] = 1
return o_h
def mnist(ntrain=60000,ntest=10000,onehot=True):
data_dir = os.path.join(datasets_dir,'mnist/')
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000))
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000))
trX = trX/255.
teX = teX/255.
trX = trX[:ntrain]
trY = trY[:ntrain]
teX = teX[:ntest]
teY = teY[:ntest]
if onehot:
trY = one_hot(trY, 10)
teY = one_hot(teY, 10)
else:
trY = np.asarray(trY)
teY = np.asarray(teY)
return trX,teX,trY,teY
Overwriting load.py
Softmax
回归相当于 Logistic
回归的一个一般化,Logistic
回归处理的是两类问题,Softmax
回归处理的是 N
类问题。
Logistic
回归输出的是标签为 1 的概率(标签为 0 的概率也就知道了),对应地,对 N 类问题 Softmax
输出的是每个类对应的概率。
具体的内容,可以参考 UFLDL
教程:
http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
In [4]:
import theano
from theano import tensor as T
import numpy as np
from load import mnist
Using gpu device 1: Tesla C2075 (CNMeM is disabled)
我们来看它具体的实现。
这两个函数一个是将数据转化为 GPU
计算的类型,另一个是初始化权重:
In [5]:
def floatX(X):
return np.asarray(X, dtype=theano.config.floatX)
def init_weights(shape):
return theano.shared(floatX(np.random.randn(*shape) * 0.01))
Softmax
的模型在 theano
中已经实现好了:
In [6]:
A = T.matrix()
B = T.nnet.softmax(A)
test_softmax = theano.function([A], B)
a = floatX(np.random.rand(3, 4))
b = test_softmax(a)
print b.shape
# 行和
print b.sum(1)
(3, 4)
[ 1.00000012 1\. 1\. ]
softmax
函数会按照行对矩阵进行 Softmax
归一化。
所以我们的模型为:
In [7]:
def model(X, w):
return T.nnet.softmax(T.dot(X, w))
导入数据:
In [8]:
trX, teX, trY, teY = mnist(onehot=True)
定义变量,并初始化权重:
In [9]:
X = T.fmatrix()
Y = T.fmatrix()
w = init_weights((784, 10))
定义模型输出和预测:
In [10]:
py_x = model(X, w)
y_pred = T.argmax(py_x, axis=1)
损失函数为多类的交叉熵,这个在 theano
中也被定义好了:
In [11]:
cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))
gradient = T.grad(cost=cost, wrt=w)
update = [[w, w - gradient * 0.05]]
编译 train
和 predict
函数:
In [12]:
train = theano.function(inputs=[X, Y], outputs=cost, updates=update, allow_input_downcast=True)
predict = theano.function(inputs=[X], outputs=y_pred, allow_input_downcast=True)
迭代 100 次,测试集正确率为 0.925:
In [13]:
for i in range(100):
for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
cost = train(trX[start:end], trY[start:end])
print "{0:03d}".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))
000 0.8862
001 0.8985
002 0.9042
003 0.9084
004 0.9104
005 0.9121
006 0.9121
007 0.9142
008 0.9158
009 0.9163
010 0.9162
011 0.9166
012 0.9171
013 0.9176
014 0.9182
015 0.9182
016 0.9184
017 0.9188
018 0.919
019 0.919
020 0.9194
021 0.9201
022 0.9204
023 0.9203
024 0.9205
025 0.9207
026 0.9207
027 0.9209
028 0.9214
029 0.9213
030 0.9212
031 0.9211
032 0.9217
033 0.9217
034 0.9217
035 0.922
036 0.9222
037 0.922
038 0.922
039 0.9218
040 0.9219
041 0.9223
042 0.9225
043 0.9226
044 0.9227
045 0.9225
046 0.9227
047 0.9231
048 0.9231
049 0.9231
050 0.9232
051 0.9232
052 0.9231
053 0.9231
054 0.9233
055 0.9233
056 0.9237
057 0.9239
058 0.9239
059 0.9239
060 0.924
061 0.9242
062 0.9242
063 0.9243
064 0.9243
065 0.9244
066 0.9244
067 0.9244
068 0.9245
069 0.9244
070 0.9244
071 0.9245
072 0.9244
073 0.9243
074 0.9243
075 0.9244
076 0.9243
077 0.9242
078 0.9244
079 0.9244
080 0.9243
081 0.9242
082 0.9239
083 0.9241
084 0.9242
085 0.9243
086 0.9244
087 0.9243
088 0.9243
089 0.9244
090 0.9246
091 0.9246
092 0.9246
093 0.9247
094 0.9246
095 0.9246
096 0.9246
097 0.9246
098 0.9246
099 0.9248