Tensorflow TFRecords及多线程训练介绍 ——详细

也可移步my github查看

先修知识——protocol buffer

TF框架中多处使用了protocol buffer,protocol buffer全称Google Protocol Buffer,简称Protobuf,是一种结构化数据存储格式,类似于常见的Json和xml,而且这种格式经过编译可以生成对应C++或Java或Python类的形式,即可以用编程语言读取或修改数据,不仅如此,还可以进一步将定义的结构化数据进行序列化,转化成二进制数据存下来或发送出去,非常适合做数据存储或 RPC 数据交换格式。更具体的介绍可以参考网上比较推荐的文章:Google Protocol Buffer 的使用和原理。其实TensorFlow计算图思想的实现也是基于protocol buffer的,感兴趣的可以看一下,本文主要介绍TFRecords,TFRecords是TF官方推荐使用的数据存储形式,也是使用了protocol buffer,下面结合TFRecords详细介绍其使用方法和原理。

protocol buffer的使用

参考Google Protocol Buffer 的使用和原理可以发现,要得到本地存储的序列化数据,需要先定义.proto 文件,再编译成编程语言描述的类,然后实例化该类(该类也已自动生成setter getter修改类和序列化类等方法),并序列化保存到本地或进行传输。TFRecords的思想也是将数据集中的数据以结构化的形式存到.proto中,然后序列化存储到本地,方便使用时读取并还原数据,只不过TF又对这个过程进行了一点封装,看起来和protocol buffer原始的使用方式略有差别。

protocol buffer中需要先将数据以结构化文件.proto的格式展现,然后可以编译成C++ Java 或python类进行后续操作,在TFRecords的应用中tf.train.Example类就是扮演了这一角色,TF中它的原始.proto文件定义在tensorflow/core/example/example.proto中,如下代码片:

message Example {
  Features features = 1;
};

可以看到Example类中封装的数据应该是features,是Features类型的,而Features在python代码中就对应了tf.train.Features类,其原始.proto文件定义在tensorflow/core/example/feature.proto中,如下代码片:

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

可以看到,Features中的数据又是feature(注意没有s),而feature属性的类型是map<string, Feature>类型,string不必说了,关键是Feature类型,和Features一样,Feature对应tf.train.Feature类,其原始.proto文件也定义在tensorflow/core/example/feature.proto中,如下代码片:

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1; # bytes_list float_list int64_list也是和之前一样,对应一个类
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

将数据集转化成TFRecords形式

TFRecords的定义过程就是使用了刚介绍的几个类:tf.train.Exampletf.train.Featurestf.train.Feature,知道了这几个类的定义以及它们的嵌套关系,再去理解TFRecords的产生就容易多了。
首先,使用tf.train.Example来封装我们的数据,然后使用tf.python_io.TFRecordWriter来写入磁盘,其中几个类的的嵌套方式和上述一致,见如下代码:

#本段代码来自[TensorFlow高效读取数据](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)

import os
import tensorflow as tf 
from PIL import Image
cwd = os.getcwd()
'''
此处我加载的数据目录如下:
0 -- img1.jpg
     img2.jpg
     img3.jpg
     ...
1 -- img1.jpg
     img2.jpg
     ...
2 -- ...
...
'''
# 先定义writer对象,writer负责将得到的记录写入TFRecords文件,此处为train.tfrecords文件
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + "/"
  # 一张一张的写入TFRecords文件
  for img_name in os.listdir(class_path):
    img_path = class_path + img_name
    img = Image.open(img_path)
    img = img.resize((224, 224)) #对图片做一些预处理操作
    img_raw = img.tobytes()     #将图片转化为原生bytes
    # 封装仅Example对象中
    example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
    writer.write(example.SerializeToString())  #序列化为字符串并写入磁盘
writer.close()

读取数据

以上存储数据时,Example调用SerializeToString()方法将自己序列化并由writer = tf.python_io.TFRecordWriter("train.tfrecords")对象保存,最终是将所有的图片文件和label保存到同一个tfrecords文件train.tfrecords中了。读取数据则以上过程的逆,先获取序列化数据,再解析:由tf.python_io.tf_record_iterator("train.tfrecords")方法(注意这个是方法)返回所有本地序列化文件迭代器,然后由Example调用ParseFromString()方法解析,代码如下:

for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
  # 本段代码来自[TensorFlow高效读取数据](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)
  example = tf.train.Example()
  # 进行解析
  example.ParseFromString(serialized_example)
  # 逐个读取example对象里封装的东西
  image = example.features.feature['image'].bytes_list.value
  label = example.features.feature['label'].int64_list.value
  # 可以做一些预处理之类的
  print image, labe

这是最基本的数据读取方式,tf.python_io.tf_record_iterator方法每次解析一个.tfrecords文件。而在实际应用中,由于数据集往往很大,所以往往将数据分开保存至多个tfrecords文件中,在这种情况下,TF提供了其他的接口进行读取,所以正常情况下我们可能不会使用上述的数据读取方式,以下才是重点,但必须强调的是整体的思想是一致的,都是先获取序列化文件,然后解析,只是接口函数稍有不同。

TF的多线程训练是TF框架重新设计的,不是简单地使用python语言多线程来搞得,很多时候TF多线程是和TFRecords配套使用的,下面介绍的数据读取方法也是多线程训练的数据读取方式。十图详解tensorflow数据读取机制这篇文章深入浅出>的介绍了TF多线程读取数据和训练的原理,多线程这一块接口多,也比较难以理解,下面仅从使用的角度出发谈谈我个人的理解,不详细追究里面的实现原理。

假设我们按照上述方式将数据保存到了两个tfrecords文件中,分别为'1.tfrecords'和'2.tfrecords',保存在DATA_ROOT路径中,那么我们分几步读取数据,参考如下代码:

    1. 读取tfrecords文件名到队列中,使用tf.train.string_input_producer函数,该函数可以接收一个文件名列表,并自动返回一个对应的文件名队列filename_queue,之所以用队列是为了后续多线程考虑(队列和多线程经常搭配使用)
    1. 实例化tf.TFRecordReader()类生成reader对象,接收filename_queue参数,并读取该队列中文件名对应的文件,得到serialized_example(读到的就是.tfrecords序列化文件)
    1. 解析,注意这里的解析不是用的Example对象里的函数,而是tf.parse_single_example函数,该函数能从serialized_example中解析出一条数据,当然也可以用tf.parse_example解析多条数据,此处暂不赘述。这里tf.parse_single_example函数传入参数serialized_examplefeatures,其中features是字典的形式,指定每个key的解析方式,比如image_raw使用tf.FixedLenFeature方法解析,这种解析方式返回一个Tensor,大多数解析方式也都是这种,另一种是tf.VarLenFeature方法,返回SparseTensor,用于处理稀疏数据,不赘述。这里还要注意必须告诉解析函数以何种数据类型解析,这必须与生成TFRecords文件时指定的数据类型一致。最后返回features是一个字典,里面存放了每一项的解析结果。
    1. 最后只要读出features中的数据即可。比如,features['label'],features['pixels']。但要注意的是,此时的image_raw依然是字符串类型的(可以看写入代码中的image_raw),需要进一步还原成像素数组,用TF提供的函数tf.decode_raw来搞定images = tf.decode_raw(features['image_raw'],tf.uint8)

至此,就定义好了完成一次数据读取的代码,有了它,下面的训练时的多线程方法就有了数据来源,下节讨论。

# 读取文件。
filename_queue = tf.train.string_input_producer(["Records/output.tfrecords"])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)

# 解析读取的样例。
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'pixels':tf.FixedLenFeature([],tf.int64),
        'label':tf.FixedLenFeature([],tf.int64)
    })

images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32) #需要用tf.cast做一个类型转换
pixels = tf.cast(features['pixels'],tf.int32)

# 下面的代码下节讨论
sess = tf.Session()

# 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
    image, label, pixel = sess.run([images, labels, pixels])

TF多线程机制

假设已将数据集文件转换成了TFRecords格式,共两个文件,每个文件中存储两条数据,两个文件如下,下面用多线程的方式读取并训练,分为以下几个步骤:

