之前采用的穷举法来求解最优的w值,但是当数据量大维度多的时候,会导致训练性能急剧下降
而梯度下降算法的原理:迭代找到目标函数的最小值,或者收敛到最小值
梯度的概念:
1.在单变量的函数中,梯度其实就是函数的微分,代表着函数在某个给定点的切线的斜率
2.在多变量函数中,梯度是一个向量,向量有方向,梯度的方向就指出了函数在给定点的上升最快的方向,所以循着梯度的反方向就能找到函数的最小值
3.α在梯度下降算法中被称作为学习率或者步长,意味着我们可以通过α来控制每一步走的距离
本次使用的模型为简单线性回归模型:y = w*x
1.定义线性回归模型
#定义模型
def forward(x):
return w*x
#初始随机猜测一个权值w=1.0
w=-1.0
2.计算损失函数
#定义总成本计算函数---平均平方误差
def cost(xs,ys):
cost = 0
for x,y in zip(xs,ys):
y_pred = forward(x)
cost += (y_pred-y) ** 2
return cost/len(xs)
3.数据集
#定义 DataSet-实际 w==2
x_date = [1.0,2.0,3.0]
y_date = [2.0,4.0,6.0]
4.定义梯度计算公式
#定义平均梯度计算公式
def gradient(xs,ys):
grad = 0
for x, y in zip(xs,ys):
grad += 2*x*(x*w-y)
return grad / len(xs)
5.使用梯度下降算法进行训练
for epoch in range(1,1000):
cost_val = cost(x_date,y_date)
#每一次训练用的梯度都是该数据集的平均梯度,当数据集给定时,平均梯度也就随之确定
w = w - 0.01 * gradient(x_date, y_date) #学习率定义为0.01
#打印输出是第几次训练,训练权值和训练的损失函数
w_list.append(w)
mse_list.append(cost_val)
print("epoch=", epoch, "w=", w, "loss=", cost_val)
6.完整代码及训练结果
# -*- codeing = utf-8 -*-
import matplotlib.pyplot as plt
#采用梯度下降算法来优化穷举法---计算所有样本对的梯度来取平均梯度进行w权值的更新
#优点: 每一次的训练和上一次的训练没有关系,都是用数据集的平均梯度进行计算
# 所以可以采取并行计算的方式进行训练,大大节俭时间成本
#缺点: 时间复杂度低,但结果没有随机梯度训练准确
#训练数据集---实际w权值=2.0
x_date = [1.0,2.0,3.0]
y_date = [2.0,4.0,6.0]
#初始随机猜测一个权值w=1.0
w=-1.0
#定义线性回归模型
def forward(x):
return x*w
#定义总成本计算函数---平均平方误差
def cost(xs,ys):
cost = 0
for x,y in zip(xs,ys):
y_pred = forward(x)
cost += (y_pred-y) ** 2
return cost/len(xs)
w_list = []
mse_list = []
#定义平均梯度计算公式
def gradient(xs,ys):
grad = 0
for x, y in zip(xs,ys):
grad += 2*x*(x*w-y)
return grad / len(xs)
print("predict before training:", "w=", w, forward(4))
#通过梯度下降算法来训练模型---设定为1-100次:左闭右开区间
#每一次训练更新的是 w 的值-而采用的平均梯度没有发生改变
#进行100次训练-每一次采用的梯度都是所有样本对的平均梯度
for epoch in range(1,1000):
cost_val = cost(x_date,y_date)
#每一次训练用的梯度都是该数据集的平均梯度,当数据集给定时,平均梯度也就随之确定
w = w - 0.01 * gradient(x_date, y_date) #学习率定义为0.01
#打印输出是第几次训练,训练权值和训练的损失函数
w_list.append(w)
mse_list.append(cost_val)
print("epoch=", epoch, "w=", w, "loss=", cost_val)
print("predict completed:", "w=", w, forward(4))
plt.plot(w_list,mse_list)
plt.ylabel("mean Square error")
plt.xlabel("w")
plt.show()