Scikit-Learn&TensorFlow机器学习实用指南(四):分类【上】

作者:LeonG
本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow 机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

本文全部代码和数据集保存在我的github-----LeonG的github

机器学习的监督学习任务中最常见的任务是回归(用于预测某个值)和分类(预测某个类别)。

分类任务
回归任务

在这一章,我们重点来学习分类任务

Scikit-Learn包是一个集成了大部分机器学习经典算法的python库,它可以帮助我们快速上手机器学习的众多应用。

注意:本项目主要使用Jupyter Notebook进行开发,Jupyter Notebook是一个简便易上手的python工作环境,源代码请使用Jupyter Notebook打开。

1.MINIST数据集

1.1下载数据集

首先给大家介绍一下机器学习的经典数据集--MNIST,它是拥有70000张手写数字图片的数据集合。这个数据集经典程度相当于机器学习领域的“Hello World”。

from sklearn.datasets import fetch_mldata
#先将mnist数据集下载到项目目录(./mldata/mnist-original.mat)
#下载地址:https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat
mnist = fetch_mldata('MNIST Original',data_home='./')
mnist
image

介绍一下这三个标签的意义:

  • DESCR键描述数据集

  • target键存放一个标签数组

  • data键存放一个图片数组,数组的一行表示一个样例,也就是一张图片

然后看看这个数据集的结构,主要是查看维度

x,y = mnist["data"],mnist["target"]
x.shape
y.shape
image

现在我们总结一下MNIST数据集的特点:

MNIST有70000张图片,众所周知,图片是由像素点构成的。

784意味着每张图片拥有784个像素点,这是因为每张图片都是28*28像素的,且每个像素点都介于0~255之间。

target标签表示对应位置的图片是什么数字。

我相信每个接触到MNIST数据集的人都迫不及待地想看一下图片是什么样子了,说实话我也是(搓手手),让我们随便展示一张图片。

import matplotlib
from matplotlib import pyplot as plt
some_digit = x[36000] #第36001张图片
img = some_digit.reshape(28,28)  #还原成28*28的结构
plt.imshow(img,cmap = matplotlib.cm.binary, interpolation="nearest")
plt.show()
y[36000] #顺便看一下标签
image

哟,不错,整挺好,可以看到图片上的数字和标签正好对应起来了,有了这样一个利器,我们就能慢慢揭开机器学习分类任务的神秘面纱了。

部分mnist图片

1.1划分训练集

少侠稍安勿躁,在进行下一步之前,我们总是要先将数据集分为训练集和测试集两个部分,最好还能打乱一下顺序,因为有的算法对顺序的敏感度很高。

  1. 把前60000条化为训练集,后10000条归为测试集
  2. 打乱训练集的顺序,测试集的顺序mnist已经帮我们打乱好了
#分割数据集
x_train,x_test,y_train,y_test = x[:60000],x[60000:],y[:60000],y[60000:]
#打乱训练集
import numpy as np
shuffie_index = np.random.permutation(60000)
x_train,y_train = x_train[shuffie_index],y_train[shuffie_index]

接下来正式开始学习分类任务。

2.训练一个二分类器

鲁迅曾经说过:学走先学爬(鲁迅:我没说过),多分类属于比较高级的分类任务,我们先做个简单的,比如鉴定一张图片是否是数字5

这个“数字 5 检测器”就是一个二分类器,能够识别两个类,“是5”和“非5”。让我们为这个分类任务创建新的标签,也就是将所有数字标签转为是和否两种标签(bool类型)

y_train_5 = (y_train == 5) #训练集的标签
y_test_5 = (y_test == 5) #测试集的标签
image

接下来就是选择一个分类器了,给大家介绍一下本次的主角SGD(随机梯度下降分类器)。其实随机梯度下降分类器并不是一个独立的算法,而是一系列利用随机梯度下降求解参数的算法集合。

