写此文原因
网上其实有不少关于pytorch自定义数据集的tutorial,但是之所以要写这个,是因为我发现他们并没有结合一两个的神经网络来讲解。所以我觉得再写一个tutorial讲解关于如何读取任意的数据集,并且让某个网络训练该数据集还是有必要的。
在初学pytorch的时候,我们一般使用的是pytorch自带的一些数据集,比如 (代码参考1)
from torchvision.datasets.mnist import MNIST
...
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
....
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
引入MNIST数据集。最初始的训练网络是Lenet-5识别MNIST里面的数字。这就导致当你面对很多JPG, PNG的格式的torchvision.datasets
里没有的图像时,不知道怎么读取他们。这篇文章会带领大家读取自定义的数据集并训练他们。
最后的lenet5代码自定义数据集的实现请在我的github下载
https://github.com/zhaozhongch/Pytorch_Lenet5_CustomDataset
内容
下面我们从网上下载PNG格式的MNIST数据集。
git clone https://github.com/myleott/mnist_png.git
cd mnist_png
tar -xvf mnist_png.tar.gz #解压文件夹
解压之后在minst_png/mnist_png
文件夹里你会看到testing
和training
两个文件夹,进入testing
你会看到10个文件夹分别储存数字为0~9的图片。下面我们简单实现Lenet-5网络来识别图片中的数字。
Lenet5网络如下图
途中范例给的输入图片是32X32,实际我们上面的PNG图片大小是28X28,网络其他结构依次减小即可。
输入图片1通道28X28,输入给第一层
第一层卷积层,卷积核大小5X5,输出图像6通道,24X24,卷积之后接激励函数ReLU
第二层池化层,使用平均池化,池化核大小2X2,输出图像6通道,12X12
第三层卷积层,卷积核大小还是5X5,输出图像16通道,大小8X8。之后再接ReLu
第四层再接2X2池化。输出16通道,4X4大小图片。
第五层全连接层,先把16X4X4的图片"展平"为线性向量,再通过线性变换把图片"展平"为120维的变量,接ReLu
第六层再把120维降为84维,接ReLu
第七层再降为10维(对应0~9 10种数字可能性)输出。
讲解Lenet5并不是本文的重点,所以简单的说了上面的网络结构后我们就给出网络实现,对于细节不熟悉的新手可以参考文章1。
根据上面的网络结构,网络在pytorch中的实现如下
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.pool = nn.AvgPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.linear1 = nn.Linear(16*4*4, 120)
self.linear2 = nn.Linear(120,84)
self.linear3 = nn.Linear(84,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
理论上来说是很简单的。
那么针对网络对输入数据的要求,我们应该怎么把最开始下载的一堆图片输入进去呢?这就要用到pytorch里的Dataset
类了。
你需要定义一个类,继承Dataset
类,然后类里必须包含3个函数__init__
,__len__
,__getitem__
,具体结构如下
class ReadDataset(Dataset):
def __init__(self, 参数...):
def __len__(self, 参数...):
...
return 数据长度
def __getitem__(self, 参数...):
...
return 字典
__len__
需要返回一个表示数据长度的整型量,__getitem__
需要返回一个字典。ReadDataset
这个类名是自定义的,继承了Dataset
即可。
接下来的过程,我们先简单过一遍得到结果,再回看为什么这么做。
为了处理MNIST dataset,我们先把training
文件夹里的图像label读取进来
data_length = 60000
data_label = [-1] * data_length
prev_dir = './mnist_png/mnist_png/training/'
after_dir = '.png'
for id in range(10):
id_string = str(id)
for filename in glob(prev_dir + id_string +'/*.png'):
position = filename.replace(prev_dir+id_string+'/', '')
position = position.replace(after_dir, '')
data_label[int(position)] = id
这几行代码的作用,是把training
文件夹里的10个文件夹里的共计60000张图片放入到data_label
里。举个例子,图片编号为21的图,包含的数字是0(在training
文件夹的0
文件夹里),那么data_label[21] = 0
。
接下来定义继承了Dataset
类的ReadDataset
类,具体如下。
class ReadDataset(Dataset):
def __init__(self, imgs_dir, data_label):
self.imgs_dir = imgs_dir
self.ids = data_label
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
img = np.array(img)
img = img.reshape(1,28,28)
if img.max() > 1:
img = img / 255
return {'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
可以看到,构造函数__init__
里我们有两个参数,一个是imgs_dir
,图像地址,另一个是我们之前创建的列表data_label
,赋值给self.ids
. __len__()
仅仅是返回了data_label
的长度。
有趣的是__getitem__
函数,我们看到这个函数的参数是i
,传入了i
之后,我们首先根据ids
找到它对应的图像里所标识的数字,继而根据
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
找到图像并转化为黑白。之后再转化为np,再reshape。原图像读进来本来是28X28,但是根据网络的要求,输入需要是图像通道数X图像尺寸,黑白图片通道为1,所以我们reshape为1X28X28。最后图像的像素点的灰度值归一化到0到1.因为我们要使用cross entropy代价函数来训练,根据官网,要求cross entropy的矩阵输入的值为0到1。返回的内容格式必须是字典,我们这儿字典的内容图像和图像内对应的数字(label)是
{'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
这个getitem函数如果调用,最终达到的目的就是,假如我在代码中输入A = __getitem__(0)
,我就应该能得到0.png
对应的那张图像,获取图像的方式就是A['image'],获取图像是数字几的方式是A['label']
。
有了上面的内容作为铺垫,我们看看主函数里读取数据的具体操作。首先有下面一行内容
prev_dir = './mnist_png/mnist_png/testing/'
...
all_data = ReadDataset(prev_dir, data_label)
我们把prev_dir
和之前得到的data_label
作为参数传入了ReadDataset
并返回了all_data
。有的人可能说,诶,我没看到ReadDataset
有返回值呀。这是因为这些写在了Dataset
这个类里,不然继承它干什么呢。随后,我们把这个返回值赋值给DataLoader
,就可以定义从torchvision
里自带的MNIST dataset一样的操作了。
test_loader = DataLoader(all_data, batch_size=batch_size, shuffle=True, num_workers=4)
定义好batch_size,num_workers,代价函数这些之后,我们就可以在训练的时候使用返回值test_loader
了。
with torch.no_grad():
for data in test_loader:
images = data['image']
labels = data['label']
...
我们可以看到其实我们并没有显式地调用__getitem__
函数,而是通过data遍历test_loader
, data会自动根据ReadDataset
里ids的长度,从1到ids.length
来批量读取图像。如果你设置了batch_size
等于4,那么for循环的第一次循环,会调用__getitem__
四次,data['image']
就会返回__getitem__
,return {'image':...}
中image
所对应的内容。
设置代价函数这些不是本文的内容,就不细讲了。具体的可参见github代码。
可能大家看了上面的例子还是有些不明不白,因为虽然ReadDataset
这个类的内容就是定义三个函数,但是这三个函数具体的内容是什么,就需要根据实际情况确定了。我们上面的数据集的图像是分别储存在0~9个文件夹中,其他的数据可能不是这么储存的,就需要想新的办法获得那个data_label
列表。但是你的最终目的是很明白的,
1:getitem所返回的内容,需要能输入到网络里,比如我们的
images = data['image']
...
outputs = net(images.float())
2: 根据0到ids
的长度的indx,能遍历你想要使用的所有图像。
假想你显式调用__getitem__(0)
,你需要能获得名字为0.png
或者0.jpg
之类的图像的内容。
说这些不如多看两个例子再自己实践一下。上面的lenet5的例子之外,我在github里分别分开写了CPU的方法和GPU的方法,当然其实就一两行代码的事儿。不过考虑到这还是属于接近新手范畴的tutorial,就分开写了。
另外我还在github代码里提供了稍微复杂的网络UNET的实现,UNET是用来做语义分割的网络,不熟悉的同学可以自行看下语义分割是什么blabla。在UNET的这个网络里,我同样是读取的自定义的数据集而不是使用torchvision.dataset
里带的数据集。代码放于github
https://github.com/zhaozhongch/Pytorch_UNET_MultiObjects
当下用得最多的pytroch的UNET的实现还只是一个物体分割的(参考此处
),我顺便拔高了一下实现多个物体的语义分割了,不过最后语义分割的效果图不是非常好,因为懒得花时间去仔细fine tune了。但我相信作为tutorial级别的代码,我觉得跑一遍熟悉一下网络结构怎么定义,怎么自定义数据集,已经很够了。觉得还不错的可以给这俩仓库点个小star哈哈。
关于这两个网络实现或者其他内容不懂的同学欢迎私信。