CNN在mnist数据集上实现

这次我们使用CNN中最经典的Lenet网络在mnist数据集上进行训练和预测。

  • 卷积NN
    主要有两部分组成,一部分是对输入图片特征提取,一部分是全连接网络,主要组成操作包括卷积、池化、激活等。

  • Lenet网络模型
    Lenet是提出比较早,能有效解决手写数字图片识别的卷积模型,模型结构如下:


    0.PNG

其中,padding=valid代表非全0填充,输出图片尺寸=(输入尺寸-卷积核尺寸+1)/步长;padding=same代表全0填充,输出尺寸=输入尺寸/步长;pooling不改变深度。
对Lenet进行调整使其使用于mnist数据集,结构如下:


Lenet_on_mnist.PNG

实现还是分三模块:forward,backwa,test,主要改变是在forward:


lenet1.png

定义获得权重、偏执,增加对卷积,池化的函数。
lenet2.png

按上层结构前向传播,返回预测值。

backward和test跟上一篇中改动不大,主要是要注意输入的大小:

leb1.png

输入占位大小改变
leb2.png

喂入的barch_size大小改变
同理,在test文件中,测试数据的大小也相应改变。


新手学习,欢迎指教!

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容