9.1.3 常用算法

  1. CART(Classification and Regression Trees)

  2. C4.5(successor of ID3)

9.2 R中的实现

9.2.1 相关软件包

  1. rpart:建立分类树及相关递归划分算法的实现

  2. rpart.plot:专门对rpart模型绘制巨册书

  3. maptree:修剪、绘制包括rpart在内的树形结构

  4. RWeka:用于R与与Weka的链接,Weka中集合了用Java编写的一系列机器学习算法,从数据预处理,到分类、回归,再到聚类、关联分析,以及各种可视化。

9.2.3 数据集

  1. 数据集概况
rm(list = ls())
setwd("E:/Lang/R/Experiment/book_9/")
library("rpart")
data(car.test.frame)
head(car.test.frame)
##                  Price   Country Reliability Mileage  Type Weight Disp.
## Eagle Summit 4    8895       USA           4      33 Small   2560    97
## Ford Escort   4   7402       USA           2      33 Small   2345   114
## Ford Festiva 4    6319     Korea           4      37 Small   1845    81
## Honda Civic 4     6635 Japan/USA           5      32 Small   2260    91
## Mazda Protege 4   6599     Japan           5      32 Small   2440   113
## Mercury Tracer 4  8672    Mexico           4      26 Small   2285    97
##                   HP
## Eagle Summit 4   113
## Ford Escort   4   90
## Ford Festiva 4    63
## Honda Civic 4     92
## Mazda Protege 4  103
## Mercury Tracer 4  82
cardata <- car.test.frame
cardata$Mileage <- 100 * 4.546 / (1.6 * cardata$Mileage) # transform Mileage to oil consuption
str(cardata) # observe the dataset
## 'data.frame':    60 obs. of  8 variables:
##  $ Price      : int  8895 7402 6319 6635 6599 8672 7399 7254 9599 5866 ...
##  $ Country    : Factor w/ 8 levels "France","Germany",..: 8 8 5 4 3 6 4 5 3 3 ...
##  $ Reliability: int  4 2 4 5 5 4 5 1 5 NA ...
##  $ Mileage    : num  8.61 8.61 7.68 8.88 8.88 ...
##  $ Type       : Factor w/ 6 levels "Compact","Large",..: 4 4 4 4 4 4 4 4 4 4 ...
##  $ Weight     : int  2560 2345 1845 2260 2440 2285 2275 2350 2295 1900 ...
##  $ Disp.      : int  97 114 81 91 113 97 97 98 109 73 ...
##  $ HP         : int  113 90 63 92 103 82 90 74 90 73 ...
summary(cardata)
##      Price            Country    Reliability       Mileage      
##  Min.   : 5866   USA      :26   Min.   :1.000   Min.   : 7.679  
##  1st Qu.: 9932   Japan    :19   1st Qu.:2.000   1st Qu.:10.523  
##  Median :12216   Japan/USA: 7   Median :3.000   Median :12.353  
##  Mean   :12616   Korea    : 3   Mean   :3.388   Mean   :11.962  
##  3rd Qu.:14933   Germany  : 2   3rd Qu.:5.000   3rd Qu.:13.530  
##  Max.   :24760   France   : 1   Max.   :5.000   Max.   :15.785  
##                  (Other)  : 2   NA's   :11                      
##       Type        Weight         Disp.             HP       
##  Compact:15   Min.   :1845   Min.   : 73.0   Min.   : 63.0  
##  Large  : 3   1st Qu.:2571   1st Qu.:113.8   1st Qu.:101.5  
##  Medium :13   Median :2885   Median :144.5   Median :111.5  
##  Small  :13   Mean   :2901   Mean   :152.1   Mean   :122.3  
##  Sporty : 9   3rd Qu.:3231   3rd Qu.:180.0   3rd Qu.:142.8  
##  Van    : 7   Max.   :3855   Max.   :305.0   Max.   :225.0  
## 
  1. 数据预处理
