摘要

近些年,大数据技术飞速发展,各种大数据竞赛平台也随之发展起来。大数据为我们的生活带来很多便利,可以帮助我们解决超出传统统计学范围的情况。本文主要任务是预测泰坦尼克号上乘客的存活几率,所用的数据为乘客基本信息数据,包括性别、身份、所属阶级等等12个变量,针对缺失值较多超过总样本数的75%的船舱号变量变量,我们选择删除,对于缺失值数量不多的年龄变量,我们利用随机森林计算相似度来填充缺失值,分别对离散变量和连续变量采取不同的方法进行可视化和分组,最后利用得到的变量建立随机森林模型,对测试集上的乘客进行预测。

关键字:随机森林;缺失值处理;存活率;二分类问题

Abstract

In recent years, big data technology has developed rapidly, and various big data competition platforms have also developed. Big data brings a lot of convenience to our lives and can help us solve situations beyond the scope of traditional statistics. The main task of this paper is to predict the survival probability of passengers on the Titanic. The data used are basic information of passengers, including 12 variables such as gender, identity, class, etc., for cabins with missing values exceeding 75% of the total number of samples. For the variable variable, we choose to delete. For the age variable with a small number of missing values, we use the random forest to calculate the similarity to fill in the missing value, and take different methods to visualize and group the discrete variable and the continuous variable respectively. Variables establish a random forest model to predict passengers on the test set.

Key words:Random Forest; Missing value processing; Survival rate; Dichotomy problem

背景介绍与研究思路

1912年4月10日,号称 “世界工业史上的奇迹”的豪华客轮泰坦尼克号开始了自己的处女航,从英国的南安普顿出发驶往美国纽约,4月14日晚,泰坦尼克号在北大西洋撞上冰山而倾覆,1502人葬生海底,705人得救。造成了当时在和平时期最严重的一次航海事故,也是迄今为止最著名的一次海难。 在这种生死存亡的紧要关头,我们常常认为社会等级越高、影响力越大,公众认可度越高的人物,生存的概率应该越大,其次,乘客家庭成员多,成员间的协作和对求生的渴望度越高,生存的概率越高。然而,很多时候,事情产生这样的结果的原因并非我们主观臆测的那样,我们需要通过对真实数据进行科学的分析,才能发现很多事情并非我们想象的那样简单,事情产生的本质,往往隐藏在数据之中。 本文的数据来自KAGGLE大数据竞赛平台上泰坦尼克号人员存活率预测这个入门课题。主要任务是利用泰坦尼克号上乘客和船员的年龄、性别、身份等等个人信息来预测他在泰坦尼克号沉船事件中的存活概率,本文的研究思路为:先对数据集进行数据可视化,观察数据段结构特征和相关性,再对数据进行一系列的处理,包括处理缺失值、提取有用信息生成新变量等等。最后,用经过预处理所得的数据选择建立合适的模型来进行存活率预测。

基本方法与原理

根据不同目的和应用场景,数据挖掘大致可以用于聚类、分类、寻找关联性规则和异常检测等等四类。本文所研究的是一个分类问题,就是要根据已有的特征数据,进行特征提取与特征选择,拟合训练模型,建立一个有标签的二分类模型,可以处理二分类问题的模型有很多,如GBDT、KNN、神经网络、SVM和朴素贝叶斯等等。本文中所用到的具体特征提取和建模方法如下所述。 随机森林(Random Forest)是决策树的一个集成。传统的决策树模型一般准确率不太高而且容易出过拟合问题。于是,许多专家通过融合多个决策树模型来提高预测精度,这种方法叫组合(ensemble)。 首先使用训练集建立许多基分类模型,然后按照投票方式或者取基分类模型预测值的平均值的方式,来确定最后的预测值。Bagging是一种最早时候的组合树方式,也被称为自助聚集,是一种通过随机抽取训集中一部分样本来构造决策树的方法。随机分割选取是第二种组合方法,该方法的基本思想是随机从节点的n个最好的分割中选择一个分割。 随机森林是一类统计学习中的理论,它使用bootstrap重新抽样的方式在初的样本中提取大量样本,通过决策树对每个bootstrap样本进行构造模型,后组合多个决策树,使用投票的方式预测最后的结果。许多理论研究表明RF具有相当好的预测效果,在各种数据集上,与其他算法相比,都有着非常不错的优点。它可以在相当高的维度的数据上进行分析,随机选择特征子集。在建立随机森林时,在泛化误差计算上采用无偏估计,模型泛化的能力相当不错,训练速度快,易于进行并行化,在训练的时候,树之间是独立的。 然而,随机森林有其自身的某些缺点。在解决很多噪声非常多的分类问题与回归问题时,随机森林被发现可能产生过拟合的情况,在训练集和测试集上的表现相差明显。对于具有同取值属性的数据,具有更多取值的属性更可能对随机森林产生比较严重的影响。因此,随机森林在这类数据上产生的权值是不真实的。