这个分类器默认用的算法是SVM(线性支持向量机),既然是线性的,自然就适合二分类方法,毕竟非黑即白嘛。而且该分类器的好处是处理大量数据时非常高效,让我们先创建一个SGDClassifier分类器然后训练一遍。

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(x_train,y_train_5) #训练分类器

这里的代码很常用所以解释一下,首先调用SGDClassifier函数新建了一个分类器,随机种子设为42,算法中需要的参数都是按照某种方式随机生成的,然后调用fit训练函数,输入训练数据,通过梯度下降法训练参数,最后得到的sgd_clf是训练好的分类器

image

现在随便测试一张图片:

sgd_clf.predict(X[36000])
array([ True], dtype=bool)

分类器猜测这个数字代表 5( True )。看起来在这个例子当中,它猜对了。现在让我们评估 这个模型的性能。

3.评估性能

首先我们应该使用交叉验证法,具体原理参考上一章中提到的K这交叉验证法。在这里我们只管调用函数 cross_val_score就好了。

K = 3意味着将把训练集分成3折,使用其他折进行训练,剩下的折用于测试精确度。

from sklearn.model_selection import cross_val_score
#cv代表折数,scoring为计算方式,这里计算精准度
cross_val_score(sgd_clf,x_train,y_train_5,cv = 3,scoring="accuracy")
array([ 0.9502 , 0.96565, 0.96495]

看起来准确率非常高,但是并不是这样的,想象一下,假如你全猜“非5”,准确率依然有90%!这是因为图片5的数量太少了,只占总数据的十分之一。所以准确率并不能很好的表达模型的精度。

那怎么样才能判断模型的精准度呢,现在我们打开新世界的大门。

3.1混淆矩阵

对于分类器来说,混淆矩阵是个不错的判断精度工具,什么是混淆矩阵呢?

大致的意思就是类别A被判断为B的次数,我们做一个简单的表格就能看懂了,其中加粗的对角线就是正确分类的数量。

image

首先,我们使用K折交叉验证得出一系列的预测,也就是模型得出的预测值。使用函数cross_val_predict,注意,有别于上面提到的 cross_val_score,那个是分数,这个是预测值,然后使用函数confusion_matrix得到混淆矩阵,是不是很简单?

from sklearn.model_selection import cross_val_predict
#计算预测值
y_train_pred = cross_val_predict(sgd_clf,x_train,y_train_5,cv = 3)

from sklearn.metrics import confusion_matrix
#计算混淆矩阵
confusion_matrix(y_train_5,y_train_pred)
array([[53887,   692],
       [ 1279,  4142]], dtype=int64)

来,总结一下这个矩阵中四个值代表的意义:(正例:是5,反例:非5)

53887:代号TN,全名true negative,真反例,就是反例分对了的意思。

1279:代号FN,全名false negative,假反例,就是错分为反例的意思。

692:代号FP,全名fasle positive,假正例,就是错分为正例的意思。

4142:代号TP,全名true positive,真正例,就是正例分对了的意思。

好了,我不是有意要绕晕你的,如果你没看懂,直接看图啦

image

形象生动(并没有)

基于以上的划分方式,我们要引出准确率、召回率这对欢喜冤家。

3.2准确率与召回率

准确率:precision = {TP\over TP+FP}

召回率:recall = {TP\over TP+FN}

准确率的意义很简单,当模型给出一个预测时,该预测的可靠程度。

召回率是指所有正例中,被正确预测的个数。

我们测试一下刚才建立的分类模型,准确率的计算函数是precision_score,召回率是recall_score

from  sklearn.metrics import precision_score,recall_score
precision_score(y_train_5,y_train_pred)
recall_score(y_train_5,y_train_pred)
0.8568473314025652
0.7640656705404907

也就是说,该模型有85.6%的几率预测正确,但是所有的图片5,只有76%被正确识别了。

而准确率和召回率是一对难以调和的冤家,因为当你提高准确率,也就是更加严格的判断是否为5,可能很多的图片5被误杀,导致召回率直线下降。

通俗点说:

准确率高:宁缺毋滥(判断正例的数据少,但是准度高)
召回率高:宁可错杀一千,不可放过一个(判断正例的数据多,但是准度低)

其中意味,需要你自己细细体会,这里将准确率和召回率呈现的趋势展示给大家看:

image

提高阈值,准确率上升,召回率下降,反之准确率下降,召回率提升,那么最好的位置就是两者都在0.8左右。

两者无法兼得,所以提出新的判断标准,F1值,近似为两者的平均值。但是计算方式考虑到了较小的值有更大的权重。

F1 = {2\over {1\over preccsion}+{1\over recall}}

看不懂公式不重要,只要知道F1代表的是两者的平均值就好,咱们直接调函数f1_score

from sklearn.metrics import f1_score
 f1_score(y_train_5, y_train_pred)
0.78468208092485547

有时候你需要很高的准确率,比如判断水果是否变质,我们只需要输出的水果都是好水果就行了。

有时候你需要很高的召回率,比如判断背包内是否携带易燃易爆物品,即使是错误识别了某些背包,也不能让危险的背包通过。

有时候你需要兼顾二者,就着力于提高F1值。

那么如何调整准确率和召回率呢?最直接的办法就是调整阈值了,简单地说就是调整判断图片是否为5的标准。

一般来说,分类器会得出一张图片是正例(是5)的得分值,我们只要调整得分值的判断标准,就相当于调整了了阈值。

调用decision_function方法,这个方法返回了每个样例的得分,然后基于这个得分,你可以任意的调整阈值。

计算所有的样例得分值方法很简单,把交叉验证函数中的method改为decision_function就可以了:

y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3,method="decision_function")