# cardata$Group_Mileage <- cardata$Mileage; # add a variable of Group_Mileage
# with(cardata,
#      {
#        Group_Mileage[which(Mileage >= 11.6)] <<- "A"
#        Group_Mileage[which(Mileage <= 9)] <<- "C"
#        Group_Mileage[which(! Group_Mileage %in% c("A", "C"))] <<- "B"
#        print(Group_Mileage)
#      }
# )
# names(cardata)
# with(cardata,
#      {
#        Group_Mileage <<- "A"
#      }
# )
# head(cardata$Group_Mileage, 5)

car <- cardata
Group_Mileage <- cardata$Mileage
Group_Mileage[cardata$Mileage >= 11.6] <- "A"
Group_Mileage[cardata$Mileage <= 9] <- "C"
Group_Mileage[! Group_Mileage %in% c("A", "C")] <- "B"
cardata$Group_Mileage <- Group_Mileage
cardata[1:5, c(4, 9)] # check the first 5 lines
##                  Mileage Group_Mileage
## Eagle Summit 4  8.609848             C
## Ford Escort   4 8.609848             C
## Ford Festiva 4  7.679054             C
## Honda Civic 4   8.878906             C
## Mazda Protege 4 8.878906             C

使用sampling包中的strata进行分层抽样,训练集:测试集 = 3:1

strata在分层抽样之前一定要依照分类变量排序

library("sampling")
set.seed(1) # set the random seed
# with(cardata,
#      {
#        a <<- round(1/4 * sum(Group_Mileage == "A"))
#        b <<- round(1/4 * sum(Group_Mileage == "B"))
#        c <<- round(1/4 * sum(Group_Mileage == "C"))
#      })
a <- round(1/4 * sum(cardata$Group_Mileage == "A"))
b <- round(1/4 * sum(cardata$Group_Mileage == "B"))
c <- round(1/4 * sum(cardata$Group_Mileage == "C"))
cardata <- cardata[order(cardata$Group_Mileage), ] # sort the dataset before stratified sampling
sub <- strata(cardata, stratanames = "Group_Mileage", size = c(a, b, c), method = "srswor")
head(sub)
##    Group_Mileage ID_unit      Prob Stratum
## 7              A       7 0.2571429       1
## 10             A      10 0.2571429       1
## 13             A      13 0.2571429       1
## 17             A      17 0.2571429       1
## 19             A      19 0.2571429       1
## 27             A      27 0.2571429       1
Train_Car <- cardata[-sub$ID_unit, ] # generate a train set
Test_Car <- cardata[sub$ID_unit, ] # generate a test set
nrow(Train_Car)
## [1] 45
nrow(Test_Car)
## [1] 15

9.3 应用案例

9.3.1

1. 对“Mileage”变量建立回归树——数字结果

