少年壮志不言愁
劝君惜取少年时
一、Keras简介
Keras是一个高层神经网络API,由纯Python编写,默认基于TensorFlow作为计算后端,非常适合快速开发出一个深度学习项目原型。
Keras目前兼容Python2.7-3.6,当然在我实际操作中,发现还要注重TensorFlow(以下简称为TF)版本,这里也是巨坑,在写这个Demo时候,TF最新版本为1.4,若用最新版本TF,安装Keras之后,在导入Keras过程中可能会一直报错。我通过将TF版本降至1.3成功解决不兼容问题。
二、MNIST手写数字识别
MNIST手写数字识别可以理解为深度学习领域的HelloWorld,mnist数据是手写数字的数据集合,训练集规模为60000,测试集为10000
更多详细内容可以查看官网 http://yann.lecun.com/exdb/mnist/
本文内容包括
- 加载数据的方法
- 搭建神经网络分类算法
- 对Keras实现算法程序中部分方法参数解析
三、神经网络算法分类
1.加载数据
在Keras中通过mnist.load_data()
方法实现加载数据,然而不幸的是调用该方法法时候多数情况会出现下面这样一个结果,当然,这个多半和网络有关
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1318, in do_open
encode_chunked=req.has_header('Transfer-encoding'))
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1239, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1285, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1234, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1026, in _send_output
self.send(msg)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 964, in send
self.connect()
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/http/client.py", line 1400, in connect
server_hostname=server_hostname)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 401, in wrap_socket
_context=self, _session=session)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 808, in __init__
self.do_handshake()
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 1061, in do_handshake
self._sslobj.do_handshake()
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/ssl.py", line 683, in do_handshake
self._sslobj.do_handshake()
ssl.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/data_utils.py", line 221, in get_file
urlretrieve(origin, fpath, dl_progress)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 248, in urlretrieve
with contextlib.closing(urlopen(url, data)) as fp:
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 223, in urlopen
return opener.open(url, data, timeout)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 526, in open
response = self._open(req, data)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 544, in _open
'_open', req)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 504, in _call_chain
result = func(*args)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1361, in https_open
context=self._context, check_hostname=self._check_hostname)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/urllib/request.py", line 1320, in do_open
raise URLError(err)
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/xhades/Documents/github/PythonEngineer/mnist/kerasmnist.py", line 69, in <module>
mnist.load_data()
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/datasets/mnist.py", line 17, in load_data
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/data_utils.py", line 223, in get_file
raise Exception(error_msg.format(origin, e.errno, e.reason))
Exception: URL fetch failure on https://s3.amazonaws.com/img-datasets/mnist.npz: None -- [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:749)
Process finished with exit code 1
先来看一下load_data()
长什么样子吧
def load_data(path='mnist.npz'):
"""Loads the MNIST dataset.
# Arguments
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
path = get_file(path,
origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
打开之后发现又调用了get_file()
方法,这里就不详细解释这个方法了,他主要实现了检查路径下是否有文件,没有则下载的功能。
基于这个理解,我手动下载了数据集数据下载链接,并且重写了load_data()
方法
# 内置load_data() 多次加载数据都是失败 于是下载数据后 自定义方法
def load_data(path="MNIST_data/mnist.npz"):
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
2.构建序贯模型网络结构
# 构建序贯模型
def train():
model = Sequential()
model.add(Dense(500,input_shape=(784,))) # 输入层, 28*28=784
model.add(Activation('tanh'))
model.add(Dropout(0.3)) # 30% dropout
model.add(Dense(300)) # 隐藏层, 300
model.add(Activation('tanh'))
model.add(Dropout(0.3)) # 30% dropout
model.add(Dense(10))
model.add(Activation('softmax'))
# 编译模型
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss = 'categorical_crossentropy', optimizer=sgd)
return model
Keras的Sequential模型,也即序贯模型,也就是单输入单输出,一条路通到底,层与层之间只有相邻关系,跨层连接统统没有。这种模型编译速度快,操作上也比较简单。在Keras 0.x中还有图模型,但是Keras1和Keras2中已被移除,只保留序贯模型
Dense就是常用的全连接层,500代表该层的输出维度,784是像素维度即28*28
Dropout层为输入数据施加Dropout。Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,Dropout层用于防止过拟合。
Activation激活函数选择tanh
,activation:激活函数,如果不指定该参数,将不会使用任何激活函数(即使用线性激活函数:a(x)=x)
最后用softmax函数将预测结果转换为标签的概率值
3.训练及测试准确率
def run():
(x_train, y_train), (x_test, y_test) = load_data()
X_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2])
X_test = x_test.reshape(x_test.shape[0], x_test.shape[1] * x_test.shape[2])
Y_train = (np.arange(10) == y_train[:, None]).astype(int)
Y_test = (np.arange(10) == y_test[:, None]).astype(int)
model = train()
model.fit(X_train, Y_train, batch_size=200, epochs=10, shuffle=True, verbose=1, validation_split=0.3)
print("Start Test.....\n")
scores = model.evaluate(X_test, Y_test, batch_size=200, verbose=1)
print("The Test Loss: %f" % scores[0])
训练主要是调用fit
方法
准确率测试evaluate
方法