在利用图像数据进行深度学习建模的任务中,如果数据集较小,我们需要进行Image Data Augmentation:对已有图片进行平移,剪切,垂直对称等操作形成新的图片。将新图片加入数据集,从而扩充数据集。Keras的内置函数ImageDataGenerator就是用来扩充图像数据集的。下面我们对Keras中的ImageDataGenerator的各项参数进行说明和使用策略。
1.ImageDataGenerator 类参数说明
keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, #将输入全部数据的均值设置为 0。一般不用。
samplewise_center=False, #将每个样本的均值设置为 0。一般不用。
featurewise_std_normalization=False,#将输入除以全部数据标准差。一般不用。
samplewise_std_normalization=False,#将输入除以其标准差。一般不用。
zca_whitening=False,#是否应用 ZCA 白化。
zca_epsilon=1e-06, #ZCA 白化的 epsilon 值。常用。
rotation_range=0,#整数。随机旋转的度数范围。常用。
width_shift_range=0.0,#浮点数,水平平移百分比,不宜太大一般0.1,0.2
height_shift_range=0.0,#浮点数,垂直平移百分比,不宜太大一般0.1,0.2
brightness_range=None,#浮点数,亮度调整。
shear_range=0.0,#浮点数,错切变换角度。
zoom_range=0.0,#浮点数[0,1],随机缩放。[llow,upp]:随机缩放范围。
channel_shift_range=0.0,#浮点数[0.0,255.0],图像上色。
fill_mode='nearest',#边界填充,一般默认。
cval=0.0,#一般不用。
horizontal_flip=False,#水平翻转,常用。
vertical_flip=False,#垂直翻转,看应用场景使用。
rescale=None,#数据缩放,常用:1/255.0。
preprocessing_function=None,
data_format=None,
validation_split=0.0,#验证集划分。常用。
dtype=None
)
2.使用案例
a.使用.flow,传入列表数据进行数据扩充。
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
#添加数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
#标签向量化
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)
#图片生成器
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True
)
# fit
datagen.fit(x_train)
# flow
datagen.flow(x_train, y_train, batch_size=32)
b.通过.flow_from_directory(directory)加载文件中的图片并进行扩充。注意:目录下需要有各个类别图像对应的文件夹。如:train文件夹下有cats(里面只有猫的图片),dogs(里面值有狗的图片)文件夹。
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
#不需要.fit()
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(32, 32),
batch_size=32)
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(32, 32),
batch_size=32)
小伙伴们如果觉得文章还行的请点个赞呦!!同时觉得文章哪里有问题的可以评论一下 谢谢你!