formula_Car_Reg <- Mileage ~ Price + Country + Reliability + Type + Weight + Disp. + HP # build a decision-tree by all variable except 'Group_Mileage', and choose regression tree
rp_Car_Reg <- rpart(formula_Car_Reg, Train_Car, method = "anova")
print(rp_Car_Reg) # print the message of regression tree
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 45 195.048600 11.836160  
##   2) Disp.< 134 18  26.181890  9.802074 *
##   3) Disp.>=134 27  44.741810 13.192210  
##     6) Weight< 3125 13   9.319054 12.371970 *
##     7) Weight>=3125 14  18.554700 13.953870 *
printcp(rp_Car_Reg) # 
## 
## Regression tree:
## rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")
## 
## Variables actually used in tree construction:
## [1] Disp.  Weight
## 
## Root node error: 195.05/45 = 4.3344
## 
## n= 45 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.636379      0   1.00000 1.04401 0.179401
## 2 0.086481      1   0.36362 0.42014 0.071784
## 3 0.010000      2   0.27714 0.48651 0.091936
summary(rp_Car_Reg) # more message will be print by using summary
## Call:
## rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")
##   n= 45 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.63637926      0 1.0000000 1.0440063 0.17940136
## 2 0.08648134      1 0.3636207 0.4201402 0.07178409
## 3 0.01000000      2 0.2771394 0.4865134 0.09193584
## 
## Variable importance
##   Disp.  Weight      HP   Price    Type Country 
##      28      22      15      13      13       9 
## 
## Node number 1: 45 observations,    complexity param=0.6363793
##   mean=11.83616, MSE=4.334412 
##   left son=2 (18 obs) right son=3 (27 obs)
##   Primary splits:
##       Disp.  < 134     to the left,  improve=0.6363793, (0 missing)
##       Weight < 2747.5  to the left,  improve=0.5394177, (0 missing)
##       Price  < 9446.5  to the left,  improve=0.5039169, (0 missing)
##       Type   splits as  LRRLLR,      improve=0.4222991, (0 missing)
##       HP     < 104.5   to the left,  improve=0.3429270, (0 missing)
##   Surrogate splits:
##       Weight  < 2747.5  to the left,  agree=0.889, adj=0.722, (0 split)
##       HP      < 109     to the left,  agree=0.800, adj=0.500, (0 split)
##       Price   < 9446.5  to the left,  agree=0.778, adj=0.444, (0 split)
##       Type    splits as  RRRLRR,      agree=0.778, adj=0.444, (0 split)
##       Country splits as  -LRLL-RR,    agree=0.733, adj=0.333, (0 split)
## 
## Node number 2: 18 observations
##   mean=9.802074, MSE=1.454549 
## 
## Node number 3: 27 observations,    complexity param=0.08648134
##   mean=13.19221, MSE=1.657104 
##   left son=6 (13 obs) right son=7 (14 obs)
##   Primary splits:
##       Weight < 3125    to the left,  improve=0.37700890, (0 missing)
##       Price  < 11522   to the left,  improve=0.36529280, (0 missing)
##       Type   splits as  LLR-LR,      improve=0.23012760, (0 missing)
##       HP     < 143     to the left,  improve=0.08723663, (0 missing)
##       Disp.  < 181.5   to the left,  improve=0.08014554, (0 missing)
##   Surrogate splits:
##       Disp.   < 166.5   to the left,  agree=0.852, adj=0.692, (0 split)
##       Price   < 13172.5 to the left,  agree=0.815, adj=0.615, (0 split)
##       Type    splits as  LRR-RR,      agree=0.778, adj=0.538, (0 split)
##       HP      < 132.5   to the left,  agree=0.778, adj=0.538, (0 split)
##       Country splits as  --RLL-LR,    agree=0.630, adj=0.231, (0 split)
## 
## Node number 6: 13 observations
##   mean=12.37197, MSE=0.7168503 
## 
## Node number 7: 14 observations
##   mean=13.95387, MSE=1.325336

改变rpart参数,探究它的使用

  1. minsplit 20->10,每个节点所包含的最小样本数
