0 前言
要求:
- 安装keras库
- 使用gpu运行。使用gpu大概10分钟能跑完,cpu跑不太现实
- 没有gpu之解决方案,传送至Google colab使用之手把手教学
1 搭建模型
目标:
- 使用keras搭建简单的网络结构,用来预测mnist数据集中的手写数字
步骤:
- 从MNIST加载图片数据。
- 数据进行预处理
- 构建网络模型,模型结构如下图
- 模型结构可视化
- 模型性能评估以及预测
- 保存模型及参数,以便下次直接使用
在这里插入图片描述
1.1、加载数据
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.datasets import mnist
# 加载mnist数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# train中有6万张手写数字图片,test中有1万张手写数字图片
print (x_train.shape)
print (x_test.shape)
运行结果:
在这里插入图片描述
1.2、数据进行预处理
# 进行one-hot(独热编码)
# 进行one-hot编码是因为损失函数需要使用交叉嫡函数(cross_entropy)
# 交叉嫡函数详解 https://zhuanlan.zhihu.com/p/35709485
y_train = np_utils.to_categorical(y_train,10)
y_test = np_utils.to_categorical(y_test,10)
# 由于下载数据得到是uint类型,在神经网络无法进行合理运算,在这里将其转化为float32类型
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
# 由于conv2D函数需要这四维的图片数据,这里就reshape维度
x_train = x_train.reshape([-1,28,28,1])
x_test = x_test.reshape([-1,28,28,1])
np_utils.to_categorical(y_train,10),即进行独热编码,解释如下代码:
import matplotlib.pyplot as plt
from keras.utils import np_utils
# 以下说明to_categorical作用,相当于进行one-hot(独热编码)
print(np_utils.to_categorical([5,6,7],8))
运行结果:
在这里插入图片描述
1.3、构建网络模型
# 选择线性结构的网络模型
model = Sequential()
# 使用Conv2D函数的padding参数默认为‘valid’
# 默认使用valid
# For the VALID padding, the output height and width are computed as:
# out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
# For the SAME padding, the output height and width are computed as:
# out_height = ceil(float(in_height) / float(strides[1]))
# 知道padding以上“valid”和“same”两种模式,就能知道通过一个卷积层后,
# 模型的输出结构了
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
print (model.output_shape)
model.add(Conv2D(32, (3, 3), activation='relu'))
print (model.output_shape)
model.add(MaxPooling2D(pool_size=(2,2)))
print (model.output_shape)
model.add(Dropout(0.25))
print (model.output_shape)
model.add(Flatten())
print (model.output_shape)
model.add(Dense(100,activation="relu"))
print (model.output_shape)
model.add(Dropout(0.5))
print (model.output_shape)
model.add(Dense(10,activation="softmax"))
print (model.output_shape)
# 模型的损失函数使用交叉嫡函数,优化器使用“sgd”,即随机梯度下降
# 梯度下降详解 https://baijiahao.baidu.com/s?id=1613121229156499765&wfr=spider&for=pc
model.compile(loss="categorical_crossentropy",optimizer="sgd",metrics="acc")
model.fit(x_train,y_train,batch_size=64,epochs=10)
运行结果:
在这里插入图片描述
1.4、模型可视化
from keras.utils import plot_model
plot_model(model,show_shapes=True)
print (model.summary())
运行结果:
在这里插入图片描述
在这里插入图片描述
1.5、模型评估及预测
- 1、模型评估:
# 使用测试集合检验模型性能
loss = model.evaluate(x_test,y_test)
print (loss)
运行结果:
在这里插入图片描述
- 2、模型预测
import numpy as np
import matplotlib.pyplot as plt
# 为了方便显示,这里就显示测试集中前两个预测结果
for i in range(2):
plt.imshow(x_test[i].reshape((28,28)))
plt.show()
print(x_test[i].shape)
# 这里预测接收的结构也是四维的,故reshape
y_pred = model.predict(x_test[i].reshape(1,28,28,1))
print(y_pred)
print(np.argmax(y_pred))
运行结果:
在这里插入图片描述
1.6、保存模型及参数
from datetime import datetime
# 暂时保存模型为h5文件
file_name = datetime.now().strftime("dentify_writtern_number_%Y%m%d_%H%M/epoch:10-loss:0.2525.h5")
model.save(file_name)
运行无结果。生成文件截图如下:
在这里插入图片描述
- keras 对h5文件读取非常简单,详见 keras之读取h5文件(三)
- 模型保存还能做迁移学习,详见 keras之迁移学习小demo(四)
- h5文件是存储神经网络模型及参数的常用文件后缀
- h5文件详解
2 源代码
# 1、加载mnist数据集
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.datasets import mnist
# 加载mnist数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# train中有6万张手写数字图片,test中有1万张手写数字图片
print (x_train.shape)
print (x_test.shape)
# 2、数据预处理
# 进行one-hot(独热编码)
# 进行one-hot编码是因为损失函数需要使用交叉嫡函数(cross_entropy)
# 交叉嫡函数详解 https://zhuanlan.zhihu.com/p/35709485
y_train = np_utils.to_categorical(y_train,10)
y_test = np_utils.to_categorical(y_test,10)
# 由于下载数据得到是uint类型,在神经网络无法进行合理运算,在这里将其转化为float32类型
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
# 由于conv2D函数需要这四维的图片数据,这里就reshape维度
x_train = x_train.reshape([-1,28,28,1])
x_test = x_test.reshape([-1,28,28,1])
import matplotlib.pyplot as plt
from keras.utils import np_utils
# 以下说明to_categorical作用,相当于进行one-hot(独热编码)
print(np_utils.to_categorical([5,6,7],8))
# 3、构建网络结构
# 选择线性结构的网络模型
model = Sequential()
# 使用Conv2D函数的padding参数默认为‘valid’
# 默认使用valid
# For the VALID padding, the output height and width are computed as:
# out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
# For the SAME padding, the output height and width are computed as:
# out_height = ceil(float(in_height) / float(strides[1]))
# 知道padding以上“valid”和“same”两种模式,就能知道通过一个卷积层后,
# 模型的输出结构了
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
print (model.output_shape)
model.add(Conv2D(32, (3, 3), activation='relu'))
print (model.output_shape)
model.add(MaxPooling2D(pool_size=(2,2)))
print (model.output_shape)
model.add(Dropout(0.25))
print (model.output_shape)
model.add(Flatten())
print (model.output_shape)
model.add(Dense(100,activation="relu"))
print (model.output_shape)
model.add(Dropout(0.5))
print (model.output_shape)
model.add(Dense(10,activation="softmax"))
print (model.output_shape)
# 模型的损失函数使用交叉嫡函数,优化器使用“sgd”,即随机梯度下降
# 梯度下降详解 https://baijiahao.baidu.com/s?id=1613121229156499765&wfr=spider&for=pc
model.compile(loss="categorical_crossentropy",optimizer="sgd",metrics="acc")
model.fit(x_train,y_train,batch_size=64,epochs=10)
# 4、模型可视化
from keras.utils import plot_model
plot_model(model,show_shapes=True)
print (model.summary())
# 5、模型评估及预测
# 使用测试集合检验模型性能
loss = model.evaluate(x_test,y_test)
print (loss)
import numpy as np
import matplotlib.pyplot as plt
# 为了方便显示,这里就显示测试集中前两个预测结果
for i in range(2):
plt.imshow(x_test[i].reshape((28,28)))
plt.show()
print(x_test[i].shape)
# 这里预测接收的结构也是四维的,故reshape
y_pred = model.predict(x_test[i].reshape(1,28,28,1))
print(y_pred)
print(np.argmax(y_pred))
# 6、模型保存
from datetime import datetime
# 暂时保存模型为h5文件
file_name = datetime.now().strftime("dentify_writtern_number_%Y%m%d_%H%M/epoch:10-loss:0.2525.h5")
model.save(file_name)
如有疑惑,以下评论区留言。力所能及,必答之。