분류 모형

종속변수가 범주형 일 때를 말하며, 이진분류에 대해서만 다루겠음. 의사결정트리는 한번 분기마다 변수의 영역을 나누는 모델

영역을 나누는 것의 기준.

영역을 나누는 것의 기준은 분류모델을 기준으로 불순도(impurity)/불확실성(uncertainty)이 최소가 될 수 있도록하는 방향으로 진행된다. 순도나 불확실성의 증감을 두고 정보획득(information gain)으로 표현하기도 함.

이 지표로 지니지수 및 엔트로피가 사용됨.

Entropy

\(Entropy(A)=\sum R_{i}(-\sum p_{k} log_{2}(p_{k}))\)

엔트로피의 감소는 불확실성의 감소를 얘기하는게 되는데 따라서 정보획득량의 증가를 뜻함.

Gini index1),2)

지니불순도와 경제학에서 소득불평등 지수에 사용되는 지니 지수는 동음이의어이다. 의사결정트리에서는 지니 불순도를 사용하며 링크 1을 참조하기 바란다.

\(Gini(A)=\sum^{d}_{i=1}(R_{i}(1-\sum^{m}_{k=1}p^2_{ik}))\)

library(caret)
library(tree)
library(dplyr)
library(rpart)
set.seed(1)
# iris=filter(iris,Species%in%c('setosa','versicolor'))
tree=rpart::rpart(Species~Sepal.Width,data=iris)
plot(tree)
text(tree)

tree
## n= 150 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
##   2) Sepal.Width>=3.35 37   6 setosa (0.83783784 0.02702703 0.13513514) *
##   3) Sepal.Width< 3.35 113  64 versicolor (0.16814159 0.43362832 0.39823009)  
##     6) Sepal.Width>=2.95 56  32 virginica (0.30357143 0.26785714 0.42857143) *
##     7) Sepal.Width< 2.95 57  23 versicolor (0.03508772 0.59649123 0.36842105) *

트리의 첫번째 분류기준이 Sepal.Width가 3.35보다 큰지 적은지 인데 3.35를 어떻게 구하는지 알아보자.

지니지수로 분류기준 알아보기

위 트리를 확인해보면 3.35에서 첫 분류가 되었다는 것을 알 수 있다. 아래에서 지니지수를 구한 결과를 통해 알아보자.

coef1=NULL
for(i in seq(0,4,0.05)){
#i보다 큰 iris값을 R_filter에 작은값을 L_filter에 할당
R_filter=filter(iris, Sepal.Width>=i)
L_filter=filter(iris, Sepal.Width<i )
#R_table, L_table은 할당된 데이터에서 범주별 개수
R_table=R_filter%>%select(Species)%>%table
R_table
L_table=L_filter%>%select(Species)%>%table
L_table
#전체 자료중에 범주별 비율의 제곱을 1에서 뺌
cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
cost_L=1-sum((L_table/(L_filter%>%nrow))^2)

gini=sum(R_table)/nrow(iris)*cost_R+sum(L_table/nrow(iris))*cost_L
  coef1=rbind(coef1,data.frame(i,cost_R,cost_L,gini))
}
coef1=coef1[complete.cases(coef1),]
plot(coef1[,1],coef1[,4],type='l',ylim=c(0,1))
points(coef1[,1],coef1[,2],type='l',col=2)
points(coef1[,1],coef1[,3],type='l',col=3)
abline(v=3.35)

coef1[which.min(coef1$gini),'gini']
## [1] 0.5397433

지니지수를 쉽게 말하면 왼쪽 분류된 애들의 확률의 제곱 합 \((1-(\frac{범주별 개수_{왼쪽}}{관측치 개수의 총합_{왼쪽}})^2)*\frac{관측치 개수의총합_{왼쪽}}{관측치 개수의 총합_{전체}}+(1-(\frac{범주별 개수_{오른쪽}}{관측치 개수의 총합_{오른쪽}})^2)*\frac{관측치 개수의총합_{오른쪽}}{관측치 개수의 총합_{전체}}\)

R_filter=filter(iris, Sepal.Width>=3.35)
L_filter=filter(iris, Sepal.Width<3.35 )
#R_table, L_table은 할당된 데이터에서 범주별 개수
R_table=R_filter%>%select(Species)%>%table
R_table
## .
##     setosa versicolor  virginica 
##         31          1          5

검은선은 지니지수의 불순도를 나타내고 붉은선은 왼쪽으로 분류된 기준 값, 초록색은 오른쪽으로 분류된 값을 나타내는데 3.35가 지니지수가 가장 낮은 것을 알 수 있다. 다시말해, 이는 정보획득량이 가장 높다고 할 수 있다.

아래는 분류 된 자료 중에서 왼쪽에 해당하는 자료를 다시 분할하는 과정을 나타낸 것이다. 2.95와 같은 결과가 나오는 것을 알 수 있다.

iris2=filter(iris,Sepal.Width<3.35)
coef2=NULL
for(i in seq(0,4,0.05)){
  R_filter=filter(iris2, Sepal.Width>=i)
  L_filter=filter(iris2, Sepal.Width<i )
  R_table=R_filter%>%select(Species)%>%table
  L_table=L_filter%>%select(Species)%>%table
  cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
  cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
  gini=sum(R_table)/nrow(iris2)*cost_R+sum(L_table/nrow(iris2))*cost_L
  coef2=rbind(coef2,data.frame(i,cost_R,cost_L,gini))
}
coef2=coef2[complete.cases(coef2),]