rp_Car_Reg1 <- rpart(formula_Car_Reg, Train_Car, method = "anova", minsplit = 10)
print(rp_Car_Reg1)
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 45 195.0486000 11.836160  
##    2) Disp.< 134 18  26.1818900  9.802074  
##      4) Price< 9702.5 8   3.5262410  8.657817 *
##      5) Price>=9702.5 10   3.8013750 10.717480 *
##    3) Disp.>=134 27  44.7418100 13.192210  
##      6) Type=Compact,Large,Medium,Sporty 22  19.7672100 12.750280  
##       12) Price< 11522 7   3.0038350 11.877100 *
##       13) Price>=11522 15   8.9357120 13.157760  
##         26) Price>=12329.5 12   3.6094660 12.889710  
##           52) Type=Compact,Large,Sporty 3   0.1766239 12.181690 *
##           53) Type=Medium 9   1.4276880 13.125710 *
##         27) Price< 12329.5 3   1.0149970 14.229990 *
##      7) Type=Van 5   1.7724000 15.136720 *
printcp(rp_Car_Reg1)
## 
## Regression tree:
## rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova", 
##     minsplit = 10)
## 
## Variables actually used in tree construction:
## [1] Disp. Price Type 
## 
## Root node error: 195.05/45 = 4.3344
## 
## n= 45 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.636379      0  1.000000 1.03964 0.177064
## 2 0.118956      1  0.363621 0.48156 0.093875
## 3 0.096665      2  0.244665 0.44479 0.079175
## 4 0.040132      3  0.148000 0.31636 0.080004
## 5 0.022103      4  0.107868 0.28433 0.086739
## 6 0.010280      5  0.085765 0.37347 0.128488
## 7 0.010000      6  0.075485 0.36521 0.128047
  1. cp 0.01->0.1,复杂度参数(complexity parameter),在建模过程中仅仅保留能使得模型拟合程度上升cp及以上的节点,该参数的作用是减去对模型贡献不大的分支,提高算法效率

建立树模型要权衡两方面问题,一个是要拟合得使分组后的变异较小,另一个是要防止过度拟合,而使模型的误差过大,前者的参数是CP,后者的参数是Xerror。所以要在Xerror最小的情况下,也使CP尽量小。如果认为树模型过于复杂,我们需要对其进行修剪 。(摘自推酷上的《分类-回归树模型(CART)在R语言中的实现》)

rp_Car_Reg2 <- rpart(formula_Car_Reg, Train_Car, method = "anova", cp = 0.1)
print(rp_Car_Reg2)
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 45 195.04860 11.836160  
##   2) Disp.< 134 18  26.18189  9.802074 *
##   3) Disp.>=134 27  44.74181 13.192210 *
printcp(rp_Car_Reg2)
## 
## Regression tree:
## rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova", 
##     cp = 0.1)
## 
## Variables actually used in tree construction:
## [1] Disp.
## 
## Root node error: 195.05/45 = 4.3344
## 
## n= 45 
## 
##        CP nsplit rel error  xerror     xstd
## 1 0.63638      0   1.00000 1.06217 0.178321
## 2 0.10000      1   0.36362 0.44258 0.075434
  1. 同时,我们可以通过剪枝函数“prune.rpart”进行剪枝
rp_Car_Reg3 <- prune.rpart(rp_Car_Reg, cp = 0.1)
print(rp_Car_Reg3)
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 45 195.04860 11.836160  
##   2) Disp.< 134 18  26.18189  9.802074 *
##   3) Disp.>=134 27  44.74181 13.192210 *
printcp(rp_Car_Reg3)
## 
## Regression tree:
## rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")
## 
## Variables actually used in tree construction:
## [1] Disp.
## 
## Root node error: 195.05/45 = 4.3344
## 
## n= 45 
## 
##        CP nsplit rel error  xerror     xstd
## 1 0.63638      0   1.00000 1.04401 0.179401
## 2 0.10000      1   0.36362 0.42014 0.071784
  1. maxdepth可以调节树的高度,根节点为0
rp_Car_Reg5 <- rpart(formula_Car_Reg, Train_Car, method = "anova", maxdepth = 1)
print(rp_Car_Reg5)
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 45 195.04860 11.836160  
##   2) Disp.< 134 18  26.18189  9.802074 *
##   3) Disp.>=134 27  44.74181 13.192210 *

2. 对“Mileage”变量建立回归树——树形结果

