4行伪代码讲清楚RNN的原理

RNN,循环神经网络,可以处理上下文关系的神经网络。

伪代码来自深度学习经典书《Python深度学习》

state_t = 0
for input_t in input_sequence:
  output_t = f(input_t, state_t)
  state_t = output_t

这个故事是这样的:
一个输入的序列,在序列模型下(比如CNN或者DNN串联那种),犹如通过了一个流水线的工厂,最后获得输出;
而RNN会将序列分解后进行遍历,之后当前元素的状态(权重与偏置)取决于上一个元素的状态(权重与偏置)。

更加具体的伪代码为:

state_t = 0
for input_t in input_sequence:
  output_t = activation(dot(W, input_t) + dot(U, state_t) + b)
  state_t = output_t

也就是,比如输入ABC,CNN或者DNN就会把这三个元素所转换为的张量一层一层的传播,而RNN就会把它分为A,B和C。设定一个初始状态,在训练A的时候采用初始状态,在训练B的时候考虑A的状态等等。

再上个图:


image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。