作者:袁良杰
2019年3月15日
比赛说明
RMS泰坦尼克号沉没是历史上最臭名昭着的沉船之一。1912年4月15日,在她的处女航中,泰坦尼克号在与冰山相撞后沉没,在2224名乘客和机组人员中造成1502人死亡。这场耸人听闻的悲剧震惊了国际社会,并为船舶制定了更好的安全规定。
造成海难失事的原因之一是乘客和机组人员没有足够的救生艇。尽管幸存下沉有一些运气因素,但有些人比其他人更容易生存,例如妇女,儿童和上流社会。
在这个挑战中,我们要求您完成对哪些人可能存活的分析。特别是,我们要求您运用机器学习工具来预测哪些乘客幸免于悲剧。
1.加载包
library(ggplot2)
ibrary(ggthemes)
library(scales)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(VIM)
## Loading required package: colorspace
## Loading required package: grid
## Loading required package: data.table
##
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
##
## between, first, last
## VIM is ready to use.
## Since version 4.0.0 the GUI is in its own package VIMGUI.
##
## Please use the package to use the new (and old) GUI.
## Suggestions and bug-reports can be submitted at: https://github.com/alexkowa/VIM/issues
##
## Attaching package: 'VIM'
## The following object is masked from 'package:datasets':
##
## sleep
library(mice)
## Loading required package: lattice
##
## Attaching package: 'mice'
## The following objects are masked from 'package:base':
##
## cbind, rbind
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
2.读取数据,查看数据
train = read.csv("C:/Users/袁良杰/Desktop/Titanic/train.csv", stringsAsFactors = F)
test = read.csv("C:/Users/袁良杰/Desktop/Titanic/test.csv", stringsAsFactors = F)
full = bind_rows(train, test)str(full)
## 'data.frame': 1309 obs. of 12 variables:
## $ PassengerId: int 1 2 3 4 5 6 7 8 9 10 ...
## $ Survived : int 0 1 1 1 0 0 0 0 1 1 ...
## $ Pclass : int 3 1 3 1 3 3 1 3 3 2 ...
## $ Name : chr "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
## $ Sex : chr "male" "female" "female" "female" ...
## $ Age : num 22 38 26 35 35 NA 54 2 27 14 ...
## $ SibSp : int 1 1 0 1 0 0 0 3 0 1 ...
## $ Parch : int 0 0 0 0 0 0 0 1 2 0 ...
## $ Ticket : chr "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
## $ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
## $ Cabin : chr "" "C85" "" "C123" ...
## $ Embarked : chr "S" "C" "S" "S" ...
##所有的变量全部展现出来,我们从这些变量中寻找隐藏信息
3.特征工程
(1)称谓
#我们注意到乘客名字中包含着身份地位信息,这可能是影响生存的一个重要变量#获取所有乘客姓名的中的称谓
full$Title <- gsub("(.*, )|(\\..*)", "", full$Name)
#查看每称谓下不同性别人数
table(full$Sex, full$Title)
##
## Capt Col Don Dona Dr Jonkheer Lady Major Master Miss Mlle Mme
## female 0 0 0 1 1 0 1 0 0 260 2 1
## male 1 4 1 0 7 1 0 2 61 0 0 0
##
## Mr Mrs Ms Rev Sir the Countess
## female 0 197 2 0 0 1
## male 757 0 0 8 1 0
#把较高地位的少数人合并为一类
rare.title = c("Capt", "Col", "Don", "Dona", "Dr", "Jonkheer", "Lady", "Major", "Rev", "Sir", "the Countess")
#同时把剩下的Mlle(法国少女),Mme(夫人),Ms重新分配类别
full$Title[full$Title == "Mlle"] = "Miss"
full$Title[full$Title == "Ms"] = "Miss"
full$Title[full$Title == "Mme"] = "Mrs"
full$Title[full$Title %in% rare.title] = "Rare.Title"
#查看新称谓下不同性别人数
table(full$Sex, full$Title)
##
## Master Miss Mr Mrs Rare.Title
## female 0 264 0 198 4
## male 61 0 757 0 25
(2)家庭规模
#乘客中有些是独自一人,有些是家庭成员一起,根据姓氏可以划分出家庭规模变量
full$Surname = sapply(full$Name, function(x) strsplit(x, split = "[,.]")[[1]][1])
cat(paste( nlevels(factor(full$Surname)), "个不同的姓氏"))
## 875个不同的姓氏
#根据乘客的配偶与父母子女的数量创建家庭规模变量
full$Fsize = full$SibSp + full$Parch + 1
full$Family = paste(full$Surname, full$Fsize, sep = "-")
#可视化不同家庭规模的生存与遇难情况
ggplot(full[1:891,],aes(x = Fsize, fill = factor(Survived))) +
geom_bar(position = "fill") +
scale_x_continuous(breaks = c(1:12)) +
labs(x = "家庭规模", y = "生存与遇难比")
#由图看出,单人与家庭规模在五人及以上时生存率较低,因此我们将家庭规模分为三类full$Fsize2[full$Fsize == 1] = "single"
full$Fsize2[full$Fsize >1 & full$Fsize <= 4] = "small"
full$Fsize2[full$Fsize >4] = "large"
mosaicplot(table(full$Fsize2, full$Survived), main = "家庭规模与生存率", shade = T)
#再看具有大量缺失值的Cabin变量full$Cabin[1:100]
## [1] "" "C85" "" "C123" ""
## [6] "" "E46" "" "" ""
## [11] "G6" "C103" "" "" ""
## [16] "" "" "" "" ""
## [21] "" "D56" "" "A6" ""
## [26] "" "" "C23 C25 C27" "" ""
## [31] "" "B78" "" "" ""
## [36] "" "" "" "" ""
## [41] "" "" "" "" ""
## [46] "" "" "" "" ""
## [51] "" "" "D33" "" "B30"
## [56] "C52" "" "" "" ""
## [61] "" "B28" "C83" "" ""
## [66] "" "F33" "" "" ""
## [71] "" "" "" "" ""
## [76] "F G73" "" "" "" ""
## [81] "" "" "" "" ""
## [86] "" "" "" "C23 C25 C27" ""
## [91] "" "" "E31" "" ""
## [96] "" "A5" "D10 D12" "" ""
#查看前100个观测,观察到cabin变量较复杂(A—G,6—123)这里不进行分析
4.缺失值
data(full,package = "VIM")
## Warning in data(full, package = "VIM"): data set 'full' not found
md.pattern(full)
## PassengerId Pclass Name Sex SibSp Parch Ticket Cabin Embarked Title
## 714 1 1 1 1 1 1 1 1 1 1
## 331 1 1 1 1 1 1 1 1 1 1
## 177 1 1 1 1 1 1 1 1 1 1
## 86 1 1 1 1 1 1 1 1 1 1
## 1 1 1 1 1 1 1 1 1 1 1
## 0 0 0 0 0 0 0 0 0 0
## Surname Fsize Family Fsize2 Fare Age Survived
## 714 1 1 1 1 1 1 1 0
## 331 1 1 1 1 1 1 0 1
## 177 1 1 1 1 1 0 1 1
## 86 1 1 1 1 1 0 0 2
## 1 1 1 1 1 0 1 0 2
## 0 0 0 0 1 263 418 682
#年龄中含有263个缺失值,费用中含有1个缺失值which(is.na(full$Age))
## [1] 6 18 20 27 29 30 32 33 37 43 46 47 48 49
## [15] 56 65 66 77 78 83 88 96 102 108 110 122 127 129
## [29] 141 155 159 160 167 169 177 181 182 186 187 197 199 202
## [43] 215 224 230 236 241 242 251 257 261 265 271 275 278 285
## [57] 296 299 301 302 304 305 307 325 331 335 336 348 352 355
## [71] 359 360 365 368 369 376 385 389 410 411 412 414 416 421
## [85] 426 429 432 445 452 455 458 460 465 467 469 471 476 482
## [99] 486 491 496 498 503 508 512 518 523 525 528 532 534 539
## [113] 548 553 558 561 564 565 569 574 579 585 590 594 597 599
## [127] 602 603 612 613 614 630 634 640 644 649 651 654 657 668
## [141] 670 675 681 693 698 710 712 719 728 733 739 740 741 761
## [155] 767 769 774 777 779 784 791 793 794 816 826 827 829 833
## [169] 838 840 847 850 860 864 869 879 889 902 914 921 925 928
## [183] 931 933 939 946 950 957 968 975 976 977 980 983 985 994
## [197] 999 1000 1003 1008 1013 1016 1019 1024 1025 1038 1040 1043 1052 1055
## [211] 1060 1062 1065 1075 1080 1083 1091 1092 1097 1103 1108 1111 1117 1119
## [225] 1125 1135 1136 1141 1147 1148 1157 1158 1159 1160 1163 1165 1166 1174
## [239] 1178 1180 1181 1182 1184 1189 1193 1196 1204 1224 1231 1234 1236 1249
## [253] 1250 1257 1258 1272 1274 1276 1300 1302 1305 1308 1309
which(is.na(full$Fare))
#第1044行
## [1] 1044
#第1044行是从S港上船 的三等舱的乘客信息,将其他同样的乘客进行可视化
ggplot(full[full$Pclass == "3" & full$Embarked == "S",], aes(x = Fare)) +
geom_density(fill = "green", alpha=0.4) +
geom_vline(aes(xintercept=median(Fare, na.rm=T)),colour="red", linetype=2, lwd=1) +
scale_x_continuous(breaks = c(0:60))
## Warning: Removed 1 rows containing non-finite values (stat_density).
#利用其他3号舱和S港上船的乘客的票价中位数填充到1044行的费用列中
full$Fare[1044] = median(full[full$Pclass == "3" & full$Embarked == "S",]$Fare, na.rm = T)
##对年龄变量缺失值进行多重插补
Factor.Vars = c("PassengerId", "Pclass", "Sex", "Title", "Surname", "Fsize2", "Family")
full[Factor.Vars] = lapply(full[Factor.Vars], function(x) as.factor(x))imp = mice(full[,names(full) %in% c("Pclass", "Sex", "Age", "Fare", "Title", "Fsize2", "SbiSp", "Parch")], seed = 1234)
## ## iter imp variable
## 1 1 Age
## 1 2 Age
## 1 3 Age
## 1 4 Age
## 1 5 Age
## 2 1 Age
## 2 2 Age
## 2 3 Age
## 2 4 Age
## 2 5 Age
## 3 1 Age
## 3 2 Age
## 3 3 Age
## 3 4 Age
## 3 5 Age
## 4 1 Age
## 4 2 Age
## 4 3 Age
## 4 4 Age
## 4 5 Age
## 5 1 Age
## 5 2 Age
## 5 3 Age
## 5 4 Age
## 5 5 Age
mice.imp = complete(imp, action = 5)
#将结果与原始数据比较
par(mfrow=c(1,2))
hist(full$Age, freq = F, main = "full$Age", col="orange", ylim = c(0,0.04) )
hist(mice.imp$Age, freq = F, main = "mice.imp$Age", col="lightblue", ylim = c(0,0.04) )
#两者十分接近,此时用插补后数据代替原始数据
full$Age = mice.imp$Age
sum(is.na(full$Age))
#此时不存在年龄缺失值
## [1] 0
特征工程2
##年龄缺失处理完之后,以年龄区间创建新变量
full$Aduch[full$Age < 18] = "Chid"
full$Aduch[full$Age >= 18] = "Adult"
table(full$Aduch, full$Survived)
##
## 0 1
## Adult 480 271
## Chid 69 71
full$Aduch = factor(full$Aduch)
data(full,package = "VIM")
## Warning in data(full, package = "VIM"): data set 'full' not found
md.pattern(full)
#对所有的变量的缺失值检查,最后进行预测
## PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
## 891 1 1 1 1 1 1 1 1 1 1 1
## 418 1 1 1 1 1 1 1 1 1 1 1
## 0 0 0 0 0 0 0 0 0 0 0
## Title Surname Fsize Family Fsize2 Aduch Survived
## 891 1 1 1 1 1 1 1 0
## 418 1 1 1 1 1 1 0 1
## 0 0 0 0 0 0 418 418
5.预测
##将整个数据集拆分成训练集和测试机
train = full[1:891,]
test = full[892:1309,]
#使用随机森林算法在训练集上建立模型
set.seed(1234)
rf.model = randomForest(factor(Survived) ~ Pclass + Sex + Age + Fare + Title + Fsize2 + SibSp + Parch + Aduch, data = train, na.action = na.roughfix, importance = T)
#并不是所有的变量都可以使用#给出变量重要性
importance(rf.model, type = 2)
## MeanDecreaseGini
## Pclass 33.465766
## Sex 51.314368
## Age 53.151593
## Fare 67.072898
## Title 80.053375
## Fsize2 17.148189
## SibSp 12.992146
## Parch 8.156831
## Aduch 3.940061
forest.pred = predict(rf.model, test)
forest.perf = table(test$Survived, forest.pred )
##从结果可以看出,所有变量中,Title变量的重要性相对最高,表明泰坦尼克沉船灾难中生存率与当时乘客的社会地位十分相关
##最后将验证集结果保存在表格文件中
answer = data.frame(PassengerID = test$PassengerId, Survived = forest.pred)
write.csv(answer, file = "ans_predict.csv", row.names