Fit_generator

the fit_generator() function of Keras can solve an OOM problem.

because the default setting is model.fit, which load the entire dataset at once and cause an OOM problem.

1. how to use fit_generator()
model.fit_generator(self, generator, 
                    steps_per_epoch=None, 
                    epochs=1, 
                    verbose=1, 
                    callbacks=None, 
                    validation_data=None, 
                    validation_steps=None,  
                    class_weight=None,
                    max_queue_size=10,   
                    workers=1, 
                    use_multiprocessing=False, 
                    shuffle=True, 
                    initial_epoch=0)

Below are some explainations of the parameters.

generator: you may need to create a generator before using it.
steps_per_epoch: equals to the num of x_train samples divided by batch_size of generator.
epochs: it should be consistent with the epochs of the entire network training.
verbose: 0,1 or 2. the method of displaying the recording log. 0 is to output log information without standard output stream, 1 is to output progress bar records, and 2 is to output one line record for each era.
callbacks: record model data like the value of some temporary results, loss, acc and so on.
validation_data: you can also create a generator like the first parameter.
class_weight: map classes to weight, can be used to solve imbalanced samples.
max_queue_size: the max capacity of the generator queues.

2. how to create a generator

just examples, cannot be carried out

   """逐步提取batch数据到显存,降低对显存的占用"""
   ylen = len(y)
   loopcount = ylen // batch_size
   while (True):
       i = randint(0,loopcount)
       yield x[i * batch_size:(i + 1) * batch_size], y[i * batch_size:(i + 1) * batch_size]
from keras.preprocessing.image import ImageDataGenerator
#...
datagen = ImageDataGenerator(……)

model.fit_generator(datagen.flow(trainXNoisy, trainX,batch_size=500),
                    steps_per_epoch=60000/500, 
                    validation_data=(testXNoisy, testX), 
                    verbose=1, 
                    epochs=EPOCHS, 
                    callbacks=[tensorboard])

ref-link: (https://blog.csdn.net/sinat_26917383/article/details/74922230)

here is a screenshot after use a ImageDataGenerator and model.fit_generator(previous graphics memory usage percentage is 68% around)


QQ图片20200227160155.png
3. what you should know:

Sometimes this method will increase the time cost of training corresponding to the generator's batch_size.
the smaller the batch_size be, the longer the time cost.
If your graphics card is not so bad, you'd better set a larger batch_size and its corresponding "steps_per_epoch", otherwise, you will find the large time cost is driving you crazy.

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • pyspark.sql模块 模块上下文 Spark SQL和DataFrames的重要类: pyspark.sql...
    mpro阅读 9,928评论 0 13
  • 本文转载自知乎 作者:季子乌 笔记版权归笔记作者所有 其中英文语句取自:英语流利说-懂你英语 ——————————...
    Danny_Edward阅读 44,095评论 4 38
  • 今天看了前言,目录还有第0章,第1章。完全被书的版面设计所吸引,除了有旁白解释,有黑体字标重点,竟然还有小插画,读...
    小小Qiang2018阅读 162评论 0 0
  • 非常不幸 她将自己 过成了东郭先生 不遗余力的救助 只不过 延长了狼的寿命 却改变不了狼的本性 东郭先生的命运 是...
    朝花夕拾01阅读 871评论 2 10
  • 一想起自己几天以前落下的文章就很是让人头疼,还有就是自己欠下的字量越来越不堪重负,不知自己何时能够一气呵成,那需要...
    无念无求阅读 300评论 2 4

友情链接更多精彩内容