1> rpart.plot
rp_Car_Reg6 <- rpart(formula_Car_Reg, Train_Car, method = "anova", minsplit = 10)
print(rp_Car_Reg6)
## n= 45 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 45 195.0486000 11.836160  
##    2) Disp.< 134 18  26.1818900  9.802074  
##      4) Price< 9702.5 8   3.5262410  8.657817 *
##      5) Price>=9702.5 10   3.8013750 10.717480 *
##    3) Disp.>=134 27  44.7418100 13.192210  
##      6) Type=Compact,Large,Medium,Sporty 22  19.7672100 12.750280  
##       12) Price< 11522 7   3.0038350 11.877100 *
##       13) Price>=11522 15   8.9357120 13.157760  
##         26) Price>=12329.5 12   3.6094660 12.889710  
##           52) Type=Compact,Large,Sporty 3   0.1766239 12.181690 *
##           53) Type=Medium 9   1.4276880 13.125710 *
##         27) Price< 12329.5 3   1.0149970 14.229990 *
##      7) Type=Van 5   1.7724000 15.136720 *
# install.packages("rpart.plot")
library("rpart.plot")
rpart.plot(rp_Car_Reg6)

当Disp. >= 134时Type != Compact,Large,Medium,Sporty[Type == Van]时Mileage为15

opar <- par(no.readonly = T)
par(mar = rep(1, 4))
rpart.plot(rp_Car_Reg6, type = 4) # draw a decision-tree

rpart.plot(rp_Car_Reg6, type = 3) # draw a decision-tree

rpart.plot(rp_Car_Reg6, type = 2) # draw a decision-tree

rpart.plot(rp_Car_Reg6, type = 1) # draw a decision-tree

rpart.plot(rp_Car_Reg6, type = 4, branch = 1)

rpart.plot(rp_Car_Reg6, type = 4, fallen.leaves = T)

2> maptree
# install.packages("maptree")
library("maptree")
## Loading required package: cluster
jpeg(filename = "draw_tree.jpg" )
draw.tree(rp_Car_Reg6, col = rep(1, 7), nodeinfo = T)
dev.off()
## png 
##   2
3> plot
plot(rp_Car_Reg6, uniform = T, main = "plot:Regression Tree" )
text(rp_Car_Reg6, use.n = T, all = T)

4> post
jpeg(filename = "post.jpg")
post(rp_Car_Reg6, file = "")
dev.off
## function (which = dev.cur()) 
## {
##     if (which == 1) 
##         stop("cannot shut down device 1 (the null device)")
##     .External(C_devoff, as.integer(which))
##     dev.cur()
## }
## <bytecode: 0x027c0884>
## <environment: namespace:grDevices>

3. 对“Group_Mileage”变量建立分类树

method根据树末端的数据类型选择相应变量分割方法,本参数有四种取值:连续型“anova”;离散型“class”;计数型(泊松过程)“poisson”;生存分析型“exp”。程序会根据因变量的类型自动选择方法,但一般情况下最好还是指明本参数,以便让程序清楚做哪一种树模型。

formula_Car_Cla <- Group_Mileage ~ Price + Country + Reliability + Type + Weight + Disp. + HP
rp_Car_Cla <- rpart(formula_Car_Cla, Train_Car, method = "class", minsplit = 5) # 节点包含的最小样本数
print(rp_Car_Cla)
## n= 45 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 45 19 A (0.57777778 0.26666667 0.15555556)  
##    2) Disp.>=134 27  2 A (0.92592593 0.07407407 0.00000000)  
##      4) Price>=11222 21  0 A (1.00000000 0.00000000 0.00000000) *
##      5) Price< 11222 6  2 A (0.66666667 0.33333333 0.00000000)  
##       10) Disp.< 152 4  0 A (1.00000000 0.00000000 0.00000000) *
##       11) Disp.>=152 2  0 B (0.00000000 1.00000000 0.00000000) *
##    3) Disp.< 134 18  8 B (0.05555556 0.55555556 0.38888889)  
##      6) Price>=9702.5 10  1 B (0.10000000 0.90000000 0.00000000) *
##      7) Price< 9702.5 8  1 C (0.00000000 0.12500000 0.87500000) *
rpart.plot(rp_Car_Cla, type = 4, fallen.leaves = T)