coef2[which.min(coef2$gini),'gini']
## [1] 0.5791858
plot(coef2[,1],coef2[,4],type='l',ylim=c(0,1))
points(coef2[,1],coef2[,2],type='l',col=2)
points(coef2[,1],coef2[,3],type='l',col=3)
abline(v=2.95)

아래는 분류 된 자료 중에서 오른쪽에 해당하는 자료를 다시 분할하는 과정을 나타낸 것이다. 분할하기전보다 불순도가 높다는 것을 알 수 있다. 이에따라 tree를 확인해보면 분할되지 않았음을 알 수 있다.

iris3=filter(iris,Sepal.Width>=3.35)
coef3=NULL
for(i in seq(0,4,0.05)){
  R_filter=filter(iris3, Sepal.Width>=i)
  L_filter=filter(iris3, Sepal.Width<i )
  R_table=R_filter%>%select(Species)%>%table
  L_table=L_filter%>%select(Species)%>%table
  cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
  cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
  gini=sum(R_table)/nrow(iris3)*cost_R+sum(L_table/nrow(iris3))*cost_L
  coef3=rbind(coef3,data.frame(i,cost_R,cost_L,gini))
}
coef3=coef3[complete.cases(coef3),]
coef3[which.min(coef3$gini),'gini']
## [1] 0.2702703
plot(coef3[,1],coef3[,4],type='l',ylim=c(0,1))
points(coef3[,1],coef3[,2],type='l',col=2)
points(coef3[,1],coef3[,3],type='l',col=3)

참고

tree=rpart::rpart(Species~.,data=iris)


par(mfrow=c(3,1))
coef1=NULL
for(i in seq(0,4,0.05)){
  R_filter=filter(iris, Petal.Length>=i)
  L_filter=filter(iris, Petal.Length<i )
  R_table=R_filter%>%select(Species)%>%table
  L_table=L_filter%>%select(Species)%>%table
  cost_R=  1-sum((R_table/(R_filter%>%nrow))^2)
  cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
  gini=sum(R_table)/nrow(iris)*cost_R+sum(L_table/nrow(iris))*cost_L
  coef1=rbind(coef1,data.frame(i,cost_R,cost_L,gini))
}
coef1=coef1[complete.cases(coef1),]
plot(coef1[,1],coef1[,4],type='l',ylim=c(0,1))
points(coef1[,1],coef1[,2],type='l',col=2)
points(coef1[,1],coef1[,3],type='l',col=3)
abline(v=coef1[which.min(coef1$gini),'i'])
tree
## n= 150 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
##   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
##   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
##     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
##     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *
iris2=filter(iris,Petal.Length>=2.45)
coef2=NULL
for(i in seq(0,4,0.05)){
  R_filter=filter(iris2, Petal.Width>=i)
  L_filter=filter(iris2, Petal.Width<i )
  R_table=R_filter%>%select(Species)%>%table
  L_table=L_filter%>%select(Species)%>%table
  cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
  cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
  gini=sum(R_table)/nrow(iris2)*cost_R+sum(L_table/nrow(iris2))*cost_L
  coef2=rbind(coef2,data.frame(i,cost_R,cost_L,gini))
}
coef2=coef2[complete.cases(coef2),]

coef2[which.min(coef2$gini),'gini']
## [1] 0.110306
plot(coef2[,1],coef2[,4],type='l',ylim=c(0,1))
points(coef2[,1],coef2[,2],type='l',col=2)
points(coef2[,1],coef2[,3],type='l',col=3)
abline(v=coef2[which.min(coef2$gini),'i'])
tree
## n= 150 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
##   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
##   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
##     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
##     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *
iris3=filter(iris,Petal.Length>=1.9)
coef3=NULL
for(i in seq(0,4,0.05)){
  R_filter=filter(iris3, Petal.Width>=i)
  L_filter=filter(iris3, Petal.Width<i )
  R_table=R_filter%>%select(Species)%>%table
  L_table=L_filter%>%select(Species)%>%table
  cost_R=1-sum((R_table/(R_filter%>%nrow))^2)
  cost_L=1-sum((L_table/(L_filter%>%nrow))^2)
  gini=sum(R_table)/nrow(iris3)*cost_R+sum(L_table/nrow(iris3))*cost_L
  coef3=rbind(coef3,data.frame(i,cost_R,cost_L,gini))
}
coef3=coef3[complete.cases(coef3),]
coef3[which.min(coef3$gini),'gini']
## [1] 0.142781
plot(coef3[,1],coef3[,4],type='l',ylim=c(0,1))
points(coef3[,1],coef3[,2],type='l',col=2)
points(coef3[,1],coef3[,3],type='l',col=3)
abline(v=coef3[which.min(coef3$gini),'i'])

library(tree)
treemod<-tree(Species~. , data=iris)
plot(treemod);text(treemod)
cv.tree(treemod, FUN=prune.misclass )
## $size
## [1] 6 4 3 2 1
## 
## $dev
## [1]   7   9  11  88 112
## 
## $k
## [1] -Inf    0    2   44   50
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
plot(cv.tree(treemod, FUN=prune.misclass ))
prune.trees <- prune.misclass(treemod, best=6)  # for regression decision tree, use prune.tree function

plot(treemod);text(treemod)

plot(prune.trees)
text(prune.trees, pretty=0)
plotcp(tree)