这里关于BP算法就不详细说明了,直接上代码:
import numpy as np
import matplotlib.pyplot as plt
n = 0 # 迭代次数
lr = 0.11 # 学习速率
# 输入数据分别是:偏置值、x1、x2、x1^2、x1*x2、x2^2
X = np.array([[1, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 1], [1, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1]])
# 标签
Y = np.array([-1, 1, 1, -1])
# 权重初始化,取值范围为-1~1
W = (np.random.random(X.shape[1]) - 0.5) * 2
print('初始化权值:', W)
def get_show(): # 绘图函数
# 正样本
x1 = [0, 1]
y1 = [1, 0]
# 负样本
x2 = [0, 1]
y2 = [0, 1]
# 生成x刻度
x_data = np.linspace(-1, 2)
plt.figure()
# 画出两条分界线
plt.plot(x_data, get_line(x_data, 1), 'r')
plt.plot(x_data, get_line(x_data, 2), 'r')
# 原始数据
plt.plot(x1, y1, 'bo')
plt.plot(x2, y2, 'ro')
plt.show()
# 获得分界线
def get_line(x, root):
a = W[5]
b = W[2] + x * W[4]
c = W[0] + x * W[1] + x * x * W[3]
# 两条不同的分界线
if root == 1:
return (-b + np.sqrt(b * b - 4 * a * c)) / (2 * a)
if root == 2:
return (-b - np.sqrt(b * b - 4 * a * c)) / (2 * a)
# 更新权值函数
def get_update():
global X, Y, W, lr, n
n += 1
# 新输出:X与W的转置相乘,得到的结果再由阶跃函数处理,得到新输出
new_output = np.dot(X, W.T)
new_W = W + lr * ((Y - new_output.T).dot(X))/int(X.shape[0])
W = new_W
if __name__ == '__main__':
for _ in range(100):
get_update()
get_show()
last_output = np.dot(X, W.T)
print('最后的逼近值:',last_output)
结果变化图:
Figure1.png
Figure2.png
Figure3.png
Figure4.png
Figure5.png
Figure6.png
Figure7.png
Figure8.png