이 예제에서는 트리모형을 중심으로 코드 소개를 한다.
따라서 아래의 패키지를 설치할 필요가 있다.
install.packages(c("party", "rpart", "randomForest"))
설명도 역시 같은 순서로 진행된다.
Conditional Inference Tree를 이용한 분류
library(party)
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))
# split test, train set
trainData <- iris[ind == 1, ]
testData <- iris[ind == 2, ]
# variable relation
myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
iris_ctree <- ctree(myFormula, data = trainData)
table(predict(iris_ctree), trainData$Species)
##
## setosa versicolor virginica
## setosa 39 0 0
## versicolor 0 27 2
## virginica 0 1 39
print(iris_ctree)
##
## Conditional inference tree with 4 terminal nodes
##
## Response: Species
## Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
## Number of observations: 108
##
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 102.182
## 2)* weights = 39
## 1) Petal.Length > 1.9
## 3) Petal.Width <= 1.6; criterion = 1, statistic = 45.874
## 4) Petal.Length <= 4.6; criterion = 0.993, statistic = 9.826
## 5)* weights = 22
## 4) Petal.Length > 4.6
## 6)* weights = 7
## 3) Petal.Width > 1.6
## 7)* weights = 40
Conditional Inference Tree
plot(iris_ctree)
plot(iris_ctree, type = "simple")
테스트 데이터 예측
testPred <- predict(iris_ctree, newdata = testData)
table(testPred, testData$Species)
##
## testPred setosa versicolor virginica
## setosa 11 0 0
## versicolor 0 21 2
## virginica 0 1 7
전처리, EDA
data("bodyfat", package = "mboost")
dim(bodyfat)
## [1] 71 10
attributes(bodyfat)
## $names
## [1] "age" "DEXfat" "waistcirc" "hipcirc"
## [5] "elbowbreadth" "kneebreadth" "anthro3a" "anthro3b"
## [9] "anthro3c" "anthro4"
##
## $row.names
## [1] "47" "48" "49" "50" "51" "52" "53" "54" "55" "56" "57"
## [12] "58" "59" "60" "61" "62" "63" "64" "65" "66" "67" "68"
## [23] "69" "70" "71" "72" "73" "74" "75" "76" "77" "78" "79"
## [34] "80" "81" "82" "83" "84" "85" "86" "87" "88" "89" "90"
## [45] "91" "92" "93" "94" "95" "96" "97" "98" "99" "100" "101"
## [56] "102" "103" "104" "105" "106" "107" "108" "109" "110" "111" "112"
## [67] "113" "114" "115" "116" "117"
##
## $class
## [1] "data.frame"
bodyfat[1:5, ]
## age DEXfat waistcirc hipcirc elbowbreadth kneebreadth anthro3a anthro3b
## 47 57 41.68 100.0 112.0 7.1 9.4 4.42 4.95
## 48 65 43.29 99.5 116.5 6.5 8.9 4.63 5.01
## 49 59 35.41 96.0 108.5 6.2 8.9 4.12 4.74
## 50 58 22.79 72.0 96.5 6.1 9.2 4.03 4.48
## 51 60 36.42 89.5 100.5 7.1 10.0 4.24 4.68
## anthro3c anthro4
## 47 4.50 6.13
## 48 4.48 6.37
## 49 4.60 5.82
## 50 3.91 5.66
## 51 4.15 5.91
set.seed(123)
# divide to train, test sets
ind <- sample(2, nrow(bodyfat), replace = TRUE, prob = c(0.7, 0.3))
bodyfat.train <- bodyfat[ind == 1, ]
bodyfat.test <- bodyfat[ind == 2, ]
library(rpart)
myFormula <- DEXfat ~ age + waistcirc + hipcirc + elbowbreadth + kneebreadth
bodyfat_rpart <- rpart(myFormula, data = bodyfat.train, control = rpart.control(minsplit = 10))
트리 플로팅
plot(bodyfat_rpart)
text(bodyfat_rpart, use.n = TRUE)
# Gives a visual representation of the cross-validation results in an
# rpart object.
plotcp(bodyfat_rpart)
에러를 최소화 하는 최적 CP값을 구한다.
opt <- which.min(bodyfat_rpart$cptable[, "xerror"])
cp <- bodyfat_rpart$cptable[opt, "CP"]
bodyfat_prune <- prune(bodyfat_rpart, cp = cp)
print(bodyfat_prune)
## n= 48
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 48 6298.00 31.20
## 2) waistcirc< 88.4 28 686.90 23.10
## 4) waistcirc< 71.5 6 77.12 16.50 *
## 5) waistcirc>=71.5 22 277.10 24.90
## 10) hipcirc< 99.65 12 68.28 22.87 *
## 11) hipcirc>=99.65 10 100.40 27.33 *
## 3) waistcirc>=88.4 20 1199.00 42.55
## 6) kneebreadth< 11.1 17 370.60 40.09
## 12) hipcirc< 109.9 7 75.37 35.72 *
## 13) hipcirc>=109.9 10 67.74 43.15 *
## 7) kneebreadth>=11.1 3 146.30 56.45 *
DEXfat_pred <- predict(bodyfat_prune, newdata = bodyfat.test)
xlim <- range(bodyfat$DEXfat)
예측값과 정답을 플로팅 한다.
plot(DEXfat_pred ~ DEXfat, data = bodyfat.test, xlab = "Observed", ylab = "Predicted",
ylim = xlim, xlim = xlim)
abline(a = 0, b = 1)
전처리 - 학습셋과 테스트셋 분류
모델 생성
ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))
trainData <- iris[ind == 1, ]
testData <- iris[ind == 2, ]
library(randomForest)
rf <- randomForest(Species ~ ., data = trainData, ntree = 100, proximity = TRUE)
table(predict(rf), trainData$Species)
##
## setosa versicolor virginica
## setosa 40 0 0
## versicolor 0 35 3
## virginica 0 2 32
print(rf)
##
## Call:
## randomForest(formula = Species ~ ., data = trainData, ntree = 100, proximity = TRUE)
## Type of random forest: classification
## Number of trees: 100
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4.46%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 40 0 0 0.00000
## versicolor 0 35 2 0.05405
## virginica 0 3 32 0.08571
attributes(rf)
## $names
## [1] "call" "type" "predicted"
## [4] "err.rate" "confusion" "votes"
## [7] "oob.times" "classes" "importance"
## [10] "importanceSD" "localImportance" "proximity"
## [13] "ntree" "mtry" "forest"
## [16] "y" "test" "inbag"
## [19] "terms"
##
## $class
## [1] "randomForest.formula" "randomForest"
(rf$err.rate)
## OOB setosa versicolor virginica
## [1,] 0.08889 0 0.07143 0.20000
## [2,] 0.06061 0 0.04762 0.16667
## [3,] 0.04878 0 0.07692 0.08696
## [4,] 0.04255 0 0.06897 0.07143
## [5,] 0.03846 0 0.06061 0.06061
## [6,] 0.04673 0 0.05882 0.08824
## [7,] 0.04587 0 0.05556 0.08824
## [8,] 0.05405 0 0.05405 0.11765
## [9,] 0.05405 0 0.05405 0.11765
## [10,] 0.05405 0 0.08108 0.08824
## [11,] 0.05357 0 0.08108 0.08571
## [12,] 0.04464 0 0.05405 0.08571
## [13,] 0.05357 0 0.08108 0.08571
## [14,] 0.05357 0 0.08108 0.08571
## [15,] 0.05357 0 0.08108 0.08571
## [16,] 0.05357 0 0.08108 0.08571
## [17,] 0.05357 0 0.08108 0.08571
## [18,] 0.05357 0 0.08108 0.08571
## [19,] 0.05357 0 0.08108 0.08571
## [20,] 0.05357 0 0.08108 0.08571
## [21,] 0.05357 0 0.08108 0.08571
## [22,] 0.05357 0 0.08108 0.08571
## [23,] 0.05357 0 0.08108 0.08571
## [24,] 0.05357 0 0.08108 0.08571
## [25,] 0.05357 0 0.08108 0.08571
## [26,] 0.05357 0 0.08108 0.08571
## [27,] 0.05357 0 0.08108 0.08571
## [28,] 0.06250 0 0.08108 0.11429
## [29,] 0.06250 0 0.08108 0.11429
## [30,] 0.06250 0 0.08108 0.11429
## [31,] 0.05357 0 0.08108 0.08571
## [32,] 0.05357 0 0.08108 0.08571
## [33,] 0.05357 0 0.08108 0.08571
## [34,] 0.05357 0 0.08108 0.08571
## [35,] 0.05357 0 0.08108 0.08571
## [36,] 0.04464 0 0.05405 0.08571
## [37,] 0.05357 0 0.08108 0.08571
## [38,] 0.04464 0 0.05405 0.08571
## [39,] 0.04464 0 0.05405 0.08571
## [40,] 0.04464 0 0.05405 0.08571
## [41,] 0.04464 0 0.05405 0.08571
## [42,] 0.05357 0 0.08108 0.08571
## [43,] 0.05357 0 0.08108 0.08571
## [44,] 0.05357 0 0.08108 0.08571
## [45,] 0.05357 0 0.08108 0.08571
## [46,] 0.05357 0 0.08108 0.08571
## [47,] 0.04464 0 0.05405 0.08571
## [48,] 0.05357 0 0.08108 0.08571
## [49,] 0.04464 0 0.05405 0.08571
## [50,] 0.04464 0 0.05405 0.08571
## [51,] 0.05357 0 0.08108 0.08571
## [52,] 0.05357 0 0.08108 0.08571
## [53,] 0.05357 0 0.08108 0.08571
## [54,] 0.04464 0 0.05405 0.08571
## [55,] 0.04464 0 0.05405 0.08571
## [56,] 0.05357 0 0.08108 0.08571
## [57,] 0.05357 0 0.08108 0.08571
## [58,] 0.05357 0 0.08108 0.08571
## [59,] 0.04464 0 0.05405 0.08571
## [60,] 0.04464 0 0.05405 0.08571
## [61,] 0.04464 0 0.05405 0.08571
## [62,] 0.04464 0 0.05405 0.08571
## [63,] 0.04464 0 0.05405 0.08571
## [64,] 0.04464 0 0.05405 0.08571
## [65,] 0.05357 0 0.08108 0.08571
## [66,] 0.05357 0 0.08108 0.08571
## [67,] 0.04464 0 0.05405 0.08571
## [68,] 0.04464 0 0.05405 0.08571
## [69,] 0.05357 0 0.08108 0.08571
## [70,] 0.05357 0 0.08108 0.08571
## [71,] 0.05357 0 0.08108 0.08571
## [72,] 0.04464 0 0.05405 0.08571
## [73,] 0.04464 0 0.05405 0.08571
## [74,] 0.04464 0 0.05405 0.08571
## [75,] 0.04464 0 0.05405 0.08571
## [76,] 0.04464 0 0.05405 0.08571
## [77,] 0.04464 0 0.05405 0.08571
## [78,] 0.04464 0 0.05405 0.08571
## [79,] 0.04464 0 0.05405 0.08571
## [80,] 0.04464 0 0.05405 0.08571
## [81,] 0.04464 0 0.05405 0.08571
## [82,] 0.04464 0 0.05405 0.08571
## [83,] 0.04464 0 0.05405 0.08571
## [84,] 0.04464 0 0.05405 0.08571
## [85,] 0.04464 0 0.05405 0.08571
## [86,] 0.04464 0 0.05405 0.08571
## [87,] 0.04464 0 0.05405 0.08571
## [88,] 0.04464 0 0.05405 0.08571
## [89,] 0.04464 0 0.05405 0.08571
## [90,] 0.04464 0 0.05405 0.08571
## [91,] 0.04464 0 0.05405 0.08571
## [92,] 0.04464 0 0.05405 0.08571
## [93,] 0.04464 0 0.05405 0.08571
## [94,] 0.04464 0 0.05405 0.08571
## [95,] 0.04464 0 0.05405 0.08571
## [96,] 0.04464 0 0.05405 0.08571
## [97,] 0.04464 0 0.05405 0.08571
## [98,] 0.04464 0 0.05405 0.08571
## [99,] 0.04464 0 0.05405 0.08571
## [100,] 0.04464 0 0.05405 0.08571
변수 중요도 플로팅- 변수 선택을 위함
plot(rf)
importance(rf)
## MeanDecreaseGini
## Sepal.Length 7.365
## Sepal.Width 1.582
## Petal.Length 29.313
## Petal.Width 35.675
varImpPlot(rf)
정답 클래스를 예측한 트리의 비율 - 오답 클래스를 예측한 트리의 비율
irisPred <- predict(rf, newdata = testData)
table(irisPred, testData$Species)
##
## irisPred setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 0
## virginica 0 1 15
plot(margin(rf, trainData$Species))