keras中SimpleRNN和日常RNN所不同的就是其隐藏层到输出层之间是没有权重的,即最后时刻隐藏层的输出即为最终的输出。下面以一个例子来说明运算过程
keras自带的SimpleRNN进行计算
- 导入所需要的库
import tensorflow as tf
import keras
from keras import Sequential
from keras.layers import SimpleRNN
- 搭建神经网络
model=Sequential()
model.add(SimpleRNN(32,input_shape=(None,20),activation='relu'))
model.summary()
keras中SimpleRNN 默认的激活函数为tanh,这里为了方便对比,采用relu激活函数。keras中输入的形式一般为[batch_size,timestep,num],在上述代码中,20代表的是num 。time_step是未知的None,一般batch_size在网络输入的时候不直接输进去吧(好像是在训练的时候自己可以改)。此处,我们不训练,直接用生成的初始化参数作为最后的参数,因为只是验证计算过程,只需要对比网络最后的结果和我自己算出的结果是否一致就行。
- 参数的提取
在生成网络后,初始化的参数就已经确定了,导出初始化参数作为自己验证的权重。
U=model.get_weights()[0] #输入层和循环层之间的权重,维度为(20*32)
W=model.get_weights()[1] #循环层与循环层之间的权重,维度为(32*32)
bias=model.get_weights()[2] #隐藏层的偏置项,32个
上述代码中,U表示输入和循环层之间的权重(因为在搭建网络时固定了输入的num为20,循环层t时刻的维数为32),其维度为(2032),W表示循环层之间的权重,因为t-1时刻和t时刻的都为32,所以权重矩阵维度为(3232)。这里其实的计算没有考虑batch_size,如果考虑了batch_size的话,其实都是矩阵的计算。(这里想不明白没关系,后面验证时候具体的计算流程,一看就懂了)。最后的bias表示的是某个时刻t的32个偏置项。
- 测试数据生成
test=np.random.randint(1,20,(10,2,20))
为方便验证,生成的test的维度为[10,2,20],10指的是batch_size,2表示的是时间步time_step , 20指的是输入的维度。
-直接用网络进行计算
model_predict=model.predict(test)
model_predict.shape
输出的维度为[10,32],model_predict的具体数值在后续的对比中进行展示。
自己计算SimpleRNN的输出
- 自己定义一个relu激活函数
def activation_relu(x):
for i in range(x.shape[0]):
for j in range(x.shape[1]):
if x[i,j]<0:
x[i,j]=0
return x
- 矩阵乘法进行验证
x_t1=test[:,0,:]
x_t2=test[:,1,:]
# 第一个循环层的数据
s_t1=activation_relu(np.dot(x_t1,U)+bias)
# 第二个循环层的数据
s_t2=activation_relu(np.dot(x_t2,U)+np.dot(s_t1,W)+bias)
在测试数据生成的时候,我们已经说明了测试数据的时间步为2,第一个时间步x_t1的维度为[10,20],s_t1的维度为, s_t2的维度为
。此时s_t2就为最后的输出。
对比两个结果
从上图发现SimpleRNN和我所理解的计算过程是一样的。为更准确的验证,做一个循环,代码如下:
def compare_value(a,b):
if len(a.shape)!=len(b.shape):
print("两者维度不同,无法比较")
else:
a_shape=np.array(a.shape) #原始的a.shape的格式为tuple
b_shape=np.array(b.shape)
result=a_shape-b_shape
if sum(result)!=0:
print("维度不同,无法比较")
else:
#为防止保留位数的差异,统一保留小数点后6位
a=np.round(a,3)
b=np.round(b,3)
c=a-b
#由于电脑计算位数等等一些方式,可能保留的小数位不同,会在小数点后6,7位有点点误差
#这个不是运算方式,而是电脑保留位数导致的
if c.sum()<1e-5:
print("两者相同")
else:
print("两者数据不同")
compare_value(s_t2,model_predict)
两者相同
以一张图来描述SimpleRNN的计算过程