1、什么是数据增强
数据增强是扩充数据样本规模的一种有效地方法。深度学习是基于大数据的一种方法,我们当前希望数据的规模越大、质量越高越好。模型才能够有着更好的泛化能力,然而实际采集数据的时候,往往很难覆盖掉全部的场景,比如:对于光照条件,在采集图像数据时,我们很难控制光线的比例,因此在训练模型的时候,就需要加入光照变化方面的数据增强。再有一方面就是数据的获取也需要大量的成本,如果能够自动化的生成各种训练数据,就能做到更好的开源节流。
2、数据增强的作用
- 增加训练的数据量,提高模型的泛化能力
- 增加噪声数据,提升模型的鲁棒性
3、如何进行数据增强
数据增强可以分为两类,一类是离线增强,一类是在线增强。
- 离线增强 : 直接对数据集进行处理,数据的数目会变成增强因子 x 原数据集的数目 ,这种方法常常用于数据集很小的时候.
- 在线增强 : 这种增强的方法用于,获得 batch 数据之后,然后对这个 batch 的数据进行增强,如旋转、平移、翻折等相应的变化,由于有些数据集不能接受线性级别的增长,这种方法长用于大的数据集,很多机器学习框架已经支持了这种数据增强方式,并且可以使用 GPU 优化计算。
4、pytorch数据增强操作
pytorch中数据增强的常用方法如下:
- 对图片进行一定比例的缩放
- 对图片进行随机的截取
- 对图片进行随机水平和竖直翻转
- 对图片进行随机角度的旋转
- 对图片进行亮度、对比度和颜色的随机变化等
torchvision中内置的transforms包含了这些些常用的图像变换,这些变换能够用Compose串联组合起来。
from PIL import Image
from torchvision import transforms as tfs
img = Image.open('./dog.jpg')
print('原图:')
img
原图:
4.1、中心处裁剪PIL图片
class torchvision.transforms.CenterCrop(size)
- size(序列 或 int)– 需要裁剪出的形状。如果size是int,将会裁剪成正方形;如果是形如(h, w)的序列,将会裁剪成矩形。
print('原图像尺寸:{}'.format(img.size))
re_img = tfs.CenterCrop(200)(img)
print('中心裁剪后尺寸:{}'.format(re_img.size))
re_img
原图像尺寸:(658, 411)
中心裁剪后尺寸:(200, 200)
4.2 随机改变图片的亮度、对比度和饱和度
class torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
- brightness(float或 float类型元组(min, max))– 亮度的扰动幅度。应当是非负数。
- contrast(float或 float类型元组(min, max))– 对比度扰动幅度。应当是非负数。
saturation(float或 float类型元组(min, max))– 饱和度扰动幅度。应当是非负数。
hue(float或 float类型元组(min, max))– 色相扰动幅度。hue_factor从[-hue, hue]中随机采样产生,其值应当满足0<= hue <= 0.5或-0.5 <= min <= max <= 0.5
cj_img = tfs.ColorJitter(0.8, 0.8, 0.5)(img)
cj_img
4.3 图片转换为灰阶
class torchvision.transforms.Grayscale(num_output_channels=1))
- num_output_channels(int,1或3)– 希望得到的图片通道数。
gc_img = tfs.Grayscale(1)(img)
gc_img
4.4 图像的各条边缘进行扩展
class torchvision.transforms.Pad(padding, fill=0, padding_mode='constant')
- padding(int 或 tuple)– 在每条边上展开的宽度。如果传入的是单个int,就在所有边展开。如果传入长为2的元组,则指定左右和上下的展开宽度。如果传入长为4的元组,则依次指定为左、上、右、下的展开宽度。
- fill(int 或 tuple) – 像素填充值。默认是0。如果指定长度为3的元组,表示分别填充R, G, B通道。这个参数仅在padding_mode是‘constant’时指定有效。
- padding_mode(str)– 展开类型。应当是‘constant’,‘edge’,‘reflect’或‘symmetric’之一。默认为‘constant’。
- constant:用常数扩展,这个值由fill参数指定。
- edge:用图像边缘上的值填充。
- reflect:以边缘为对称轴进行轴对称填充(边缘值不重复)。
- symmetric:用图像边缘的反转进行填充(图像的边缘值需要重复)。
# 用常数0填充
con_img = tfs.Pad(50, fill=0, padding_mode='constant')(img)
con_img
# 用图像边缘值填充
edge_img = tfs.Pad(50, fill=0, padding_mode='edge')(img)
edge_img
# 以边缘为对称轴进行轴对称填充
ref_img = tfs.Pad(50, fill=0, padding_mode='reflect')(img)
ref_img
4.5 图片在随机位置处进行裁剪
class torchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)
- size(序列 或 int)– 想要裁剪出的图片的形状。如果size是int,按照正方形(size, size)裁剪; 如果size是序列(h, w),裁剪为矩形。
- padding(int 或 序列 , 可选)– 在图像的边缘进行填充,默认0,即不做填充。如果指定长为4的序列,则分别指定左、上、右、下的填充宽度。
- pad_if_needed(boolean)– 如果设置为True,若图片小于目标形状,将进行填充以避免报异常。
rc_img = tfs.RandomCrop(200)(img)
rc_img
4.6 以给定的概率随机水平翻折PIL图片
class torchvision.transforms.RandomHorizontalFlip(p=0.5)
- p(float)– 翻折图片的概率。默认0.5。
rh_img = tfs.RandomHorizontalFlip(1)(img)
rh_img
4.7 以给定的概率随机垂直翻折PIL图片
class torchvision.transforms.RandomVerticalFlip(p=0.5)
- p(float)– 翻折图片的概率。默认0.5。
rv_img = tfs.RandomVerticalFlip(1)(img)
rv_img
4.8 以指定的角度选装图片
class torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None)
- degrees(序列 或 float or int)– 旋转角度的随机选取范围。如果degrees是序列(min, max),则从中随机选取;如果是数字,则选择范围是(-degrees, +degrees)。
- resample({PIL.Image.NEAREST , PIL.Image.BILINEAR , PIL.Image.BICUBIC} , 可选) – 可选的重采样滤波器。如果该选项忽略,或图片模式是“1”或者“P”则设置为PIL.Image.NEAREST。
- expand(bool, 可选)– 可选的扩展标志。如果设置为True, 将输出扩展到足够大从而能容纳全图。如果设置为False或不设置,输出图片将和输入同样大。注意expand标志要求 flag assumes rotation around the center and no translation。
- center(2-tuple , 可选)– 可选的旋转中心坐标。以左上角为原点计算。默认是图像中心。
rr_img = tfs.RandomRotation(45)(img)
rr_img
以上都是对图像做单次变换,torchvision提供torchvision.transforms.Compose()函数,可以将以上图像方法联合起来使用,比如先做随机翻转,然后随机截取,再做对比度增强等。
import matplotlib.pyplot as plt
%matplotlib inline
aug_img = tfs.Compose([
tfs.Resize(200),
tfs.RandomHorizontalFlip(),
tfs.RandomCrop(120),
tfs.RandomVerticalFlip(),
tfs.ColorJitter(0.5, 0.5, 0.5)
])
_, figs = plt.subplots(3, 3, figsize=(10, 10))
for i in range(3):
for j in range(3):
figs[i][j].imshow(aug_img(img))
figs[i][j].axes.get_xaxis().set_visible(False)
figs[i][j].axes.get_yaxis().set_visible(False)