参考:https://www.jianshu.com/p/6e22d21c84be
在实现一个深度学习项目时,除了搭建模型的网络结构,
更重要的一点是处理好项目需要的数据集,数据集主要有几个来源:一是经典的、框架封装好的数据集,可以在代码中直接下载使用;二是框架没有封装过的数据集,包括他人提供的数据集和自己建立的数据集,可以成为本地数据集。
我们知道pytorch有dataset和dataloader两个封装实体,他们是处理包装数据集作为输入的有力武器,目前pytorch(0.4)有两种思路去将我们的本地数据集封装进去:
一是建立自己的dataset数据集类,必须继承pyorch.utils.data.dataset实体类,并且重写它的几个方法,这点可以查看官方文档以及参考https://www.jianshu.com/p/6e22d21c84be。这个方法需要注意的是重写的方法他们的返回值类型,其中getitem方法返回的是{'image':image, 'label':label}的字典类型。利用dataloader封装我们自己实现的数据集实体,便可以加载作为输入。
二是利用pytorch.utils.data.tensordata类,这个封装类要求输入tensor类型的数据,它便可以替我们封装本地数据集为dataset。所以我们需要手动将数据集的图片和标签读取并转换为tensor数据,可以使用循环实现。
下面说说遇到的一些问题:
读取图片时用不同的Python模块,比如PIL,scipy,获得的数据类型是不一样的,需要注意方法一中重写transform时,要求输入的图片是PIL image类型。根据网络输入的size可以使用transforms改变图片的规格。
第三方的数据集标签可能会使用各种格式的文件存储,txt,mat,csv,h5,等等,这些都有对用的模块可以读取操作。需要注意的是读取的数据最好打印shape看看是否有多余的维度,这是经常出现的坑。
另外注意,要认真了解数据集的组成,比如stanford cars,里面竟然既有单通道图像也有3通道图像,,这真的把人坑惨了。防止处理数据和建立网络模型时参数匹配不上。
还有一个大坑,就是在多分类模型中,softmax输出之后的类别是从0开始的,比如10分类问题,网络输出范围是0-9,,然而一些第三方数据集提供的标签是从1开始的,即1-10,,,这就会在运行时报显卡cuda的异常。