the fit_generator() function of Keras can solve an OOM problem.
because the default setting is model.fit, which load the entire dataset at once and cause an OOM problem.
1. how to use fit_generator()
model.fit_generator(self, generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
Below are some explainations of the parameters.
generator: you may need to create a generator before using it.
steps_per_epoch: equals to the num of x_train samples divided by batch_size of generator.
epochs: it should be consistent with the epochs of the entire network training.
verbose: 0,1 or 2. the method of displaying the recording log. 0 is to output log information without standard output stream, 1 is to output progress bar records, and 2 is to output one line record for each era.
callbacks: record model data like the value of some temporary results, loss, acc and so on.
validation_data: you can also create a generator like the first parameter.
class_weight: map classes to weight, can be used to solve imbalanced samples.
max_queue_size: the max capacity of the generator queues.
2. how to create a generator
just examples, cannot be carried out
"""逐步提取batch数据到显存,降低对显存的占用"""
ylen = len(y)
loopcount = ylen // batch_size
while (True):
i = randint(0,loopcount)
yield x[i * batch_size:(i + 1) * batch_size], y[i * batch_size:(i + 1) * batch_size]
from keras.preprocessing.image import ImageDataGenerator
#...
datagen = ImageDataGenerator(……)
model.fit_generator(datagen.flow(trainXNoisy, trainX,batch_size=500),
steps_per_epoch=60000/500,
validation_data=(testXNoisy, testX),
verbose=1,
epochs=EPOCHS,
callbacks=[tensorboard])
ref-link: (https://blog.csdn.net/sinat_26917383/article/details/74922230)
here is a screenshot after use a ImageDataGenerator and model.fit_generator(previous graphics memory usage percentage is 68% around)

3. what you should know:
Sometimes this method will increase the time cost of training corresponding to the generator's batch_size.
the smaller the batch_size be, the longer the time cost.
If your graphics card is not so bad, you'd better set a larger batch_size and its corresponding "steps_per_epoch", otherwise, you will find the large time cost is driving you crazy.