下面我们单独对第36001张图片进行判断,阈值设为0,这张图我们在前面已经看到了,是图片5

y_score = sgd_clf.decision_function([x[36000]]) #第36001张图片的得分
y_score
threshold = 0 #阈值设为0
y_some_digit_pred = (y_score > threshold) #分类
y_some_digit_pred
array([1066.83525987])
array([ True])

当阈值等于0的时候,这张图片被判断为正例(是5)

阈值对准确率和召回率的影响在前面那张曲线图已经体现出来了, 你可以翻上去看看。

3.3ROC曲线

全称受试者工作特征曲线(ROC)是二分类器中常用的工具,这个曲线能反应预测为正例的数据中正确数据的比例,计算公式也很简单:ROC = {P(TP) \over P(FP)}

TP是真正例,也就是把图片5预测成5的数量

FP是假正例,也就是把不是图片5预测成5的数量

函数P就是对应的数量除以样本总数,也就是占比的意思。

再让大家复习一下这张图:

image

从下图的ROC曲线来看,在曲线最左上角的位置,大概是坐标(0.05,0.9)的位置。在左上角位置的时候,被错分为正例的数量最少,效果也是最好的。

image

调用函数roc_auc_score可以很快的算出roc值。

from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
0.9659714668088117

3.4小结

最后为性能评估做个总结,分类的任务远不止二分类,还有多分类方法,但是分类器的性能评估基本上都是用这些工具。

检验自己的分类器性能一般可以使用交叉验证来评估你的分类器,然后选择满足你需要的准确率/召回率位置或者是最佳ROC位置,找到合适的阈值点。至于选择哪种性能评估工具,取决于你的分类器更注重哪方面的性能。

4.本章总结

关于机器学习的分类任务上半部分就讲完了,这一章我们主要学习了如何训练一个二分类器,如何评估一个分类器的性能,下一章我们继续学习分类任务,不过主要是侧重于多分类任务~


欢迎来我的博客留言讨论,我的博客主页:LeonG的博客

我的知乎机器学习专栏:LeonG与机器学习

本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,444评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,421评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,036评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,363评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,460评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,502评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,511评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,280评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,736评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,014评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,190评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,848评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,531评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,159评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,411评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,067评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,078评论 2 352

推荐阅读更多精彩内容