用KNN解决非线性回归问题

一直以为KNN只是分类算法,只能在分类上用,昨天突然想起用KNN试试做回归,最近有一批数据,通过4个特征来预测1个值,原来用线性回归和神经网络尝试过,准确率只能到40%左右。用KNN结合网格搜索和交叉验证,正确率达到了79%,没错,KNN解决回归问题也很赞。

什么是KNN

KNN就是K近邻算法(k-NearestNeighbor),百度百科是这么写的:K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

KNN怎么做回归

要预测的点的值通过求与它距离最近的K个点的值的平均值得到,这里的“距离最近”可以是欧氏距离,也可以是其他距离,具体的效果依数据而定,思路一样。如下图,x轴是一个特征,y是该特征得到的值,红色点是已知点,要预测第一个点的位置,则计算离它最近的三个点(黄色线框里的三个红点)的平均值,得出第一个绿色点,依次类推,就得到了绿色的线,可以看出,这样预测的值明显比直线准。


K=3的拟合.png

上述例子是基于一个特征的,如果是一个特征向量怎么办?其实一样,距离的衡量通过求两个特征向量的欧氏距离或者皮尔逊系数或者余弦距离就行。

parametric learner和non-parametric learner

parametric learner就是像线性回归一样,给一个y=mx+b的函数,找合适的m和b参数。non-parametric learner则没有猜测的函数,KNN做回归就是一个non-parametric learner,最终它也没有得到一个方程,只是能很好地作出预测。parametric learner的优点在于不用存储原始数据,训练慢但是查询快,缺点是不能轻易更新模型;non-parametric learner的优点在于更改模型容易,训练快但是查询慢,缺点是需要存储所有点,消耗空间。

KNN解决非线性回归问题

问题解决流程按照上篇的机器学习项目流程与模型评估验证完成。

数据准备

数据如下,一个csv表格,黄色是4个特征值,绿色是1和待预测值。


数据.png

加载数据

import numpy as np
import pandas as pd
from sklearn.model_selection import ShuffleSplit

# %matplotlib inline  将图表输出内嵌到jupyter notebook中,如果不用jupyter可以忽略这句

data = pd.read_csv('fdata.csv')
data = data[data['Friction']>16]   # 这句和下句点作用是去除异常数据
data = data[data['Friction']<30]  
friction = data['Friction'] 
features = data.drop('Friction', axis = 1) #特征向量为原数据集剔除待预测列
print("数据共有{}条,每条含有{}个特征.".format(*features.shape))

输出为数据共有1635条,每条含有27个特征.

数据分割与重排

这一步使用train_test_split将数据随机拆分为80%的训练集与20%的测试集。如果不设定random_state,划分结果不那么随机,指定了random_state后,划分结果是随机的(具体工作原理没有细查,有朋友知道的感谢指教)。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(features, friction, test_size=0.2, random_state=50)

# Success
print("训练集与测试集拆分成功,训练集有{}条,测试集有{}条。".format(X_train.shape[0], X_test.shape[0]))

输出为训练集与测试集拆分成功,训练集有1304条,测试集有327条。

定义衡量标准

这一步给模型表现定义一个衡量标准,也就是最后通过什么指标来看模型训练的表现,如果在训练中用了交叉验证来找模型的最优参数,在交叉验证里就可以调用这个衡量标准做评分。上篇的流程图中写过,分类问题的衡量标准有accuracy、precision、recall、F_bate分数,回归问题的衡量标准有平均绝对误差,均方误差,R2分数和可释方差分数。这里用R2分数。

from sklearn.metrics import r2_score
def performance_metric(y_true, y_predict):
    """ Calculates and returns the performance score between 
        true and predicted values based on the metric chosen. """
    
    score = r2_score(y_true, y_predict)
   
    return score

训练模型

重头戏到了,这个部分训练模型,我用了网格搜索和交叉验证从{3,4,5,6,7,8,9,10}里寻找R2分数最高的K作为最优参数,然后用这个K进行预测。我用了shuffleSplit和K-fold两种交叉验证。

  • shuffleSplit
from sklearn.metrics import make_scorer
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV

def fit_model_shuffle(X, y):
    """ Performs grid search over the 'max_depth' parameter for a 
        decision tree regressor trained on the input data [X, y]. """
    
    # Create cross-validation sets from the training data
    cv_sets = ShuffleSplit(n_splits = 10, test_size = 0.20, random_state = 0)

    # Create a KNN regressor object
    regressor = KNeighborsRegressor()
    # Create a dictionary for the parameter 'n_neighbors' with a range from 3 to 10
    params = {'n_neighbors':range(3,10)}

    # Transform 'performance_metric' into a scoring function using 'make_scorer' 
    scoring_fnc = make_scorer(performance_metric)

    # Create the grid search object
    grid = GridSearchCV(regressor, param_grid=params,scoring=scoring_fnc,cv=cv_sets)

    # Fit the grid search object to the data to compute the optimal model
    grid = grid.fit(X, y)

    # Return the optimal model after fitting the data
    return grid.best_estimator_
  • k-fold
from sklearn.model_selection import KFold
def fit_model_k_fold(X, y):
    """ Performs grid search over the 'max_depth' parameter for a 
        decision tree regressor trained on the input data [X, y]. """
    
    # Create cross-validation sets from the training data
    # cv_sets = ShuffleSplit(n_splits = 10, test_size = 0.20, random_state = 0)
    k_fold = KFold(n_splits=10)
    
    # TODO: Create a decision tree regressor object
    regressor = KNeighborsRegressor()

    # TODO: Create a dictionary for the parameter 'max_depth' with a range from 1 to 10
    params = {'n_neighbors':range(3,10)}

    # TODO: Transform 'performance_metric' into a scoring function using 'make_scorer' 
    scoring_fnc = make_scorer(performance_metric)

    # TODO: Create the grid search object
    grid = GridSearchCV(regressor, param_grid=params,scoring=scoring_fnc,cv=k_fold)

    # Fit the grid search object to the data to compute the optimal model
    grid = grid.fit(X, y)

    # Return the optimal model after fitting the data
    return grid.best_estimator_

网格搜索返回的是一个Gridsearch的object,想用它的哪个属性就用哪个属性,API都写的很清楚,我这里返回最好的一个estimator。
用下面代码查看找到的最优K:

# Fit the training data to the model using grid search
reg = fit_model_k_fold(X_train, y_train)

print "Parameter 'n_neighbors' is {} for the optimal model.".format(reg.get_params()['n_neighbors'])

用shuffleSplit找到的最优k是8,用k-fold找到的最优k是9。

预测

# Show predictions
for i, friction in enumerate(reg.predict(features)):
    print(friction)

预测表现

用上面定义的衡量标准来衡量预测表现

print(performance_metric(y_test, reg.predict(X_test)))

到这里,整个模型就完成了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,457评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,837评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,696评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,183评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,057评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,105评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,520评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,211评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,482评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,574评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,353评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,897评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,174评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,489评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,683评论 2 335

推荐阅读更多精彩内容