先读下官网的api:
https://www.tensorflow.org/api_guides/python/reading_data
分为三种
- placeholder,把数据feed进去,这个需要自己写数据迭代器和shuffle,还要控制epoch。
- 读取文件,tfrecord,csv等
- 预加载的文件,都进内存,小数据下使用。
推荐从文件中读取,使用tfrecord,让tf自动load和shuf文件、还可以控制epoch。
将训练数据转换成tfrecord
def write_tfrecord(writer, char_ids, label_id):
"""
:param writer: tf record writer
:param char_ids: list
:param label_id: int
"""
example = tf.train.Example(features=tf.train.Features(feature={
'char_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=char_ids)),
'label_id': tf.train.Feature(int64_list=tf.train.Int64List(value=label_id))
}))
writer.write(example.SerializeToString())
writer = tf.python_io.TFRecordWriter(output_file_name)
for line open(your_files):
char_ids = ...
label_id = ...
write_tfrecord(writer, char_ids, label_id)
writer.close()
读取tfrecord,并解码
- 单纯解码
def test_read_tfrecords():
filename = "./data/train.tfrecords"
for serialized_example in tf.python_io.tf_record_iterator(filename):
example = tf.train.Example()
example.ParseFromString(serialized_example)
# traverse the Example format to get data
x = example.features.feature['char_ids'].int64_list.value
y = example.features.feature['label_id'].int64_list.value
# do something
将tfrecord、batch和train联系起来
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'char_ids': tf.FixedLenFeature([max_seq_len], tf.int64),
'label_id': tf.FixedLenFeature([1], tf.int64)
})
char_ids = tf.cast(features['char_ids'], tf.int32)
# label_ids = tf.one_hot(tf.cast(features['label_ids'], tf.int32)[0], len(data_set.label_dict))
label_id = tf.cast(features['label_id'], tf.int32)[0]
return char_ids, label_id
def inputs(filename_list, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filename_list, num_epochs=num_epochs)
char_ids, label_id = read_and_decode(filename_queue)
batch_char_ids, batch_label_id = tf.train.shuffle_batch(
[char_ids, label_id], batch_size=batch_size, num_threads=12,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000)
return batch_char_ids, batch_label_id
# main session
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
# Run training steps or whatever
x_batch, y_batch = inputs(data_files, batch_size, num_epochs=50)
train_op = ...
sess.run(train_op)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()