线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。线性回归在假设特证满足线性关系,根据给定的训练数据训练一个模型,并用此模型进行预测。
有一组“工龄 - 工资”的数据表,我们假设它满足线性关系y = a + bx,其中x为工龄,y为工资。
工龄:0 1 2 3 4 5 6 7 8 9 10
工资:103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600
定义损失函数J(a, b) ,求其偏导,得到梯度下降的公式。推导过程如下:
示例代码如下:
import matplotlib.pyplot as plt
import numpy as np
y = (103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600)
def calc_diff_a(a, b):
sum = 0
for x in range(0, 11):
sum = sum + 2 * a + 2 * b * x - 2 * y[x]
return sum
def calc_diff_b(a, b):
sum = 0
for x in range(0, 11):
sum = sum + x * (2 * a + 2 * b * x - 2 * y[x])
return sum
def cost(a, b):
sum = 0
for x in range(0, 11):
sum = sum + (a*a + b*b*x*x + 2*a*b*x - 2*a*y[x] - 2*b*x*y[x] + y[x]*y[x])
return sum;
if __name__ == "__main__":
num1 = 100000
num2 = 1
ratio = 0.0001
itercnt = 0
while itercnt < 50000:
tmp1 = calc_diff_a(num1, num2)
tmp2 = calc_diff_b(num1, num2)
num1 = num1 - ratio * tmp1
num2 = num2 - ratio * tmp2
itercnt = itercnt + 1
#print(tmp1, tmp2, cost(num1, num2))
print(num1, num2)
listx = np.linspace(0,10,11)
listy = num1 + num2 * listx
plt.figure()
plt.plot(listx, y, '*')
plt.plot(listx, listy)
plt.show()
运行结果如下:
a = 103086.36363635205
b = 1848.181818183475