Decision Tree

Decision Tree는 주어진 데이터를 분류(Classification)하는 목적으로 사용되는 분석으로 의사결정규칙(decision rule)을 나무구조(tree)로 도표화하여 분류와 예측을 수행하는 분석 기법이다.

1. 특징

Decision Tree의 대표적인 algorithm에는 C 4.5, CART, CHAID가 있으며 기본적인 생성 방식은 유사하며 가지를 분리하는 분리기준을 선택하는 방식에 약간의 차이가 있다.

1.1. 장점

1.2. 단점

2. 단계

3. 예제

party package example by Zhao

## party package를 이용하여 iris 데이터를 Decision Tree 분석기법으로
## 분류해 본다.
library(party)
## iris 데이터에 대하여 알아본다.
data(iris)
str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
## iris 데이터를 7:3의 비율로 trainData와 testData로 샘플링한다.
set.seed(1234)
ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))
trainData <- iris[ind == 1, ]
testData <- iris[ind == 2, ]
## Decision Tree 분석결과 terminal node에 대한 내용을 출력하고 분류결과를
## plot한다.
iris_ctree <- ctree(Species ~ ., data = trainData)
print(iris_ctree)
## 
##   Conditional inference tree with 4 terminal nodes
## 
## Response:  Species 
## Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
## Number of observations:  112 
## 
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 104.643
##   2)*  weights = 40 
## 1) Petal.Length > 1.9
##   3) Petal.Width <= 1.7; criterion = 1, statistic = 48.939
##     4) Petal.Length <= 4.4; criterion = 0.974, statistic = 7.397
##       5)*  weights = 21 
##     4) Petal.Length > 4.4
##       6)*  weights = 19 
##   3) Petal.Width > 1.7
##     7)*  weights = 32
plot(iris_ctree)

plot of chunk unnamed-chunk-4

## simple type으로 plot
plot(iris_ctree, type = "simple")

plot of chunk unnamed-chunk-4

## trainData에 대하여 설정한 Tree Model을 이용하여 예측한 값을 출력한다.
trainPred <- predict(iris_ctree, newdata = trainData)
trainPred
##   [1] setosa     setosa     setosa     setosa     setosa     setosa    
##   [7] setosa     setosa     setosa     setosa     setosa     setosa    
##  [13] setosa     setosa     setosa     setosa     setosa     setosa    
##  [19] setosa     setosa     setosa     setosa     setosa     setosa    
##  [25] setosa     setosa     setosa     setosa     setosa     setosa    
##  [31] setosa     setosa     setosa     setosa     setosa     setosa    
##  [37] setosa     setosa     setosa     setosa     versicolor versicolor
##  [43] versicolor versicolor versicolor versicolor versicolor versicolor
##  [49] versicolor versicolor versicolor versicolor versicolor versicolor
##  [55] versicolor virginica  versicolor versicolor versicolor versicolor
##  [61] versicolor versicolor versicolor versicolor versicolor versicolor
##  [67] versicolor versicolor versicolor versicolor versicolor versicolor
##  [73] versicolor versicolor versicolor versicolor versicolor versicolor
##  [79] virginica  virginica  virginica  virginica  virginica  virginica 
##  [85] versicolor virginica  virginica  virginica  virginica  virginica 
##  [91] virginica  virginica  virginica  virginica  virginica  virginica 
##  [97] virginica  virginica  versicolor virginica  virginica  versicolor
## [103] virginica  virginica  virginica  virginica  virginica  virginica 
## [109] virginica  virginica  virginica  virginica 
## Levels: setosa versicolor virginica
## 실제 trainData값과 예측값을 비교하는 교차표를 출력한다.
table(predict(iris_ctree), trainData$Species)
##             
##              setosa versicolor virginica
##   setosa         40          0         0
##   versicolor      0         37         3
##   virginica       0          1        31
## trainData에 대하여 설정한 Tree Model을 이용하여 예측한 값을 출력한다.
testPred <- predict(iris_ctree, newdata = testData)
testPred
##  [1] setosa     setosa     setosa     setosa     setosa     setosa    
##  [7] setosa     setosa     setosa     setosa     versicolor versicolor
## [13] versicolor versicolor versicolor versicolor versicolor versicolor
## [19] versicolor versicolor versicolor versicolor virginica  virginica 
## [25] virginica  virginica  versicolor virginica  virginica  virginica 
## [31] virginica  virginica  versicolor virginica  virginica  virginica 
## [37] virginica  virginica 
## Levels: setosa versicolor virginica
## 실제 trainData값과 예측값을 비교하는 교차표를 출력한다.
table(testPred, testData$Species)
##             
## testPred     setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0         12         2
##   virginica       0          0        14

rpart package example by Zhao

