library(rpart)
(m <- rpart(Species ~ ., data = iris))
## n= 150
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 100 setosa (0.33333 0.33333 0.33333)
## 2) Petal.Length< 2.45 50 0 setosa (1.00000 0.00000 0.00000) *
## 3) Petal.Length>=2.45 100 50 versicolor (0.00000 0.50000 0.50000)
## 6) Petal.Width< 1.75 54 5 versicolor (0.00000 0.90741 0.09259) *
## 7) Petal.Width>=1.75 46 1 virginica (0.00000 0.02174 0.97826) *
n=150은 150개의 데이터가 있음을 의미한다.
들여쓰기는 가지의 갈라지는 모양을 뜻한다.
'*'는 입사귀 노드를 의미한다.
트리의 최상단은 root노드로 위의 결과에는 '1)'로 표시된다. 괄호안은 iris의 Species별 비율을 의미한다.
모델을 좀 더 쉽게 보기 위해 plot()을 사용해 트리를 그려보자.
plot(m, compress = TRUE, margin = 0.2)
text(m, cex = 1.5)
트리를 좀 더 잘 보이게 조절하기 위한 인자들이 있는데, compress는 나무를 좀 더 조밀하게 그린 것이고, margin은 여백, cex는 글자의 크기를 뜻한다.
다른 방법의 시각화 방법을 알아보자. rpart.plot이라는 패키지를 사용하며, prp()를 통해 시각화를 제공한다.
library(rpart.plot)
prp(m, type = 0, extra = 2)
prp(m, type = 1, extra = 2)
prp(m, type = 2, extra = 2)
prp(m, type = 3, extra = 2)
prp(m, type = 4, extra = 2)
type의 종류는 0부터 4까지 있다.
extra는 “extra information"을 나타내는 인자로 1~9까지 있다.
rpart()를 사용한 예측 역시 predict()를 통해 쉽게 구할 수 있다.
head(predict(m, newdata = iris, type = "class"))
## 1 2 3 4 5 6
## setosa setosa setosa setosa setosa setosa
## Levels: setosa versicolor virginica
분류를 얻을때는 type="class"를 지정해야하지만, 기본 값이 class이므로 생략해도 된다라는 내용은 이 전에 공부한 내용이다.
library(party)
## Warning: package 'party' was built under R version 3.0.3
## Loading required package: grid
## Loading required package: zoo
## Warning: package 'zoo' was built under R version 3.0.3
##
## Attaching package: 'zoo'
##
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
##
## Loading required package: sandwich
## Warning: package 'sandwich' was built under R version 3.0.3
## Loading required package: strucchange
## Warning: package 'strucchange' was built under R version 3.0.3
## Loading required package: modeltools
## Loading required package: stats4
(m <- ctree(Species ~ ., data = iris))
##
## Conditional inference tree with 4 terminal nodes
##
## Response: Species
## Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
## Number of observations: 150
##
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 140.264
## 2)* weights = 50
## 1) Petal.Length > 1.9
## 3) Petal.Width <= 1.7; criterion = 1, statistic = 67.894
## 4) Petal.Length <= 4.8; criterion = 0.999, statistic = 13.865
## 5)* weights = 46
## 4) Petal.Length > 4.8
## 6)* weights = 8
## 3) Petal.Width > 1.7
## 7)* weights = 46
plot(m)
잎사귀 노드(맨 마지막)의 최종 분류 결과가 무엇이었는지, 또 만약 잘못 분류된 경우가 있다면 그 정도가 어느 정도인지 알려주어 모델을 개선하는데 도움을 준다.
예를 들어 Node 6에 총 8개의 결과가 있는데 이때 두개 Species가 거의 동일한 숫자로 나타난 것을 알 수 있다.
따라서 이 경우를 개선하기 위한 모델을 개발한다면 성능을 더 향상시킬 수 있다.
library(randomForest)
## Warning: package 'randomForest' was built under R version 3.0.3
## randomForest 4.6-7
## Type rfNews() to see new features/changes/bug fixes.
m <- randomForest(Species ~ ., data = iris)
m
##
## Call:
## randomForest(formula = Species ~ ., data = iris)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4.67%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 4 46 0.08
모델을 출력하면 모델 훈련에 사용되지 않은 데이터를 사용한 에러 추정치를 볼 수 있다.
예측은 Generic Function인 predict()를 사용해 수행한다.
head(predict(m, newdata = iris))
## 1 2 3 4 5 6
## setosa setosa setosa setosa setosa setosa
## Levels: setosa versicolor virginica
m <- randomForest(iris[, 1:4], iris[, 5])
m
##
## Call:
## randomForest(x = iris[, 1:4], y = iris[, 5])
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4.67%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 4 46 0.08
m <- randomForest(Species ~ ., data = iris, importance = TRUE)
importance(m)
## setosa versicolor virginica MeanDecreaseAccuracy
## Sepal.Length 5.017 7.4838 7.779 10.091
## Sepal.Width 3.911 0.1756 3.635 4.013
## Petal.Length 22.805 34.3148 30.100 34.949
## Petal.Width 21.821 29.9509 29.611 31.440
## MeanDecreaseGini
## Sepal.Length 8.765
## Sepal.Width 1.893
## Petal.Length 45.761
## Petal.Width 42.898
varImpPlot(m)
Accuracy측면에서는 Petal.Length, Petal.Width, Sepal.Length, Sepal.Width순으로 변수가 중요함을 알 수 있다.
Gini측면에서는 Petal.Width, Petal.Length, Sepal.Length, Sepal.Width 순으로 중요 했다.
expand.grid를 사용해 가능한 조합의 목록을 만들 수 있다.
(grid <- expand.grid(ntree = c(10, 100, 200), mrty = c(3, 4)))
## ntree mrty
## 1 10 3
## 2 100 3
## 3 200 3
## 4 10 4
## 5 100 4
## 6 200 4
이 파라미터 조합을 10개로 분할한 데이터(즉, K=10로 한다는 소리)에 적용하여 모델의 성능을 평가하는 일을 R회 반복하면 교차 검증을 사용한 파라미터를 찾을 수 있게 된다.
library(cvTools)
## Warning: package 'cvTools' was built under R version 3.0.3
## Loading required package: lattice
## Warning: package 'lattice' was built under R version 3.0.3
## Loading required package: robustbase
## Warning: package 'robustbase' was built under R version 3.0.3
library(foreach)
library(randomForest)
cvTools는 cvFolds를 사용하기 위함이고, cvFolds는 난수를 사용하여 데이터를 분리한다.
foreach는 값을 반환하여 결과를 한번에 모으는 데 사용하고, .combine을 사용해 결과가 리스트가 아닌 다른 형태로 되게한다.
randomForest()는 모델의 성능과 변수의 중요도를 평가하기 위해 사용되었다.
set.seed(719)
K = 10
R = 3
cv <- cvFolds(NROW(iris), K = K, R = R)
cvFolds의 첫번째 인자는 그룹으로 나눠질 관찰값들의 갯수를 의미한다.
K는 몇개의 데이터 그룹(Folds)으로 나눌 것인지를 의미한다.
R은 K-fold cross-validation을 몇 번 반복할 것인지를 의미한다.
grid <- expand.grid(ntree = c(10, 100, 200), mtry = c(3, 4))
그리고 가능한 조합을 만든다.
result <- foreach(g = 1:NROW(grid), .combine = rbind) %do% {
foreach(r = 1:R, .combine = rbind) %do% {
foreach(k = 1:K, .combine = rbind) %do% {
validation_idx <- cv$subsets[which(cv$which == k), r]
train <- iris[-validation_idx, ]
validation <- iris[validation_idx, ]
# training
m <- randomForest(Species ~ ., data = train, ntree = grid[g, "ntree"],
mtry = grid[g, "mtry"])
# prediction
predicted <- predict(m, newdata = validation)
# estimating performance
precision <- sum(predicted == validation$Species)/NROW(predicted)
return(data.frame(g = g, precision = precision))
}
}
}
cv$which는 몇번째 Fold인지를 나타내주므로 which(cv$which == k)는 는 k번째 Fold의 인덱스를 알려준다. 그리고 r은 몇번째 반복인지를 나타내므로, cv$subsets[which(cv$which == k), r]는 반복 1의 k번째 Folds의 관찰값 인덱스를 반환한다.
train데이터는 검증 데이터와 평가 데이터를 떼어놓은 뒤, 이를 제외한 나머지 데이터로, 모델을 생성한다.
validation데이터는 검증데이터로, 예측한 값과 validation의 값을 비교해 모델의 분류 성능을 검증하기 위해 사용된다.
foreach에서는 반환값이 리스트 형태이므로, 데이터를 데이터 프레임으로 모으기 위해 rbind를 사용하였다.
result
## g precision
## 1 1 1.0000
## 2 1 1.0000
## 3 1 0.9333
## 4 1 0.9333
## 5 1 1.0000
## 6 1 1.0000
## 7 1 0.8000
## 8 1 0.9333
## 9 1 0.9333
## 10 1 0.9333
## 11 1 1.0000
## 12 1 0.9333
## 13 1 0.7333
## 14 1 1.0000
## 15 1 1.0000
## 16 1 1.0000
## 17 1 0.9333
## 18 1 1.0000
## 19 1 0.9333
## 20 1 1.0000
## 21 1 0.9333
## 22 1 0.9333
## 23 1 1.0000
## 24 1 1.0000
## 25 1 0.9333
## 26 1 0.8667
## 27 1 0.9333
## 28 1 0.8000
## 29 1 0.9333
## 30 1 1.0000
## 31 2 1.0000
## 32 2 1.0000
## 33 2 0.9333
## 34 2 0.9333
## 35 2 1.0000
## 36 2 1.0000
## 37 2 0.8000
## 38 2 0.9333
## 39 2 0.9333
## 40 2 0.9333
## 41 2 1.0000
## 42 2 0.9333
## 43 2 0.7333
## 44 2 1.0000
## 45 2 1.0000
## 46 2 1.0000
## 47 2 0.9333
## 48 2 1.0000
## 49 2 1.0000
## 50 2 1.0000
## 51 2 0.9333
## 52 2 0.9333
## 53 2 1.0000
## 54 2 1.0000
## 55 2 0.9333
## 56 2 0.9333
## 57 2 0.9333
## 58 2 0.9333
## 59 2 0.9333
## 60 2 1.0000
## 61 3 1.0000
## 62 3 1.0000
## 63 3 0.9333
## 64 3 0.9333
## 65 3 1.0000
## 66 3 1.0000
## 67 3 0.8667
## 68 3 1.0000
## 69 3 0.9333
## 70 3 0.9333
## 71 3 1.0000
## 72 3 0.9333
## 73 3 0.6667
## 74 3 1.0000
## 75 3 1.0000
## 76 3 1.0000
## 77 3 0.9333
## 78 3 1.0000
## 79 3 0.9333
## 80 3 1.0000
## 81 3 0.9333
## 82 3 0.9333
## 83 3 1.0000
## 84 3 1.0000
## 85 3 0.9333
## 86 3 0.9333
## 87 3 0.9333
## 88 3 0.9333
## 89 3 0.9333
## 90 3 1.0000
## 91 4 1.0000
## 92 4 1.0000
## 93 4 0.9333
## 94 4 0.9333
## 95 4 1.0000
## 96 4 1.0000
## 97 4 0.8667
## 98 4 0.9333
## 99 4 0.9333
## 100 4 0.9333
## 101 4 1.0000
## 102 4 0.9333
## 103 4 0.7333
## 104 4 1.0000
## 105 4 1.0000
## 106 4 1.0000
## 107 4 0.9333
## 108 4 1.0000
## 109 4 1.0000
## 110 4 1.0000
## 111 4 0.9333
## 112 4 0.9333
## 113 4 1.0000
## 114 4 1.0000
## 115 4 0.9333
## 116 4 0.9333
## 117 4 0.9333
## 118 4 0.9333
## 119 4 0.9333
## 120 4 1.0000
## 121 5 1.0000
## 122 5 1.0000
## 123 5 0.9333
## 124 5 0.9333
## 125 5 1.0000
## 126 5 1.0000
## 127 5 0.8000
## 128 5 1.0000
## 129 5 0.9333
## 130 5 0.9333
## 131 5 1.0000
## 132 5 0.9333
## 133 5 0.7333
## 134 5 1.0000
## 135 5 1.0000
## 136 5 1.0000
## 137 5 0.9333
## 138 5 1.0000
## 139 5 0.9333
## 140 5 1.0000
## 141 5 0.9333
## 142 5 0.9333
## 143 5 1.0000
## 144 5 1.0000
## 145 5 0.9333
## 146 5 0.9333
## 147 5 0.9333
## 148 5 0.9333
## 149 5 0.9333
## 150 5 1.0000
## 151 6 1.0000
## 152 6 1.0000
## 153 6 0.9333
## 154 6 0.9333
## 155 6 1.0000
## 156 6 1.0000
## 157 6 0.8000
## 158 6 1.0000
## 159 6 0.9333
## 160 6 0.9333
## 161 6 1.0000
## 162 6 0.9333
## 163 6 0.7333
## 164 6 1.0000
## 165 6 1.0000
## 166 6 1.0000
## 167 6 0.9333
## 168 6 1.0000
## 169 6 1.0000
## 170 6 1.0000
## 171 6 0.9333
## 172 6 0.9333
## 173 6 1.0000
## 174 6 1.0000
## 175 6 0.9333
## 176 6 0.9333
## 177 6 0.9333
## 178 6 0.9333
## 179 6 0.9333
## 180 6 1.0000
이를 g값마다 묶어 평균을 구한다.
library(plyr)
##
## Attaching package: 'plyr'
##
## The following object is masked from 'package:modeltools':
##
## empty
ddply(result, .(g), summarize, mean_precision = mean(precision))
## g mean_precision
## 1 1 0.9444
## 2 2 0.9533
## 3 3 0.9533
## 4 4 0.9556
## 5 5 0.9533
## 6 6 0.9556
ddply()수행결과 가장 높은 성능을 보인 조합은 4번째 와 6번째이다.
grid[c(4, 6), ]
## ntree mtry
## 4 10 4
## 6 200 4