学习曲线:sklearn.model_selection.learning_curve

第一:学习曲线

    学习曲线是一种用来判断训练模型的一种方法,它会自动 把训练样本的数量按照预定的规则逐渐增加,然后画出不同训练样本数量时的模型准确度。

    我们可以把Jtrain(theta) and Jtest(theta)作为纵坐标,画出与训练集数据集m的大小关系,这就是学习曲线。通过学习曲线,可以直观地观察到模型的准确性和训练数据大小的关系。 我们可以比较直观的了解到我们的模型处于一个什么样的状态,如:过拟合(overfitting)或欠拟合(underfitting)

    如果数据集的大小为m,则通过下面的流程即可画出学习曲线:

1.把数据集分成训练数据集和交叉验证数据集(可以看作测试机)

2.取训练数据及的20%作为训练样本,训练出模型参数。

3.使用交叉验证数据集来计算训练出来的模型的准确性。

4.以续联数据及的准确性和交叉验证的准确性为纵坐标,训练数据集个数作为横坐标,在坐标轴上画出上述步骤计算出来的模型准确性。

5.训练数据集增加10%,调到步骤2,继续执行,知道训练数据集大小为100%。

第二:比较

参考链接:https://blog.csdn.net/u012328159/article/details/79255433

  1. learning_curve():这个函数主要是用来判断(可视化)模型是否过拟合的,关于过拟合,就不多说了,具体可以看以前的博客:模型选择和改进
(X,y) = datasets.load_digits(return_X_y=True)

train_sizes,train_score,test_score = learning_curve(RandomForestClassifier(),X,y,train_sizes=[0.1,0.2,0.4,0.6,0.8,1],cv=10,scoring='accuracy')

train_error =  1- np.mean(train_score,axis=1)

test_error = 1- np.mean(test_score,axis=1)

plt.plot(train_sizes,train_error,'o-',color = 'r',label = 'training')

plt.plot(train_sizes,test_error,'o-',color = 'g',label = 'testing')

plt.legend(loc='best')

plt.xlabel('traing examples')

plt.ylabel('error')

plt.show()
  1. validation_curve():这个函数主要是用来查看在参数不同的取值下模型的性能
(X,y) = datasets.load_digits(return_X_y=True)

# print(X[:2,:])

param_range = [10,20,40,80,160,250]

train_score,test_score = validation_curve(RandomForestClassifier(),X,y,param_name='n_estimators',param_range=param_range,cv=10,scoring='accuracy')

train_score =  np.mean(train_score,axis=1)

test_score = np.mean(test_score,axis=1)

plt.plot(param_range,train_score,'o-',color = 'r',label = 'training')

plt.plot(param_range,test_score,'o-',color = 'g',label = 'testing')

plt.legend(loc='best')

plt.xlabel('number of tree')

plt.ylabel('accuracy')

plt.show()

第三:参数解释

from sklearn.model_selection import learning_curve

参数解释:参考:https://blog.csdn.net/gracejpw/article/details/102370364

image

X : array-like, shape (n_samples, n_features) Training vector, where n_samples is the number of samples and n_features is the number of features.

是一个m*n的矩阵,m:样品数量,n:特征数量

y : array-like, shape (n_samples) or (n_samples, n_features), optional Target relative to X for classification or regression; None for unsupervised learning.

是一个m*1的矩阵,m:样品数量,相对于X的目标进行分类或回归

groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into train/test set.

将数据集拆分为训练/测试集时使用的样本的标签分组。**[可选]**

**train_sizes **: array-like, shape (n_ticks,), dtype float or int Relative or absolute numbers of training examples that will be used to generate the learning curve. If the dtype is float, it is regarded as a fraction of the maximum size of the training set (that is determined by the selected validation method), i.e. it has to be within (0, 1]. Otherwise it is interpreted as absolute sizes of the training sets. Note that for classification the number of samples usually have to be big enough to contain at least one sample from each class. (default: np.linspace(0.1, 1.0, 5))

指定训练样品数量的变化规则。比如:np.linspace(0.1, 1.0, 5)表示把训练样品数量从0.1-1分成5等分,生成[0.1, 0.325,0.55,0.75,1]的序列,从序列中取出训练样品数量百分比,逐个计算在当前训练样本数量情况下训练出来的模型准确性。

**cv **: int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy.

交叉验证拆分策略,可以使用sklearn.model_selection.ShuffleSplit

    None,要使用默认的三折交叉验证(v0.22版本中将改为五折)

    整数,用于指定(分层)KFold中的折叠数,

    CV splitter

    可迭代的集(训练,测试)拆分为索引数组。

    对于整数/无输入,如果估计器是分类器,y是二进制或多类,则使用StratifiedKFold。在所有其他情况下,都使用KFold。

