pytorch自定义数据集的读取

写此文原因

网上其实有不少关于pytorch自定义数据集的tutorial,但是之所以要写这个,是因为我发现他们并没有结合一两个的神经网络来讲解。所以我觉得再写一个tutorial讲解关于如何读取任意的数据集,并且让某个网络训练该数据集还是有必要的。
在初学pytorch的时候,我们一般使用的是pytorch自带的一些数据集,比如 (代码参考1)

from torchvision.datasets.mnist import MNIST
...
data_train = MNIST('./data/mnist',
                   download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]))
....
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)

引入MNIST数据集。最初始的训练网络是Lenet-5识别MNIST里面的数字。这就导致当你面对很多JPG, PNG的格式的torchvision.datasets里没有的图像时,不知道怎么读取他们。这篇文章会带领大家读取自定义的数据集并训练他们。
最后的lenet5代码自定义数据集的实现请在我的github下载
https://github.com/zhaozhongch/Pytorch_Lenet5_CustomDataset

内容

下面我们从网上下载PNG格式的MNIST数据集。

git clone https://github.com/myleott/mnist_png.git
cd mnist_png
tar -xvf mnist_png.tar.gz #解压文件夹

解压之后在minst_png/mnist_png文件夹里你会看到testingtraining两个文件夹,进入testing你会看到10个文件夹分别储存数字为0~9的图片。下面我们简单实现Lenet-5网络来识别图片中的数字。
Lenet5网络如下图

lenet5.png

途中范例给的输入图片是32X32,实际我们上面的PNG图片大小是28X28,网络其他结构依次减小即可。
输入图片1通道28X28,输入给第一层
第一层卷积层,卷积核大小5X5,输出图像6通道,24X24,卷积之后接激励函数ReLU
第二层池化层,使用平均池化,池化核大小2X2,输出图像6通道,12X12
第三层卷积层,卷积核大小还是5X5,输出图像16通道,大小8X8。之后再接ReLu
第四层再接2X2池化。输出16通道,4X4大小图片。
第五层全连接层,先把16X4X4的图片"展平"为线性向量,再通过线性变换把图片"展平"为120维的变量,接ReLu
第六层再把120维降为84维,接ReLu
第七层再降为10维(对应0~9 10种数字可能性)输出。
讲解Lenet5并不是本文的重点,所以简单的说了上面的网络结构后我们就给出网络实现,对于细节不熟悉的新手可以参考文章1
根据上面的网络结构,网络在pytorch中的实现如下

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1,6,5)
        self.pool = nn.AvgPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.linear1 = nn.Linear(16*4*4, 120)
        self.linear2 = nn.Linear(120,84)
        self.linear3 = nn.Linear(84,10)
    
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

理论上来说是很简单的。
那么针对网络对输入数据的要求,我们应该怎么把最开始下载的一堆图片输入进去呢?这就要用到pytorch里的Dataset类了。
你需要定义一个类,继承Dataset类,然后类里必须包含3个函数__init____len____getitem__,具体结构如下

class ReadDataset(Dataset):
    def __init__(self, 参数...):

    def __len__(self, 参数...):
        ...
        return 数据长度

    def __getitem__(self, 参数...):
        ...
        return 字典

__len__需要返回一个表示数据长度的整型量,__getitem__需要返回一个字典。ReadDataset这个类名是自定义的,继承了Dataset即可。
接下来的过程,我们先简单过一遍得到结果,再回看为什么这么做。
为了处理MNIST dataset,我们先把training文件夹里的图像label读取进来

data_length = 60000
data_label = [-1] * data_length
prev_dir = './mnist_png/mnist_png/training/'
after_dir = '.png'

for id in range(10):
    id_string = str(id)
    for filename in glob(prev_dir + id_string +'/*.png'):
        position = filename.replace(prev_dir+id_string+'/', '')
        position = position.replace(after_dir, '')
        data_label[int(position)] = id

这几行代码的作用,是把training文件夹里的10个文件夹里的共计60000张图片放入到data_label里。举个例子,图片编号为21的图,包含的数字是0(在training文件夹的0文件夹里),那么data_label[21] = 0
接下来定义继承了Dataset类的ReadDataset类,具体如下。

class ReadDataset(Dataset):
    def __init__(self, imgs_dir, data_label):
        self.imgs_dir = imgs_dir
        self.ids = data_label

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i):
        idx = self.ids[i]
        imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
        img = Image.open(imgs_file).convert('L')
        img = np.array(img)
        img = img.reshape(1,28,28)
        if img.max() > 1:
            img = img / 255
        return {'image': torch.from_numpy(img), 'label': torch.tensor(idx)}

