Classification

이 예제에서는 트리모형을 중심으로 코드 소개를 한다.
따라서 아래의 패키지를 설치할 필요가 있다.

install.packages(c("party", "rpart", "randomForest"))

설명도 역시 같은 순서로 진행된다.

party

Conditional Inference Tree를 이용한 분류

library(party)

str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))

# split test, train set
trainData <- iris[ind == 1, ]
testData <- iris[ind == 2, ]

# variable relation
myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width

iris_ctree <- ctree(myFormula, data = trainData)

table(predict(iris_ctree), trainData$Species)
##             
##              setosa versicolor virginica
##   setosa         39          0         0
##   versicolor      0         27         2
##   virginica       0          1        39

print(iris_ctree)
## 
##   Conditional inference tree with 4 terminal nodes
## 
## Response:  Species 
## Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
## Number of observations:  108 
## 
## 1) Petal.Length <= 1.9; criterion = 1, statistic = 102.182
##   2)*  weights = 39 
## 1) Petal.Length > 1.9
##   3) Petal.Width <= 1.6; criterion = 1, statistic = 45.874
##     4) Petal.Length <= 4.6; criterion = 0.993, statistic = 9.826
##       5)*  weights = 22 
##     4) Petal.Length > 4.6
##       6)*  weights = 7 
##   3) Petal.Width > 1.6
##     7)*  weights = 40

Conditional Inference Tree

plot(iris_ctree)

plot of chunk unnamed-chunk-1


plot(iris_ctree, type = "simple")

plot of chunk unnamed-chunk-1

테스트 데이터 예측

testPred <- predict(iris_ctree, newdata = testData)

table(testPred, testData$Species)
##             
## testPred     setosa versicolor virginica
##   setosa         11          0         0
##   versicolor      0         21         2
##   virginica       0          1         7

rpart를 활용한 예측

전처리, EDA

data("bodyfat", package = "mboost")

dim(bodyfat)
## [1] 71 10

attributes(bodyfat)
## $names
##  [1] "age"          "DEXfat"       "waistcirc"    "hipcirc"     
##  [5] "elbowbreadth" "kneebreadth"  "anthro3a"     "anthro3b"    
##  [9] "anthro3c"     "anthro4"     
## 
## $row.names
##  [1] "47"  "48"  "49"  "50"  "51"  "52"  "53"  "54"  "55"  "56"  "57" 
## [12] "58"  "59"  "60"  "61"  "62"  "63"  "64"  "65"  "66"  "67"  "68" 
## [23] "69"  "70"  "71"  "72"  "73"  "74"  "75"  "76"  "77"  "78"  "79" 
## [34] "80"  "81"  "82"  "83"  "84"  "85"  "86"  "87"  "88"  "89"  "90" 
## [45] "91"  "92"  "93"  "94"  "95"  "96"  "97"  "98"  "99"  "100" "101"
## [56] "102" "103" "104" "105" "106" "107" "108" "109" "110" "111" "112"
## [67] "113" "114" "115" "116" "117"
## 
## $class
## [1] "data.frame"

bodyfat[1:5, ]
##    age DEXfat waistcirc hipcirc elbowbreadth kneebreadth anthro3a anthro3b
## 47  57  41.68     100.0   112.0          7.1         9.4     4.42     4.95
## 48  65  43.29      99.5   116.5          6.5         8.9     4.63     5.01
## 49  59  35.41      96.0   108.5          6.2         8.9     4.12     4.74
## 50  58  22.79      72.0    96.5          6.1         9.2     4.03     4.48
## 51  60  36.42      89.5   100.5          7.1        10.0     4.24     4.68
##    anthro3c anthro4
## 47     4.50    6.13
## 48     4.48    6.37
## 49     4.60    5.82
## 50     3.91    5.66
## 51     4.15    5.91

set.seed(123)

# divide to train, test sets
ind <- sample(2, nrow(bodyfat), replace = TRUE, prob = c(0.7, 0.3))
bodyfat.train <- bodyfat[ind == 1, ]
bodyfat.test <- bodyfat[ind == 2, ]


library(rpart)

myFormula <- DEXfat ~ age + waistcirc + hipcirc + elbowbreadth + kneebreadth

bodyfat_rpart <- rpart(myFormula, data = bodyfat.train, control = rpart.control(minsplit = 10))

트리 플로팅

plot(bodyfat_rpart)

text(bodyfat_rpart, use.n = TRUE)

plot of chunk unnamed-chunk-2

# Gives a visual representation of the cross-validation results in an
# rpart object.
plotcp(bodyfat_rpart)

plot of chunk unnamed-chunk-3

에러를 최소화 하는 최적 CP값을 구한다.

opt <- which.min(bodyfat_rpart$cptable[, "xerror"])

