在数据处理和网络定义完成后,跑模型时突然出现了错误:
刚开始也不知道哪里的问题,发现有可能是内存耗尽了,然后就放进去500张图片进行fit,然后问题就消失了,猜想应该是数据太大,内存开销不够。
发现官方文档中说可以使用fit_generator()分批训练。
官方文档如下:
fit_generator(self, generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。
参数:
- generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
- steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
- epochs:整数,在数据集上迭代的总数。
- works:在使用基于进程的线程时,最多需要启动的进程数量。
- use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。
例子:
datagen = ImageDataGenator(...)
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
validation_data=(x_test, y_test),
workers=4)