自然语言处理中的大部分模型输入都是变长的离散序列, 在tensorflow处理变长序列中介绍了tensorflow中如何循环神经网络如何处理变长序列. 目前深度学习的优化算法基本都是基于批量随机梯度下降的, 那么构成批量的变长序列在tensorflow中也有相应的优化策略.
问题
在机器翻译中, 一般会设置最大的序列长度, 如max_length=512, 那么极端情况下, 随机shuffle的训练数据中可能存在序列长度为[1,1,1,512]的一个批量. 虽然在tensorflow处理变长序列一文中处理变长序列的技巧可以保证理论上的正确训练, 但是计算效率是非常低的. 一个[1,1,1,512]批量的计算所需要的时间约等于一个[512, 512, 512, 512]批量.
方案
如何在不影响精度的情况下, 提升计算效率?tensorflow中提供了基于序列长度分桶策略的解决方案, 由tf.data.experimental.bucket_by_sequence_length实现.
@tf_export("data.experimental.bucket_by_sequence_length")
def bucket_by_sequence_length(element_length_func,
bucket_boundaries,
bucket_batch_sizes,
padded_shapes=None,
padding_values=None,
pad_to_bucket_boundary=False,
no_padding=False,
drop_remainder=False)
其中element_length_func用来计算序列真实长度的函数, bucket_boundaries指每一个桶中存储序列长度的边界值, bucket_batch_sizes指每一个桶内的批量大小.
分桶的具体的步骤如下:
- 计算序列的真实长度, 由参数
element_length_func来计算, 如tensor2tensor中计算tf.example的长度代码
def example_length(example):
length = 0
# Length of the example is the maximum length of the feature lengths
for _, v in sorted(six.iteritems(example)):
# For images the sequence length is the size of the spatial dimensions.
feature_length = (tf.shape(v)[0] if len(v.get_shape()) < 3 else
tf.shape(v)[0] * tf.shape(v)[1])
length = tf.maximum(length, feature_length)
return length
- 按照每个桶包含序列的长度范围分配当前序列的桶id, 由
element_to_bucket_id实现
def element_to_bucket_id(*args):
"""Return int64 id of the length bucket for this element."""
seq_length = element_length_func(*args)
boundaries = list(bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
conditions_c = math_ops.logical_and(
math_ops.less_equal(buckets_min, seq_length),
math_ops.less(seq_length, buckets_max))
bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
return bucket_id
- 当每个桶中累积到预设批量大小数目的序列后, 按照当前批量中最长序列填充所有序列作为模型的输入。
应用
基于tensor2tensor开发翻译模型时, batch_size不是代表批量序列的数目, 而表示单个批量中包含token的总数量. Problem类的inputs_fn方法中, tensor2tensor.util.data_reader._batching_scheme 依据批量大小(tokens), 最大序列长度(max_length), 最小桶序列长度(min_length_bucket), 桶的步长(length_bucket_step), 计算出每个桶中序列长度的边界bucket_boundaries, 并且每个桶中的批量大小bucket_batch_sizes由总共的token量除以序列长度边界得到.