cp <- bodyfat_rpart$cptable[opt, "CP"]

bodyfat_prune <- prune(bodyfat_rpart, cp = cp)

print(bodyfat_prune)
## n= 48 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 48 6298.00 31.20  
##    2) waistcirc< 88.4 28  686.90 23.10  
##      4) waistcirc< 71.5 6   77.12 16.50 *
##      5) waistcirc>=71.5 22  277.10 24.90  
##       10) hipcirc< 99.65 12   68.28 22.87 *
##       11) hipcirc>=99.65 10  100.40 27.33 *
##    3) waistcirc>=88.4 20 1199.00 42.55  
##      6) kneebreadth< 11.1 17  370.60 40.09  
##       12) hipcirc< 109.9 7   75.37 35.72 *
##       13) hipcirc>=109.9 10   67.74 43.15 *
##      7) kneebreadth>=11.1 3  146.30 56.45 *

DEXfat_pred <- predict(bodyfat_prune, newdata = bodyfat.test)

xlim <- range(bodyfat$DEXfat)

예측값과 정답을 플로팅 한다.

plot(DEXfat_pred ~ DEXfat, data = bodyfat.test, xlab = "Observed", ylab = "Predicted", 
    ylim = xlim, xlim = xlim)

abline(a = 0, b = 1)

plot of chunk unnamed-chunk-4

randomForest 를 활용한 예측

전처리 - 학습셋과 테스트셋 분류

모델 생성


ind <- sample(2, nrow(iris), replace = TRUE, prob = c(0.7, 0.3))

trainData <- iris[ind == 1, ]

testData <- iris[ind == 2, ]


library(randomForest)

rf <- randomForest(Species ~ ., data = trainData, ntree = 100, proximity = TRUE)

table(predict(rf), trainData$Species)
##             
##              setosa versicolor virginica
##   setosa         40          0         0
##   versicolor      0         35         3
##   virginica       0          2        32

print(rf)
## 
## Call:
##  randomForest(formula = Species ~ ., data = trainData, ntree = 100,      proximity = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 100
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 4.46%
## Confusion matrix:
##            setosa versicolor virginica class.error
## setosa         40          0         0     0.00000
## versicolor      0         35         2     0.05405
## virginica       0          3        32     0.08571

attributes(rf)
## $names
##  [1] "call"            "type"            "predicted"      
##  [4] "err.rate"        "confusion"       "votes"          
##  [7] "oob.times"       "classes"         "importance"     
## [10] "importanceSD"    "localImportance" "proximity"      
## [13] "ntree"           "mtry"            "forest"         
## [16] "y"               "test"            "inbag"          
## [19] "terms"          
## 
## $class
## [1] "randomForest.formula" "randomForest"

