tensorflow处理变长序列

深度学习中处理变长序列往往都是使用循环神经网络RNN(Recurrent Neural Network),其中RNN有很多变种,包括朴素RNN、LSTM、GRU等。目前的优化算法基本都是基于批量随机梯度下降的,那么批量变长序列在训练中就要格外注意短序列中的填充符。

一般声明序列的真实长度有两种方式,例如批量序列张量维度:[batch_size, time_steps, dims]

  • 真实序列的长度值,sequence_length=[batch_size,],例如tf.nn.dynamic_rnntf.nn.static_rnn中的sequence_length参数,
  • 序列掩码,mask=[batch_size, time_steps],例如tf.keras.layer.RNN类的call方法中的mask参数

当然无论是序列长度值,还是序列掩码包含的信息都是一样的,二者是可以互相转化的。一般处理变长序列的策略是当time_step>sequence_length时,

  • 状态值state一般采用复制前一步的state
  • 输出值output可以直接输出同维度零向量,也可以复制前一步的output

单向RNN

以下代码是tf.python.ops.rnn._rnn_step中关于序列长度条件判断的核心代码,它是单步RNN操作函数,依赖于sequence_length的具体数值,主要步骤包括:

  1. 如果time_step >= max_sequence_length,输出(zero_output, old_state)

    flat_state = nest.flatten(state)
    flat_zero_output = nest.flatten(zero_output)
    empty_update = lambda: flat_zero_output + flat_state
    final_output_and_state = control_flow_ops.cond(
        # if t >= max_seq_len: copy all state through, output zeros
        time >= max_sequence_length, empty_update,
        # otherwise calculation is required: copy some or all of it through
        _maybe_copy_some_through)
    
  2. 计算rnn_cell,得到新的new_output和new_state

    new_output, new_state = call_cell()
    
  3. 如果time_step < min_sequence_length,输出(new_output, new_state)

    flat_new_state = nest.flatten(new_state)
    flat_new_output = nest.flatten(new_output)
    
    control_flow_ops.cond(
        # if t < min_seq_len: calculate and return everything
        time < min_sequence_length, lambda: flat_new_output + flat_new_state,
        # else copy some of it through
        lambda: _copy_some_through(flat_new_output, flat_new_state))
    
  4. 如果time_step >= sequence_length,输出(zero_output, old_state),否则输出(new_output, new_state)

    copy_cond = time >= sequence_length
    
    def _copy_one_through(output, new_output):
        # TensorArray and scalar get passed through.
        if isinstance(output, tensor_array_ops.TensorArray):
            return new_output
        if output.shape.rank == 0:
            return new_output
        # Otherwise propagate the old or the new value.
        with ops.colocate_with(new_output):
            return array_ops.where(copy_cond, output, new_output)
    
    flat_new_output = [
        _copy_one_through(zero_output, new_output)
        for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
    flat_new_state = [
        _copy_one_through(state, new_state)
        for state, new_state in zip(flat_state, flat_new_state)]
    

tf.python.keras.backend.rnn中依赖于序列mask实现了同样的策略。

双向RNN

如果说在单向RNN中,sequence_length的使用还比较显式,那么在双向RNN中就存在一些隐式使用。首先明确一下双向RNN的计算步骤(代码参见tf.nn.static_bidirectional_rnn):

  1. 利用cell_fw计算前向RNN,得到output_fwoutput_state_fw

    output_fw, output_state_fw = static_rnn(
        cell_fw,
        inputs,
        initial_state_fw,
        dtype,
        sequence_length,
        scope=fw_scope)
    
  2. 在time_step维度reverse输入序列inputs

    reversed_inputs = _reverse_seq(inputs, sequence_length)
    
  3. 利用cell_bw计算后向RNN,得到output_bw (tmp)和state_bw。相对于原始序列,output_bwoutput_state_bw都是逆序的。

    tmp, output_state_bw = static_rnn(
        cell_bw,
        reversed_inputs,
        initial_state_bw,
        dtype,
        sequence_length,
        scope=bw_scope)
    
  4. 在time_step维度reverse output_bw

    output_bw = _reverse_seq(tmp, sequence_length)
    
  5. 在time_step维度拼接output_fwoutput_bw,得到outputs

    flat_outputs = tuple(
        array_ops.concat([fw, bw], 1)
        for fw, bw in zip(flat_output_fw, flat_output_bw))
    

与单向RNN一样,sequence_length同样用在RNN的计算上,同时在reverse序列时,也需要输入sequence_length参数,利用tf.reverse_sequence只逆序序列的有效位置,否则在拼接前向和后向输出序列是发生错位,也就是出现了有效位置与填充位置的拼接。

优点

  • 填充符位置采用复制策略,提升计算性能
  • 针对真实序列长度的处理满足理论要求,不然会出现实现与理论不相符的状态
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。