特别说明:j48()函数的使用需要先安装Java软件
首先导入car.test.frame数据集,为了增加可读性,然后将其变量名转换成中文。library(tidyverse)
library(rpart)
data(car.test.frame)
car<-tibble(car.test.frame)
car
## # A tibble: 60 × 8
## Price Country Reliability Mileage Type Weight Disp. HP
## <int> <fct> <int> <int> <fct> <int> <int> <int>
## 1 8895 USA 4 33 Small 2560 97 113
## 2 7402 USA 2 33 Small 2345 114 90
## 3 6319 Korea 4 37 Small 1845 81 63
## 4 6635 Japan/USA 5 32 Small 2260 91 92
## 5 6599 Japan 5 32 Small 2440 113 103
## 6 8672 Mexico 4 26 Small 2285 97 82
## 7 7399 Japan/USA 5 33 Small 2275 97 90
## 8 7254 Korea 1 28 Small 2350 98 74
## 9 9599 Japan 5 25 Small 2295 109 90
## 10 5866 Japan NA 34 Small 1900 73 73
## # ℹ 50 more rows
car$Mileage <- 100*4.546/(1.6*car$Mileage)
names(car) <- c("价格","产地","可靠性","油耗","类型","车重",
"发动机功率","净马力")
tibble(car)
## # A tibble: 60 × 8
## 价格 产地 可靠性 油耗 类型 车重 发动机功率 净马力
## <int> <fct> <int> <dbl> <fct> <int> <int> <int>
## 1 8895 USA 4 8.61 Small 2560 97 113
## 2 7402 USA 2 8.61 Small 2345 114 90
## 3 6319 Korea 4 7.68 Small 1845 81 63
## 4 6635 Japan/USA 5 8.88 Small 2260 91 92
## 5 6599 Japan 5 8.88 Small 2440 113 103
## 6 8672 Mexico 4 10.9 Small 2285 97 82
## 7 7399 Japan/USA 5 8.61 Small 2275 97 90
## 8 7254 Korea 1 10.1 Small 2350 98 74
## 9 9599 Japan 5 11.4 Small 2295 109 90
## 10 5866 Japan NA 8.36 Small 1900 73 73
## # ℹ 50 more rows
summary(car) #获取数据集的概要信息
## 价格 产地 可靠性 油耗 类型
## Min. : 5866 USA :26 Min. :1.000 Min. : 7.679 Compact:15
## 1st Qu.: 9932 Japan :19 1st Qu.:2.000 1st Qu.:10.523 Large : 3
## Median :12216 Japan/USA: 7 Median :3.000 Median :12.353 Medium :13
## Mean :12616 Korea : 3 Mean :3.388 Mean :11.962 Small :13
## 3rd Qu.:14933 Germany : 2 3rd Qu.:5.000 3rd Qu.:13.530 Sporty : 9
## Max. :24760 France : 1 Max. :5.000 Max. :15.785 Van : 7
## (Other) : 2 NA's :11
## 车重 发动机功率 净马力
## Min. :1845 Min. : 73.0 Min. : 63.0
## 1st Qu.:2571 1st Qu.:113.8 1st Qu.:101.5
## Median :2885 Median :144.5 Median :111.5
## Mean :2901 Mean :152.1 Mean :122.3
## 3rd Qu.:3231 3rd Qu.:180.0 3rd Qu.:142.8
## Max. :3855 Max. :305.0 Max. :225.0
##
#传统方式利用现有变量生成新的变量
#Group_Mileage <- matrix(0,60,1) #设矩阵Group_Mileage用来存放新变量
#Group_Mileage[which(car$"油耗">=11.6)] <- "A"
#Group_Mileage[which(car$"油耗"<=9)] <- "C"
#Group_Mileage[which(Group_Mileage==0)] <- "B"
#car$"分组油耗" <- Group_Mileage
#car[1:10,c(4,9)]
#tibble(car)
#利用mutate()函数生成新的变量
library(dplyr)
car <- car %>%
mutate(分组油耗 = case_when(
油耗 >= 11.6 ~ "A",
油耗 <= 9 ~ "C",
TRUE ~ "B" # 其他情况赋值为 "B"
))
# 查看修改后的前10行数据的指定列
print(car[1:10, c(4, 9)]) # 油耗在第四列,分组油耗是第九列
## # A tibble: 10 × 2
## 油耗 分组油耗
## <dbl> <chr>
## 1 8.61 C
## 2 8.61 C
## 3 7.68 C
## 4 8.88 C
## 5 8.88 C
## 6 10.9 B
## 7 8.61 C
## 8 10.1 B
## 9 11.4 B
## 10 8.36 C
# 查看整个数据框
tibble(car)
## # A tibble: 60 × 9
## 价格 产地 可靠性 油耗 类型 车重 发动机功率 净马力 分组油耗
## <int> <fct> <int> <dbl> <fct> <int> <int> <int> <chr>
## 1 8895 USA 4 8.61 Small 2560 97 113 C
## 2 7402 USA 2 8.61 Small 2345 114 90 C
## 3 6319 Korea 4 7.68 Small 1845 81 63 C
## 4 6635 Japan/USA 5 8.88 Small 2260 91 92 C
## 5 6599 Japan 5 8.88 Small 2440 113 103 C
## 6 8672 Mexico 4 10.9 Small 2285 97 82 B
## 7 7399 Japan/USA 5 8.61 Small 2275 97 90 C
## 8 7254 Korea 1 10.1 Small 2350 98 74 B
## 9 9599 Japan 5 11.4 Small 2295 109 90 B
## 10 5866 Japan NA 8.36 Small 1900 73 73 C
## # ℹ 50 more rows
library(sampling)
a <- round(1/4*sum(car$"分组油耗"=="A"))
b <- round(1/4*sum(car$"分组油耗"=="B"))
c <- round(1/4*sum(car$"分组油耗"=="C"))
a;b;c
## [1] 9
## [1] 4
## [1] 2
#使用strata函数对car.test.frame中的“分组油耗”变量进行分层抽样
set.seed(12345)
sub <- car %>%
strata(stratanames="分组油耗",size=c(c,b,a),method="srswor")
tibble(sub) #输出抽样信息
## # A tibble: 15 × 4
## 分组油耗 ID_unit Prob Stratum
## <chr> <int> <dbl> <int>
## 1 C 3 0.222 1
## 2 C 12 0.222 1
## 3 B 19 0.25 2
## 4 B 22 0.25 2
## 5 B 25 0.25 2
## 6 B 36 0.25 2
## 7 A 16 0.257 3
## 8 A 29 0.257 3
## 9 A 31 0.257 3
## 10 A 35 0.257 3
## 11 A 47 0.257 3
## 12 A 49 0.257 3
## 13 A 54 0.257 3
## 14 A 57 0.257 3
## 15 A 59 0.257 3
Train_Car <- car[-sub$ID_unit,] #生成训练样本集
Test_Car <- car[sub$ID_unit,] #生成测试样本集
nrow(Train_Car);nrow(Test_Car) #显示训练集和测试集行数,检查比例是否为3:1
## [1] 45
## [1] 15
注:J48()函数对中文识别不太完善,所以我们还原变量名称为英文
names(Train_Car)=c("Price","Country","Reliability","Mileage",
"Type","Weight","Disp.","HP","Oil_Consumption")
Train_Car$Oil_Consumption<-as.factor(Train_Car$Oil_Consumption)
names(Test_Car)=c("Price","Country","Reliability","Mileage",
"Type","Weight","Disp.","HP","Oil_Consumption")
Test_Car$Oil_Consumption<-as.factor(Test_Car$Oil_Consumption)
# install.packages("RWeka")
library(RWeka)
#注明:先安装java软件
# install.packages("rJava")
library(rJava)
formula=Oil_Consumption~Price+Country+Reliability+Type+Weight+Disp.+HP
C45_0=J48(formula,Train_Car)
C45_0
## J48 pruned tree
## ------------------
##
## Disp. <= 133
## | Price <= 9483: C (9.0/3.0)
## | Price > 9483: B (5.0)
## Disp. > 133: A (24.0/2.0)
##
## Number of Leaves : 3
##
## Size of the tree : 5
summary(C45_0)
##
## === Summary ===
##
## Correctly Classified Instances 33 86.8421 %
## Incorrectly Classified Instances 5 13.1579 %
## Kappa statistic 0.766
## Mean absolute error 0.1345
## Root mean squared error 0.2593
## Relative absolute error 34.9259 %
## Root relative squared error 59.4341 %
## Total Number of Instances 38
##
## === Confusion Matrix ===
##
## a b c <-- classified as
## 22 0 0 | a = A
## 2 5 3 | b = B
## 0 0 6 | c = C
C45_1=J48(formula,Train_Car,control=Weka_control(M=3))
C45_1
## J48 pruned tree
## ------------------
##
## Disp. <= 133
## | Price <= 9483: C (9.0/3.0)
## | Price > 9483: B (5.0)
## Disp. > 133: A (24.0/2.0)
##
## Number of Leaves : 3
##
## Size of the tree : 5
summary(C45_1)
##
## === Summary ===
##
## Correctly Classified Instances 33 86.8421 %
## Incorrectly Classified Instances 5 13.1579 %
## Kappa statistic 0.766
## Mean absolute error 0.1345
## Root mean squared error 0.2593
## Relative absolute error 34.9259 %
## Root relative squared error 59.4341 %
## Total Number of Instances 38
##
## === Confusion Matrix ===
##
## a b c <-- classified as
## 22 0 0 | a = A
## 2 5 3 | b = B
## 0 0 6 | c = C
# install.packages("partykit")
# install.packages("grid")
library(partykit)
library(grid)
plot(C45_0,main="C45_0")
plot(C45_1,main="C45_1")
#预测
pre1=predict(C45_1,Test_Car,type="class")
pre1
## [1] C C B B B B A B A A A A A A A
## Levels: A B C
ConfM1=table(Test_Car$Oil_Consumption,pre1)
ConfM1 #输出混淆矩阵
## pre1
## A B C
## A 8 1 0
## B 0 4 0
## C 0 0 2
(p=sum(as.numeric(pre1!=Test_Car$Oil_Consumption))/nrow(Test_Car))
## [1] 0.06666667
#计算分类错误率
(E1<-(sum(ConfM1)-sum(diag(ConfM1)))/sum(ConfM1)) #计算分类错误率
## [1] 0.06666667
二分类模型的混淆矩阵:
从混淆矩阵计算的典型指标有:\[P_{e}=\frac{<a,b>}{n^2},\] \(n\)为总样本数,a和b分别定义如下: 设有C类样本,每类真实样本数构成的向量为 \[a=(a_{1},a_{2},...,a_{c}),\] 每类预测样本数构成的向量为 \[b=(b_{1},b_{2},...,b_{c}),\]
评价指标:
1、准确率accuracy:正确预测的平均数 \[accuracy = \frac{1}{N}\sum_{k=1}^{|C|}\sum_{x:g(x)=k}I\{g(x)=\hat{g}(x)\}\] 其中\(N\)为样本数,\(|C|\)为总的类别数,\(g(x)\)和\(\hat{g}(x)\)分别为目标变量的真实值和预测值。
2、加权准确率 \[weight-accuracy = \sum_{k=1}^{|C|}w_{k}\sum_{x:g(x)=k}I\{g(x)=\hat{g}(x)\}\] \[w_{k} = \frac {1} {|C_{k}|}, \forall k \in \{1,\dots C\},\] 其中\(|C_{k}|\)表示第\(k\)类样本数目。
# 加载MASS包
library(MASS)
##
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
##
## select
# 定义一个计算简单准确率的函数
calculate.accuracy <- function(predictions, true.labels) {
# 计算预测正确的样本数,并除以总样本数以得到准确率
return(length(which(predictions == true.labels)) / length(true.labels))
}
# 定义一个计算加权准确率的函数
calculate.w.accuracy <- function(predictions, true.labels, weights) {
# 获取所有的类别
C <- levels(true.labels)
# 检查权重向量维度是否与类别向量维度一致
if (length(weights) != length(C)) {
stop("Number of weights should agree with the number of classes.")
}
# 检查权重和是否为1
if (sum(weights) != 1) {
stop("Weights do not sum to 1")
}
# 对每个类别计算准确率
accs <- lapply(C, function(x) {
idx <- which(true.labels == x) # 获取当前类别的所有索引
return(calculate.accuracy(predictions[idx], true.labels[idx]))
})
# 计算加权准确率
acc <- sum(weights*unlist(accs)) # unlist()将列表转化成向量
return(acc)
}
acc <- calculate.accuracy(Test_Car$Oil_Consumption,pre1)
print(paste0("Accuracy is: ", round(acc, 3)))
## [1] "Accuracy is: 0.933"
fraction_weights <-
table(Test_Car$Oil_Consumption)/sum(table(Test_Car$Oil_Consumption))%>%
fractions()
fraction_weights
##
## A B C
## 3/5 4/15 2/15
acc1 <- calculate.w.accuracy(Test_Car$Oil_Consumption,pre1,weights=fraction_weights)
print(paste0("weight-Accuracy is: ", round(acc1, 3)))
## [1] "weight-Accuracy is: 947/1000"
library(caret)
library(e1071)
cm<-confusionMatrix(Test_Car$Oil_Consumption,pre1)
cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B C
## A 8 1 0
## B 0 4 0
## C 0 0 2
##
## Overall Statistics
##
## Accuracy : 0.9333
## 95% CI : (0.6805, 0.9983)
## No Information Rate : 0.5333
## P-Value [Acc > NIR] : 0.001135
##
## Kappa : 0.8837
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: A Class: B Class: C
## Sensitivity 1.0000 0.8000 1.0000
## Specificity 0.8571 1.0000 1.0000
## Pos Pred Value 0.8889 1.0000 1.0000
## Neg Pred Value 1.0000 0.9091 1.0000
## Prevalence 0.5333 0.3333 0.1333
## Detection Rate 0.5333 0.2667 0.1333
## Detection Prevalence 0.6000 0.2667 0.1333
## Balanced Accuracy 0.9286 0.9000 1.0000
library(rpart)
library(rpart.plot)
library(pander)
formula=Oil_Consumption~Price+Country+Reliability+Type+Weight+Disp.+HP
ctl1<-rpart.control(minsplit = 3, maxcompete = 3,
maxdepth = 2, cp =0.001, xval = 5)
# CART建立决策树模型
treeFit <- rpart(formula, data = Train_Car, method = "class",
parms = list(split="gini"), control = ctl1)
names(treeFit)
## [1] "frame" "where" "call"
## [4] "terms" "cptable" "method"
## [7] "parms" "control" "functions"
## [10] "numresp" "splits" "csplit"
## [13] "variable.importance" "y" "ordered"
printcp(treeFit)
##
## Classification tree:
## rpart(formula = formula, data = Train_Car, method = "class",
## parms = list(split = "gini"), control = ctl1)
##
## Variables actually used in tree construction:
## [1] Disp. Price
##
## Root node error: 19/45 = 0.42222
##
## n= 45
##
## CP nsplit rel error xerror xstd
## 1 0.47368 0 1.00000 1.00000 0.17438
## 2 0.21053 1 0.52632 0.84211 0.16900
## 3 0.00100 2 0.31579 0.73684 0.16345
# 绘制各变量重要性柱状图
par(las=2)
barplot(sort(treeFit$variable.importance),
cex.names = 0.6,
names.arg = names(sort(treeFit$variable.importance))
, horiz = T)
# 绘制决策树
rpart.plot(treeFit, type = 4, branch = 0, extra = 1)
# 预测
cFit1 = predict(treeFit, Test_Car, type = "class")
# 计算混淆矩阵
confM1 = table(Test_Car$Oil_Consumption, cFit1)
error1 = (sum(confM1)-sum(diag(confM1)))/sum(confM1);error1
## [1] 0.1333333
#
cm1<-confusionMatrix(Test_Car$Oil_Consumption,cFit1)
cm1
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B C
## A 7 2 0
## B 0 4 0
## C 0 0 2
##
## Overall Statistics
##
## Accuracy : 0.8667
## 95% CI : (0.5954, 0.9834)
## No Information Rate : 0.4667
## P-Value [Acc > NIR] : 0.001684
##
## Kappa : 0.7761
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: A Class: B Class: C
## Sensitivity 1.0000 0.6667 1.0000
## Specificity 0.7500 1.0000 1.0000
## Pos Pred Value 0.7778 1.0000 1.0000
## Neg Pred Value 1.0000 0.8182 1.0000
## Prevalence 0.4667 0.4000 0.1333
## Detection Rate 0.4667 0.2667 0.1333
## Detection Prevalence 0.6000 0.2667 0.1333
## Balanced Accuracy 0.8750 0.8333 1.0000
tibble(E1,error1)
## # A tibble: 1 × 2
## E1 error1
## <dbl> <dbl>
## 1 0.0667 0.133
ctl2<-rpart.control(minsplit = 3, maxcompete = 3,
maxdepth = 3, cp =0.001, xval = 5)
# CART建立决策树模型
treeFit1 <- rpart(formula, data = Train_Car, method = "class",
parms = list(split="gini"), control = ctl2)
names(treeFit1)
## [1] "frame" "where" "call"
## [4] "terms" "cptable" "method"
## [7] "parms" "control" "functions"
## [10] "numresp" "splits" "csplit"
## [13] "variable.importance" "y" "ordered"
printcp(treeFit1)
##
## Classification tree:
## rpart(formula = formula, data = Train_Car, method = "class",
## parms = list(split = "gini"), control = ctl2)
##
## Variables actually used in tree construction:
## [1] Country Disp. Price
##
## Root node error: 19/45 = 0.42222
##
## n= 45
##
## CP nsplit rel error xerror xstd
## 1 0.473684 0 1.000000 1.00000 0.17438
## 2 0.210526 1 0.526316 0.78947 0.16644
## 3 0.105263 2 0.315789 0.78947 0.16644
## 4 0.052632 3 0.210526 0.68421 0.16002
## 5 0.001000 6 0.052632 0.73684 0.16345
# 绘制各变量重要性柱状图
par(las=2)
barplot(sort(treeFit1$variable.importance),
cex.names = 0.6,
names.arg = names(sort(treeFit1$variable.importance))
, horiz = T)
# 绘制决策树
rpart.plot(treeFit1, type = 4, branch = 0, extra = 1)
# 预测
cFit2 = predict(treeFit1, Test_Car, type = "class")
# 计算混淆矩阵
confM2 = table(Test_Car$Oil_Consumption, cFit2)
error2 = (sum(confM2)-sum(diag(confM2)))/sum(confM2);error2
## [1] 0.2
#
cm2<-confusionMatrix(Test_Car$Oil_Consumption,cFit2)
cm2
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B C
## A 7 2 0
## B 0 4 0
## C 0 1 1
##
## Overall Statistics
##
## Accuracy : 0.8
## 95% CI : (0.5191, 0.9567)
## No Information Rate : 0.4667
## P-Value [Acc > NIR] : 0.009047
##
## Kappa : 0.6591
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: A Class: B Class: C
## Sensitivity 1.0000 0.5714 1.00000
## Specificity 0.7500 1.0000 0.92857
## Pos Pred Value 0.7778 1.0000 0.50000
## Neg Pred Value 1.0000 0.7273 1.00000
## Prevalence 0.4667 0.4667 0.06667
## Detection Rate 0.4667 0.2667 0.06667
## Detection Prevalence 0.6000 0.2667 0.13333
## Balanced Accuracy 0.8750 0.7857 0.96429
#剪枝
treeFit3<-prune(treeFit1,cp=0.2)
# 绘制决策树
rpart.plot(treeFit3, type = 4, branch = 0, extra = 1)
# 预测
cFit3 <- predict(treeFit3, Test_Car, type = "class")
# 计算混淆矩阵
confM3 <- table(Test_Car$Oil_Consumption, cFit3)
error3 <- (sum(confM3)-sum(diag(confM3)))/sum(confM3);error3
## [1] 0.1333333
#
cm3<-confusionMatrix(Test_Car$Oil_Consumption,cFit3)
cm3
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B C
## A 7 2 0
## B 0 4 0
## C 0 0 2
##
## Overall Statistics
##
## Accuracy : 0.8667
## 95% CI : (0.5954, 0.9834)
## No Information Rate : 0.4667
## P-Value [Acc > NIR] : 0.001684
##
## Kappa : 0.7761
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: A Class: B Class: C
## Sensitivity 1.0000 0.6667 1.0000
## Specificity 0.7500 1.0000 1.0000
## Pos Pred Value 0.7778 1.0000 1.0000
## Neg Pred Value 1.0000 0.8182 1.0000
## Prevalence 0.4667 0.4000 0.1333
## Detection Rate 0.4667 0.2667 0.1333
## Detection Prevalence 0.6000 0.2667 0.1333
## Balanced Accuracy 0.8750 0.8333 1.0000
E_C45<-E1
E1_CART<-error1
E2_CART<-error2
E3_CART<-error3
errorlist<-cbind(E_C45,E1_CART,E2_CART,E3_CART)
pander::pander(errorlist*100)
E_C45 | E1_CART | E2_CART | E3_CART |
---|---|---|---|
6.667 | 13.33 | 20 | 13.33 |