/patah/to/data.tfrecords-00000-of-00002
/patah/to/data.tfrecords-00001-of-00002
    1. 获取TFRecords文件队列。TF提供了tf.train.match_filenames_once函数帮助获取所有满足条件的TFRecords文件,tf.train.match_filenames_once函数参数为正则表达式,返回匹配上的所有文件名集合变量。当然,也可以选择不用该函数,用纯python也可以匹配,python的话最终返回一个list类型即可,但正规起见,还是推荐使用TF提供的方法。然后tf.train.string_input_producer函数依此生成文件名队列filename_queue
files = tf.train.match_filenames_once("/patah/to/data.tfrecords-*") # 
filename_queue = tf.train.string_input_producer(files, shuffle=False)
    1. 解析TFRecords文件中的数据,和上面一样,不赘述。
# 读取文件。
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)

# 解析读取的样例。
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'pixels':tf.FixedLenFeature([],tf.int64),
        'label':tf.FixedLenFeature([],tf.int64)
    })

decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)
#pixels = tf.cast(features['pixels'],tf.int32)
# 最后只要labels和images
labels = tf.cast(features['label'],tf.int32)
images = tf.reshape(retyped_images, [784])
  • 3)将读取到的数据打包为batch。上一段代码得到了labelsimages,这是一条数据,训练一次需要一个batch的数据,怎么搞?难道将上述代码用for循环反复执行batch_size次?这样做未尝不可,但效率很低,TF提供了tf.train.shuffle_batch函数,上述解析代码只要提供一次,然后将labelsimages作为tf.train.shuffle_batch函数的参数,tf.train.shuffle_batch就能自动获取到一个batch的labelsimagestf.train.shuffle_batch函数获取batch的过程需要生成一个队列(加入计算图中),然后一个一个入队labelsimages,然后出队组合batch。关于里面参数的解释,batch_size就是batch的大小,capacity指的是队列的容量,比如capacity设为1,而batch_szie为3,那么组成一个batch的过程中,出队的操作就会因为数据不足而频繁地被阻塞来等待入队加入数据,运行效率很低。相反,如果capacity被设置的很大,比如设为1000,而batch_size设置为3,那么入队操作在空闲时就会频繁入队,供过于求并非坏事,糟糕的是这样会占用很多内存资源,而且没有得到多少效率上的提升。还有一点值得注意,当使用tf.train.shuffle_batch时,为了使得shuffle效果好一点,出队后队列剩余元素必须得足够多,因为太少的话也没什么必要打乱了,因此tf.train.shuffle_batch函数要求提供min_after_dequeue参数来保证出队后队内元素足够多,这样队列就会等队内元素足够多时才会出队。显而易见,capacity必须大于min_after_dequeue。关于capacitymin_after_dequeue的设置,参考《TensorFlow 实战Google深度学习框架》,给出了设置capacity大小的一种比较科学的方式,min_after_dequeue根据数据集大小和batch_size综合考虑,而capacity则设置为capacity= min_after_dequeue+ 3*batch_size,在效率和资源占用之间取得平衡。组合batch_size的代码如下:
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size

image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
                                                    batch_size=batch_size, 
                                                    capacity=capacity, 
                                                    min_after_dequeue=min_after_dequeue)
    1. 启动多线程训练模型。训练过程和单线程的基本一致,唯一的区别就是多了一个tf.train.start_queue_runners函数,这个函数中传入参数sess,就可以做到多线程训练,具体地细节还不是很了解,但照壶画瓢应该没问题了,有空再深挖下。
# 前向传播
y = inference(image_batch)
    
# 计算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=label_batch)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
    
# 计算最后的损失函数(加入正则化)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
regularaztion = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularaztion

# 优化损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    
# 初始化会话,并开始训练过程。
with tf.Session() as sess:
  # 初始化所有变量
  tf.global_variables_initializer().run()
   
  coord = tf.train.Coordinator()

  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  # 循环的训练神经网络。
  for i in range(TRAINING_STEPS):
    if i % 1000 == 0:
      print("After %d training step(s), loss is %g " % (i, sess.run(loss)))              
    sess.run(train_step) 

    coord.request_stop()
    coord.join(threads

参考

TensorFlow高效读取数据

Google Protocol Buffer 的使用和原理

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容