近些年,大数据技术飞速发展,各种大数据竞赛平台也随之发展起来。大数据为我们的生活带来很多便利,可以帮助我们解决超出传统统计学范围的情况。本文主要任务是预测泰坦尼克号上乘客的存活几率,所用的数据为乘客基本信息数据,包括性别、身份、所属阶级等等12个变量,针对缺失值较多超过总样本数的75%的船舱号变量变量,我们选择删除,对于缺失值数量不多的年龄变量,我们利用随机森林计算相似度来填充缺失值,分别对离散变量和连续变量采取不同的方法进行可视化和分组,最后利用得到的变量建立随机森林模型,对测试集上的乘客进行预测。
关键字:随机森林;缺失值处理;存活率;二分类问题
In recent years, big data technology has developed rapidly, and various big data competition platforms have also developed. Big data brings a lot of convenience to our lives and can help us solve situations beyond the scope of traditional statistics. The main task of this paper is to predict the survival probability of passengers on the Titanic. The data used are basic information of passengers, including 12 variables such as gender, identity, class, etc., for cabins with missing values exceeding 75% of the total number of samples. For the variable variable, we choose to delete. For the age variable with a small number of missing values, we use the random forest to calculate the similarity to fill in the missing value, and take different methods to visualize and group the discrete variable and the continuous variable respectively. Variables establish a random forest model to predict passengers on the test set.
Key words:Random Forest; Missing value processing; Survival rate; Dichotomy problem
1912年4月10日,号称 “世界工业史上的奇迹”的豪华客轮泰坦尼克号开始了自己的处女航,从英国的南安普顿出发驶往美国纽约,4月14日晚,泰坦尼克号在北大西洋撞上冰山而倾覆,1502人葬生海底,705人得救。造成了当时在和平时期最严重的一次航海事故,也是迄今为止最著名的一次海难。 在这种生死存亡的紧要关头,我们常常认为社会等级越高、影响力越大,公众认可度越高的人物,生存的概率应该越大,其次,乘客家庭成员多,成员间的协作和对求生的渴望度越高,生存的概率越高。然而,很多时候,事情产生这样的结果的原因并非我们主观臆测的那样,我们需要通过对真实数据进行科学的分析,才能发现很多事情并非我们想象的那样简单,事情产生的本质,往往隐藏在数据之中。 本文的数据来自KAGGLE大数据竞赛平台上泰坦尼克号人员存活率预测这个入门课题。主要任务是利用泰坦尼克号上乘客和船员的年龄、性别、身份等等个人信息来预测他在泰坦尼克号沉船事件中的存活概率,本文的研究思路为:先对数据集进行数据可视化,观察数据段结构特征和相关性,再对数据进行一系列的处理,包括处理缺失值、提取有用信息生成新变量等等。最后,用经过预处理所得的数据选择建立合适的模型来进行存活率预测。
根据不同目的和应用场景,数据挖掘大致可以用于聚类、分类、寻找关联性规则和异常检测等等四类。本文所研究的是一个分类问题,就是要根据已有的特征数据,进行特征提取与特征选择,拟合训练模型,建立一个有标签的二分类模型,可以处理二分类问题的模型有很多,如GBDT、KNN、神经网络、SVM和朴素贝叶斯等等。本文中所用到的具体特征提取和建模方法如下所述。 随机森林(Random Forest)是决策树的一个集成。传统的决策树模型一般准确率不太高而且容易出过拟合问题。于是,许多专家通过融合多个决策树模型来提高预测精度,这种方法叫组合(ensemble)。 首先使用训练集建立许多基分类模型,然后按照投票方式或者取基分类模型预测值的平均值的方式,来确定最后的预测值。Bagging是一种最早时候的组合树方式,也被称为自助聚集,是一种通过随机抽取训集中一部分样本来构造决策树的方法。随机分割选取是第二种组合方法,该方法的基本思想是随机从节点的n个最好的分割中选择一个分割。 随机森林是一类统计学习中的理论,它使用bootstrap重新抽样的方式在初的样本中提取大量样本,通过决策树对每个bootstrap样本进行构造模型,后组合多个决策树,使用投票的方式预测最后的结果。许多理论研究表明RF具有相当好的预测效果,在各种数据集上,与其他算法相比,都有着非常不错的优点。它可以在相当高的维度的数据上进行分析,随机选择特征子集。在建立随机森林时,在泛化误差计算上采用无偏估计,模型泛化的能力相当不错,训练速度快,易于进行并行化,在训练的时候,树之间是独立的。 然而,随机森林有其自身的某些缺点。在解决很多噪声非常多的分类问题与回归问题时,随机森林被发现可能产生过拟合的情况,在训练集和测试集上的表现相差明显。对于具有同取值属性的数据,具有更多取值的属性更可能对随机森林产生比较严重的影响。因此,随机森林在这类数据上产生的权值是不真实的。
首先,导入所需的程序包。
其次,导入泰坦尼克号的乘客信息数据,初步了解数据结构。
titanic <- read_csv("Titanic-dataset/train.csv")
## Parsed with column specification:
## cols(
## PassengerId = col_integer(),
## Survived = col_integer(),
## Pclass = col_integer(),
## Name = col_character(),
## Sex = col_character(),
## Age = col_double(),
## SibSp = col_integer(),
## Parch = col_integer(),
## Ticket = col_character(),
## Fare = col_double(),
## Cabin = col_character(),
## Embarked = col_character()
## )
summary(titanic)
## PassengerId Survived Pclass Name
## Min. : 1.0 Min. :0.0000 Min. :1.000 Length:891
## 1st Qu.:223.5 1st Qu.:0.0000 1st Qu.:2.000 Class :character
## Median :446.0 Median :0.0000 Median :3.000 Mode :character
## Mean :446.0 Mean :0.3838 Mean :2.309
## 3rd Qu.:668.5 3rd Qu.:1.0000 3rd Qu.:3.000
## Max. :891.0 Max. :1.0000 Max. :3.000
##
## Sex Age SibSp Parch
## Length:891 Min. : 0.42 Min. :0.000 Min. :0.0000
## Class :character 1st Qu.:20.12 1st Qu.:0.000 1st Qu.:0.0000
## Mode :character Median :28.00 Median :0.000 Median :0.0000
## Mean :29.70 Mean :0.523 Mean :0.3816
## 3rd Qu.:38.00 3rd Qu.:1.000 3rd Qu.:0.0000
## Max. :80.00 Max. :8.000 Max. :6.0000
## NA's :177
## Ticket Fare Cabin Embarked
## Length:891 Min. : 0.00 Length:891 Length:891
## Class :character 1st Qu.: 7.91 Class :character Class :character
## Mode :character Median : 14.45 Mode :character Mode :character
## Mean : 32.20
## 3rd Qu.: 31.00
## Max. :512.33
##
n_distinct(titanic$PassengerId)
## [1] 891
可以看到,数据集中共有12个变量和891条记录,12个变量分别为乘客编号(PassengerId)、是否存活(Survived)、乘客所属阶级(Pclass)、乘客名字及身份头衔(Name)、乘客性别(Sex)、乘客年龄(Age)、同行兄弟姐妹集配偶数(SibSp)、同行父母子女数(Parch)、船票编号(Ticket)、船票价格(Fare)、船舱号(Cabin)和登船港口(Embarked)。其中乘客年龄和船票价格为梁旭变量,其它都为离散变量。891条记录分别为891个不同的个体,无重复数据。
seq_along(titanic)%>%
map(~sum(is.na(titanic[.])))%>%
str()
## List of 12
## $ : int 0
## $ : int 0
## $ : int 0
## $ : int 0
## $ : int 0
## $ : int 177
## $ : int 0
## $ : int 0
## $ : int 0
## $ : int 0
## $ : int 687
## $ : int 2
观察各个变量的缺失值个数,可以看到,年龄变量有177个缺失值,大约占总样本数的20%,可以通过一定方法填充,船舱号变量有687个缺失值,超过总样本数的75%,因此选择直接删掉这一变量。此外,登船港口有2个缺失值,其余变量均没有缺失值。
数据可视化是指借助图形化的手段,清晰有效地传达数据所包含是内容,它可以帮助我们更好的了解数据和分析数据,从数据中获取价值。 本章,我们将对不同变量和是否存活之间的关系分别进行数据可视化和定性分析。对于离散变量一般使用柱状图分析,对于连续变量一般使用分布曲线图进行分析。 首先对所属阶级这个变量进行可视化。所属阶级和是否存活两个特征都为离散变量,因此选择柱状图。
ggplot(titanic, aes(Pclass,fill = as.factor(Survived)))+
geom_bar(position = 'dodge')+
labs(
title = paste('Number of Survivals of different class')
)+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
ggplot(titanic, aes(Pclass, fill=as.character(Survived)))+
geom_bar(position = 'fill')+
labs(
y = NULL,
title = paste('Survival rate of different class'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
从柱状图和分层柱状图中可以看出,第3阶级的乘客人数最多但存活率最低,而第1阶级和第2阶级的乘客人数相对较少,但存活率却远高于第3阶级乘客的存活率,可以看出乘客所属阶级变量与待预测变量之间有一定的相关性。 姓名是一个字符型变量,将姓名这一变量中关于乘客身份头衔这一信息的字符串提取出来生成新变量,变量名为title,然后用柱状图对title变量和存活率的关系进行分析。
titanic$title = gsub('(.*, )|(\\..*)', '', titanic$Name)
ggplot(titanic, aes(title))+
geom_bar(aes(fill = as.character(Survived)), position='dodge')+
labs(
y = NULL,
title = paste('Number of survivals of different title'))+
theme(plot.title = element_text(hjust = 0.5), legend.position="bottom")
ggplot(titanic, aes(title))+
geom_bar(aes(fill = as.character(Survived)), position='dodge')+
labs(
y = NULL,
title = paste('Number of survivals of different title'))+
ylim(-5,50)+ #view details
theme(plot.title = element_text(hjust = 0.5), legend.position="bottom")
## Warning: Removed 5 rows containing missing values (geom_bar).
ggplot(titanic, aes(title))+
geom_bar(aes(fill = as.character(Survived)), position='fill')+
labs(
y = NULL,
title = paste('Survival rate of different title'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
柱状图横轴代表不同title的组,纵轴代表人数,可以看出乘客一共有19种不同的身份,分别为Capt、Col、Don、Dr、Jonkheer、Lady、Major、Master、Miss、Mlle、Mme、Mr、Mrs、Ms、Rev、Sir、the Countess。由分层柱状图可以看到,身份为Cap、Don、Jonkheer和Rev的人全部遇难,而身份为Lady、Mlle、Mme、Ms、Sir、the Countess的人则全部获救。其中,身份为Capt、Col、Don、Dr、Jonkheer、Lady、Major、Rev、Sir和the Countess的人数较 少,我们限制纵轴的范围,观察不同身份的乘客人数细节,并结合遇难的实际信息,考虑把Cap、Don、Jonkheer和Rev合并为一种新身份low,将Col、Dr和Major合并为一种新身份medium,将Lady、Mlle、Mme、Ms、Sir和the Countess合并为一种新身份。合并之后一共有7类不同身份的人群,再对合并后的数据进行可视化。
medium_survive_title = c( 'Col','Dr','Major')
high_survive_title = c('Lady','Mlle','Ms','Mme','Sir','the Countess')
low_survive_title = c('Jonkheer','Capt','Don','Rev')
titanic$title[titanic$title %in% medium_survive_title] = 'medium'
titanic$title[titanic$title %in% high_survive_title] = 'high'
titanic$title[titanic$title %in% low_survive_title] = 'low'
ggplot(titanic, aes(title))+
geom_bar(aes(fill = as.character(Survived)))+
labs(
y = NULL,
title = paste('Number of survivals of different title'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
分层柱状图的横轴代表不同title的组,纵轴代表人数,合并后title的组数明显减少,根据其中大部分分组可以看出变量title与最后的是否存活是相关的,但Master和medium这两组title对是否存活的区分并不明显。 下一步,对性别变量进行可视化。
ggplot(titanic)+
geom_bar(aes(Sex, fill=as.character(Survived)), position = 'dodge')+
labs(
title = paste('Number of survivals of different sex'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
ggplot(titanic)+
geom_bar(aes(Sex, fill = as.factor(Survived)), position = 'fill')+
labs(
y = NULL,
title = paste('Survival rate of different sex'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
从两个柱状图中可以看出,男性乘客总人数比女性乘客多,但是男性的存活率远远低于女性,性别变量和待预测变量之间相关性显著。 下一步,对年龄变量进行可视化。
range(titanic$Age,na.rm=T)
## [1] 0.42 80.00
ggplot(titanic, aes(x = Age,y = ..density..,color = as.character(Survived)))+
geom_freqpoly(binwidth=2)+
labs(
title = paste('Probability of survival of different age'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
## Warning: Removed 177 rows containing non-finite values (stat_bin).
年龄变量是一个连续性变量,范围从0.42到80之间不等,因此选用分布曲线图对其进行可视化。 分布曲线图的横坐标为年龄,纵坐标为分布密度,蓝色曲线代表获救的人群,红色曲线代表遇难的人群,可以看出,12岁以下的儿童获救的可能性更高,而20岁左右的青年遇难的可能性更高。因此,我们对年龄变量进行离散化处理,划分为(0,12)、(12,30)、(30,90)三个区间,用柱状图观察不同年龄组人群的存活情况。
#split age into 3 groups
age_group = function(age){
cut(age,
breaks = c(0,12,30,90),
labels = c('child','young','adult'))
}
titanic = titanic%>%
mutate(age_group=age_group(as.numeric(Age)))
sum(is.na(titanic$age_group))
## [1] 177
ggplot(titanic)+
geom_bar(aes(age_group,fill = as.character(Survived)), position='dodge')+
labs(
title = paste('Number of survivals of different age groups'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
ggplot(titanic)+
geom_bar(aes(age_group,fill = as.character(Survived)), position='fill')+
labs(
y = NULL,
title = paste('Survival rate of different age groups'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
由分层柱状图可以看出,青年的存活率远低于儿童,而其余成年人的存活率介于两者之间,没有明显的相关性。年龄变量的缺失值较多,在下文的分析中会对其进行填充。 下面对兄弟姐妹配偶数和父母子女数进行可视化,这两个变量存在一定的关联性,他们共同构成了家庭成员数。
range(titanic$SibSp)
## [1] 0 8
ggplot(titanic)+
geom_bar(aes(as.character(SibSp),fill = as.character(Survived)),position='dodge')+
labs(
title = paste('Number of survivals with different number of sis , bro and spouse'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
ggplot(titanic)+
geom_bar(aes(as.character(SibSp),fill = as.character(Survived)),position='fill')+
labs(
y = NULL,
title = paste('Survival rate with different number of sis , bro and spouse'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
由两个分层柱状图对比可知,独自乘船没有兄弟姐妹配偶陪同的乘客占绝大部分,有1个兄弟姐妹配偶陪同的乘客存活的可能性最高,而兄弟姐妹配偶人数越多的乘客存活的可能性越低,兄弟姐妹配偶数为5至8人的乘客甚至无人生还。
range(titanic$Parch)
## [1] 0 6
ggplot(titanic,aes(as.character(Parch),fill = as.character(Survived)))+
geom_bar(position='dodge')+
labs(
title = paste('Number of survivals with different number of parents and children'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
ggplot(titanic,aes(as.character(Parch),fill = as.character(Survived)))+
geom_bar(position='fill')+
labs(
y=NULL,
title = paste('Survival rate with different number of parents and children'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
由分层柱状图可以看出,有父母子女陪同的乘客存活的可能性同有兄弟姐妹配偶陪同的乘客有相似性,但父母子女陪同的人数越多乘客存活率越低这一趋势不那么明显。下面,我们将兄弟姐妹配偶数和父母子女树相加再加上乘客本身,得到的数字表示家庭规模,再重新画图分析。
titanic$family_size = titanic$SibSp + titanic$Parch + 1
ggplot(titanic,aes(as.factor(family_size),fill = as.factor(Survived)))+
geom_bar(position='fill')+
labs(
y = NULL,
title = paste('Survival rate with different family size'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
柱状图的横坐标为家庭规模,也即家庭人数,纵轴蓝色柱子为获救的比例,红色柱子为遇难的比例,可以看出,家庭规模很大(达到8至11人)的乘客全部遇难,家庭规模较大(5至7人)和单独一人的乘客获救比例较低,而2至4人同行的乘客获救比例较高。家庭规模与存活率有一定的相关关系。 下一步,对船票价格变量进行可视化。
range(titanic$Fare)
## [1] 0.0000 512.3292
ggplot(titanic,aes(x = Fare,y = ..density..,color = as.character(Survived)))+
geom_freqpoly(binwidth=5)+
labs(
title=paste('Number of survivals of different age'))+
xlim(0,200)+ #view details
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
## Warning: Removed 20 rows containing non-finite values (stat_bin).
## Warning: Removed 4 rows containing missing values (geom_path).
船票价格变量是一个连续性变量,范围从0到512.33之间不等,因此选择分布曲线图对其进行可视化。 分布曲线图的横坐标为船票价格,纵坐标为分布密度,蓝色曲线代表获救的人群,红色曲线代表遇难的人群,可以看出,船票价格为15以下时,遇难的可能性较高,而船票价格高于50时,获救的可能性较高,因此,将船票价格分为(0,15)、(15,50)、(50,520)三个区间,用柱状图观察不同船票价格组人群的存活情况。
#split fare into 3 groups
fare_group = function(fare){
cut(fare,
breaks = c(-1,15,50,520),
labels = c('cheap','medium','expensive'))
}
titanic = titanic%>%
mutate(fare_group=fare_group(Fare))
ggplot(titanic,aes(fare_group,fill = as.character(Survived)))+
geom_bar(position='fill')+
labs(
y = NULL,
title = paste('Survival rate with different family size'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
从上面分层柱状图可以看出,船票价格越高,乘客存活的可能性越高,船票价格与存活率之间相关性较高。 最后,对登船港口编号进行可视化。
ggplot(titanic,aes(as.character(Embarked),fill = as.character(Survived)))+
geom_bar(position='fill')+
labs(
y = NULL,
title = paste('Survival rate with different family size'))+
theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
登船港口这一变量有两个缺失值,可以对其进行缺失值填充,但从主观上来看,登船港口与乘客是否获救没有逻辑上的关系,因此,在构建模型时不加入登船港口这一变量。
titanic$Pclass=as.factor(titanic$Pclass)
titanic$Sex=as.factor(titanic$Sex)
titanic$fare_group=as.factor(titanic$fare_group)
titanic$title=as.factor(titanic$title)
titanic$age_group=as.factor(titanic$age_group)
本文选择放弃登船港口和船舱号这两个变量,只需要对177个样本的年龄变量缺失值进行填充,这里,采用随机森林算法计算样本之间的接近度,以此来估算预测数据中缺失值并进行填充。
titanic=rfImpute(Survived ~ Pclass + Sex + fare_group + title + age_group +
family_size, data=titanic, iter=3, ntree=500)
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
## | Out-of-bag |
## Tree | MSE %Var(y) |
## 500 | 0.1245 52.64 |
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
## | Out-of-bag |
## Tree | MSE %Var(y) |
## 500 | 0.124 52.42 |
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
## | Out-of-bag |
## Tree | MSE %Var(y) |
## 500 | 0.1237 52.31 |
本文利用上述变量,建立了随机森林模用于分类和预测。随机森林从决策树中发展演变而来,是一种非常适用于离散变量的分类和预测模型。具体建模过程如下。
#拆分数据集
train = titanic[1:600,]
test = titanic[601:891,]
#构建预测模型
titanic.rf1 = randomForest(Survived ~ Pclass + Sex + fare_group + title +
age_group + family_size, data = train,
ntree=100, importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf1)
titanic.rf1
##
## Call:
## randomForest(formula = Survived ~ Pclass + Sex + fare_group + title + age_group + family_size, data = train, ntree = 100, importance = TRUE, mtry = 2, proximity = T)
## Type of random forest: regression
## Number of trees: 100
## No. of variables tried at each split: 2
##
## Mean of squared residuals: 0.1343301
## % Var explained: 43.71
importance=importance(x=titanic.rf1)#值越大说明变量的重要性越强
importance
## %IncMSE IncNodePurity
## Pclass 18.176312 9.360425
## Sex 8.890673 18.628440
## fare_group 13.307486 5.572320
## title 13.828542 29.116470
## age_group 9.055191 4.089925
## family_size 15.030640 9.090589
varImpPlot(titanic.rf1)
首先,指定树的数量为100,曲线图的横坐标为树的棵树,纵坐标为模型在训练集上的均方误差(MSE),可以看出,随着树的棵树增加,均方误差先迅速下降,随后逐渐变平缓并稳定在一定的值——0.1334左右。对模型的分类变量进行打分,发现年龄组(age_group)这一变量对纯度的贡献最小,因此尝试删除这一变量建立一个新的模型。
#构建预测模型
titanic.rf2 = randomForest(Survived ~ Pclass + Sex + fare_group + title +
family_size, data = train, ntree=100,
importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf2)
titanic.rf1
##
## Call:
## randomForest(formula = Survived ~ Pclass + Sex + fare_group + title + age_group + family_size, data = train, ntree = 100, importance = TRUE, mtry = 2, proximity = T)
## Type of random forest: regression
## Number of trees: 100
## No. of variables tried at each split: 2
##
## Mean of squared residuals: 0.1343301
## % Var explained: 43.71
importance=importance(x=titanic.rf2)#值越大说明变量的重要性越强
importance
## %IncMSE IncNodePurity
## Pclass 15.264003 10.259386
## Sex 8.642028 18.560983
## fare_group 13.547646 5.518051
## title 17.206129 31.043194
## family_size 16.137833 9.825409
varImpPlot(titanic.rf2)
去掉年龄分组这一变量后发现MSE依然为0.1334左右,模型的可解释性并没有显著提高,因此仍然保留第一个模型。 其次,我们通过用模型对测试集的数据进行预测,得到存活的概率,计算MSE值,寻找最佳拟合的树的棵树。
MSE=double()
for (i in 1:100){
titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title +
age_group + family_size, data = train,
ntree=i, importance=TRUE, mtry = 2, proximity=T)
prediction = predict(titanic.rf, test[2:7])
MSE=c(MSE,sum((prediction-test[[1]])^2)/length(prediction))
}
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
MSE
## [1] 0.1184807 0.1325057 0.1163700 0.1221974 0.1239209 0.1332497 0.1178497
## [8] 0.1278534 0.1241177 0.1208822 0.1238956 0.1195575 0.1241976 0.1203918
## [15] 0.1166526 0.1239402 0.1194311 0.1217139 0.1214666 0.1193377 0.1197277
## [22] 0.1169179 0.1219785 0.1181112 0.1188443 0.1199948 0.1195515 0.1207829
## [29] 0.1154690 0.1210636 0.1181157 0.1213097 0.1193304 0.1169185 0.1209694
## [36] 0.1202204 0.1192203 0.1235869 0.1181122 0.1173937 0.1172977 0.1187657
## [43] 0.1205281 0.1196954 0.1188193 0.1179140 0.1195792 0.1195855 0.1191853
## [50] 0.1219210 0.1186126 0.1189868 0.1196729 0.1189038 0.1188475 0.1191219
## [57] 0.1192790 0.1188679 0.1202945 0.1186181 0.1197412 0.1163795 0.1194676
## [64] 0.1188208 0.1189766 0.1184523 0.1215369 0.1192862 0.1193127 0.1221132
## [71] 0.1180586 0.1199215 0.1195164 0.1206050 0.1208982 0.1190916 0.1208144
## [78] 0.1192770 0.1213136 0.1193227 0.1192813 0.1201426 0.1200201 0.1188056
## [85] 0.1189111 0.1185196 0.1193564 0.1184725 0.1190194 0.1202242 0.1198914
## [92] 0.1180638 0.1175627 0.1179119 0.1193513 0.1187116 0.1178236 0.1195129
## [99] 0.1181008 0.1166771
plot(1:100, MSE, type='l', xlab='ntree', main='MSE on test set')
上示折线图横坐标为树的棵树,纵坐标为计算出来的模型预测结果在测试集上MSE数值,由图可以看出,模型的均方误差在树的棵树达到15时就已经趋于稳定,树的棵树为22时,测试集上的预测误差最小,因此,结合训练集上的MSE值,选择树的棵树为30时建立模型比较合理。 接下来,选择每棵树节点中用于二叉树的变量个数mtry。
MSE=double()
for (i in 1:5){
titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title +
age_group + family_size, data = train,
ntree=100, importance=TRUE, mtry = i, proximity=T)
prediction = predict(titanic.rf, test[2:7])
MSE=c(MSE,sum((prediction-test[[1]])^2)/length(prediction))
}
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
MSE
## [1] 0.1275496 0.1192860 0.1220171 0.1230154 0.1245706
plot(1:5, MSE, type='l', xlab='mtry', main='MSE on test set')
上示折线图横坐标为每棵树节点中用于二叉树的变量个数,纵坐标为计算出来的测试集上MSE数值,由图可以看出,每棵树节点中用于二叉树的变量个数为2时,测试集上的预测误差最小,因此,选择每棵树节点中用于二叉树的变量个数为2时建立模型比较合理。 最后,选择树的棵树为30,每棵树节点中用于二叉树的变量个数为2,建立最终随机森林模型并进行预测。
titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title +
age_group + family_size, data = train,
ntree=30, importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf)
titanic.rf
##
## Call:
## randomForest(formula = Survived ~ Pclass + Sex + fare_group + title + age_group + family_size, data = train, ntree = 30, importance = TRUE, mtry = 2, proximity = T)
## Type of random forest: regression
## Number of trees: 30
## No. of variables tried at each split: 2
##
## Mean of squared residuals: 0.135006
## % Var explained: 43.42
prediction = predict(titanic.rf, test[2:7])
prediction
## 601 602 603 604 605 606
## 0.72561573 0.12550731 0.24012987 0.15010325 0.24012987 0.08510214
## 607 608 609 610 611 612
## 0.12550731 0.36677737 0.72561573 0.98806807 0.55930078 0.12550731
## 613 614 615 616 617 618
## 0.73040038 0.12550731 0.15010325 0.85324018 0.05273942 0.69217388
## 619 620 621 622 623 624
## 0.89728099 0.15535733 0.11978069 0.20952459 0.12818724 0.12550731
## 625 626 627 628 629 630
## 0.17999744 0.24012987 0.06246657 0.98516964 0.12550731 0.12550731
## 631 632 633 634 635 636
## 0.24012987 0.15010325 0.36872179 0.24958208 0.22948575 0.88694656
## 637 638 639 640 641 642
## 0.15010325 0.06717356 0.43263411 0.07516471 0.12550731 0.98799151
## 643 644 645 646 647 648
## 0.22948575 0.29434988 0.77102073 0.20952459 0.12550731 0.36872179
## 649 650 651 652 653 654
## 0.12550731 0.58858436 0.12550731 0.96386813 0.12550731 0.58858436
## 655 656 657 658 659 660
## 0.58858436 0.08220039 0.12550731 0.63818882 0.15535733 0.39599733
## 661 662 663 664 665 666
## 0.33054206 0.15010325 0.24012987 0.15010325 0.11978069 0.16788304
## 667 668 669 670 671 672
## 0.15535733 0.12550731 0.15010325 0.99327009 0.95813908 0.20952459
## 673 674 675 676 677 678
## 0.16332821 0.16332821 0.15535733 0.12550731 0.12550731 0.58858436
## 679 680 681 682 683 684
## 0.55930078 0.20952459 0.58858436 0.18427797 0.12550731 0.05200790
## 685 686 687 688 689 690
## 0.06717356 0.03458462 0.05200790 0.12550731 0.12550731 0.98986661
## 691 692 693 694 695 696
## 0.20952459 0.48420426 0.29434988 0.12550731 0.36872179 0.16332821
## 697 698 699 700 701 702
## 0.15010325 0.58858436 0.39599733 0.15010325 0.95426055 0.24012987
## 703 704 705 706 707 708
## 0.38223786 0.12550731 0.11978069 0.13162343 0.97853420 0.24012987
## 709 710 711 712 713 714
## 0.98516964 0.75720379 0.82288682 0.24012987 0.20952459 0.12550731
## 715 716 717 718 719 720
## 0.16332821 0.12550731 0.98806807 0.88694656 0.17999744 0.15010325
## 721 722 723 724 725 726
## 0.92844552 0.11978069 0.16332821 0.16332821 0.26212573 0.12550731
## 727 728 729 730 731 732
## 0.72561573 0.58858436 0.04010310 0.38223786 0.98516964 0.29907363
## 733 734 735 736 737 738
## 0.15535733 0.15535733 0.15535733 0.17999744 0.44734000 0.13795074
## 739 740 741 742 743 744
## 0.12550731 0.12550731 0.24012987 0.20952459 0.95939979 0.07516471
## 745 746 747 748 749 750
## 0.15010325 0.29709705 0.12818724 0.88694656 0.26212573 0.15010325
## 751 752 753 754 755 756
## 0.94519766 0.59416123 0.15010325 0.12550731 0.86210759 0.72911956
## 757 758 759 760 761 762
## 0.12550731 0.15535733 0.15010325 0.93132204 0.12550731 0.15010325
## 763 764 765 766 767 768
## 0.12550731 0.88493676 0.12550731 0.99327009 0.36872179 0.35583739
## 769 770 771 772 773 774
## 0.07516471 0.15010325 0.12550731 0.15010325 0.97853420 0.12550731
## 775 776 777 778 779 780
## 0.88036130 0.12550731 0.12550731 0.54604513 0.12550731 0.99327009
## 781 782 783 784 785 786
## 0.58858436 0.95426055 0.36677737 0.06927101 0.12550731 0.12550731
## 787 788 789 790 791 792
## 0.58858436 0.14497916 0.75998157 0.13795074 0.12550731 0.10866631
## 793 794 795 796 797 798
## 0.24003096 0.24012987 0.12550731 0.16332821 0.50744634 0.48969497
## 799 800 801 802 803 804
## 0.12550731 0.68911471 0.16332821 0.95813908 0.76747391 0.59416123
## 805 806 807 808 809 810
## 0.12550731 0.15010325 0.15923074 0.58858436 0.16332821 0.99327009
## 811 812 813 814 815 816
## 0.12550731 0.19334344 0.16332821 0.31178734 0.15010325 0.24958208
## 817 818 819 820 821 822
## 0.58858436 0.06717356 0.15010325 0.14497916 0.97660342 0.12550731
## 823 824 825 826 827 828
## 0.12301262 0.40968654 0.14497916 0.12550731 0.29434988 0.92869692
## 829 830 831 832 833 834
## 0.12550731 0.99160342 0.40968654 0.92869692 0.12550731 0.12550731
## 835 836 837 838 839 840
## 0.12550731 0.98417918 0.12550731 0.12550731 0.35750755 0.24012987
## 841 842 843 844 845 846
## 0.12550731 0.15535733 0.75235091 0.15010325 0.12550731 0.15010325
## 847 848 849 850 851 852
## 0.03316014 0.15010325 0.04999025 0.99327009 0.19758859 0.15010325
## 853 854 855 856 857 858
## 0.72770807 0.81337111 0.97869463 0.40968654 0.97660342 0.24012987
## 859 860 861 862 863 864
## 0.72307607 0.12550731 0.05273942 0.11609362 0.86380834 0.24003096
## 865 866 867 868 869 870
## 0.15535733 0.97853420 0.91716301 0.13795074 0.12550731 0.61786494
## 871 872 873 874 875 876
## 0.12550731 0.97660342 0.15923074 0.15010325 0.89400959 0.58858436
## 877 878 879 880 881 882
## 0.12550731 0.12550731 0.12550731 0.99327009 0.89400959 0.15010325
## 883 884 885 886 887 888
## 0.58858436 0.15535733 0.12550731 0.43263411 0.08764166 0.75235091
## 889 890 891
## 0.77102073 0.36677737 0.15010325
MSE=sum((prediction-test[[1]])^2)/length(prediction)
MSE
## [1] 0.1225102
最终得到的模型在测试集上的MSE值为0.1172,预测结果大于0.5则认为则将该乘客分类为1,预测结果小于0.5则认为则将该乘客分类为0.
本文利用KAGGLE大数据竞赛平台上泰坦尼克号人员存活率预测竞赛课题,根据泰坦尼克号上乘客和船员的年龄、性别、身份等等个人信息来预测他在泰坦尼克号沉船事件中的存活概率,本文主要实现了数据的可视化,进行特征工程,包括用随机森林算法填充理缺失值、从姓名中提取title信息生成新变量,将年龄和船票价格进行离散化等等。最后,用经过预处理所得的数据建立随机森林模型来预测存活率,最终得到的预测结果为一个概率值,其越接近1表示该乘客获救的可能性越高,最终模型在测试集上的均方误差为0.1172。 均方误差为0.1172表示模型预测能力还可以进一步地提升,在后期的模型优化过程中可以进行如下工作:1)进一步的完善前期数据处理,可以提取生成一些区分性能更加的变量加入预测模型;2)尝试更多的模型,如GBDT、KNN、神经网络等等;3)进行模型融合,进一步提升预测能力。