数据导入和预处理

特别说明:j48()函数的使用需要先安装Java软件

首先导入car.test.frame数据集,为了增加可读性,然后将其变量名转换成中文。

数据读取

library(tidyverse)
library(rpart)
data(car.test.frame)

car<-tibble(car.test.frame)
car
## # A tibble: 60 × 8
##    Price Country   Reliability Mileage Type  Weight Disp.    HP
##    <int> <fct>           <int>   <int> <fct>  <int> <int> <int>
##  1  8895 USA                 4      33 Small   2560    97   113
##  2  7402 USA                 2      33 Small   2345   114    90
##  3  6319 Korea               4      37 Small   1845    81    63
##  4  6635 Japan/USA           5      32 Small   2260    91    92
##  5  6599 Japan               5      32 Small   2440   113   103
##  6  8672 Mexico              4      26 Small   2285    97    82
##  7  7399 Japan/USA           5      33 Small   2275    97    90
##  8  7254 Korea               1      28 Small   2350    98    74
##  9  9599 Japan               5      25 Small   2295   109    90
## 10  5866 Japan              NA      34 Small   1900    73    73
## # ℹ 50 more rows
car$Mileage <- 100*4.546/(1.6*car$Mileage)
names(car) <- c("价格","产地","可靠性","油耗","类型","车重",
                        "发动机功率","净马力")
tibble(car)
## # A tibble: 60 × 8
##     价格 产地      可靠性  油耗 类型   车重 发动机功率 净马力
##    <int> <fct>      <int> <dbl> <fct> <int>      <int>  <int>
##  1  8895 USA            4  8.61 Small  2560         97    113
##  2  7402 USA            2  8.61 Small  2345        114     90
##  3  6319 Korea          4  7.68 Small  1845         81     63
##  4  6635 Japan/USA      5  8.88 Small  2260         91     92
##  5  6599 Japan          5  8.88 Small  2440        113    103
##  6  8672 Mexico         4 10.9  Small  2285         97     82
##  7  7399 Japan/USA      5  8.61 Small  2275         97     90
##  8  7254 Korea          1 10.1  Small  2350         98     74
##  9  9599 Japan          5 11.4  Small  2295        109     90
## 10  5866 Japan         NA  8.36 Small  1900         73     73
## # ℹ 50 more rows
summary(car) #获取数据集的概要信息
##       价格              产地        可靠性           油耗             类型   
##  Min.   : 5866   USA      :26   Min.   :1.000   Min.   : 7.679   Compact:15  
##  1st Qu.: 9932   Japan    :19   1st Qu.:2.000   1st Qu.:10.523   Large  : 3  
##  Median :12216   Japan/USA: 7   Median :3.000   Median :12.353   Medium :13  
##  Mean   :12616   Korea    : 3   Mean   :3.388   Mean   :11.962   Small  :13  
##  3rd Qu.:14933   Germany  : 2   3rd Qu.:5.000   3rd Qu.:13.530   Sporty : 9  
##  Max.   :24760   France   : 1   Max.   :5.000   Max.   :15.785   Van    : 7  
##                  (Other)  : 2   NA's   :11                                   
##       车重        发动机功率        净马力     
##  Min.   :1845   Min.   : 73.0   Min.   : 63.0  
##  1st Qu.:2571   1st Qu.:113.8   1st Qu.:101.5  
##  Median :2885   Median :144.5   Median :111.5  
##  Mean   :2901   Mean   :152.1   Mean   :122.3  
##  3rd Qu.:3231   3rd Qu.:180.0   3rd Qu.:142.8  
##  Max.   :3855   Max.   :305.0   Max.   :225.0  
## 

将目标变量转换成分类型变量

#传统方式利用现有变量生成新的变量
#Group_Mileage <- matrix(0,60,1) #设矩阵Group_Mileage用来存放新变量
#Group_Mileage[which(car$"油耗">=11.6)] <- "A"
#Group_Mileage[which(car$"油耗"<=9)] <- "C"
#Group_Mileage[which(Group_Mileage==0)] <- "B"

#car$"分组油耗" <- Group_Mileage
#car[1:10,c(4,9)]
#tibble(car)
#利用mutate()函数生成新的变量
library(dplyr)

car <- car %>%
  mutate(分组油耗 = case_when(
    油耗 >= 11.6 ~ "A",
    油耗 <= 9 ~ "C",
    TRUE ~ "B"  # 其他情况赋值为 "B"
  ))

