在使用 TensorFlow 的时候,很多地方都会遇到 protobuf (.proto) 文件,比如配置 TensorFlow Detection API 的过程中需要执行如下语句:
$ protoc object_detection/protos/*.proto --python_out=.
表示把文件夹 protos 下的所有 .proto 文件转化对应的 .py 文件(配置过程见 TensorFlow 训练自己的目标检测器)。 之后,再借助这些转化来的 .py 文件就可以读取特定格式的 .pbtxt, .config 等文件了。比如,可以使用 string_int_label_map.proto 转化来的 string_int_label_map_pb2.py 文件来读取目标检测与实例分割 的类名与类标号配置文件(这种文件以 .pbtxt 作为后缀名,假设要检测 person 和 car 等类目标):
item {
id: 1
name: 'person'
}
item {
id: 2
name: 'car'
}
...
Protocol Buffer(protobuf) 是谷歌开源的一种数据存储语言,每一个文件以 .proto 为后缀,它不依赖特定的语言与平台,扩展性极强。TensorFlow 内部的数据存储(比如 .pb 和 .ckpt 等)基本都使用 protobuf 格式。
下面以一个简单的例子来说明 protobuf 文件的读取。
1.protobuf 数据结构定义
假设我们的目的是读取如下文件(命名为 students.pbtxt,使用文本编辑器编辑):
student_info {
name: 'Zhang San';
age: 20;
sex: 0;
}
student_info {
name: 'Li Si';
age: 25;
sex: 0;
}
student_info {
name: 'Wang Wu';
age: 18;
sex: 1;
}
显然,这是一份简单的结构化数据,但若使用传统的数据读取方式且要快速方便的解析出其中的学生信息,却不容易。此时,如果使用 protobuf 则相当便捷。
首先,用文本编辑器编辑一个 .proto 文件(命名为:student_info.proto):
syntax = "proto3";
package proto_test;
message Student {
string name = 1;
int32 age = 2;
int32 sex = 3;
}
message StudentInfo {
repeated Student student_info = 1;
}
其中的 syntax
指定使用的 protobuf 版本,可以填写 proto2 (protobuf 2)和 proto3 (protobuf 3),这里使用的是后者。下面的 package
指定 student_info.proto 文件所在的文件夹名字。接下来,以关键字 message
开头定义了一个简单的数据结构 Student,里面包括三个可选字段 name,age 和 sex,后面的数字 1,2,3 指定这三个字段在编码序列化后的二进制数据中的顺序,因此在同一个 message
内部,它们是不允许重复的。如果是 protobuf 2 的版本,需要在可选字段前面加上 optional 关键字(关键字包括 optional、required、repeated)。定义好 Student 结构之后, students.pbtxt 文件的内容基本是重复这个结构,因此还需要定义一个新的 message
StudentInfo,它包含一个可重复的 (repeated) 字段 student_info。至此,数据解析格式定义完了,接下来要将它转化为 Python 格式语言,执行:
$ protoc proto_test/*.proto --python_out=.
会自动生成一个 student_info_pb2.py 文件,前几行如下:
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: proto_test/student_info.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='proto_test/student_info.proto',
package='proto_test',
syntax='proto3',
...
基于这个 student_info_pb2.py 就可以方便的解析 students.pbtxt 中的内容了。
2.读取 .pbtxt 文件
有了转化来的 student_info_pb2.py 文件,解析 students.pbtxt 就轻而易举了,代码如下(命名为:read_student_info.py):
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 8 19:18:37 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from google.protobuf import text_format
import student_info_pb2
def load_pbtxt_file(path):
"""Read .pbtxt file.
Args:
path: Path to StringIntLabelMap proto text file (.pbtxt file).
Returns:
A StringIntLabelMapProto.
Raises:
ValueError: If path is not exist.
"""
if not tf.gfile.Exists(path):
raise ValueError('`path` is not exist.')
with tf.gfile.GFile(path, 'r') as fid:
pbtxt_string = fid.read()
pbtxt = student_info_pb2.StudentInfo()
try:
text_format.Merge(pbtxt_string, pbtxt)
except text_format.ParseError:
pbtxt.ParseFromString(pbtxt_string)
return pbtxt
def get_student_info_dict(path):
"""Reads a .pbtxt file and returns a dictionary.
Args:
path: Path to StringIntLabelMap proto text file.
Returns:
A dictionary mapping class names to indices.
"""
pbtxt = load_pbtxt_file(path)
result_dict = {}
for student_info in pbtxt.student_info:
result_dict[student_info.name] = [student_info.age, student_info.sex]
return result_dict
首先, 使用 tf.gfile.GFile
将指定文件读入成字符串,然后定义一个 student_info_pb2.StudentInfo() 结构,这样使用 google.protobuf 的 text_format.Merge
直接将 student_info 一个一个解析成 Student 结构,此时读取 name、age 和 sex 字段只需要通过 .
属性即可。
如开头所言,我们来读取 students.pbtxt 文件:
student_info {
name: 'Zhang San';
age: 20;
sex: 0;
}
student_info {
name: 'Li Si';
age: 25;
sex: 0;
}
student_info {
name: 'Wang Wu';
age: 18;
sex: 1;
}
读取代码如下:
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 8 19:26:54 2018
@author: shirhe-lyh
"""
import read_student_info
if __name__ == '__main__':
student_info_path = './students.pbtxt'
students_dict = read_student_info.get_student_info_dict(student_info_path)
print(students_dict)
执行后输出:
{'Li Si': [25, 0], 'Wang Wu': [18, 1], 'Zhang San': [20, 0]}