可以看到,构造函数__init__里我们有两个参数,一个是imgs_dir,图像地址,另一个是我们之前创建的列表data_label,赋值给self.ids. __len__()仅仅是返回了data_label的长度。
有趣的是__getitem__函数,我们看到这个函数的参数是i,传入了i之后,我们首先根据ids找到它对应的图像里所标识的数字,继而根据

imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')

找到图像并转化为黑白。之后再转化为np,再reshape。原图像读进来本来是28X28,但是根据网络的要求,输入需要是图像通道数X图像尺寸,黑白图片通道为1,所以我们reshape为1X28X28。最后图像的像素点的灰度值归一化到0到1.因为我们要使用cross entropy代价函数来训练,根据官网,要求cross entropy的矩阵输入的值为0到1。返回的内容格式必须是字典,我们这儿字典的内容图像和图像内对应的数字(label)是

{'image': torch.from_numpy(img), 'label': torch.tensor(idx)}

这个getitem函数如果调用,最终达到的目的就是,假如我在代码中输入A = __getitem__(0),我就应该能得到0.png对应的那张图像,获取图像的方式就是A['image'],获取图像是数字几的方式是A['label']

有了上面的内容作为铺垫,我们看看主函数里读取数据的具体操作。首先有下面一行内容

prev_dir = './mnist_png/mnist_png/testing/'
...
all_data = ReadDataset(prev_dir, data_label)

我们把prev_dir和之前得到的data_label作为参数传入了ReadDataset并返回了all_data。有的人可能说,诶,我没看到ReadDataset有返回值呀。这是因为这些写在了Dataset这个类里,不然继承它干什么呢。随后,我们把这个返回值赋值给DataLoader,就可以定义从torchvision里自带的MNIST dataset一样的操作了。

test_loader = DataLoader(all_data, batch_size=batch_size, shuffle=True, num_workers=4)

定义好batch_size,num_workers,代价函数这些之后,我们就可以在训练的时候使用返回值test_loader了。

with torch.no_grad():
    for data in test_loader:
        images = data['image']
        labels = data['label']
...

我们可以看到其实我们并没有显式地调用__getitem__函数,而是通过data遍历test_loader, data会自动根据ReadDataset里ids的长度,从1到ids.length来批量读取图像。如果你设置了batch_size等于4,那么for循环的第一次循环,会调用__getitem__四次,data['image']就会返回__getitem__,return {'image':...}image所对应的内容。
设置代价函数这些不是本文的内容,就不细讲了。具体的可参见github代码。
可能大家看了上面的例子还是有些不明不白,因为虽然ReadDataset这个类的内容就是定义三个函数,但是这三个函数具体的内容是什么,就需要根据实际情况确定了。我们上面的数据集的图像是分别储存在0~9个文件夹中,其他的数据可能不是这么储存的,就需要想新的办法获得那个data_label列表。但是你的最终目的是很明白的,
1:getitem所返回的内容,需要能输入到网络里,比如我们的

images = data['image']
...
outputs = net(images.float())

2: 根据0到ids的长度的indx,能遍历你想要使用的所有图像
假想你显式调用__getitem__(0) ,你需要能获得名字为0.png或者0.jpg之类的图像的内容。
说这些不如多看两个例子再自己实践一下。上面的lenet5的例子之外,我在github里分别分开写了CPU的方法和GPU的方法,当然其实就一两行代码的事儿。不过考虑到这还是属于接近新手范畴的tutorial,就分开写了。
另外我还在github代码里提供了稍微复杂的网络UNET的实现,UNET是用来做语义分割的网络,不熟悉的同学可以自行看下语义分割是什么blabla。在UNET的这个网络里,我同样是读取的自定义的数据集而不是使用torchvision.dataset里带的数据集。代码放于github
https://github.com/zhaozhongch/Pytorch_UNET_MultiObjects
当下用得最多的pytroch的UNET的实现还只是一个物体分割的(参考此处
),我顺便拔高了一下实现多个物体的语义分割了,不过最后语义分割的效果图不是非常好,因为懒得花时间去仔细fine tune了。但我相信作为tutorial级别的代码,我觉得跑一遍熟悉一下网络结构怎么定义,怎么自定义数据集,已经很够了。觉得还不错的可以给这俩仓库点个小star哈哈。
关于这两个网络实现或者其他内容不懂的同学欢迎私信。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,172评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,346评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,788评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,299评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,409评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,467评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,476评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,262评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,699评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,994评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,167评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,827评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,499评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,149评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,387评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,028评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,055评论 2 352