数据预处理

首先,导入所需的程序包。

其次,导入泰坦尼克号的乘客信息数据,初步了解数据结构。

titanic <- read_csv("Titanic-dataset/train.csv")
## Parsed with column specification:
## cols(
##   PassengerId = col_integer(),
##   Survived = col_integer(),
##   Pclass = col_integer(),
##   Name = col_character(),
##   Sex = col_character(),
##   Age = col_double(),
##   SibSp = col_integer(),
##   Parch = col_integer(),
##   Ticket = col_character(),
##   Fare = col_double(),
##   Cabin = col_character(),
##   Embarked = col_character()
## )
summary(titanic)
##   PassengerId       Survived          Pclass          Name          
##  Min.   :  1.0   Min.   :0.0000   Min.   :1.000   Length:891        
##  1st Qu.:223.5   1st Qu.:0.0000   1st Qu.:2.000   Class :character  
##  Median :446.0   Median :0.0000   Median :3.000   Mode  :character  
##  Mean   :446.0   Mean   :0.3838   Mean   :2.309                     
##  3rd Qu.:668.5   3rd Qu.:1.0000   3rd Qu.:3.000                     
##  Max.   :891.0   Max.   :1.0000   Max.   :3.000                     
##                                                                     
##      Sex                 Age            SibSp           Parch       
##  Length:891         Min.   : 0.42   Min.   :0.000   Min.   :0.0000  
##  Class :character   1st Qu.:20.12   1st Qu.:0.000   1st Qu.:0.0000  
##  Mode  :character   Median :28.00   Median :0.000   Median :0.0000  
##                     Mean   :29.70   Mean   :0.523   Mean   :0.3816  
##                     3rd Qu.:38.00   3rd Qu.:1.000   3rd Qu.:0.0000  
##                     Max.   :80.00   Max.   :8.000   Max.   :6.0000  
##                     NA's   :177                                     
##     Ticket               Fare           Cabin             Embarked        
##  Length:891         Min.   :  0.00   Length:891         Length:891        
##  Class :character   1st Qu.:  7.91   Class :character   Class :character  
##  Mode  :character   Median : 14.45   Mode  :character   Mode  :character  
##                     Mean   : 32.20                                        
##                     3rd Qu.: 31.00                                        
##                     Max.   :512.33                                        
## 
n_distinct(titanic$PassengerId)
## [1] 891

可以看到,数据集中共有12个变量和891条记录,12个变量分别为乘客编号(PassengerId)、是否存活(Survived)、乘客所属阶级(Pclass)、乘客名字及身份头衔(Name)、乘客性别(Sex)、乘客年龄(Age)、同行兄弟姐妹集配偶数(SibSp)、同行父母子女数(Parch)、船票编号(Ticket)、船票价格(Fare)、船舱号(Cabin)和登船港口(Embarked)。其中乘客年龄和船票价格为梁旭变量,其它都为离散变量。891条记录分别为891个不同的个体,无重复数据。

seq_along(titanic)%>%
  map(~sum(is.na(titanic[.])))%>% 
  str()
## List of 12
##  $ : int 0
##  $ : int 0
##  $ : int 0
##  $ : int 0
##  $ : int 0
##  $ : int 177
##  $ : int 0
##  $ : int 0
##  $ : int 0
##  $ : int 0
##  $ : int 687
##  $ : int 2

观察各个变量的缺失值个数,可以看到,年龄变量有177个缺失值,大约占总样本数的20%,可以通过一定方法填充,船舱号变量有687个缺失值,超过总样本数的75%,因此选择直接删掉这一变量。此外,登船港口有2个缺失值,其余变量均没有缺失值。

数据可视化

数据可视化是指借助图形化的手段,清晰有效地传达数据所包含是内容,它可以帮助我们更好的了解数据和分析数据,从数据中获取价值。 本章,我们将对不同变量和是否存活之间的关系分别进行数据可视化和定性分析。对于离散变量一般使用柱状图分析,对于连续变量一般使用分布曲线图进行分析。 首先对所属阶级这个变量进行可视化。所属阶级和是否存活两个特征都为离散变量,因此选择柱状图。