(rf$err.rate)
##            OOB setosa versicolor virginica
##   [1,] 0.08889      0    0.07143   0.20000
##   [2,] 0.06061      0    0.04762   0.16667
##   [3,] 0.04878      0    0.07692   0.08696
##   [4,] 0.04255      0    0.06897   0.07143
##   [5,] 0.03846      0    0.06061   0.06061
##   [6,] 0.04673      0    0.05882   0.08824
##   [7,] 0.04587      0    0.05556   0.08824
##   [8,] 0.05405      0    0.05405   0.11765
##   [9,] 0.05405      0    0.05405   0.11765
##  [10,] 0.05405      0    0.08108   0.08824
##  [11,] 0.05357      0    0.08108   0.08571
##  [12,] 0.04464      0    0.05405   0.08571
##  [13,] 0.05357      0    0.08108   0.08571
##  [14,] 0.05357      0    0.08108   0.08571
##  [15,] 0.05357      0    0.08108   0.08571
##  [16,] 0.05357      0    0.08108   0.08571
##  [17,] 0.05357      0    0.08108   0.08571
##  [18,] 0.05357      0    0.08108   0.08571
##  [19,] 0.05357      0    0.08108   0.08571
##  [20,] 0.05357      0    0.08108   0.08571
##  [21,] 0.05357      0    0.08108   0.08571
##  [22,] 0.05357      0    0.08108   0.08571
##  [23,] 0.05357      0    0.08108   0.08571
##  [24,] 0.05357      0    0.08108   0.08571
##  [25,] 0.05357      0    0.08108   0.08571
##  [26,] 0.05357      0    0.08108   0.08571
##  [27,] 0.05357      0    0.08108   0.08571
##  [28,] 0.06250      0    0.08108   0.11429
##  [29,] 0.06250      0    0.08108   0.11429
##  [30,] 0.06250      0    0.08108   0.11429
##  [31,] 0.05357      0    0.08108   0.08571
##  [32,] 0.05357      0    0.08108   0.08571
##  [33,] 0.05357      0    0.08108   0.08571
##  [34,] 0.05357      0    0.08108   0.08571
##  [35,] 0.05357      0    0.08108   0.08571
##  [36,] 0.04464      0    0.05405   0.08571
##  [37,] 0.05357      0    0.08108   0.08571
##  [38,] 0.04464      0    0.05405   0.08571
##  [39,] 0.04464      0    0.05405   0.08571
##  [40,] 0.04464      0    0.05405   0.08571
##  [41,] 0.04464      0    0.05405   0.08571
##  [42,] 0.05357      0    0.08108   0.08571
##  [43,] 0.05357      0    0.08108   0.08571
##  [44,] 0.05357      0    0.08108   0.08571
##  [45,] 0.05357      0    0.08108   0.08571
##  [46,] 0.05357      0    0.08108   0.08571
##  [47,] 0.04464      0    0.05405   0.08571
##  [48,] 0.05357      0    0.08108   0.08571
##  [49,] 0.04464      0    0.05405   0.08571
##  [50,] 0.04464      0    0.05405   0.08571
##  [51,] 0.05357      0    0.08108   0.08571
##  [52,] 0.05357      0    0.08108   0.08571
##  [53,] 0.05357      0    0.08108   0.08571
##  [54,] 0.04464      0    0.05405   0.08571
##  [55,] 0.04464      0    0.05405   0.08571
##  [56,] 0.05357      0    0.08108   0.08571
##  [57,] 0.05357      0    0.08108   0.08571
##  [58,] 0.05357      0    0.08108   0.08571
##  [59,] 0.04464      0    0.05405   0.08571
##  [60,] 0.04464      0    0.05405   0.08571
##  [61,] 0.04464      0    0.05405   0.08571
##  [62,] 0.04464      0    0.05405   0.08571
##  [63,] 0.04464      0    0.05405   0.08571
##  [64,] 0.04464      0    0.05405   0.08571
##  [65,] 0.05357      0    0.08108   0.08571
##  [66,] 0.05357      0    0.08108   0.08571
##  [67,] 0.04464      0    0.05405   0.08571
##  [68,] 0.04464      0    0.05405   0.08571
##  [69,] 0.05357      0    0.08108   0.08571
##  [70,] 0.05357      0    0.08108   0.08571
##  [71,] 0.05357      0    0.08108   0.08571
##  [72,] 0.04464      0    0.05405   0.08571
##  [73,] 0.04464      0    0.05405   0.08571
##  [74,] 0.04464      0    0.05405   0.08571
##  [75,] 0.04464      0    0.05405   0.08571
##  [76,] 0.04464      0    0.05405   0.08571
##  [77,] 0.04464      0    0.05405   0.08571
##  [78,] 0.04464      0    0.05405   0.08571
##  [79,] 0.04464      0    0.05405   0.08571
##  [80,] 0.04464      0    0.05405   0.08571
##  [81,] 0.04464      0    0.05405   0.08571
##  [82,] 0.04464      0    0.05405   0.08571
##  [83,] 0.04464      0    0.05405   0.08571
##  [84,] 0.04464      0    0.05405   0.08571
##  [85,] 0.04464      0    0.05405   0.08571
##  [86,] 0.04464      0    0.05405   0.08571
##  [87,] 0.04464      0    0.05405   0.08571
##  [88,] 0.04464      0    0.05405   0.08571
##  [89,] 0.04464      0    0.05405   0.08571
##  [90,] 0.04464      0    0.05405   0.08571
##  [91,] 0.04464      0    0.05405   0.08571
##  [92,] 0.04464      0    0.05405   0.08571
##  [93,] 0.04464      0    0.05405   0.08571
##  [94,] 0.04464      0    0.05405   0.08571
##  [95,] 0.04464      0    0.05405   0.08571
##  [96,] 0.04464      0    0.05405   0.08571
##  [97,] 0.04464      0    0.05405   0.08571
##  [98,] 0.04464      0    0.05405   0.08571
##  [99,] 0.04464      0    0.05405   0.08571
## [100,] 0.04464      0    0.05405   0.08571

변수 중요도 플로팅- 변수 선택을 위함

plot(rf)

plot of chunk unnamed-chunk-6


importance(rf)
##              MeanDecreaseGini
## Sepal.Length            7.365
## Sepal.Width             1.582
## Petal.Length           29.313
## Petal.Width            35.675

varImpPlot(rf)

plot of chunk unnamed-chunk-6

  1. 예측값 출력
  2. Margin 플로팅

정답 클래스를 예측한 트리의 비율 - 오답 클래스를 예측한 트리의 비율

irisPred <- predict(rf, newdata = testData)

table(irisPred, testData$Species)
##             
## irisPred     setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0         12         0
##   virginica       0          1        15

plot(margin(rf, trainData$Species))

plot of chunk unnamed-chunk-7

Reference