今天来看下CNN的具体调用流程,这里我们以图像分类的例子说明
CNN结构
其中:
mnist:获取数据。
to_categorical:进行one-hot编码。
后面的几个就是创建模型所需要的库。
这里我们重点说下模型构建:
Sequential:构建序列化容器,后面的所有层其实就都追加在该序列后了。
add:追加网络层。
Conv2D:添加2维卷积层,其中第一层32指的是32个滤波器,每个滤波器大小是(5,5),激活函数使用relu,input_shap需要指定图像尺寸,这里是高28、宽28、1通道图像。
MaxPool2D:池化层。
Flatten:多维数据转换为一维,多用于卷积到全连接层中间。
Dense:全连接层。
Dropout:降低过拟合风险。
模型训练的两个重要参数:
batch_size:训练多少样本更新一次参数,这里是每次训练100个样本更新一次权重参数。
epochs:数据集被用来训练多少轮,需要寻优,过多容易过拟合。
对于模型参数选择,需要看训练集和测试集的指标变化趋势,从指标变化图来确定最优参数组合。