ggplot(titanic, aes(Pclass,fill = as.factor(Survived)))+
  geom_bar(position = 'dodge')+
  labs(
    title = paste('Number of Survivals of different class')
  )+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

ggplot(titanic, aes(Pclass, fill=as.character(Survived)))+
  geom_bar(position = 'fill')+
  labs(
    y = NULL,
    title = paste('Survival rate of different class'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

从柱状图和分层柱状图中可以看出,第3阶级的乘客人数最多但存活率最低,而第1阶级和第2阶级的乘客人数相对较少,但存活率却远高于第3阶级乘客的存活率,可以看出乘客所属阶级变量与待预测变量之间有一定的相关性。 姓名是一个字符型变量,将姓名这一变量中关于乘客身份头衔这一信息的字符串提取出来生成新变量,变量名为title,然后用柱状图对title变量和存活率的关系进行分析。

titanic$title = gsub('(.*, )|(\\..*)', '', titanic$Name)
ggplot(titanic, aes(title))+
  geom_bar(aes(fill = as.character(Survived)), position='dodge')+
  labs(
    y = NULL,
    title = paste('Number of survivals of different title'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position="bottom")

ggplot(titanic, aes(title))+
  geom_bar(aes(fill = as.character(Survived)), position='dodge')+
  labs(
    y = NULL,
    title = paste('Number of survivals of different title'))+
  ylim(-5,50)+  #view details
  theme(plot.title = element_text(hjust = 0.5), legend.position="bottom")
## Warning: Removed 5 rows containing missing values (geom_bar).

ggplot(titanic, aes(title))+
  geom_bar(aes(fill = as.character(Survived)), position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate of different title'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

柱状图横轴代表不同title的组,纵轴代表人数,可以看出乘客一共有19种不同的身份,分别为Capt、Col、Don、Dr、Jonkheer、Lady、Major、Master、Miss、Mlle、Mme、Mr、Mrs、Ms、Rev、Sir、the Countess。由分层柱状图可以看到,身份为Cap、Don、Jonkheer和Rev的人全部遇难,而身份为Lady、Mlle、Mme、Ms、Sir、the Countess的人则全部获救。其中,身份为Capt、Col、Don、Dr、Jonkheer、Lady、Major、Rev、Sir和the Countess的人数较 少,我们限制纵轴的范围,观察不同身份的乘客人数细节,并结合遇难的实际信息,考虑把Cap、Don、Jonkheer和Rev合并为一种新身份low,将Col、Dr和Major合并为一种新身份medium,将Lady、Mlle、Mme、Ms、Sir和the Countess合并为一种新身份。合并之后一共有7类不同身份的人群,再对合并后的数据进行可视化。

medium_survive_title = c( 'Col','Dr','Major')
high_survive_title = c('Lady','Mlle','Ms','Mme','Sir','the Countess')
low_survive_title = c('Jonkheer','Capt','Don','Rev')
titanic$title[titanic$title %in% medium_survive_title] = 'medium' 
titanic$title[titanic$title %in% high_survive_title] = 'high'
titanic$title[titanic$title %in% low_survive_title] = 'low'

ggplot(titanic, aes(title))+
  geom_bar(aes(fill = as.character(Survived)))+
  labs(
    y = NULL,
    title = paste('Number of survivals of different title'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

分层柱状图的横轴代表不同title的组,纵轴代表人数,合并后title的组数明显减少,根据其中大部分分组可以看出变量title与最后的是否存活是相关的,但Master和medium这两组title对是否存活的区分并不明显。 下一步,对性别变量进行可视化。

ggplot(titanic)+
  geom_bar(aes(Sex, fill=as.character(Survived)), position = 'dodge')+
  labs(
    title = paste('Number of survivals of different sex'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

ggplot(titanic)+
  geom_bar(aes(Sex, fill = as.factor(Survived)), position = 'fill')+
  labs(
    y = NULL,
    title = paste('Survival rate of different sex'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

从两个柱状图中可以看出,男性乘客总人数比女性乘客多,但是男性的存活率远远低于女性,性别变量和待预测变量之间相关性显著。 下一步,对年龄变量进行可视化。

range(titanic$Age,na.rm=T)
## [1]  0.42 80.00
ggplot(titanic, aes(x = Age,y = ..density..,color = as.character(Survived)))+
  geom_freqpoly(binwidth=2)+
  labs(
    title = paste('Probability of survival of different age'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
## Warning: Removed 177 rows containing non-finite values (stat_bin).

年龄变量是一个连续性变量,范围从0.42到80之间不等,因此选用分布曲线图对其进行可视化。 分布曲线图的横坐标为年龄,纵坐标为分布密度,蓝色曲线代表获救的人群,红色曲线代表遇难的人群,可以看出,12岁以下的儿童获救的可能性更高,而20岁左右的青年遇难的可能性更高。因此,我们对年龄变量进行离散化处理,划分为(0,12)、(12,30)、(30,90)三个区间,用柱状图观察不同年龄组人群的存活情况。

#split age into 3 groups 
age_group = function(age){
  cut(age,
      breaks = c(0,12,30,90),
      labels = c('child','young','adult'))
}
titanic = titanic%>%
  mutate(age_group=age_group(as.numeric(Age)))
sum(is.na(titanic$age_group))
## [1] 177
ggplot(titanic)+
  geom_bar(aes(age_group,fill = as.character(Survived)), position='dodge')+
  labs(
    title = paste('Number of survivals of different age groups'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

ggplot(titanic)+
  geom_bar(aes(age_group,fill = as.character(Survived)), position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate of different age groups'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

由分层柱状图可以看出,青年的存活率远低于儿童,而其余成年人的存活率介于两者之间,没有明显的相关性。年龄变量的缺失值较多,在下文的分析中会对其进行填充。 下面对兄弟姐妹配偶数和父母子女数进行可视化,这两个变量存在一定的关联性,他们共同构成了家庭成员数。

range(titanic$SibSp)
## [1] 0 8
ggplot(titanic)+
  geom_bar(aes(as.character(SibSp),fill = as.character(Survived)),position='dodge')+
  labs(
    title = paste('Number of survivals with different number of sis , bro and spouse'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

ggplot(titanic)+
  geom_bar(aes(as.character(SibSp),fill = as.character(Survived)),position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate with different number of sis , bro and spouse'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

由两个分层柱状图对比可知,独自乘船没有兄弟姐妹配偶陪同的乘客占绝大部分,有1个兄弟姐妹配偶陪同的乘客存活的可能性最高,而兄弟姐妹配偶人数越多的乘客存活的可能性越低,兄弟姐妹配偶数为5至8人的乘客甚至无人生还。

range(titanic$Parch)
## [1] 0 6
ggplot(titanic,aes(as.character(Parch),fill = as.character(Survived)))+
  geom_bar(position='dodge')+
  labs(
    title = paste('Number of survivals with different number of parents and children'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

ggplot(titanic,aes(as.character(Parch),fill = as.character(Survived)))+
  geom_bar(position='fill')+
  labs(
    y=NULL,
    title = paste('Survival rate with different number of parents and children'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

由分层柱状图可以看出,有父母子女陪同的乘客存活的可能性同有兄弟姐妹配偶陪同的乘客有相似性,但父母子女陪同的人数越多乘客存活率越低这一趋势不那么明显。下面,我们将兄弟姐妹配偶数和父母子女树相加再加上乘客本身,得到的数字表示家庭规模,再重新画图分析。

titanic$family_size = titanic$SibSp + titanic$Parch + 1
ggplot(titanic,aes(as.factor(family_size),fill = as.factor(Survived)))+
  geom_bar(position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate with different family size'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

柱状图的横坐标为家庭规模,也即家庭人数,纵轴蓝色柱子为获救的比例,红色柱子为遇难的比例,可以看出,家庭规模很大(达到8至11人)的乘客全部遇难,家庭规模较大(5至7人)和单独一人的乘客获救比例较低,而2至4人同行的乘客获救比例较高。家庭规模与存活率有一定的相关关系。 下一步,对船票价格变量进行可视化。

range(titanic$Fare)
## [1]   0.0000 512.3292
ggplot(titanic,aes(x = Fare,y = ..density..,color = as.character(Survived)))+
  geom_freqpoly(binwidth=5)+
  labs(
    title=paste('Number of survivals of different age'))+
  xlim(0,200)+ #view details
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")
## Warning: Removed 20 rows containing non-finite values (stat_bin).
## Warning: Removed 4 rows containing missing values (geom_path).

船票价格变量是一个连续性变量,范围从0到512.33之间不等,因此选择分布曲线图对其进行可视化。 分布曲线图的横坐标为船票价格,纵坐标为分布密度,蓝色曲线代表获救的人群,红色曲线代表遇难的人群,可以看出,船票价格为15以下时,遇难的可能性较高,而船票价格高于50时,获救的可能性较高,因此,将船票价格分为(0,15)、(15,50)、(50,520)三个区间,用柱状图观察不同船票价格组人群的存活情况。

#split fare into 3 groups 
fare_group = function(fare){
  cut(fare,
      breaks = c(-1,15,50,520),
      labels = c('cheap','medium','expensive'))
}
titanic = titanic%>%
  mutate(fare_group=fare_group(Fare))
ggplot(titanic,aes(fare_group,fill = as.character(Survived)))+
  geom_bar(position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate with different family size'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

从上面分层柱状图可以看出,船票价格越高,乘客存活的可能性越高,船票价格与存活率之间相关性较高。 最后,对登船港口编号进行可视化。

ggplot(titanic,aes(as.character(Embarked),fill = as.character(Survived)))+
  geom_bar(position='fill')+
  labs(
    y = NULL,
    title = paste('Survival rate with different family size'))+
  theme(plot.title = element_text(hjust = 0.5), legend.position = "bottom")

登船港口这一变量有两个缺失值,可以对其进行缺失值填充,但从主观上来看,登船港口与乘客是否获救没有逻辑上的关系,因此,在构建模型时不加入登船港口这一变量。

titanic$Pclass=as.factor(titanic$Pclass)
titanic$Sex=as.factor(titanic$Sex)
titanic$fare_group=as.factor(titanic$fare_group)
titanic$title=as.factor(titanic$title)
titanic$age_group=as.factor(titanic$age_group)

填充缺失值

本文选择放弃登船港口和船舱号这两个变量,只需要对177个样本的年龄变量缺失值进行填充,这里,采用随机森林算法计算样本之间的接近度,以此来估算预测数据中缺失值并进行填充。

titanic=rfImpute(Survived ~ Pclass + Sex + fare_group + title + age_group + 
                   family_size, data=titanic, iter=3, ntree=500)
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
##      |      Out-of-bag   |
## Tree |      MSE  %Var(y) |
##  500 |   0.1245    52.64 |
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
##      |      Out-of-bag   |
## Tree |      MSE  %Var(y) |
##  500 |    0.124    52.42 |
## Warning in randomForest.default(xf, y, ntree = ntree, ..., do.trace =
## ntree, : The response has five or fewer unique values. Are you sure you
## want to do regression?
##      |      Out-of-bag   |
## Tree |      MSE  %Var(y) |
##  500 |   0.1237    52.31 |

模型建立与实证分析

本文利用上述变量,建立了随机森林模用于分类和预测。随机森林从决策树中发展演变而来,是一种非常适用于离散变量的分类和预测模型。具体建模过程如下。

#拆分数据集
train = titanic[1:600,]
test = titanic[601:891,]

#构建预测模型
titanic.rf1 = randomForest(Survived ~ Pclass + Sex + fare_group + title + 
                            age_group + family_size, data = train, 
                          ntree=100, importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf1)

titanic.rf1
## 
## Call:
##  randomForest(formula = Survived ~ Pclass + Sex + fare_group +      title + age_group + family_size, data = train, ntree = 100,      importance = TRUE, mtry = 2, proximity = T) 
##                Type of random forest: regression
##                      Number of trees: 100
## No. of variables tried at each split: 2
## 
##           Mean of squared residuals: 0.1343301
##                     % Var explained: 43.71
importance=importance(x=titanic.rf1)#值越大说明变量的重要性越强
importance
##               %IncMSE IncNodePurity
## Pclass      18.176312      9.360425
## Sex          8.890673     18.628440
## fare_group  13.307486      5.572320
## title       13.828542     29.116470
## age_group    9.055191      4.089925
## family_size 15.030640      9.090589
varImpPlot(titanic.rf1)

首先,指定树的数量为100,曲线图的横坐标为树的棵树,纵坐标为模型在训练集上的均方误差(MSE),可以看出,随着树的棵树增加,均方误差先迅速下降,随后逐渐变平缓并稳定在一定的值——0.1334左右。对模型的分类变量进行打分,发现年龄组(age_group)这一变量对纯度的贡献最小,因此尝试删除这一变量建立一个新的模型。

#构建预测模型
titanic.rf2 = randomForest(Survived ~ Pclass + Sex + fare_group + title + 
                         family_size, data = train, ntree=100,
                         importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf2)

titanic.rf1
## 
## Call:
##  randomForest(formula = Survived ~ Pclass + Sex + fare_group +      title + age_group + family_size, data = train, ntree = 100,      importance = TRUE, mtry = 2, proximity = T) 
##                Type of random forest: regression
##                      Number of trees: 100
## No. of variables tried at each split: 2
## 
##           Mean of squared residuals: 0.1343301
##                     % Var explained: 43.71
importance=importance(x=titanic.rf2)#值越大说明变量的重要性越强
importance
##               %IncMSE IncNodePurity
## Pclass      15.264003     10.259386
## Sex          8.642028     18.560983
## fare_group  13.547646      5.518051
## title       17.206129     31.043194
## family_size 16.137833      9.825409
varImpPlot(titanic.rf2)

去掉年龄分组这一变量后发现MSE依然为0.1334左右,模型的可解释性并没有显著提高,因此仍然保留第一个模型。 其次,我们通过用模型对测试集的数据进行预测,得到存活的概率,计算MSE值,寻找最佳拟合的树的棵树。

MSE=double()
for (i in 1:100){
  titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title + 
                            age_group + family_size, data = train, 
                          ntree=i, importance=TRUE, mtry = 2, proximity=T)
  prediction = predict(titanic.rf, test[2:7])
  MSE=c(MSE,sum((prediction-test[[1]])^2)/length(prediction))
}
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
MSE
##   [1] 0.1184807 0.1325057 0.1163700 0.1221974 0.1239209 0.1332497 0.1178497
##   [8] 0.1278534 0.1241177 0.1208822 0.1238956 0.1195575 0.1241976 0.1203918
##  [15] 0.1166526 0.1239402 0.1194311 0.1217139 0.1214666 0.1193377 0.1197277
##  [22] 0.1169179 0.1219785 0.1181112 0.1188443 0.1199948 0.1195515 0.1207829
##  [29] 0.1154690 0.1210636 0.1181157 0.1213097 0.1193304 0.1169185 0.1209694
##  [36] 0.1202204 0.1192203 0.1235869 0.1181122 0.1173937 0.1172977 0.1187657
##  [43] 0.1205281 0.1196954 0.1188193 0.1179140 0.1195792 0.1195855 0.1191853
##  [50] 0.1219210 0.1186126 0.1189868 0.1196729 0.1189038 0.1188475 0.1191219
##  [57] 0.1192790 0.1188679 0.1202945 0.1186181 0.1197412 0.1163795 0.1194676
##  [64] 0.1188208 0.1189766 0.1184523 0.1215369 0.1192862 0.1193127 0.1221132
##  [71] 0.1180586 0.1199215 0.1195164 0.1206050 0.1208982 0.1190916 0.1208144
##  [78] 0.1192770 0.1213136 0.1193227 0.1192813 0.1201426 0.1200201 0.1188056
##  [85] 0.1189111 0.1185196 0.1193564 0.1184725 0.1190194 0.1202242 0.1198914
##  [92] 0.1180638 0.1175627 0.1179119 0.1193513 0.1187116 0.1178236 0.1195129
##  [99] 0.1181008 0.1166771
plot(1:100, MSE, type='l', xlab='ntree', main='MSE on test set')

上示折线图横坐标为树的棵树,纵坐标为计算出来的模型预测结果在测试集上MSE数值,由图可以看出,模型的均方误差在树的棵树达到15时就已经趋于稳定,树的棵树为22时,测试集上的预测误差最小,因此,结合训练集上的MSE值,选择树的棵树为30时建立模型比较合理。 接下来,选择每棵树节点中用于二叉树的变量个数mtry。

MSE=double()
for (i in 1:5){
  titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title + 
                            age_group + family_size, data = train, 
                          ntree=100, importance=TRUE, mtry = i, proximity=T)
  prediction = predict(titanic.rf, test[2:7])
  MSE=c(MSE,sum((prediction-test[[1]])^2)/length(prediction))
}
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?

## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
MSE
## [1] 0.1275496 0.1192860 0.1220171 0.1230154 0.1245706
plot(1:5, MSE, type='l', xlab='mtry', main='MSE on test set')

上示折线图横坐标为每棵树节点中用于二叉树的变量个数,纵坐标为计算出来的测试集上MSE数值,由图可以看出,每棵树节点中用于二叉树的变量个数为2时,测试集上的预测误差最小,因此,选择每棵树节点中用于二叉树的变量个数为2时建立模型比较合理。 最后,选择树的棵树为30,每棵树节点中用于二叉树的变量个数为2,建立最终随机森林模型并进行预测。

titanic.rf = randomForest(Survived ~ Pclass + Sex + fare_group + title + 
                            age_group + family_size, data = train, 
                          ntree=30, importance=TRUE, mtry = 2, proximity=T)
## Warning in randomForest.default(m, y, ...): The response has five or fewer
## unique values. Are you sure you want to do regression?
plot(titanic.rf)

titanic.rf
## 
## Call:
##  randomForest(formula = Survived ~ Pclass + Sex + fare_group +      title + age_group + family_size, data = train, ntree = 30,      importance = TRUE, mtry = 2, proximity = T) 
##                Type of random forest: regression
##                      Number of trees: 30
## No. of variables tried at each split: 2
## 
##           Mean of squared residuals: 0.135006
##                     % Var explained: 43.42
prediction = predict(titanic.rf, test[2:7])
prediction
##        601        602        603        604        605        606 
## 0.72561573 0.12550731 0.24012987 0.15010325 0.24012987 0.08510214 
##        607        608        609        610        611        612 
## 0.12550731 0.36677737 0.72561573 0.98806807 0.55930078 0.12550731 
##        613        614        615        616        617        618 
## 0.73040038 0.12550731 0.15010325 0.85324018 0.05273942 0.69217388 
##        619        620        621        622        623        624 
## 0.89728099 0.15535733 0.11978069 0.20952459 0.12818724 0.12550731 
##        625        626        627        628        629        630 
## 0.17999744 0.24012987 0.06246657 0.98516964 0.12550731 0.12550731 
##        631        632        633        634        635        636 
## 0.24012987 0.15010325 0.36872179 0.24958208 0.22948575 0.88694656 
##        637        638        639        640        641        642 
## 0.15010325 0.06717356 0.43263411 0.07516471 0.12550731 0.98799151 
##        643        644        645        646        647        648 
## 0.22948575 0.29434988 0.77102073 0.20952459 0.12550731 0.36872179 
##        649        650        651        652        653        654 
## 0.12550731 0.58858436 0.12550731 0.96386813 0.12550731 0.58858436 
##        655        656        657        658        659        660 
## 0.58858436 0.08220039 0.12550731 0.63818882 0.15535733 0.39599733 
##        661        662        663        664        665        666 
## 0.33054206 0.15010325 0.24012987 0.15010325 0.11978069 0.16788304 
##        667        668        669        670        671        672 
## 0.15535733 0.12550731 0.15010325 0.99327009 0.95813908 0.20952459 
##        673        674        675        676        677        678 
## 0.16332821 0.16332821 0.15535733 0.12550731 0.12550731 0.58858436 
##        679        680        681        682        683        684 
## 0.55930078 0.20952459 0.58858436 0.18427797 0.12550731 0.05200790 
##        685        686        687        688        689        690 
## 0.06717356 0.03458462 0.05200790 0.12550731 0.12550731 0.98986661 
##        691        692        693        694        695        696 
## 0.20952459 0.48420426 0.29434988 0.12550731 0.36872179 0.16332821 
##        697        698        699        700        701        702 
## 0.15010325 0.58858436 0.39599733 0.15010325 0.95426055 0.24012987 
##        703        704        705        706        707        708 
## 0.38223786 0.12550731 0.11978069 0.13162343 0.97853420 0.24012987 
##        709        710        711        712        713        714 
## 0.98516964 0.75720379 0.82288682 0.24012987 0.20952459 0.12550731 
##        715        716        717        718        719        720 
## 0.16332821 0.12550731 0.98806807 0.88694656 0.17999744 0.15010325 
##        721        722        723        724        725        726 
## 0.92844552 0.11978069 0.16332821 0.16332821 0.26212573 0.12550731 
##        727        728        729        730        731        732 
## 0.72561573 0.58858436 0.04010310 0.38223786 0.98516964 0.29907363 
##        733        734        735        736        737        738 
## 0.15535733 0.15535733 0.15535733 0.17999744 0.44734000 0.13795074 
##        739        740        741        742        743        744 
## 0.12550731 0.12550731 0.24012987 0.20952459 0.95939979 0.07516471 
##        745        746        747        748        749        750 
## 0.15010325 0.29709705 0.12818724 0.88694656 0.26212573 0.15010325 
##        751        752        753        754        755        756 
## 0.94519766 0.59416123 0.15010325 0.12550731 0.86210759 0.72911956 
##        757        758        759        760        761        762 
## 0.12550731 0.15535733 0.15010325 0.93132204 0.12550731 0.15010325 
##        763        764        765        766        767        768 
## 0.12550731 0.88493676 0.12550731 0.99327009 0.36872179 0.35583739 
##        769        770        771        772        773        774 
## 0.07516471 0.15010325 0.12550731 0.15010325 0.97853420 0.12550731 
##        775        776        777        778        779        780 
## 0.88036130 0.12550731 0.12550731 0.54604513 0.12550731 0.99327009 
##        781        782        783        784        785        786 
## 0.58858436 0.95426055 0.36677737 0.06927101 0.12550731 0.12550731 
##        787        788        789        790        791        792 
## 0.58858436 0.14497916 0.75998157 0.13795074 0.12550731 0.10866631 
##        793        794        795        796        797        798 
## 0.24003096 0.24012987 0.12550731 0.16332821 0.50744634 0.48969497 
##        799        800        801        802        803        804 
## 0.12550731 0.68911471 0.16332821 0.95813908 0.76747391 0.59416123 
##        805        806        807        808        809        810 
## 0.12550731 0.15010325 0.15923074 0.58858436 0.16332821 0.99327009 
##        811        812        813        814        815        816 
## 0.12550731 0.19334344 0.16332821 0.31178734 0.15010325 0.24958208 
##        817        818        819        820        821        822 
## 0.58858436 0.06717356 0.15010325 0.14497916 0.97660342 0.12550731 
##        823        824        825        826        827        828 
## 0.12301262 0.40968654 0.14497916 0.12550731 0.29434988 0.92869692 
##        829        830        831        832        833        834 
## 0.12550731 0.99160342 0.40968654 0.92869692 0.12550731 0.12550731 
##        835        836        837        838        839        840 
## 0.12550731 0.98417918 0.12550731 0.12550731 0.35750755 0.24012987 
##        841        842        843        844        845        846 
## 0.12550731 0.15535733 0.75235091 0.15010325 0.12550731 0.15010325 
##        847        848        849        850        851        852 
## 0.03316014 0.15010325 0.04999025 0.99327009 0.19758859 0.15010325 
##        853        854        855        856        857        858 
## 0.72770807 0.81337111 0.97869463 0.40968654 0.97660342 0.24012987 
##        859        860        861        862        863        864 
## 0.72307607 0.12550731 0.05273942 0.11609362 0.86380834 0.24003096 
##        865        866        867        868        869        870 
## 0.15535733 0.97853420 0.91716301 0.13795074 0.12550731 0.61786494 
##        871        872        873        874        875        876 
## 0.12550731 0.97660342 0.15923074 0.15010325 0.89400959 0.58858436 
##        877        878        879        880        881        882 
## 0.12550731 0.12550731 0.12550731 0.99327009 0.89400959 0.15010325 
##        883        884        885        886        887        888 
## 0.58858436 0.15535733 0.12550731 0.43263411 0.08764166 0.75235091 
##        889        890        891 
## 0.77102073 0.36677737 0.15010325
MSE=sum((prediction-test[[1]])^2)/length(prediction)
MSE
## [1] 0.1225102

最终得到的模型在测试集上的MSE值为0.1172,预测结果大于0.5则认为则将该乘客分类为1,预测结果小于0.5则认为则将该乘客分类为0.

总结与展望

本文利用KAGGLE大数据竞赛平台上泰坦尼克号人员存活率预测竞赛课题,根据泰坦尼克号上乘客和船员的年龄、性别、身份等等个人信息来预测他在泰坦尼克号沉船事件中的存活概率,本文主要实现了数据的可视化,进行特征工程,包括用随机森林算法填充理缺失值、从姓名中提取title信息生成新变量,将年龄和船票价格进行离散化等等。最后,用经过预处理所得的数据建立随机森林模型来预测存活率,最终得到的预测结果为一个概率值,其越接近1表示该乘客获救的可能性越高,最终模型在测试集上的均方误差为0.1172。 均方误差为0.1172表示模型预测能力还可以进一步地提升,在后期的模型优化过程中可以进行如下工作:1)进一步的完善前期数据处理,可以提取生成一些区分性能更加的变量加入预测模型;2)尝试更多的模型,如GBDT、KNN、神经网络等等;3)进行模型融合,进一步提升预测能力。