## rpart package를 이용하여 bodyfat 데이터를 Decision Tree 분석기법으로
## 분류해 본다.
library(rpart)
## bodyfat 데이터에 대하여 알아본다.
data("bodyfat", package = "mboost")
dim(bodyfat)
## [1] 71 10
attributes(bodyfat)
## $names
##  [1] "age"          "DEXfat"       "waistcirc"    "hipcirc"     
##  [5] "elbowbreadth" "kneebreadth"  "anthro3a"     "anthro3b"    
##  [9] "anthro3c"     "anthro4"     
## 
## $row.names
##  [1] "47"  "48"  "49"  "50"  "51"  "52"  "53"  "54"  "55"  "56"  "57" 
## [12] "58"  "59"  "60"  "61"  "62"  "63"  "64"  "65"  "66"  "67"  "68" 
## [23] "69"  "70"  "71"  "72"  "73"  "74"  "75"  "76"  "77"  "78"  "79" 
## [34] "80"  "81"  "82"  "83"  "84"  "85"  "86"  "87"  "88"  "89"  "90" 
## [45] "91"  "92"  "93"  "94"  "95"  "96"  "97"  "98"  "99"  "100" "101"
## [56] "102" "103" "104" "105" "106" "107" "108" "109" "110" "111" "112"
## [67] "113" "114" "115" "116" "117"
## 
## $class
## [1] "data.frame"
## bodyfat 데이터를 7:3의 비율로 bodyfat.train과 bodyfat.test로
## 샘플링한다.
set.seed(1234)
ind <- sample(2, nrow(bodyfat), replace = TRUE, prob = c(0.7, 0.3))
bodyfat.train <- bodyfat[ind == 1, ]
bodyfat.test <- bodyfat[ind == 2, ]
## Decision Tree 분석결과 terminal node에 대한 내용을 출력하고 분류결과를
## plot한다.
myFormula <- DEXfat ~ age + waistcirc + hipcirc + elbowbreadth + kneebreadth
bodyfat_rpart <- rpart(myFormula, data = bodyfat.train, control = rpart.control(minsplit = 10))
print(bodyfat_rpart)
## n= 56 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 56 7265.0000 30.95  
##    2) waistcirc< 88.4 31  960.5000 22.56  
##      4) hipcirc< 96.25 14  222.3000 18.41  
##        8) age< 60.5 9   66.8800 16.19 *
##        9) age>=60.5 5   31.2800 22.41 *
##      5) hipcirc>=96.25 17  299.6000 25.97  
##       10) waistcirc< 77.75 6   30.7300 22.32 *
##       11) waistcirc>=77.75 11  145.7000 27.96  
##         22) hipcirc< 99.5 3    0.2569 23.75 *
##         23) hipcirc>=99.5 8   72.2900 29.54 *
##    3) waistcirc>=88.4 25 1417.0000 41.35  
##      6) waistcirc< 104.8 18  330.6000 38.09  
##       12) hipcirc< 109.9 9   69.0000 34.38 *
##       13) hipcirc>=109.9 9   13.0800 41.81 *
##      7) waistcirc>=104.8 7  404.3000 49.73 *
plot(bodyfat_rpart)
text(bodyfat_rpart, use.n = TRUE)

plot of chunk unnamed-chunk-10

## prune(가지치기)를 통하여 Tree의 과적합을 방지할 수 있다.
print(bodyfat_rpart$cptable)
##        CP nsplit rel error xerror    xstd
## 1 0.67273      0   1.00000 1.0195 0.18724
## 2 0.09391      1   0.32727 0.4415 0.10853
## 3 0.06038      2   0.23337 0.4271 0.09363
## 4 0.03420      3   0.17299 0.3842 0.09031
## 5 0.01708      4   0.13879 0.3038 0.07296
## 6 0.01696      5   0.12170 0.2740 0.06600
## 7 0.01007      6   0.10475 0.2694 0.06614
## 8 0.01000      7   0.09468 0.2695 0.06621
## cptable의 xerror를 최소로 하는 cp를 기준으로 가지치기 강도를
## 설정하였다.
opt <- which.min(bodyfat_rpart$cptable[, "xerror"])
cp <- bodyfat_rpart$cptable[opt, "CP"]
bodyfat_prune <- prune(bodyfat_rpart, cp = cp)
## 가지치기를 실시한 Decision Tree 분석결과 terminal node에 대한 내용을
## 출력하고 분류결과를 plot한다.
print(bodyfat_prune)
## n= 56 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 56 7265.00 30.95  
##    2) waistcirc< 88.4 31  960.50 22.56  
##      4) hipcirc< 96.25 14  222.30 18.41  
##        8) age< 60.5 9   66.88 16.19 *
##        9) age>=60.5 5   31.28 22.41 *
##      5) hipcirc>=96.25 17  299.60 25.97  
##       10) waistcirc< 77.75 6   30.73 22.32 *
##       11) waistcirc>=77.75 11  145.70 27.96 *
##    3) waistcirc>=88.4 25 1417.00 41.35  
##      6) waistcirc< 104.8 18  330.60 38.09  
##       12) hipcirc< 109.9 9   69.00 34.38 *
##       13) hipcirc>=109.9 9   13.08 41.81 *
##      7) waistcirc>=104.8 7  404.30 49.73 *
plot(bodyfat_prune)
text(bodyfat_prune, use.n = TRUE)

plot of chunk unnamed-chunk-11

## trainData에 대하여 설정한 Tree Model을 이용하여 testData에 적합하여
## 예측한 결과를 plot을 통해 확인해 본다.
DEXfat_pred <- predict(bodyfat_prune, newdata = bodyfat.test)
xlim <- range(bodyfat$DEXfat)
plot(DEXfat_pred ~ DEXfat, data = bodyfat.test, xlab = "Observed", ylab = "Predicted", 
    ylim = xlim, xlim = xlim)
abline(a = 0, b = 1)

plot of chunk unnamed-chunk-12