4. 利用测试集预测,目标变量

pre_Car_Cla <- predict(rp_Car_Cla, Test_Car, type = "class")
pre_Car_Cla # 返回值为一个向量,行名和cardata相同
##        Mitsubishi Sigma V6              Peugeot 405 4 
##                          A                          B 
##            Acura Legend V6           Eagle Premier V6 
##                          A                          A 
##        Ford Thunderbird V6       Chevrolet Caprice V8 
##                          A                          A 
## Ford LTD Crown Victoria V8     Dodge Grand Caravan V6 
##                          A                          A 
##         Mitsubishi Wagon 4           Mercury Tracer 4 
##                          A                          C 
##            Subaru Loyale 4           Toyota Corolla 4 
##                          C                          C 
##             Plymouth Laser              Honda Civic 4 
##                          B                          C 
##             Subaru Justy 3 
##                          C 
## Levels: A B C
p <- sum(as.numeric(pre_Car_Cla != Test_Car$Group_Mileage))/nrow(Test_Car) # 计算错误率
p
## [1] 0.2666667
table(Test_Car$Group_Mileage)
## 
## A B C 
## 9 4 2
table(pre_Car_Cla)
## pre_Car_Cla
## A B C 
## 8 2 5
table(Test_Car$Group_Mileage, pre_Car_Cla)
##    pre_Car_Cla
##     A B C
##   A 8 1 0
##   B 0 1 3
##   C 0 0 2

混淆矩阵:不在对角线的为错误的地方,行为Test_Train$Group_Mileage,列为pre_Car_Cla,行相加或列相加可得到单个的频数统计
混淆矩阵如何建立?

9.3.2 C4.5的应用

C4.5仅仅适用于离散变量,即构建分类树

#install.packages("RWeka")
library("RWeka")
Train_Car$Group_Mileage <- as.factor(Train_Car$Group_Mileage)
formula <- Group_Mileage ~ Price + Country + Reliability + Type + Weight + Disp. + HP
C45_0 <- J48(formula, Train_Car)
print(C45_0)
## J48 pruned tree
## ------------------
## 
## Price <= 9410: C (8.0/1.0)
## Price > 9410
## |   Disp. <= 133: B (8.0/1.0)
## |   Disp. > 133: A (22.0/2.0)
## 
## Number of Leaves  :  3
## 
## Size of the tree :   5
summary(C45_0)
## 
## === Summary ===
## 
## Correctly Classified Instances          34               89.4737 %
## Incorrectly Classified Instances         4               10.5263 %
## Kappa statistic                          0.8203
## Mean absolute error                      0.1252
## Root mean squared error                  0.2502
## Relative absolute error                 31.4614 %
## Root relative squared error             56.3318 %
## Total Number of Instances               38     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  20  1  0 |  a = A
##   2  7  1 |  b = B
##   0  0  7 |  c = C
plot(C45_0)

control参数设置

C45_1 <- J48(formula, Train_Car, control = Weka_control(M = 3)) # 每个叶子节点的最小观测样本
print(C45_0)
## J48 pruned tree
## ------------------
## 
## Price <= 9410: C (8.0/1.0)
## Price > 9410
## |   Disp. <= 133: B (8.0/1.0)
## |   Disp. > 133: A (22.0/2.0)
## 
## Number of Leaves  :  3
## 
## Size of the tree :   5
summary(C45_0)
## 
## === Summary ===
## 
## Correctly Classified Instances          34               89.4737 %
## Incorrectly Classified Instances         4               10.5263 %
## Kappa statistic                          0.8203
## Mean absolute error                      0.1252
## Root mean squared error                  0.2502
## Relative absolute error                 31.4614 %
## Root relative squared error             56.3318 %
## Total Number of Instances               38     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  20  1  0 |  a = A
##   2  7  1 |  b = B
##   0  0  7 |  c = C
plot(C45_0)