tfrecords读取文件

以前做科研论文的时候, 所使用的音频数据比较少, 所以都是直接读进内存中在feeding给placeholder。现在在做一些偏工程的项目时就发现远远不行了,feeding训练速度远远提不上来。所以这两周都在为训练提速而折磨。在此记录下来尝试的方式。
tensorflow推荐使用tfrecords来存储数据, 这样能加快数据的读取。

def convert_to_tfrecord(loader):
    ''' modefy batch_size=1 in './conf/train_ce.conf' before convert to tfrecord data format '''
    def write_tfrecords(queue, i):
        start_time = time.time()
        while queue.empty():
            if time.time()-start_time > 600:               #超时队列中还没有数据该进程就退出
                print('wait timeout! proc %d exit!'%i)
                exit()
            time.sleep(1)
        writer = tf.python_io.TFRecordWriter('./train_input/tfrecords_file/train_dataset_%d.tfrecords'%i)
        while queue.qsize():
            batch = queue.get()                            # 从队列中获取一个样本
            example = tf.train.Example(features=tf.train.Features(feature={
                'feature':  tf.train.Feature(float_list=tf.train.FloatList(value=batch[0].flatten())),
                'label':    tf.train.Feature(int64_list=tf.train.Int64List(value=batch[1].flatten())),
                'mask':     tf.train.Feature(int64_list=tf.train.Int64List(value=batch[2].flatten())),
                'length':   tf.train.Feature(int64_list=tf.train.Int64List(value=[batch[3][0][0]]))
                #'feature_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(batch[0].shape)))
            }))                                            # 这里将二维的feature label mask 转为一维的进行存储
            writer.write(example.SerializeToString())
        writer.close()

    start = time.time()
    queue = Queue(512)
    proc_record = []
    for i in range(10):
        p = Process(target=write_tfrecords, args=(queue, i)) #开10个进程用来写入数据
        p.start()
        proc_record.append(p)
    num = 0
    while True:
        try:
            batch = loader.next()                            # 获取一个样本, 压入队列
        except StopIteration:
            tf.logging.info('finished convert to tfrecords')
            break
        if batch is not None:
            queue.put(batch)
            num += 1
        else:
            break
    for p in proc_record:   p.join()                        # 等待所有进程结束
    print('num:', num)
    print('time:', time.time()-start)

程序写了一个多进程写入tfrecords, 在主进程中读取数据压入队列,再开辟10个进程从队列中读取数据, 因为我的loader.next加载数据比较长,所以在子进程中设置了循环等待。

在尝试过多线程, 应为python GIL的原因, 所以速度没有提升, 改成了多进程。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 必备的理论基础 1.操作系统作用: 隐藏丑陋复杂的硬件接口,提供良好的抽象接口。 管理调度进程,并将多个进程对硬件...
    drfung阅读 3,595评论 0 5
  • 一. 操作系统概念 操作系统位于底层硬件与应用软件之间的一层.工作方式: 向下管理硬件,向上提供接口.操作系统进行...
    月亮是我踢弯得阅读 6,033评论 3 28
  • TF官网上给出了三种读取数据的方式: Preloaded data: 预加载数据 Feeding: Python ...
    Liu91阅读 8,703评论 0 9
  • Tensorflow的数据读取有三种方式: Preloaded data: 预加载数据,也就是TensorFlow...
    是neinei啊阅读 3,864评论 0 2
  • 题目1: ajax 是什么?有什么作用? 是一种用于概括异步加载页面内容的技术,AJAX 可以使网页实现异步更新。...
    大大的萝卜阅读 228评论 0 0