scoring:字符串,可调用或无,可选,默认:None,模型性能的评价指标,如(‘accuracy’、‘f1’、”mean_squared_error”等)

exploit_incremental_learning:布尔值,可选,默认值:False

如果估算器支持增量学习,此参数将用于加快拟合不同训练集大小的速度。

n_jobs:int或None,可选(默认=None)

要并行运行的作业数。None表示1。 -1表示使用所有处理器。

pre_dispatch:整数或字符串,可选

并行执行的预调度作业数(默认为全部)。该选项可以减少分配的内存。该字符串可以是“ 2 * n_jobs”之类的表达式。

shuffle:布尔值,可选

是否在基于``train_sizes’'为前缀之前对训练数据进行洗牌。

random_state:int,RandomState实例或无,可选(默认=None)

如果为int,则random_state是随机数生成器使用的种子;否则为false。如果是RandomState实例,则random_state是随机数生成器;如果为None,则随机数生成器是np.random使用的RandomState实例。在shuffle为True时使用。

error_score:‘raise’ | ‘raise-deprecating’ 或数字

如果估算器拟合中出现错误,则分配给分数的值。如果设置为“ raise”,则会引发错误。如果设置为“raise-deprecating”,则会在出现错误之前打印FutureWarning。如果给出数值,则引发FitFailedWarning。此参数不会影响重新安装步骤,这将始终引发错误。默认值为“不赞成使用”,但从0.22版开始,它将更改为np.nan。

返回值:

image

第四:使用


from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import ShuffleSplit

from sklearn.model_selection import train_test_split

from sklearn.model_selection import learning_curve

from sklearn.preprocessing import PolynomialFeatures

from sklearn.pipeline import Pipeline

from sklearn.datasets import load_breast_cancer

import matplotlib.pyplot as plt

import numpy as np

import time

cancer = load_breast_cancer()

X      = cancer.data

y      = cancer.target

def polynomial_model(degree = 1, **kargs):

    polynomial_features = PolynomialFeatures(degree = degree, include_bias = False)

    logistic_regression = LogisticRegression(**kargs)

    pipeline            = Pipeline([("pf", polynomial_features),

                                    ("lr", logistic_regression)])

    return pipeline

def plot_learning_curve(plt, estimator, title, X, y, ylim = None, cv = None, n_jobs = 1, train_size = np.linspace(0.1,1,5)):

    plt.title(title)

    if ylim is not None:

        plt.ylim(*ylim)

    plt.xlabel("Training examples")

    plt.ylabel("Score")

    train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv = cv, n_jobs = n_jobs, train_sizes = train_size)

    print("train_sizes:\n",train_sizes, "\ntrain_scores:\n",train_scores, "\ntest_scores:\n",test_scores)

    train_scores_mean = np.mean(train_scores, axis = 1)

    test_scores_mean  = np.mean(test_scores, axis = 1)

    train_scores_std  = np.std(train_scores, axis = 1)

    test_scores_std  = np.std(test_scores,  axis = 1)

    plt.grid()

    plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, alpha = 0.1,color = "r")

    plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, alpha = 0.1,color = "g")

    plt.plot(train_sizes, train_scores_mean, "o-", color = "r", label = "Training score")

    plt.plot(train_sizes, test_scores_mean,"o-", color = "g", label = "Cross-validation score")

    plt.legend(loc = "best")

    return plt

cv = ShuffleSplit(n_splits = 10, test_size = 0.2, random_state = 0)

title = "Learning Curves (degreee={0}, penalty={1})"

degrees = [1,2]

penalty = ["l1", "l2"]

start = time.clock()

plt.figure(figsize = (12,4), dpi = 144)

j = 0

for p in penalty:

    for i in range(len(degrees)):

        plt.subplot(len(penalty), len(degrees), j + 1)

        plot_learning_curve(plt, polynomial_model(degree = degrees[i], penalty = p), title.format(degrees[i], p), X, y, ylim = (0.8,1.01), cv = cv)

        j += 1

plt.tight_layout()

plt.savefig("1.png")

learning_curve的返回值结果如下:


learning_curve返回结果展示

共选择了5组数据且选择了10折交叉验证,所以,train_sizes 为5个元素的narray,train_scores 和 test_scores为5*10的矩阵,每一行,为一次数据的每一折的结果,对其求平均值,作为最终的准确性。
第五:性能评估

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,189评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,577评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,857评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,703评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,705评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,620评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,995评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,656评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,898评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,639评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,720评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,395评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,982评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,953评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,195评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,907评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,472评论 2 342