MARS可简单理解为分段线性函数,针对某一特征变量x与响应变量y存在较为复杂的非线性关系,通过寻找合适数目(n)的cut point/knot分隔为若干(n+1)近似线性模型(hinge function)。
1、多元回归与MARS的超参数
背景知识理解
- 线性回归的假设是:特征变量与响应变量存在单调的线性关系。当面对复杂数据集有多个特征变量进行多元回归时(暂时假设变量间无相关性),有的特征变量与响应变量是有明显的线性关系;而有的变量情况比较复杂(如上图所示)。
- 在对特征变量x与响应变量y间进行分段回归时,cut-point的选择原则是得到的两个线性模型的SSE比分段之前有所降低,并选取降低最多的点(x)做为分割点。
- 一般线性回归时,1 predictor对应一个terms(可以理解为一种关系或者就是coefficients);当执行分段回归时,1 predictor就会对应 n+1 个 term(假如存在n个term)
超参数:the number of terms retained in the final model
- 如果假设n个特征变量全部与响应变量存在单调线性回归关系,那么理论上就会有 n+1 个terms(还有截距项)
- 但在多元线性回归时,有三点需要注意:
(1)首先并不是所有的特征变量都对模型预测有贡献(干扰变量);
(2)其次并不是所有有贡献的变量都是与响应变量单调线性相关的,此时可以使用分段线性回归。
(3)最后变量间可能互存在一定关联性(interaction) - 对于前两点:到底使用多少个terms参与模型建立是需要调整、优化的,即第一个超参数:the number of terms retained in the final model
- 对于第三点:通过设置超参数:the maximum degree of interactions进行调整,
2、代码实操
示例数据:预测房价
ames <- AmesHousing::make_ames()
dim(ames)
## [1] 2930 81
set.seed(123)
library(rsample)
split <- initial_split(ames, prop = 0.7,
strata = "Sale_Price")
ames_train <- training(split)
# [1] 2049 81
ames_test <- testing(split)
# [1] 881 81
library(caret)
- 对于超参数terms,一般建议每间隔10取一个候选值;
- 对于超参数degree,一般不建议超过3
# create a tuning grid
hyper_grid <- expand.grid(
degree = 1:3,
nprune = seq(2, 100, length.out = 10) %>% floor()
)
# 30 combination
head(hyper_grid)
# degree nprune
# 1 1 2
# 2 2 2
# 3 3 2
# 4 1 12
# 5 2 12
# 6 3 12
- 主要执行MARS的包为
earth
,通过caret包实现超参数组合的交叉验证的grid search
# Cross-validated model
set.seed(1111) # for reproducibility
cv_mars <- train(
x = subset(ames_train, select = -Sale_Price),
y = ames_train$Sale_Price,
method = "earth",
metric = "RMSE",
trControl = trainControl(method = "cv", number = 10),
tuneGrid = hyper_grid
)
# View results
cv_mars$bestTune
# nprune degree
# 5 45 1
# 最佳参数组合的模型性能
cv_mars$results %>%
dplyr::filter(nprune == cv_mars$bestTune$nprune, degree == cv_mars$bestTune$degree)
# degree nprune RMSE Rsquared MAE RMSESD RsquaredSD MAESD
# 1 1 45 26435.26 0.8903344 17013.83 4390.809 0.03478498 1583.466
#可视化
ggplot(cv_mars)
- 使用测试集评价模型
pred = predict(cv_mars, ames_test)
caret::RMSE(ames_test$Sale_Price, pred)
# [1] 23703.75
- 评价特征变量重要性
library(vip)
p1 <- vip(cv_mars, num_features = 40, geom = "point", value = "gcv") + ggtitle("GCV")
p2 <- vip(cv_mars, num_features = 40, geom = "point", value = "rss") + ggtitle("RSS")
gridExtra::grid.arrange(p1, p2, ncol = 2)
library(pdp)
# Construct partial dependence plots
p1 <- partial(cv_mars, pred.var = "Gr_Liv_Area", grid.resolution = 10) %>%
autoplot()
p2 <- partial(cv_mars, pred.var = "Year_Built", grid.resolution = 10) %>%
autoplot()
p3 <- partial(cv_mars, pred.var = c("Gr_Liv_Area", "Year_Built"),
grid.resolution = 10) %>%
plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, colorkey = TRUE,
screen = list(z = -20, x = -60))
# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)
此外MARS也可以用于分类问题,这里就暂不记录了