# 查看修改后的前10行数据的指定列
print(car[1:10, c(4, 9)])  # 油耗在第四列,分组油耗是第九列
## # A tibble: 10 × 2
##     油耗 分组油耗
##    <dbl> <chr>   
##  1  8.61 C       
##  2  8.61 C       
##  3  7.68 C       
##  4  8.88 C       
##  5  8.88 C       
##  6 10.9  B       
##  7  8.61 C       
##  8 10.1  B       
##  9 11.4  B       
## 10  8.36 C
# 查看整个数据框
tibble(car)
## # A tibble: 60 × 9
##     价格 产地      可靠性  油耗 类型   车重 发动机功率 净马力 分组油耗
##    <int> <fct>      <int> <dbl> <fct> <int>      <int>  <int> <chr>   
##  1  8895 USA            4  8.61 Small  2560         97    113 C       
##  2  7402 USA            2  8.61 Small  2345        114     90 C       
##  3  6319 Korea          4  7.68 Small  1845         81     63 C       
##  4  6635 Japan/USA      5  8.88 Small  2260         91     92 C       
##  5  6599 Japan          5  8.88 Small  2440        113    103 C       
##  6  8672 Mexico         4 10.9  Small  2285         97     82 B       
##  7  7399 Japan/USA      5  8.61 Small  2275         97     90 C       
##  8  7254 Korea          1 10.1  Small  2350         98     74 B       
##  9  9599 Japan          5 11.4  Small  2295        109     90 B       
## 10  5866 Japan         NA  8.36 Small  1900         73     73 C       
## # ℹ 50 more rows

生成训练样本集和测试样本集

library(sampling)
a <- round(1/4*sum(car$"分组油耗"=="A"))
b <- round(1/4*sum(car$"分组油耗"=="B"))
c <- round(1/4*sum(car$"分组油耗"=="C"))

a;b;c
## [1] 9
## [1] 4
## [1] 2
#使用strata函数对car.test.frame中的“分组油耗”变量进行分层抽样
set.seed(12345)
sub <- car %>% 
      strata(stratanames="分组油耗",size=c(c,b,a),method="srswor")

tibble(sub)   #输出抽样信息
## # A tibble: 15 × 4
##    分组油耗 ID_unit  Prob Stratum
##    <chr>      <int> <dbl>   <int>
##  1 C              3 0.222       1
##  2 C             12 0.222       1
##  3 B             19 0.25        2
##  4 B             22 0.25        2
##  5 B             25 0.25        2
##  6 B             36 0.25        2
##  7 A             16 0.257       3
##  8 A             29 0.257       3
##  9 A             31 0.257       3
## 10 A             35 0.257       3
## 11 A             47 0.257       3
## 12 A             49 0.257       3
## 13 A             54 0.257       3
## 14 A             57 0.257       3
## 15 A             59 0.257       3
Train_Car <- car[-sub$ID_unit,]   #生成训练样本集
Test_Car <- car[sub$ID_unit,]           #生成测试样本集
nrow(Train_Car);nrow(Test_Car)      #显示训练集和测试集行数,检查比例是否为3:1
## [1] 45
## [1] 15

注:J48()函数对中文识别不太完善,所以我们还原变量名称为英文

names(Train_Car)=c("Price","Country","Reliability","Mileage",
"Type","Weight","Disp.","HP","Oil_Consumption") 
Train_Car$Oil_Consumption<-as.factor(Train_Car$Oil_Consumption)

names(Test_Car)=c("Price","Country","Reliability","Mileage",
"Type","Weight","Disp.","HP","Oil_Consumption") 
Test_Car$Oil_Consumption<-as.factor(Test_Car$Oil_Consumption)

安装RWeka和rJava软件包

# install.packages("RWeka")
library(RWeka)
#注明:先安装java软件
# install.packages("rJava")
library(rJava)

通过j48()函数实现C4.5算法

