종속변수가 범주형 일 때를 말하며, 이진분류에 대해서만 다루겠음. 의사결정트리는 한번 분기마다 변수의 영역을 나누는 모델
영역을 나누는 것의 기준은 분류모델을 기준으로 불순도(impurity)/불확실성(uncertainty)이 최소가 될 수 있도록하는 방향으로 진행된다. 순도나 불확실성의 증감을 두고 정보획득(information gain)으로 표현하기도 함.
이 지표로 지니지수 및 엔트로피가 사용됨.
\(Entropy(A)=\sum R_{i}(-\sum p_{k} log_{2}(p_{k}))\)
엔트로피의 감소는 불확실성의 감소를 얘기하는게 되는데 따라서 정보획득량의 증가를 뜻함.
지니불순도와 경제학에서 소득불평등 지수에 사용되는 지니 지수는 동음이의어이다. 의사결정트리에서는 지니 불순도를 사용하며 링크 1을 참조하기 바란다.
\(Gini(A)=\sum^{d}_{i=1}(R_{i}(1-\sum^{m}_{k=1}p^2_{ik}))\)
library(caret)
library(tree)
library(dplyr)
library(rpart)
set.seed(1)
# iris=filter(iris,Species%in%c('setosa','versicolor'))
tree=rpart::rpart(Species~Sepal.Width,data=iris)
plot(tree)
text(tree)
tree
## n= 150
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
## 2) Sepal.Width>=3.35 37 6 setosa (0.83783784 0.02702703 0.13513514) *
## 3) Sepal.Width< 3.35 113 64 versicolor (0.16814159 0.43362832 0.39823009)
## 6) Sepal.Width>=2.95 56 32 virginica (0.30357143 0.26785714 0.42857143) *
## 7) Sepal.Width< 2.95 57 23 versicolor (0.03508772 0.59649123 0.36842105) *
트리의 첫번째 분류기준이 Sepal.Width가 3.35보다 큰지 적은지 인데 3.35를 어떻게 구하는지 알아보자.
위 트리를 확인해보면 3.35에서 첫 분류가 되었다는 것을 알 수 있다. 아래에서 지니지수를 구한 결과를 통해 알아보자.
coef1=NULL
for(i in seq(0,4,0.05)){
#i보다 큰 iris값을 R_filter에 작은값을 L_filter에 할당
R_filter=filter(iris, Sepal.Width>=i)
L_filter=filter(iris, Sepal.Width<i )
#R_table, L_table은 할당된 데이터에서 범주별 개수
R_table=R_filter%>%select(Species)%>%table
R_table
L_table=L_filter%>%select(Species)%>%table
L_table
#전체 자료중에 범주별 비율의 제곱을 1에서 뺌
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris)*cost_R+sum(L_table/nrow(iris))*cost_L
coef1=rbind(coef1,data.frame(i,cost_R,cost_L,gini))
}
coef1=coef1[complete.cases(coef1),]
plot(coef1[,1],coef1[,4],type='l',ylim=c(0,1))
points(coef1[,1],coef1[,2],type='l',col=2)
points(coef1[,1],coef1[,3],type='l',col=3)
abline(v=3.35)
coef1[which.min(coef1$gini),'gini']
## [1] 0.5397433
지니지수를 쉽게 말하면 왼쪽 분류된 애들의 확률의 제곱 합 \((1-(\frac{범주별 개수_{왼쪽}}{관측치 개수의 총합_{왼쪽}})^2)*\frac{관측치 개수의총합_{왼쪽}}{관측치 개수의 총합_{전체}}+(1-(\frac{범주별 개수_{오른쪽}}{관측치 개수의 총합_{오른쪽}})^2)*\frac{관측치 개수의총합_{오른쪽}}{관측치 개수의 총합_{전체}}\)
R_filter=filter(iris, Sepal.Width>=3.35)
L_filter=filter(iris, Sepal.Width<3.35 )
#R_table, L_table은 할당된 데이터에서 범주별 개수
R_table=R_filter%>%select(Species)%>%table
R_table
## .
## setosa versicolor virginica
## 31 1 5
검은선은 지니지수의 불순도를 나타내고 붉은선은 왼쪽으로 분류된 기준 값, 초록색은 오른쪽으로 분류된 값을 나타내는데 3.35가 지니지수가 가장 낮은 것을 알 수 있다. 다시말해, 이는 정보획득량이 가장 높다고 할 수 있다.
아래는 분류 된 자료 중에서 왼쪽에 해당하는 자료를 다시 분할하는 과정을 나타낸 것이다. 2.95와 같은 결과가 나오는 것을 알 수 있다.
iris2=filter(iris,Sepal.Width<3.35)
coef2=NULL
for(i in seq(0,4,0.05)){
R_filter=filter(iris2, Sepal.Width>=i)
L_filter=filter(iris2, Sepal.Width<i )
R_table=R_filter%>%select(Species)%>%table
L_table=L_filter%>%select(Species)%>%table
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris2)*cost_R+sum(L_table/nrow(iris2))*cost_L
coef2=rbind(coef2,data.frame(i,cost_R,cost_L,gini))
}
coef2=coef2[complete.cases(coef2),]
coef2[which.min(coef2$gini),'gini']
## [1] 0.5791858
plot(coef2[,1],coef2[,4],type='l',ylim=c(0,1))
points(coef2[,1],coef2[,2],type='l',col=2)
points(coef2[,1],coef2[,3],type='l',col=3)
abline(v=2.95)
아래는 분류 된 자료 중에서 오른쪽에 해당하는 자료를 다시 분할하는 과정을 나타낸 것이다. 분할하기전보다 불순도가 높다는 것을 알 수 있다. 이에따라 tree를 확인해보면 분할되지 않았음을 알 수 있다.
iris3=filter(iris,Sepal.Width>=3.35)
coef3=NULL
for(i in seq(0,4,0.05)){
R_filter=filter(iris3, Sepal.Width>=i)
L_filter=filter(iris3, Sepal.Width<i )
R_table=R_filter%>%select(Species)%>%table
L_table=L_filter%>%select(Species)%>%table
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris3)*cost_R+sum(L_table/nrow(iris3))*cost_L
coef3=rbind(coef3,data.frame(i,cost_R,cost_L,gini))
}
coef3=coef3[complete.cases(coef3),]
coef3[which.min(coef3$gini),'gini']
## [1] 0.2702703
plot(coef3[,1],coef3[,4],type='l',ylim=c(0,1))
points(coef3[,1],coef3[,2],type='l',col=2)
points(coef3[,1],coef3[,3],type='l',col=3)
tree=rpart::rpart(Species~.,data=iris)
par(mfrow=c(3,1))
coef1=NULL
for(i in seq(0,4,0.05)){
R_filter=filter(iris, Petal.Length>=i)
L_filter=filter(iris, Petal.Length<i )
R_table=R_filter%>%select(Species)%>%table
L_table=L_filter%>%select(Species)%>%table
cost_R= 1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris)*cost_R+sum(L_table/nrow(iris))*cost_L
coef1=rbind(coef1,data.frame(i,cost_R,cost_L,gini))
}
coef1=coef1[complete.cases(coef1),]
plot(coef1[,1],coef1[,4],type='l',ylim=c(0,1))
points(coef1[,1],coef1[,2],type='l',col=2)
points(coef1[,1],coef1[,3],type='l',col=3)
abline(v=coef1[which.min(coef1$gini),'i'])
tree
## n= 150
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
## 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
## 3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
## 6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
## 7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
iris2=filter(iris,Petal.Length>=2.45)
coef2=NULL
for(i in seq(0,4,0.05)){
R_filter=filter(iris2, Petal.Width>=i)
L_filter=filter(iris2, Petal.Width<i )
R_table=R_filter%>%select(Species)%>%table
L_table=L_filter%>%select(Species)%>%table
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris2)*cost_R+sum(L_table/nrow(iris2))*cost_L
coef2=rbind(coef2,data.frame(i,cost_R,cost_L,gini))
}
coef2=coef2[complete.cases(coef2),]
coef2[which.min(coef2$gini),'gini']
## [1] 0.110306
plot(coef2[,1],coef2[,4],type='l',ylim=c(0,1))
points(coef2[,1],coef2[,2],type='l',col=2)
points(coef2[,1],coef2[,3],type='l',col=3)
abline(v=coef2[which.min(coef2$gini),'i'])
tree
## n= 150
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
## 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
## 3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
## 6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
## 7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
iris3=filter(iris,Petal.Length>=1.9)
coef3=NULL
for(i in seq(0,4,0.05)){
R_filter=filter(iris3, Petal.Width>=i)
L_filter=filter(iris3, Petal.Width<i )
R_table=R_filter%>%select(Species)%>%table
L_table=L_filter%>%select(Species)%>%table
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
gini=sum(R_table)/nrow(iris3)*cost_R+sum(L_table/nrow(iris3))*cost_L
coef3=rbind(coef3,data.frame(i,cost_R,cost_L,gini))
}
coef3=coef3[complete.cases(coef3),]
coef3[which.min(coef3$gini),'gini']
## [1] 0.142781
plot(coef3[,1],coef3[,4],type='l',ylim=c(0,1))
points(coef3[,1],coef3[,2],type='l',col=2)
points(coef3[,1],coef3[,3],type='l',col=3)
abline(v=coef3[which.min(coef3$gini),'i'])
library(tree)
treemod<-tree(Species~. , data=iris)
plot(treemod);text(treemod)
cv.tree(treemod, FUN=prune.misclass )
## $size
## [1] 6 4 3 2 1
##
## $dev
## [1] 7 9 11 88 112
##
## $k
## [1] -Inf 0 2 44 50
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.tree(treemod, FUN=prune.misclass ))
prune.trees <- prune.misclass(treemod, best=6) # for regression decision tree, use prune.tree function
plot(treemod);text(treemod)
plot(prune.trees)
text(prune.trees, pretty=0)
plotcp(tree)