6.2 全连接层
画图工具的链接如下:
http://alexlenail.me/NN-SVG/index.html
6.2.1 张量方式实现
在 TensorFlow 中,要实现全连接层,只需定义好权值张量和偏置张量,并利用 tf.matmul() 函数即可。
如:让 ,权值矩阵,偏置张量
x = tf.random.normal([2,784])
w = tf.Variable(tf.random.normal([784,256],stddev=0.1))
b = tf.Variable(tf.zeros([256]))
y = x @ w + b
y = tf.nn.relu(y)
y
#输出结果为
<tf.Tensor: id=32, shape=(2, 256), dtype=float32, numpy=
array([[7.71913290e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 2.31523323e+00, 6.52046919e+00,
3.29700470e-01, 0.00000000e+00, 9.10421133e-01, 3.05028844e+00,
0.00000000e+00, 4.54208583e-01, 2.06560063e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 3.39103150e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 1.41346848e+00, 0.00000000e+00,
2.62670159e-01, 5.09379745e-01, 0.00000000e+00, 7.26842523e-01,
3.25483620e-01, 0.00000000e+00, 3.36569405e+00, 0.00000000e+00,
3.18554354e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
3.49711514e+00, 5.54364681e-01, 2.64297247e-01, 0.00000000e+00,
3.58792257e+00, 9.66847777e-01, 0.00000000e+00, 3.03364110e+00,
0.00000000e+00, 1.79568231e+00, 0.00000000e+00, 0.00000000e+00,
4.87591743e+00, 0.00000000e+00, 1.35650539e+00, 1.45709491e+00,
2.53773332e-01, 0.00000000e+00, 4.55542755e+00, 0.00000000e+00,
2.61453032e+00, 5.70898771e+00, 0.00000000e+00, 3.54384494e+00,
3.70477438e-02, 0.00000000e+00, 2.85954285e+00, 3.52746582e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 2.74943805e+00, 0.00000000e+00, 2.45141840e+00,
6.41983986e-01, 0.00000000e+00, 1.41170359e+00, 1.51873493e+00,
1.37690508e+00, 0.00000000e+00, 2.36272526e+00, 5.02816725e+00,
8.65906477e-01, 0.00000000e+00, 0.00000000e+00, 1.78874230e+00,
2.50994110e+00, 0.00000000e+00, 2.26814771e+00, 1.39309406e-01,
1.31797361e+00, 1.96663916e-01, 0.00000000e+00, 1.50965190e+00,
1.63897133e+00, 3.09265780e+00, 4.00402367e-01, 0.00000000e+00,
0.00000000e+00, 1.31546986e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.30133009e+00, 0.00000000e+00, 0.00000000e+00,
5.75942945e+00, 2.69783318e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.09467018e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.99055243e-01, 2.54434037e+00, 5.17293596e+00,
7.99074948e-01, 6.07301664e+00, 0.00000000e+00, 4.91364062e-01,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
2.26829839e+00, 1.00115919e+00, 0.00000000e+00, 0.00000000e+00,
9.99154508e-01, 0.00000000e+00, 1.98625875e+00, 0.00000000e+00,
1.56914759e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
6.41001046e-01, 0.00000000e+00, 0.00000000e+00, 1.78947270e+00,
0.00000000e+00, 0.00000000e+00, 5.54811525e+00, 2.02449083e-01,
3.83221936e+00, 0.00000000e+00, 1.43161571e+00, 3.81807876e+00,
2.32804728e+00, 0.00000000e+00, 5.86981654e-01, 1.37415338e+00,
0.00000000e+00, 2.23189306e+00, 1.55265594e+00, 0.00000000e+00,
1.83392847e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
3.13988662e+00, 0.00000000e+00, 2.06149364e+00, 0.00000000e+00,
8.79067779e-01, 0.00000000e+00, 0.00000000e+00, 3.22869968e+00,
8.63524675e-02, 2.32554674e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.02852488e+00, 0.00000000e+00, 2.26370525e+00,
0.00000000e+00, 2.82050991e+00, 4.89757490e+00, 3.80764604e+00,
0.00000000e+00, 4.40563631e+00, 0.00000000e+00, 8.55568707e-01,
0.00000000e+00, 1.18234396e-01, 1.99563265e+00, 0.00000000e+00,
1.20063639e+00, 1.59806740e+00, 0.00000000e+00, 3.04061627e+00,
0.00000000e+00, 4.56145334e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 2.56617641e+00, 0.00000000e+00, 3.76130891e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 2.25188398e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 2.29746628e+00, 0.00000000e+00, 0.00000000e+00,
4.91090775e-01, 0.00000000e+00, 0.00000000e+00, 2.68757701e-01,
0.00000000e+00, 2.83860111e+00, 3.06481957e+00, 0.00000000e+00,
0.00000000e+00, 8.86397243e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 2.67475843e+00, 4.45377469e-01,
4.66923058e-01, 0.00000000e+00, 1.23887026e+00, 0.00000000e+00,
3.24229908e+00, 3.95938778e+00, 7.80869126e-01, 2.35901022e+00,
0.00000000e+00, 0.00000000e+00, 4.39496756e+00, 0.00000000e+00,
5.58587492e-01, 0.00000000e+00, 8.33164930e-01, 0.00000000e+00,
1.05002940e-01, 3.09266973e+00, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 1.71956635e+00, 8.52033377e-01, 2.87431836e+00,
0.00000000e+00, 1.44510102e+00, 0.00000000e+00, 2.26490664e+00,
3.95393014e+00, 0.00000000e+00, 3.09220028e+00, 4.65225697e-01,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
8.06872129e-01, 0.00000000e+00, 0.00000000e+00, 2.72551465e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 3.16905856e-01, 0.00000000e+00,
0.00000000e+00, 4.74110663e-01, 0.00000000e+00, 1.06873035e+00,
7.46809125e-01, 1.56842375e+00, 0.00000000e+00, 2.77372360e-01,
5.24407804e-01, 0.00000000e+00, 0.00000000e+00, 1.68116665e+00,
1.11929429e+00, 2.97584724e+00, 1.55387759e+00, 0.00000000e+00,
2.18963528e+00, 3.72682428e+00, 0.00000000e+00, 0.00000000e+00,
2.31971771e-01, 0.00000000e+00, 3.32603455e+00, 0.00000000e+00,
0.00000000e+00, 1.10847163e+00, 0.00000000e+00, 2.06082296e+00,
1.85625434e+00, 0.00000000e+00, 0.00000000e+00, 2.25890326e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
3.19665432e-01, 3.07176018e+00, 1.99592948e-01, 0.00000000e+00,
4.49227810e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 1.23876023e+00, 0.00000000e+00,
3.98132086e-01, 0.00000000e+00, 1.18834615e-01, 0.00000000e+00,
6.52224493e+00, 0.00000000e+00, 0.00000000e+00, 6.56467795e-01,
0.00000000e+00, 0.00000000e+00, 3.61150384e+00, 1.50383353e-01,
6.41233504e-01, 2.60811955e-01, 0.00000000e+00, 0.00000000e+00,
1.19467771e+00, 0.00000000e+00, 3.53946090e-02, 4.60676384e+00,
3.01892710e+00, 5.12624645e+00, 2.63171482e+00, 5.81020164e+00,
0.00000000e+00, 2.47122884e-01, 3.25780749e+00, 0.00000000e+00,
0.00000000e+00, 2.11186957e+00, 4.28522396e+00, 0.00000000e+00,
1.05025351e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
5.55113196e-01, 0.00000000e+00, 3.63788271e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.00781822e+00,
4.70510149e+00, 8.88607144e-01, 0.00000000e+00, 1.94936168e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.11702406e-01,
0.00000000e+00, 0.00000000e+00, 7.71915376e-01, 1.00005352e+00,
0.00000000e+00, 4.98903692e-01, 1.67754424e+00, 0.00000000e+00,
1.33315754e+00, 0.00000000e+00, 4.45194721e+00, 0.00000000e+00,
2.69036627e+00, 0.00000000e+00, 0.00000000e+00, 5.00457287e-01,
0.00000000e+00, 0.00000000e+00, 3.96098042e+00, 5.24755621e+00,
0.00000000e+00, 5.01480627e+00, 1.02065539e+00, 0.00000000e+00,
0.00000000e+00, 3.81922662e-01, 2.42469931e+00, 3.28508353e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.32257307e-01,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.82755089e-01,
0.00000000e+00, 0.00000000e+00, 1.46961081e+00, 2.96623898e+00,
0.00000000e+00, 3.29777431e+00, 0.00000000e+00, 1.14533782e+00,
2.60658717e+00, 5.20111752e+00, 3.79591346e+00, 1.49161506e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.52116537e+00,
2.91816592e+00, 0.00000000e+00, 2.67469835e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.96935892e+00,
0.00000000e+00, 1.40678453e+00, 2.23776054e+00, 4.61781788e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.61894703e+00, 3.61282587e+00, 3.94072843e+00, 0.00000000e+00,
4.54843044e-04, 0.00000000e+00, 0.00000000e+00, 8.03797722e-01,
0.00000000e+00, 5.72851419e+00, 1.02512193e+00, 0.00000000e+00,
1.29191601e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
6.11544657e+00, 3.15644312e+00, 0.00000000e+00, 5.57662821e+00,
4.99616480e+00, 7.72207022e-01, 5.39032459e+00, 0.00000000e+00,
0.00000000e+00, 2.25529015e-01, 0.00000000e+00, 5.36596715e-01,
0.00000000e+00, 1.81711638e+00, 0.00000000e+00, 0.00000000e+00,
9.11764145e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
4.94917774e+00, 0.00000000e+00, 3.31027699e+00, 0.00000000e+00,
1.70392036e-01, 3.61162424e-01, 2.54643679e+00, 0.00000000e+00,
0.00000000e+00, 6.02995348e+00, 0.00000000e+00, 3.06049228e-01,
1.20904040e+00, 1.63265777e+00, 0.00000000e+00, 0.00000000e+00,
2.19519377e+00, 6.64466918e-01, 0.00000000e+00, 2.53132772e+00,
2.92869377e+00, 4.00854206e+00, 0.00000000e+00, 6.26695919e+00]],
dtype=float32)>
6.2.2 层方式实现
使用 tensorflow.keras.layers.Dense() 函数
x = tf.random.normal([4, 28 * 28])
fc = layers.Dense(521,activation='relu')
fc = layers.Dense(521,activation='relu')(x)
6.3 神经网络
通过层层堆叠 神经网络示例1 中的全连接层,能够堆叠成任意层数的网络(保证前一层的输出节点数与当前层的输入节点数相匹配),我们把这种由神经元构成的网络叫做神经网络。如下图所示
画图工具的链接如下:
http://alexlenail.me/NN-SVG/index.html
对于多层神经网络实现方式如下
6.3.1 张量方式实现
需要定义各层的权值矩阵和偏置向量。
# 隐藏层1张量
w1 = tf.Variable(tf.random.normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
# 隐藏层2张量
w2 = tf.Variable(tf.random.normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
# 隐藏层3张量
w3 = tf.Variable(tf.random.normal([128,64],stddev=0.1))
b3 = tf.Variable(tf.zeros([64]))
# 隐藏层4张量
w4 = tf.Variable(tf.random.normal([64,10],stddev=0.1))
b4 = tf.Variable(tf.zeros([10]))
with tf.GradientTape() as tape:
y1 = tf.nn.relu(x @ w1 + b1)
y2 = tf.nn.relu(y1 @ w2 + b2)
y3 = tf.nn.relu(y2 @ w3 +b3)
y4 = tf.nn.relu(y3 @ w4 + b4)
6.3.2 层方式实现
A
fc1 = layers.Dense(256,activation='relu')
fc2 = layers.Dense(128,activation='relu')
fc3 = layers.Dense(64,activation='relu')
fc4 = layers.Dense(10,activation=None)
x = tf.random.normal([4,28*28])
h1 = fc1(x)
h2 = fc2(h1)
h3 = fc3(h2)
h4 = fc4(h3)
B
model = tf.keras.Sequential([
layers.Dense(256,activation='relu'),
layers.Dense(128,activation='relu'),
layers.Dense(64,activation='relu'),
layers.Dense(10,activation=None)
])
y = model(x)
6.3.3 优化目标
把神经网络从输入到输出的计算过程叫做前向传播(数据张量从第一层流动至输出层的过程);
前向传播的最后一步就是完成误差的计算,然后利用梯度下降算法迭代更新。
6.4 激活函数
可以参考我之前写的激活函数总结:
https://www.jianshu.com/p/4f1a82fe723a
6.5 输出层设计
神经网络的最后一层,除了和所有的隐藏层一样,完成维度变换,特征提取的功能,还作为输出层的使用,需要根据具体的任务场景来决定是否使用激活函数,以及使用什么类型的激活函数。
常见的几种输出类型包括:
- :输出属于整个实数空间,或者某段普通的实数空间,比如函数值趋势的预测,年龄的预测问题等
- :输出值特别地落在 [0,1] 区间,如图片生成,图片像素值一般用 [0,1]表示:或者二分类问题的概率,如硬币正反面概率预测问题
- 并且:输出值落在 [0,1] 区间,并且所以输出值之和为1,常见的如多分类问题,如 MNIST 手写数字图片识别,图片属于10个类别的概率之和为1
- :输出值在 [-1,1]之间
6.5.1 普通实数空间
正弦函数曲线预测,年龄的预测,股票走势的预测等都属于整个或者部分连续的实数空间,输出层可以不加激活函数。
6.5.2 [0,1]区间
图片的生成,二分类问题等,都属于输出值属于 [0,1]。
在机器学习中,一般会将图片的像素值归一化到 [0,1] 区间,如果直接使用输出层的值,像素的值范围会分布整个实数空间。为了让像素的值范围映射到 [0,1]的有效实数空间,需要在输出层后添加某个合适的激活函数,其中 Sigmoid 函数刚好具有此功能。
6.5.3 [0,1]区间,和为1
输出值 ,所有输出值之和为1,这种设定以多分类问题最为常见。
使用 Softmax 函数。
6.5.4 [-1,1]
如果希望输出值的范围分布在 [-1,1],可以简单地使用 tanh 激活函数。
6.6 误差计算
常见的误差计算函数有:均方差,交叉熵,KL散度,Hinge Loss 函数等。
均方差主要用于回归问题,交叉熵主要用于分类问题。
6.6.1 均方差
均方差误差(Mean Squared Error, MSE)函数把输出向量和真实向量映射到笛卡尔坐标系的两个点上,通过计算这两个点的欧式距离的平方来衡量两个向量之间的差距:
o = tf.random.normal([2,10])
y_onehot = tf.constant([1,3])
y_onehot = tf.one_hot(y_onehot, depth=10)
loss = keras.losses.MSE(y_onehot, o)
criteon = keras.losses.MeanSquaredError()
criteon(o,y_onehot)
6.6.2 交叉熵
熵在信息学科中也叫做信息熵(香农熵),熵越大,代表的不确定性也就越大,信息量也就越大。
某个分布的熵定义为:
e.g. 对于4分类问题,如果某个样本的真实标签是第4类,其 one-hot 编码为 [0,0,0,1],即这张图片的分类是唯一确定的,不确定性为0,其熵为0。
如果它预测的概率分布是 [0.1,0.1,0.1,0.7],它的熵约为1.356。
基于熵引出交叉熵的定义:
通过变换,交叉熵可以分解为的熵与的 KL 散度的和:
其中 KL 定义为:
需要注意的是,交叉熵和 KL 散度都不是对称的,即:
交叉熵可以很好地衡量两个分布之间的差别,特别地,当分类问题中的编码分布采用 one-hot 编码时:,此时
6.7 神经网络类型
全连接层是神经网络中最基本的网络类型。
缺点:参数多(处理较大特征长度的数据时)
6.7.1 卷积神经网络
用于图片分类: AlexNet,VGG,GoogLeNet,ResNet,DenseNet 等
用于目标识别: RCNN,Fast RCNN,Faster RCNN,Mask RCNN 等
6.7.2 循环神经网络
卷积神经忘了由于缺乏 Menmory 机制和处理补丁长序列信号的能力,并不擅长处理自然语言人物。循环神经网络被证明。
RNN, LSTM,Seq2Seq,GNMT,GRU,双向RNN
6.7.3 注意力(机制)网络
Attention 的提出,克服了 RNN 训练不稳定,难以并行化等缺。
Transformer,GPT,BERT,GRT-2
6.7.4 图神经网络
类似于社交网络,通信网络,蛋白质分子结构等一系列不规则的空间拓扑结构的数据,CNN,RNN效果不好。
GCN, GAT,EdgeConv,DeepGCN 等。
参考资料:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book