formula=Oil_Consumption~Price+Country+Reliability+Type+Weight+Disp.+HP
C45_0=J48(formula,Train_Car)
C45_0
## J48 pruned tree
## ------------------
## 
## Disp. <= 133
## |   Price <= 9483: C (9.0/3.0)
## |   Price > 9483: B (5.0)
## Disp. > 133: A (24.0/2.0)
## 
## Number of Leaves  :  3
## 
## Size of the tree :   5
summary(C45_0)
## 
## === Summary ===
## 
## Correctly Classified Instances          33               86.8421 %
## Incorrectly Classified Instances         5               13.1579 %
## Kappa statistic                          0.766 
## Mean absolute error                      0.1345
## Root mean squared error                  0.2593
## Relative absolute error                 34.9259 %
## Root relative squared error             59.4341 %
## Total Number of Instances               38     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  22  0  0 |  a = A
##   2  5  3 |  b = B
##   0  0  6 |  c = C
C45_1=J48(formula,Train_Car,control=Weka_control(M=3))
C45_1
## J48 pruned tree
## ------------------
## 
## Disp. <= 133
## |   Price <= 9483: C (9.0/3.0)
## |   Price > 9483: B (5.0)
## Disp. > 133: A (24.0/2.0)
## 
## Number of Leaves  :  3
## 
## Size of the tree :   5
summary(C45_1)
## 
## === Summary ===
## 
## Correctly Classified Instances          33               86.8421 %
## Incorrectly Classified Instances         5               13.1579 %
## Kappa statistic                          0.766 
## Mean absolute error                      0.1345
## Root mean squared error                  0.2593
## Relative absolute error                 34.9259 %
## Root relative squared error             59.4341 %
## Total Number of Instances               38     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  22  0  0 |  a = A
##   2  5  3 |  b = B
##   0  0  6 |  c = C

决策树可视化:需要安装partykit和grid包

# install.packages("partykit")
# install.packages("grid")
library(partykit)
library(grid)
plot(C45_0,main="C45_0")

plot(C45_1,main="C45_1")

预测

#预测
pre1=predict(C45_1,Test_Car,type="class") 
pre1
##  [1] C C B B B B A B A A A A A A A
## Levels: A B C
ConfM1=table(Test_Car$Oil_Consumption,pre1) 
ConfM1         #输出混淆矩阵
##    pre1
##     A B C
##   A 8 1 0
##   B 0 4 0
##   C 0 0 2
(p=sum(as.numeric(pre1!=Test_Car$Oil_Consumption))/nrow(Test_Car))
## [1] 0.06666667
#计算分类错误率

(E1<-(sum(ConfM1)-sum(diag(ConfM1)))/sum(ConfM1)) #计算分类错误率
## [1] 0.06666667

模型评价

二分类模型的混淆矩阵:

从混淆矩阵计算的典型指标有:

\[P_{e}=\frac{<a,b>}{n^2},\] \(n\)为总样本数,a和b分别定义如下: 设有C类样本,每类真实样本数构成的向量为 \[a=(a_{1},a_{2},...,a_{c}),\] 每类预测样本数构成的向量为 \[b=(b_{1},b_{2},...,b_{c}),\]

评价指标:

1、准确率accuracy:正确预测的平均数 \[accuracy = \frac{1}{N}\sum_{k=1}^{|C|}\sum_{x:g(x)=k}I\{g(x)=\hat{g}(x)\}\] 其中\(N\)为样本数,\(|C|\)为总的类别数,\(g(x)\)\(\hat{g}(x)\)分别为目标变量的真实值和预测值。

2、加权准确率 \[weight-accuracy = \sum_{k=1}^{|C|}w_{k}\sum_{x:g(x)=k}I\{g(x)=\hat{g}(x)\}\] \[w_{k} = \frac {1} {|C_{k}|}, \forall k \in \{1,\dots C\},\] 其中\(|C_{k}|\)表示第\(k\)类样本数目。

# 加载MASS包
library(MASS)
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
# 定义一个计算简单准确率的函数
calculate.accuracy <- function(predictions, true.labels) {
    # 计算预测正确的样本数,并除以总样本数以得到准确率
    return(length(which(predictions == true.labels)) / length(true.labels))
}

# 定义一个计算加权准确率的函数
calculate.w.accuracy <- function(predictions, true.labels, weights) {
    # 获取所有的类别
    C <- levels(true.labels)
    
    # 检查权重向量维度是否与类别向量维度一致
    if (length(weights) != length(C)) {
        stop("Number of weights should agree with the number of classes.")
    }
    
    # 检查权重和是否为1
    if (sum(weights) != 1) {
        stop("Weights do not sum to 1")
    }
    
    # 对每个类别计算准确率
    accs <- lapply(C, function(x) {
        idx <- which(true.labels == x)  # 获取当前类别的所有索引
        return(calculate.accuracy(predictions[idx], true.labels[idx]))
    })
    
    # 计算加权准确率
    acc <- sum(weights*unlist(accs))  # unlist()将列表转化成向量
    return(acc)
}
acc <- calculate.accuracy(Test_Car$Oil_Consumption,pre1)
print(paste0("Accuracy is: ", round(acc, 3)))
## [1] "Accuracy is: 0.933"
fraction_weights <- 
  table(Test_Car$Oil_Consumption)/sum(table(Test_Car$Oil_Consumption))%>%  
  fractions()

