scikit-learn 中的超参数优化(网格搜索)

网格遍历搜索

超参数调优,最基本的就是网格搜索的方法。最常用的是网格遍历的方法,其遍历给定的参数组合,来比较模型在各参数组合下的表现。

scikit-learn.model_selection中提供了GridSerachCV,来实现网格搜索。构造网格搜索,需要指定以下元素:

  • 一个分类或回归的学习器
  • 一个需要搜索的参数空间
  • 一个交叉验证模式
  • 一个评分方法

在GridSearchCV中对应的参数为

  • estimator:要实现的学习器对象。学习器要有scoring评分方法。
  • param_grid:列表或字典表示的参数字典网格。
  • cv:指定交叉验证策略。
    1、不传参数,默认使用5层交叉验证。
    2、传入整数参数时,则为指定层数的交叉验证。
    3、或传入交叉验证方法对象。
    4、也可传入一个可迭代的由数据索引数组表示的(train,test)分割。
    不传参或整数参数的情况下,对分类器使用分层交叉验证,其他情况使用K折叠交叉验证。
  • scoring:指定评分方法
    默认使用,学习器自身的scoring方法。可以是指定字符串,自定义的可调评分方法。同时可以将多种评分方法以列表的形式传入。详细内容参见官方文档

此外还有一些参数可用于性能调优,如n_jobs、pre_dispatch可以指定并行运算的作业数量。

GridSearchCV实现了 fit 和 score 方法,学习器自身的方法,也可以通过GridSeachCV直接调用。训练后可通过best_estimator、best_score、best_params等属性获取最佳参数等信息。

网格随机搜索

除了遍历搜索,还可以进行随机搜索,sklearn.model_selection中的RandomizedSearchCV就是这样的方法。

其参数和使用方法与GridSeachCV类似,只是不对所有可能的超参数组合进行遍历。而是根据参数n_iter,生成指定个数的超参数组合,并在其上进行比较。

适用于超参数非常多,不支持穷举遍历的情况,可以结合启发式搜索方法进行参数调优。

scikit-learn中的超参数组合方法

除了遍历和随机的网格搜索方法,scikit-learn也附带提供了遍历和随机的超参数组合方法。只返回超参数组合的集合。

ParameterGrid方法生所有超参数的组合。并可以通过其他函数读取。
ParameterSampler则进行随机组合。不遍历所有超参数组合,而是采用随机采样的方式组合超参数,并生成指定n_iter个组合供迭代使用。

这两种方法只生成超参数的组合。在不使用GridSeachCV或RandomizedSearchCV,而使用其他调优方法时可以应用。

其他优化方法

scikit-learn中还提供了一些模型特定的交叉验证方法,这些方法可以提升验证效率。
包括ElasticNetCV、LassoCV、RidgeCV、LogisticRegressionCV等等。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容