参考李航的统计学习
感知机学习算法
Python实现感知机代码
import numpy as np
w = np.array([0,0])
b = 0
#更新w和b
def update(item):
global w,b
w = w + item[1]*np.array(item[0])
b = b + item[1]
def check(date_set):
global w,b
flag = False
#遍历训练集
for item in date_set:
x=np.array(item[0])
jieguo = item[1] * (np.dot(w,x)+b)
#判断是否误分类
if jieguo <= 0:
flag = True
update(item)
print(w,b)
return flag
if __name__ == "__main__":
#训练集
date_set = [[(3,3),1],[(4,3),1],[(1,1),-1]]
#不断测试模型,当模型不存在误分类,停止训练
for i in range(500):
if not check(date_set):break
Python代码实现对偶形式
import numpy as np
#训练集
date_set = np.array([[[3,3],1],[[4,3],1],[[1,1],-1]])
#设置alpha和b的初始值
a = np.zeros((len(date_set),1),np.float)
b = 0.0
x = np.empty((len(date_set),2),np.float)
for i in range(len(date_set)):
x[i] = date_set[i,0]
y = np.array(date_set[:,1])
Gm = None
#求Gram矩阵
def gram():
gm = np.empty((len(date_set),len(date_set)),np.int)
for i in range(len(date_set)):
for j in range(len(date_set)):
gm[i][j] = np.dot(date_set[i][0],date_set[j][0])
return gm
#更新alpha和b的值
def update(i):
global a,b
a[i] += 1
b += y[i]
#测试模型
def check():
global a,b
flag = False
for i in range(len(date_set)):
jieguo = 0
for j in range(len(date_set)):
jieguo += a[j]*y[j]*Gm[i,j]
jieguo = (jieguo + b)*y[i]
print(jieguo)
if jieguo <= 0:
flag = True
update(i)
if not flag:
w=0.0
for i in range(len(date_set)):
w += a[i]*y[i]*x[i]
print(w,b)
return False
return True
if __name__ == "__main__":
Gm = gram()
for i in range(1000):
if not check():break