fraction_weights 
## 
##    A    B    C 
##  3/5 4/15 2/15
acc1 <- calculate.w.accuracy(Test_Car$Oil_Consumption,pre1,weights=fraction_weights)
print(paste0("weight-Accuracy is: ", round(acc1, 3)))
## [1] "weight-Accuracy is: 947/1000"

安装caret和e1071包,利用confusionMatrix()函数进行评价

Kappa统计量主要比较两个或者多个观测者对同一事物、或者观测者对同一事物的两次或者多次观测结果是否一致,以由机遇造成的一致性和 实际观测的一致性之间的差别大小作为评价基础的统计指标。 Kappa统计量取值在\([-1,1]\)之间,其值的大小具有不同的意义:
library(caret)
library(e1071)
cm<-confusionMatrix(Test_Car$Oil_Consumption,pre1)
cm
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction A B C
##          A 8 1 0
##          B 0 4 0
##          C 0 0 2
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9333          
##                  95% CI : (0.6805, 0.9983)
##     No Information Rate : 0.5333          
##     P-Value [Acc > NIR] : 0.001135        
##                                           
##                   Kappa : 0.8837          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: A Class: B Class: C
## Sensitivity            1.0000   0.8000   1.0000
## Specificity            0.8571   1.0000   1.0000
## Pos Pred Value         0.8889   1.0000   1.0000
## Neg Pred Value         1.0000   0.9091   1.0000
## Prevalence             0.5333   0.3333   0.1333
## Detection Rate         0.5333   0.2667   0.1333
## Detection Prevalence   0.6000   0.2667   0.1333
## Balanced Accuracy      0.9286   0.9000   1.0000

建立CART模型

library(rpart)
library(rpart.plot)
library(pander)

formula=Oil_Consumption~Price+Country+Reliability+Type+Weight+Disp.+HP



ctl1<-rpart.control(minsplit = 3, maxcompete = 3,
                    maxdepth = 2, cp =0.001, xval = 5)
# CART建立决策树模型
treeFit <- rpart(formula, data = Train_Car, method = "class", 
                parms = list(split="gini"), control = ctl1)
names(treeFit)
##  [1] "frame"               "where"               "call"               
##  [4] "terms"               "cptable"             "method"             
##  [7] "parms"               "control"             "functions"          
## [10] "numresp"             "splits"              "csplit"             
## [13] "variable.importance" "y"                   "ordered"
printcp(treeFit)
## 
## Classification tree:
## rpart(formula = formula, data = Train_Car, method = "class", 
##     parms = list(split = "gini"), control = ctl1)
## 
## Variables actually used in tree construction:
## [1] Disp. Price
## 
## Root node error: 19/45 = 0.42222
## 
## n= 45 
## 
##        CP nsplit rel error  xerror    xstd
## 1 0.47368      0   1.00000 1.00000 0.17438
## 2 0.21053      1   0.52632 0.84211 0.16900
## 3 0.00100      2   0.31579 0.73684 0.16345
# 绘制各变量重要性柱状图
par(las=2)
barplot(sort(treeFit$variable.importance),
        cex.names = 0.6,
        names.arg = names(sort(treeFit$variable.importance))
        , horiz = T)

# 绘制决策树
rpart.plot(treeFit, type = 4, branch = 0, extra = 1)

# 预测
cFit1 = predict(treeFit, Test_Car, type = "class")
# 计算混淆矩阵
confM1 = table(Test_Car$Oil_Consumption, cFit1)
error1 = (sum(confM1)-sum(diag(confM1)))/sum(confM1);error1
## [1] 0.1333333
#
cm1<-confusionMatrix(Test_Car$Oil_Consumption,cFit1)
cm1
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction A B C
##          A 7 2 0
##          B 0 4 0
##          C 0 0 2
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8667          
##                  95% CI : (0.5954, 0.9834)
##     No Information Rate : 0.4667          
##     P-Value [Acc > NIR] : 0.001684        
##                                           
##                   Kappa : 0.7761          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: A Class: B Class: C
## Sensitivity            1.0000   0.6667   1.0000
## Specificity            0.7500   1.0000   1.0000
## Pos Pred Value         0.7778   1.0000   1.0000
## Neg Pred Value         1.0000   0.8182   1.0000
## Prevalence             0.4667   0.4000   0.1333
## Detection Rate         0.4667   0.2667   0.1333
## Detection Prevalence   0.6000   0.2667   0.1333
## Balanced Accuracy      0.8750   0.8333   1.0000
tibble(E1,error1)
## # A tibble: 1 × 2
##       E1 error1
##    <dbl>  <dbl>
## 1 0.0667  0.133
ctl2<-rpart.control(minsplit = 3, maxcompete = 3,
                    maxdepth = 3, cp =0.001, xval = 5)
