Decision Tree는 주어진 데이터를 분류(Classification)하는 목적으로 사용되는 분석으로 의사결정규칙(decision rule)을 나무구조(tree)로 도표화하여 분류와 예측을 수행하는 분석 기법이다.
Decision Tree의 대표적인 algorithm에는 C 4.5, CART, CHAID가 있으며 기본적인 생성 방식은 유사하며 가지를 분리하는 분리기준을 선택하는 방식에 약간의 차이가 있다.
해석이 용이하다.
missing value을 효과적으로 처리한다.
처리속도가 빠르다.
예측모델로서의 정확도가 떨어진다.
variance가 높아 불안정하다.
같은 데이터로부터 다른 Tree가 나올 수 있다.
Step1. 최적의 분리기준을 가지는 변수를 기준으로 데이터를 분류한다.
Step2. Step1 단계를 반복한다.
Stop1. 분리된 데이터의 속성이 같을 때
Stop2. 더 이상 분리할 기준변수가 없을 때
Stop3. 더 이상 분리할 데이터가 없을 때
## party package를 이용하여 iris 데이터를 Decision Tree 분석기법으로
## 분류해 본다.
library(party)
## iris 데이터에 대하여 알아본다.
data(iris)
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
## iris 데이터를 7:3의 비율로 trainData와 testData로 샘플링한다.
set.seed(1234)
ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))
trainData <- iris[ind == 1, ]
testData <- iris[ind == 2, ]
## Decision Tree 분석결과 terminal node에 대한 내용을 출력하고 분류결과를
## plot한다.
iris_ctree <- ctree(Species ~ ., data = trainData)
print(iris_ctree)
##
## Conditional inference tree with 4 terminal nodes
##
## Response: Species
## Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
## Number of observations: 112
##
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 104.643
## 2)* weights = 40
## 1) Petal.Length > 1.9
## 3) Petal.Width <= 1.7; criterion = 1, statistic = 48.939
## 4) Petal.Length <= 4.4; criterion = 0.974, statistic = 7.397
## 5)* weights = 21
## 4) Petal.Length > 4.4
## 6)* weights = 19
## 3) Petal.Width > 1.7
## 7)* weights = 32
plot(iris_ctree)
## simple type으로 plot
plot(iris_ctree, type = "simple")
## trainData에 대하여 설정한 Tree Model을 이용하여 예측한 값을 출력한다.
trainPred <- predict(iris_ctree, newdata = trainData)
trainPred
## [1] setosa setosa setosa setosa setosa setosa
## [7] setosa setosa setosa setosa setosa setosa
## [13] setosa setosa setosa setosa setosa setosa
## [19] setosa setosa setosa setosa setosa setosa
## [25] setosa setosa setosa setosa setosa setosa
## [31] setosa setosa setosa setosa setosa setosa
## [37] setosa setosa setosa setosa versicolor versicolor
## [43] versicolor versicolor versicolor versicolor versicolor versicolor
## [49] versicolor versicolor versicolor versicolor versicolor versicolor
## [55] versicolor virginica versicolor versicolor versicolor versicolor
## [61] versicolor versicolor versicolor versicolor versicolor versicolor
## [67] versicolor versicolor versicolor versicolor versicolor versicolor
## [73] versicolor versicolor versicolor versicolor versicolor versicolor
## [79] virginica virginica virginica virginica virginica virginica
## [85] versicolor virginica virginica virginica virginica virginica
## [91] virginica virginica virginica virginica virginica virginica
## [97] virginica virginica versicolor virginica virginica versicolor
## [103] virginica virginica virginica virginica virginica virginica
## [109] virginica virginica virginica virginica
## Levels: setosa versicolor virginica
## 실제 trainData값과 예측값을 비교하는 교차표를 출력한다.
table(predict(iris_ctree), trainData$Species)
##
## setosa versicolor virginica
## setosa 40 0 0
## versicolor 0 37 3
## virginica 0 1 31
## trainData에 대하여 설정한 Tree Model을 이용하여 예측한 값을 출력한다.
testPred <- predict(iris_ctree, newdata = testData)
testPred
## [1] setosa setosa setosa setosa setosa setosa
## [7] setosa setosa setosa setosa versicolor versicolor
## [13] versicolor versicolor versicolor versicolor versicolor versicolor
## [19] versicolor versicolor versicolor versicolor virginica virginica
## [25] virginica virginica versicolor virginica virginica virginica
## [31] virginica virginica versicolor virginica virginica virginica
## [37] virginica virginica
## Levels: setosa versicolor virginica
## 실제 trainData값과 예측값을 비교하는 교차표를 출력한다.
table(testPred, testData$Species)
##
## testPred setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 2
## virginica 0 0 14
## rpart package를 이용하여 bodyfat 데이터를 Decision Tree 분석기법으로
## 분류해 본다.
library(rpart)
## bodyfat 데이터에 대하여 알아본다.
data("bodyfat", package = "mboost")
dim(bodyfat)
## [1] 71 10
attributes(bodyfat)
## $names
## [1] "age" "DEXfat" "waistcirc" "hipcirc"
## [5] "elbowbreadth" "kneebreadth" "anthro3a" "anthro3b"
## [9] "anthro3c" "anthro4"
##
## $row.names
## [1] "47" "48" "49" "50" "51" "52" "53" "54" "55" "56" "57"
## [12] "58" "59" "60" "61" "62" "63" "64" "65" "66" "67" "68"
## [23] "69" "70" "71" "72" "73" "74" "75" "76" "77" "78" "79"
## [34] "80" "81" "82" "83" "84" "85" "86" "87" "88" "89" "90"
## [45] "91" "92" "93" "94" "95" "96" "97" "98" "99" "100" "101"
## [56] "102" "103" "104" "105" "106" "107" "108" "109" "110" "111" "112"
## [67] "113" "114" "115" "116" "117"
##
## $class
## [1] "data.frame"
## bodyfat 데이터를 7:3의 비율로 bodyfat.train과 bodyfat.test로
## 샘플링한다.
set.seed(1234)
ind <- sample(2, nrow(bodyfat), replace = TRUE, prob = c(0.7, 0.3))
bodyfat.train <- bodyfat[ind == 1, ]
bodyfat.test <- bodyfat[ind == 2, ]
## Decision Tree 분석결과 terminal node에 대한 내용을 출력하고 분류결과를
## plot한다.
myFormula <- DEXfat ~ age + waistcirc + hipcirc + elbowbreadth + kneebreadth
bodyfat_rpart <- rpart(myFormula, data = bodyfat.train, control = rpart.control(minsplit = 10))
print(bodyfat_rpart)
## n= 56
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 56 7265.0000 30.95
## 2) waistcirc< 88.4 31 960.5000 22.56
## 4) hipcirc< 96.25 14 222.3000 18.41
## 8) age< 60.5 9 66.8800 16.19 *
## 9) age>=60.5 5 31.2800 22.41 *
## 5) hipcirc>=96.25 17 299.6000 25.97
## 10) waistcirc< 77.75 6 30.7300 22.32 *
## 11) waistcirc>=77.75 11 145.7000 27.96
## 22) hipcirc< 99.5 3 0.2569 23.75 *
## 23) hipcirc>=99.5 8 72.2900 29.54 *
## 3) waistcirc>=88.4 25 1417.0000 41.35
## 6) waistcirc< 104.8 18 330.6000 38.09
## 12) hipcirc< 109.9 9 69.0000 34.38 *
## 13) hipcirc>=109.9 9 13.0800 41.81 *
## 7) waistcirc>=104.8 7 404.3000 49.73 *
plot(bodyfat_rpart)
text(bodyfat_rpart, use.n = TRUE)
## prune(가지치기)를 통하여 Tree의 과적합을 방지할 수 있다.
print(bodyfat_rpart$cptable)
## CP nsplit rel error xerror xstd
## 1 0.67273 0 1.00000 1.0195 0.18724
## 2 0.09391 1 0.32727 0.4415 0.10853
## 3 0.06038 2 0.23337 0.4271 0.09363
## 4 0.03420 3 0.17299 0.3842 0.09031
## 5 0.01708 4 0.13879 0.3038 0.07296
## 6 0.01696 5 0.12170 0.2740 0.06600
## 7 0.01007 6 0.10475 0.2694 0.06614
## 8 0.01000 7 0.09468 0.2695 0.06621
## cptable의 xerror를 최소로 하는 cp를 기준으로 가지치기 강도를
## 설정하였다.
opt <- which.min(bodyfat_rpart$cptable[, "xerror"])
cp <- bodyfat_rpart$cptable[opt, "CP"]
bodyfat_prune <- prune(bodyfat_rpart, cp = cp)
## 가지치기를 실시한 Decision Tree 분석결과 terminal node에 대한 내용을
## 출력하고 분류결과를 plot한다.
print(bodyfat_prune)
## n= 56
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 56 7265.00 30.95
## 2) waistcirc< 88.4 31 960.50 22.56
## 4) hipcirc< 96.25 14 222.30 18.41
## 8) age< 60.5 9 66.88 16.19 *
## 9) age>=60.5 5 31.28 22.41 *
## 5) hipcirc>=96.25 17 299.60 25.97
## 10) waistcirc< 77.75 6 30.73 22.32 *
## 11) waistcirc>=77.75 11 145.70 27.96 *
## 3) waistcirc>=88.4 25 1417.00 41.35
## 6) waistcirc< 104.8 18 330.60 38.09
## 12) hipcirc< 109.9 9 69.00 34.38 *
## 13) hipcirc>=109.9 9 13.08 41.81 *
## 7) waistcirc>=104.8 7 404.30 49.73 *
plot(bodyfat_prune)
text(bodyfat_prune, use.n = TRUE)
## trainData에 대하여 설정한 Tree Model을 이용하여 testData에 적합하여
## 예측한 결과를 plot을 통해 확인해 본다.
DEXfat_pred <- predict(bodyfat_prune, newdata = bodyfat.test)
xlim <- range(bodyfat$DEXfat)
plot(DEXfat_pred ~ DEXfat, data = bodyfat.test, xlab = "Observed", ylab = "Predicted",
ylim = xlim, xlim = xlim)
abline(a = 0, b = 1)