# CART建立决策树模型
treeFit1 <- rpart(formula, data = Train_Car, method = "class", 
                parms = list(split="gini"), control = ctl2)
names(treeFit1)
##  [1] "frame"               "where"               "call"               
##  [4] "terms"               "cptable"             "method"             
##  [7] "parms"               "control"             "functions"          
## [10] "numresp"             "splits"              "csplit"             
## [13] "variable.importance" "y"                   "ordered"
printcp(treeFit1)
## 
## Classification tree:
## rpart(formula = formula, data = Train_Car, method = "class", 
##     parms = list(split = "gini"), control = ctl2)
## 
## Variables actually used in tree construction:
## [1] Country Disp.   Price  
## 
## Root node error: 19/45 = 0.42222
## 
## n= 45 
## 
##         CP nsplit rel error  xerror    xstd
## 1 0.473684      0  1.000000 1.00000 0.17438
## 2 0.210526      1  0.526316 0.78947 0.16644
## 3 0.105263      2  0.315789 0.78947 0.16644
## 4 0.052632      3  0.210526 0.68421 0.16002
## 5 0.001000      6  0.052632 0.73684 0.16345
# 绘制各变量重要性柱状图
par(las=2)
barplot(sort(treeFit1$variable.importance),
        cex.names = 0.6,
        names.arg = names(sort(treeFit1$variable.importance))
        , horiz = T)

# 绘制决策树
rpart.plot(treeFit1, type = 4, branch = 0, extra = 1)

# 预测
cFit2 = predict(treeFit1, Test_Car, type = "class")
# 计算混淆矩阵
confM2 = table(Test_Car$Oil_Consumption, cFit2)
error2 = (sum(confM2)-sum(diag(confM2)))/sum(confM2);error2
## [1] 0.2
#
cm2<-confusionMatrix(Test_Car$Oil_Consumption,cFit2)
cm2
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction A B C
##          A 7 2 0
##          B 0 4 0
##          C 0 1 1
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8             
##                  95% CI : (0.5191, 0.9567)
##     No Information Rate : 0.4667          
##     P-Value [Acc > NIR] : 0.009047        
##                                           
##                   Kappa : 0.6591          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: A Class: B Class: C
## Sensitivity            1.0000   0.5714  1.00000
## Specificity            0.7500   1.0000  0.92857
## Pos Pred Value         0.7778   1.0000  0.50000
## Neg Pred Value         1.0000   0.7273  1.00000
## Prevalence             0.4667   0.4667  0.06667
## Detection Rate         0.4667   0.2667  0.06667
## Detection Prevalence   0.6000   0.2667  0.13333
## Balanced Accuracy      0.8750   0.7857  0.96429
#剪枝
treeFit3<-prune(treeFit1,cp=0.2)
# 绘制决策树
rpart.plot(treeFit3, type = 4, branch = 0, extra = 1)

# 预测
cFit3 <- predict(treeFit3, Test_Car, type = "class")
# 计算混淆矩阵
confM3 <- table(Test_Car$Oil_Consumption, cFit3)
error3 <- (sum(confM3)-sum(diag(confM3)))/sum(confM3);error3
## [1] 0.1333333
#
cm3<-confusionMatrix(Test_Car$Oil_Consumption,cFit3)
cm3
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction A B C
##          A 7 2 0
##          B 0 4 0
##          C 0 0 2
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8667          
##                  95% CI : (0.5954, 0.9834)
##     No Information Rate : 0.4667          
##     P-Value [Acc > NIR] : 0.001684        
##                                           
##                   Kappa : 0.7761          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: A Class: B Class: C
## Sensitivity            1.0000   0.6667   1.0000
## Specificity            0.7500   1.0000   1.0000
## Pos Pred Value         0.7778   1.0000   1.0000
## Neg Pred Value         1.0000   0.8182   1.0000
## Prevalence             0.4667   0.4000   0.1333
## Detection Rate         0.4667   0.2667   0.1333
## Detection Prevalence   0.6000   0.2667   0.1333
## Balanced Accuracy      0.8750   0.8333   1.0000

结论

E_C45<-E1
E1_CART<-error1
E2_CART<-error2
E3_CART<-error3
errorlist<-cbind(E_C45,E1_CART,E2_CART,E3_CART)
pander::pander(errorlist*100)
E_C45 E1_CART E2_CART E3_CART
6.667 13.33 20 13.33