[14-03-14]

어떤 분류에 속하는지 알려진 훈련 데이터를 사용해 모델을 훈련시키고, 이 모델을 사용해 새로운 관찰값의 분류를 예측하는 방법을 분류 알고리즘이라고 말한다.

1. 데이터 탐색

1.1 기술 통계

# doBy::summary

library(doBy)
## Loading required package: survival
## Loading required package: splines
## Loading required package: MASS
summary(iris)
##   Sepal.Length   Sepal.Width    Petal.Length   Petal.Width 
##  Min.   :4.30   Min.   :2.00   Min.   :1.00   Min.   :0.1  
##  1st Qu.:5.10   1st Qu.:2.80   1st Qu.:1.60   1st Qu.:0.3  
##  Median :5.80   Median :3.00   Median :4.35   Median :1.3  
##  Mean   :5.84   Mean   :3.06   Mean   :3.76   Mean   :1.2  
##  3rd Qu.:6.40   3rd Qu.:3.30   3rd Qu.:5.10   3rd Qu.:1.8  
##  Max.   :7.90   Max.   :4.40   Max.   :6.90   Max.   :2.5  
##        Species  
##  setosa    :50  
##  versicolor:50  
##  virginica :50  
##                 
##                 
## 

# Hmisc::describe()

library(Hmisc)
## Loading required package: grid
## Loading required package: lattice
## Warning: package 'lattice' was built under R version 3.0.3
## Loading required package: Formula
## 
## Attaching package: 'Hmisc'
## 
## The following objects are masked from 'package:base':
## 
##     format.pval, round.POSIXt, trunc.POSIXt, units
describe(iris)
## iris 
## 
##  5  Variables      150  Observations
## ---------------------------------------------------------------------------
## Sepal.Length 
##       n missing  unique    Mean     .05     .10     .25     .50     .75 
##     150       0      35   5.843   4.600   4.800   5.100   5.800   6.400 
##     .90     .95 
##   6.900   7.255 
## 
## lowest : 4.3 4.4 4.5 4.6 4.7, highest: 7.3 7.4 7.6 7.7 7.9 
## ---------------------------------------------------------------------------
## Sepal.Width 
##       n missing  unique    Mean     .05     .10     .25     .50     .75 
##     150       0      23   3.057   2.345   2.500   2.800   3.000   3.300 
##     .90     .95 
##   3.610   3.800 
## 
## lowest : 2.0 2.2 2.3 2.4 2.5, highest: 3.9 4.0 4.1 4.2 4.4 
## ---------------------------------------------------------------------------
## Petal.Length 
##       n missing  unique    Mean     .05     .10     .25     .50     .75 
##     150       0      43   3.758    1.30    1.40    1.60    4.35    5.10 
##     .90     .95 
##    5.80    6.10 
## 
## lowest : 1.0 1.1 1.2 1.3 1.4, highest: 6.3 6.4 6.6 6.7 6.9 
## ---------------------------------------------------------------------------
## Petal.Width 
##       n missing  unique    Mean     .05     .10     .25     .50     .75 
##     150       0      22   1.199     0.2     0.2     0.3     1.3     1.8 
##     .90     .95 
##     2.2     2.3 
## 
## lowest : 0.1 0.2 0.3 0.4 0.5, highest: 2.1 2.2 2.3 2.4 2.5 
## ---------------------------------------------------------------------------
## Species 
##       n missing  unique 
##     150       0       3 
## 
## setosa (50, 33%), versicolor (50, 33%) 
## virginica (50, 33%) 
## ---------------------------------------------------------------------------

describe()함수는 결측치의 존재 및 서로 다른 값의 수를 알려주는 점이 편리하다.

1.2 데이터 시각화

plot(iris)  # 산점도 행렬을 플롯함

plot of chunk unnamed-chunk-2

plot(iris$Sepal.Length)  # 2차원 산점도를 플롯함

plot of chunk unnamed-chunk-2

plot(iris$Species)  # Species별 도수를 나타내줌

plot of chunk unnamed-chunk-2

그래프에 표시된 데이터가 붓꽃의 어느 종별인지를 표현하기 위해 formula를 사용한다.

plot(iris$Species ~ iris$Sepal.Length, data = iris)

plot of chunk unnamed-chunk-3

산점도 상의 점 색상을 다르게 지정할 수도 있다.

plot(iris$Sepal.Length, col = as.numeric(iris$Species))

plot of chunk unnamed-chunk-4

caret패키지에는 featurePlot()함수가 있는데, 더욱 편하게 그려준다.
featurePlot()은 인자로, X와 Y를 받아 Y에 따라 분리해 표현해준다.
예를 들면 다음과 같다.

library(caret)
## Warning: package 'caret' was built under R version 3.0.3
## Loading required package: ggplot2
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:survival':
## 
##     cluster
library(ellipse)
## Warning: package 'ellipse' was built under R version 3.0.3
featurePlot(iris[, 1:4], iris$Species, "ellipse")

plot of chunk unnamed-chunk-5

위 그림을 그리기 위해서는 ellipse패키지를 설치하고 로딩해야만 한다.

2. 전처리(Preprocessing)

2.1 데이터 변환

데이터 정규화(Feature Scaling)

head(cbind(as.data.frame(scale(iris[, 1:4])), iris$Species))
##   Sepal.Length Sepal.Width Petal.Length Petal.Width iris$Species
## 1      -0.8977     1.01560       -1.336      -1.311       setosa
## 2      -1.1392    -0.13154       -1.336      -1.311       setosa
## 3      -1.3807     0.32732       -1.392      -1.311       setosa
## 4      -1.5015     0.09789       -1.279      -1.311       setosa
## 5      -1.0184     1.24503       -1.336      -1.311       setosa
## 6      -0.5354     1.93331       -1.166      -1.049       setosa

PCA(Principal Component Analysis, 주성분 분석)

p <- princomp(iris[, 1:4], cor = TRUE)

summary()를 통해 주성분들이 데이터의 분산 중 얼마만큼을 설명해 주는지 알 수 있다.

summary(p)
## Importance of components:
##                        Comp.1 Comp.2  Comp.3   Comp.4
## Standard deviation     1.7084 0.9560 0.38309 0.143926
## Proportion of Variance 0.7296 0.2285 0.03669 0.005179
## Cumulative Proportion  0.7296 0.9581 0.99482 1.000000
plot(p, type = "l")

plot of chunk unnamed-chunk-8

“Proportion of Variance"를 보면 comp.1(PC1)은 데이터의 분산중 72.96%를 설명함을 알 수 있다.
Scree도표를 통해 더 쉽게 알 수 있다.

head(predict(p, iris[, 1:4]))
##      Comp.1  Comp.2   Comp.3    Comp.4
## [1,] -2.265 -0.4800  0.12771  0.024168
## [2,] -2.081  0.6741  0.23461  0.103007
## [3,] -2.364  0.3419 -0.04420  0.028377
## [4,] -2.299  0.5974 -0.09129 -0.065956
## [5,] -2.390 -0.6468 -0.01574 -0.035923
## [6,] -2.076 -1.4892 -0.02697  0.006608

범주형 변수의 재표현

많은 수의 수준을 가진 범주형 변수를 포함한 데이터를 작성해보자.

(all <- factor(c(paste0(LETTERS, "0"), paste0(LETTERS, "1"))))
##  [1] A0 B0 C0 D0 E0 F0 G0 H0 I0 J0 K0 L0 M0 N0 O0 P0 Q0 R0 S0 T0 U0 V0 W0
## [24] X0 Y0 Z0 A1 B1 C1 D1 E1 F1 G1 H1 I1 J1 K1 L1 M1 N1 O1 P1 Q1 R1 S1 T1
## [47] U1 V1 W1 X1 Y1 Z1
## 52 Levels: A0 A1 B0 B1 C0 C1 D0 D1 E0 E1 F0 F1 G0 G1 H0 H1 I0 I1 J0 ... Z1

paste0를 이용해 A0부터 Z0까지를 손쉽게 만들었다.

(data <- data.frame(lvl = all, value = rnorm(length(all))))
##    lvl    value
## 1   A0 -2.49189
## 2   B0  0.61688
## 3   C0 -0.87622
## 4   D0  0.44166
## 5   E0 -1.67334
## 6   F0  0.15690
## 7   G0 -0.20873
## 8   H0  0.05312
## 9   I0 -1.10307
## 10  J0 -0.38098
## 11  K0 -1.17965
## 12  L0  0.75465
## 13  M0 -0.70436
## 14  N0  0.36806
## 15  O0 -0.73521
## 16  P0 -0.85985
## 17  Q0  0.84881
## 18  R0 -0.08224
## 19  S0  0.82510
## 20  T0  0.22591
## 21  U0 -0.16411
## 22  V0  1.55847
## 23  W0  0.20557
## 24  X0 -1.31290
## 25  Y0  1.92093
## 26  Z0  0.70008
## 27  A1  2.43553
## 28  B1 -1.15030
## 29  C1  0.84308
## 30  D1  0.09538
## 31  E1  0.88024
## 32  F1 -0.51014
## 33  G1  0.22830
## 34  H1 -1.51062
## 35  I1  0.33266
## 36  J1  0.23863
## 37  K1 -0.40896
## 38  L1 -0.38910
## 39  M1 -0.96756
## 40  N1 -0.26465
## 41  O1 -0.99411
## 42  P1  1.90090
## 43  Q1 -0.07783
## 44  R1 -0.08926
## 45  S1  0.45586
## 46  T1 -0.78917
## 47  U1  0.07447
## 48  V1 -2.23557
## 49  W1  1.65267
## 50  X1  1.25777
## 51  Y1  1.18257
## 52  Z1  0.09437

Random Forest는 분류와 회귀분석에 주로 이용되는 Breiman's random forest algorithm을 수행하기 위한 패키지이다.

library(randomForest)
## Warning: package 'randomForest' was built under R version 3.0.3
## randomForest 4.6-7
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## 
## The following object is masked from 'package:Hmisc':
## 
##     combine
m <- randomForest(value ~ lvl, data = data)
## Error: Can not handle categorical predictors with more than 32 categories.

이 데이터를 Random Forest에 입력으로 주면 에러 메시지가 출력된다. 최대 32개의 범주만 다룰 수 있기 때문이다. 에러를 해결하기 위해 발생 빈도가 적은 수준들을 하나로 묶거나, 범주형 변수의 수준을 숫자로 취급하면 된다.

또다른 방법은 여러개의 가변수를 사용해 범주형 변수를 재표현 하는 것이다. model.matrix()를 사용해 재표현 할 수 있다.

(x <- data.frame(lvl = factor(c("A", "B", "A", "A", "C")), value = c(1, 3, 2, 
    4, 5)))
##   lvl value
## 1   A     1
## 2   B     3
## 3   A     2
## 4   A     4
## 5   C     5
model.matrix(~lvl, data = x)
##   (Intercept) lvlB lvlC
## 1           1    0    0
## 2           1    1    0
## 3           1    0    0
## 4           1    0    0
## 5           1    0    1
## attr(,"assign")
## [1] 0 1 1
## attr(,"contrasts")
## attr(,"contrasts")$lvl
## [1] "contr.treatment"
model.matrix(~lvl, data = x)[, -1]
##   lvlB lvlC
## 1    0    0
## 2    1    0
## 3    0    0
## 4    0    0
## 5    0    1

model.matrix()의 인자를 보면 data=x를 전부 사용하는 것이 아니라 lvl이라는 변수만 인자로 넘겨지도록 formula형식을 썼다.
'A', 'B', 'C'의 3개 수준을 저장한 lvl이라는 Factor를 3개의 컬럼으로 재표현한 예이다.
A는 (0,0), B는 (1,0), C는 (0,1)로 변환되었음을 알 수 있다.

2.2 결측값(NA)의 처리

iris_na <- iris
iris_na[c(10, 20, 25, 40, 32), 3] <- NA
iris_na[c(33, 100, 123), 1] <- NA

우선 결측치를 임의적으로 생성해보았다. 3열, 즉 Petal.Length의 10,20,25,40,32번째 값과, Sepal.Length의 33,100,123번째 값을 결측값으로 만들었다.

iris_na[!complete.cases(iris_na), ]
##     Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
## 10           4.9         3.1           NA         0.1     setosa
## 20           5.1         3.8           NA         0.3     setosa
## 25           4.8         3.4           NA         0.2     setosa
## 32           5.4         3.4           NA         0.4     setosa
## 33            NA         4.1          1.5         0.1     setosa
## 40           5.1         3.4           NA         0.2     setosa
## 100           NA         2.8          4.1         1.3 versicolor
## 123           NA         2.8          6.7         2.0  virginica

complete.cases()는 데이터 프레임의 각 행마다 적용하며, 각 행에 저장된 모든 값이 NA가 아닐때에만 TRUE를 반환한다.
진리값은 TRUE를 반환하기 때문에 그에 해당하는 행을 색인으로 하여 결측값이 있는 행을 찾아낼 수 있다.

iris_na[is.na(iris_na$Sepal.Length), ]
##     Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
## 33            NA         4.1          1.5         0.1     setosa
## 100           NA         2.8          4.1         1.3 versicolor
## 123           NA         2.8          6.7         2.0  virginica

is.na()도 TRUE나 FALSE를 반환하기 때문에 색인을 통해 해당하는 행을 찾아 낼 수 있다.

mapply(median, iris_na[1:4], na.rm = TRUE)
## Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
##          5.8          3.0          4.4          1.3

mapply()는 다수의 인자를 함수에 넘긴다. 첫번째 열끼리 묶어 중앙값을 구하고, 다시 두번째 열끼리 묶어 중앙값을 구하는 작업을 반복하게 된다.
na.rm=TRUE는 mapply()의 인자로 주어지지만 mapply()내에서 직접 사용되는 것은 아니고, median()을 호출할때 median()의 인자로 넘겨지게 된다.
NA값이 위치한 곳에 중앙값을 대치하는 패키지는 DMwR::centralImputation이다.

library(DMwR)
## Warning: package 'DMwR' was built under R version 3.0.3
## KernSmooth 2.23 loaded
## Copyright M. P. Wand 1997-2009
centralImputation(iris_na[1:4])
##     Sepal.Length Sepal.Width Petal.Length Petal.Width
## 1            5.1         3.5          1.4         0.2
## 2            4.9         3.0          1.4         0.2
## 3            4.7         3.2          1.3         0.2
## 4            4.6         3.1          1.5         0.2
## 5            5.0         3.6          1.4         0.2
## 6            5.4         3.9          1.7         0.4
## 7            4.6         3.4          1.4         0.3
## 8            5.0         3.4          1.5         0.2
## 9            4.4         2.9          1.4         0.2
## 10           4.9         3.1          4.4         0.1
## 11           5.4         3.7          1.5         0.2
## 12           4.8         3.4          1.6         0.2
## 13           4.8         3.0          1.4         0.1
## 14           4.3         3.0          1.1         0.1
## 15           5.8         4.0          1.2         0.2
## 16           5.7         4.4          1.5         0.4
## 17           5.4         3.9          1.3         0.4
## 18           5.1         3.5          1.4         0.3
## 19           5.7         3.8          1.7         0.3
## 20           5.1         3.8          4.4         0.3
## 21           5.4         3.4          1.7         0.2
## 22           5.1         3.7          1.5         0.4
## 23           4.6         3.6          1.0         0.2
## 24           5.1         3.3          1.7         0.5
## 25           4.8         3.4          4.4         0.2
## 26           5.0         3.0          1.6         0.2
## 27           5.0         3.4          1.6         0.4
## 28           5.2         3.5          1.5         0.2
## 29           5.2         3.4          1.4         0.2
## 30           4.7         3.2          1.6         0.2
## 31           4.8         3.1          1.6         0.2
## 32           5.4         3.4          4.4         0.4
## 33           5.8         4.1          1.5         0.1
## 34           5.5         4.2          1.4         0.2
## 35           4.9         3.1          1.5         0.2
## 36           5.0         3.2          1.2         0.2
## 37           5.5         3.5          1.3         0.2
## 38           4.9         3.6          1.4         0.1
## 39           4.4         3.0          1.3         0.2
## 40           5.1         3.4          4.4         0.2
## 41           5.0         3.5          1.3         0.3
## 42           4.5         2.3          1.3         0.3
## 43           4.4         3.2          1.3         0.2
## 44           5.0         3.5          1.6         0.6
## 45           5.1         3.8          1.9         0.4
## 46           4.8         3.0          1.4         0.3
## 47           5.1         3.8          1.6         0.2
## 48           4.6         3.2          1.4         0.2
## 49           5.3         3.7          1.5         0.2
## 50           5.0         3.3          1.4         0.2
## 51           7.0         3.2          4.7         1.4
## 52           6.4         3.2          4.5         1.5
## 53           6.9         3.1          4.9         1.5
## 54           5.5         2.3          4.0         1.3
## 55           6.5         2.8          4.6         1.5
## 56           5.7         2.8          4.5         1.3
## 57           6.3         3.3          4.7         1.6
## 58           4.9         2.4          3.3         1.0
## 59           6.6         2.9          4.6         1.3
## 60           5.2         2.7          3.9         1.4
## 61           5.0         2.0          3.5         1.0
## 62           5.9         3.0          4.2         1.5
## 63           6.0         2.2          4.0         1.0
## 64           6.1         2.9          4.7         1.4
## 65           5.6         2.9          3.6         1.3
## 66           6.7         3.1          4.4         1.4
## 67           5.6         3.0          4.5         1.5
## 68           5.8         2.7          4.1         1.0
## 69           6.2         2.2          4.5         1.5
## 70           5.6         2.5          3.9         1.1
## 71           5.9         3.2          4.8         1.8
## 72           6.1         2.8          4.0         1.3
## 73           6.3         2.5          4.9         1.5
## 74           6.1         2.8          4.7         1.2
## 75           6.4         2.9          4.3         1.3
## 76           6.6         3.0          4.4         1.4
## 77           6.8         2.8          4.8         1.4
## 78           6.7         3.0          5.0         1.7
## 79           6.0         2.9          4.5         1.5
## 80           5.7         2.6          3.5         1.0
## 81           5.5         2.4          3.8         1.1
## 82           5.5         2.4          3.7         1.0
## 83           5.8         2.7          3.9         1.2
## 84           6.0         2.7          5.1         1.6
## 85           5.4         3.0          4.5         1.5
## 86           6.0         3.4          4.5         1.6
## 87           6.7         3.1          4.7         1.5
## 88           6.3         2.3          4.4         1.3
## 89           5.6         3.0          4.1         1.3
## 90           5.5         2.5          4.0         1.3
## 91           5.5         2.6          4.4         1.2
## 92           6.1         3.0          4.6         1.4
## 93           5.8         2.6          4.0         1.2
## 94           5.0         2.3          3.3         1.0
## 95           5.6         2.7          4.2         1.3
## 96           5.7         3.0          4.2         1.2
## 97           5.7         2.9          4.2         1.3
## 98           6.2         2.9          4.3         1.3
## 99           5.1         2.5          3.0         1.1
## 100          5.8         2.8          4.1         1.3
## 101          6.3         3.3          6.0         2.5
## 102          5.8         2.7          5.1         1.9
## 103          7.1         3.0          5.9         2.1
## 104          6.3         2.9          5.6         1.8
## 105          6.5         3.0          5.8         2.2
## 106          7.6         3.0          6.6         2.1
## 107          4.9         2.5          4.5         1.7
## 108          7.3         2.9          6.3         1.8
## 109          6.7         2.5          5.8         1.8
## 110          7.2         3.6          6.1         2.5
## 111          6.5         3.2          5.1         2.0
## 112          6.4         2.7          5.3         1.9
## 113          6.8         3.0          5.5         2.1
## 114          5.7         2.5          5.0         2.0
## 115          5.8         2.8          5.1         2.4
## 116          6.4         3.2          5.3         2.3
## 117          6.5         3.0          5.5         1.8
## 118          7.7         3.8          6.7         2.2
## 119          7.7         2.6          6.9         2.3
## 120          6.0         2.2          5.0         1.5
## 121          6.9         3.2          5.7         2.3
## 122          5.6         2.8          4.9         2.0
## 123          5.8         2.8          6.7         2.0
## 124          6.3         2.7          4.9         1.8
## 125          6.7         3.3          5.7         2.1
## 126          7.2         3.2          6.0         1.8
## 127          6.2         2.8          4.8         1.8
## 128          6.1         3.0          4.9         1.8
## 129          6.4         2.8          5.6         2.1
## 130          7.2         3.0          5.8         1.6
## 131          7.4         2.8          6.1         1.9
## 132          7.9         3.8          6.4         2.0
## 133          6.4         2.8          5.6         2.2
## 134          6.3         2.8          5.1         1.5
## 135          6.1         2.6          5.6         1.4
## 136          7.7         3.0          6.1         2.3
## 137          6.3         3.4          5.6         2.4
## 138          6.4         3.1          5.5         1.8
## 139          6.0         3.0          4.8         1.8
## 140          6.9         3.1          5.4         2.1
## 141          6.7         3.1          5.6         2.4
## 142          6.9         3.1          5.1         2.3
## 143          5.8         2.7          5.1         1.9
## 144          6.8         3.2          5.9         2.3
## 145          6.7         3.3          5.7         2.5
## 146          6.7         3.0          5.2         2.3
## 147          6.3         2.5          5.0         1.9
## 148          6.5         3.0          5.2         2.0
## 149          6.2         3.4          5.4         2.3
## 150          5.9         3.0          5.1         1.8

좀 전에 결측치로 나타난 값들이 열의 중앙값으로 대체되었음을 볼 수 있다.

knnImputation(iris_na[1:4])[c(10, 20, 25, 32, 33, 40, 100, 123), ]
##     Sepal.Length Sepal.Width Petal.Length Petal.Width
## 10         4.900         3.1        1.452         0.1
## 20         5.100         3.8        1.540         0.3
## 25         4.800         3.4        1.457         0.2
## 32         5.400         3.4        1.484         0.4
## 33         5.463         4.1        1.500         0.1
## 40         5.100         3.4        1.476         0.2
## 100        5.891         2.8        4.100         1.3
## 123        7.077         2.8        6.700         2.0

2.3 변수 선택(Feature Selection)

0에 가까운 분산(Near zero Variance)

library(caret)
library(mlbench)
data(Soybean)
nearZeroVar(Soybean)
## [1] 19 26 28
nearZeroVar(Soybean, saveMetrics = TRUE)
##                 freqRatio percentUnique zeroVar   nzv
## Class               1.011        2.7818   FALSE FALSE
## date                1.137        1.0249   FALSE FALSE
## plant.stand         1.208        0.2928   FALSE FALSE
## precip              4.098        0.4392   FALSE FALSE
## temp                1.879        0.4392   FALSE FALSE
## hail                3.425        0.2928   FALSE FALSE
## crop.hist           1.005        0.5857   FALSE FALSE
## area.dam            1.214        0.5857   FALSE FALSE
## sever               1.651        0.4392   FALSE FALSE
## seed.tmt            1.374        0.4392   FALSE FALSE
## germ                1.104        0.4392   FALSE FALSE
## plant.growth        1.951        0.2928   FALSE FALSE
## leaves              7.870        0.2928   FALSE FALSE
## leaf.halo           1.548        0.4392   FALSE FALSE
## leaf.marg           1.615        0.4392   FALSE FALSE
## leaf.size           1.480        0.4392   FALSE FALSE
## leaf.shread         5.073        0.2928   FALSE FALSE
## leaf.malf          12.311        0.2928   FALSE FALSE
## leaf.mild          26.750        0.4392   FALSE  TRUE
## stem                1.253        0.2928   FALSE FALSE
## lodging            12.381        0.2928   FALSE FALSE
## stem.cankers        1.984        0.5857   FALSE FALSE
## canker.lesion       1.808        0.5857   FALSE FALSE
## fruiting.bodies     4.548        0.2928   FALSE FALSE
## ext.decay           3.681        0.4392   FALSE FALSE
## mycelium          106.500        0.2928   FALSE  TRUE
## int.discolor       13.205        0.4392   FALSE FALSE
## sclerotia          31.250        0.2928   FALSE  TRUE
## fruit.pods          3.131        0.5857   FALSE FALSE
## fruit.spots         3.450        0.5857   FALSE FALSE
## seed                4.139        0.2928   FALSE FALSE
## mold.growth         7.821        0.2928   FALSE FALSE
## seed.discolor       8.016        0.2928   FALSE FALSE
## seed.size           9.017        0.2928   FALSE FALSE
## shriveling         14.184        0.2928   FALSE FALSE
## roots               6.407        0.4392   FALSE FALSE

분산이 0에 가까운 변수를 바로 출력해준다.
saveMetrics=TRUE를 지정하면 분석 결과의 표가 출력된다.
출력된 표의 nzv컬럼은 Near Zero Variance를 뜻하므로, TRUE로 표시된 변수들을 제거할 수 있다.

mySoybean <- Soybean[, -nearZeroVar(Soybean)]

상관 계수(Correlation)

library(mlbench)
data(Vehicle)
findCorrelation(cor(subset(Vehicle, select = -c(Class))))
## [1]  3  8 11  7  9  2

3, 8, 11, 7, 9, 2번째 컬럼의 변수간 상관계수가 높은 것으로 나타났다. 변수를 제거할 수 있다.

myVehicle <- Vehicle[, -c(findCorrelation(cor(subset(Vehicle, select = -c(Class)))))]
library(FSelector)
## Warning: package 'FSelector' was built under R version 3.0.3

library(mlbench)

data(Ozone)
(v <- linear.correlation(V4 ~ ., data = subset(Ozone, select = -c(V1, V2, V3))))
##     attr_importance
## V5         0.584144
## V6         0.004681
## V7         0.443566
## V8         0.769864
## V9         0.723173
## V10        0.580268
## V11        0.229903
## V12        0.731950
## V13        0.414715

V1, V2, V3는 Factor형 변수이므로 계산에서 제외되었다.

v는 하나의 열을 갖는 데이터 프레임이다. 따라서 다음과 같이 정렬해 볼 수 있다.

v[order(-v), , drop = FALSE]
##     attr_importance
## V8         0.769864
## V12        0.731950
## V9         0.723173
## V5         0.584144
## V10        0.580268
## V7         0.443566
## V13        0.414715
## V11        0.229903
## V6         0.004681

order(-v)는 큰 수부터 작은수로의 순서를 반환하는 명령이다. 데이터 프레임에 단위의 차이, 크기의 차이가 많이 나면 order()를 실행하는게 별 의미가 없겠지만, 하나의 열을 갖기 때문에 가능하다.
drop=FALSE를 사용해 데이터가 벡터로 변환되는 것을 막았다.

Chi Square

chi.squared(Class ~ ., data = Vehicle)
##              attr_importance
## Comp                  0.3043
## Circ                  0.2975
## D.Circ                0.3588
## Rad.Ra                0.3509
## Pr.Axis.Ra            0.2265
## Max.L.Ra              0.3235
## Scat.Ra               0.4654
## Elong                 0.4557
## Pr.Axis.Rect          0.4475
## Max.L.Rect            0.3060
## Sc.Var.Maxis          0.4338
## Sc.Var.maxis          0.4922
## Ra.Gyr                0.2940
## Skew.Maxis            0.3088
## Skew.maxis            0.2470
## Kurt.maxis            0.3339
## Kurt.Maxis            0.2732
## Holl.Ra               0.3886

Class를 예측하는데 변수들의 중요도를 평가하는 것이므로, formula에 형식이 ” Class ~ . “ 이다.

모델을 사용한 변수 중요도 평가

library(mlbench)
library(rpart)
library(caret)
data(BreastCancer)
m <- rpart(Class ~ ., data = BreastCancer)

rpart는 의사결정나무를 그려주는 패키지이다.
varImp()를 이용해 변수 중요도를 찾아보자.

varImp(m)
##                 Overall
## Bare.nuclei       203.7
## Bl.cromatin       197.9
## Cell.shape        216.4
## Cell.size         222.9
## Id                307.9
## Cl.thickness        0.0
## Marg.adhesion       0.0
## Epith.c.size        0.0
## Normal.nucleoli     0.0
## Mitoses             0.0

3. 모델 평가 방법

3.1 평가 메트릭(metric)

예측한 분류결과가 담긴 predicted벡터와 이들의 실제 분류가 담긴 actual벡터를 정의해보자.

predicted <- c(1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1)
actual <- c(1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1)
xtabs(~predicted + actual)
##          actual
## predicted 0 1
##         0 3 2
##         1 1 6

예측 결과와 실제 결과가 일치하는 경우와 그렇지 않은 경우를 쉽게 알 수 있다.

prop.table()을 사용해 비율도 쉽게 계산가능하다.

prop.table(xtabs(~predicted + actual))
##          actual
## predicted       0       1
##         0 0.25000 0.16667
##         1 0.08333 0.50000

정분류 비율(Accuracy)는 예측값중 올바른 값의 비율로 대각선의 비율을 더하면 된다. 0.25 + 0.50 = 0.7이 정분류 비율이다.

다음과 같이 Accuracy를 구할 수도 있다.

sum(predicted == actual)/NROW(actual)
## [1] 0.75

predicted와 actual이 같은 값을 갖는 갯수를 세고, 총 갯수(actual이나 predicted나 같다.)

library(caret)
library(class)
confusionMatrix(predicted, actual)
## Warning: package 'e1071' was built under R version 3.0.3
## 
## Attaching package: 'e1071'
## 
## The following object is masked from 'package:Hmisc':
## 
##     impute
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 0 1
##          0 3 2
##          1 1 6
##                                         
##                Accuracy : 0.75          
##                  95% CI : (0.428, 0.945)
##     No Information Rate : 0.667         
##     P-Value [Acc > NIR] : 0.393         
##                                         
##                   Kappa : 0.471         
##  Mcnemar's Test P-Value : 1.000         
##                                         
##             Sensitivity : 0.750         
##             Specificity : 0.750         
##          Pos Pred Value : 0.600         
##          Neg Pred Value : 0.857         
##              Prevalence : 0.333         
##          Detection Rate : 0.250         
##    Detection Prevalence : 0.417         
##       Balanced Accuracy : 0.750         
##                                         
##        'Positive' Class : 0             
## 

'Error in library(e1071) : there is no package called ‘e1071’'이런 에러 메시지가 출력되는데, e1071이라는 패키지를 설치하면 에러가 없어지고 정상적으로 실행된다….왜그럴까

세부 결과를 바로 얻고자하면 str()을 사용해 구조를 살펴보고 개별 메트릭을 가져올 수 있다.

cm <- confusionMatrix(predicted, actual)
str(cm)
## List of 5
##  $ positive: chr "0"
##  $ table   : 'table' int [1:2, 1:2] 3 1 2 6
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ Prediction: chr [1:2] "0" "1"
##   .. ..$ Reference : chr [1:2] "0" "1"
##  $ overall : Named num [1:7] 0.75 0.471 0.428 0.945 0.667 ...
##   ..- attr(*, "names")= chr [1:7] "Accuracy" "Kappa" "AccuracyLower" "AccuracyUpper" ...
##  $ byClass : Named num [1:8] 0.75 0.75 0.6 0.857 0.333 ...
##   ..- attr(*, "names")= chr [1:8] "Sensitivity" "Specificity" "Pos Pred Value" "Neg Pred Value" ...
##  $ dots    : list()
##  - attr(*, "class")= chr "confusionMatrix"
cm$overall["Accuracy"]
## Accuracy 
##     0.75

3.2 ROC 커브

probs <- runif(100)  # 균일 분포에서 표본 추출
labels <- as.factor(ifelse(probs > 0.5 & runif(100) < 0.4, "A", "B"))

probs는 분류 알고리즘이 예측한 점수이고, labels는 정답에 해당하는 분류(true class)가 저장된 벡터이다.
labels내의 ifelse는 약간의 분류 실패를 시뮬레이션 해 본 것이다.

ROCR을 사용하기 위해 prediction객체를 만든다.

library(ROCR)
## Warning: package 'ROCR' was built under R version 3.0.3
## Loading required package: gplots
## Warning: package 'gplots' was built under R version 3.0.3
## 
## Attaching package: 'gplots'
## 
## The following object is masked from 'package:stats':
## 
##     lowess
pred <- prediction(probs, labels)

prediction 객체를 performance() 함수에 넘겨 분류 알고리즘의 다양한 성능을 얻을 수 있다.

plot(performance(prediction(probs, labels), "tpr", "fpr"))

plot of chunk unnamed-chunk-37

plot(performance(prediction(probs, labels), "acc", "cutoff"))

plot of chunk unnamed-chunk-37

performance(prediction.obj, measure, x.measure="cutoff”, …)의 형태를 가진다.
tpr(True Positive Rate, 정확하게 Positive로 분류함), fpr(False Positive Rate, Positive라고 예측한 것이 잘못됨)을 사용해 ROC커브를 그려보았다.
마찬가지로 acc(Accuracy, 정분류), cutoff를 인자로 지장하면, cutoff값에 따른 Accuracy의 변화를 볼 수 있다. cutoff는 민감도와 특이도를 원하는 수준으로 맞추기 위한 분류의 임계값? 기준값? 이다.

3.3 교차 검증(cross validation)

iris 데이터에 대해 10-Fold Cross Validation을 3회 반복 수행하기위해 cvFolds()를 사용한 예이다.

library(cvTools)
## Warning: package 'cvTools' was built under R version 3.0.3
## Loading required package: robustbase
## Warning: package 'robustbase' was built under R version 3.0.3
## 
## Attaching package: 'robustbase'
## 
## The following object is masked from 'package:survival':
## 
##     heart

set.seed(719)
(cv <- cvFolds(NROW(iris), K = 10, R = 3))
## 
## Repeated 10-fold CV with 3 replications:    
## Fold      1   2   3
##   1      92   4  86
##   2      52   3 144
##   3      17  75   5
##   4      13  98  61
##   5      61  30  16
##   6       9 129 148
##   7       8 121  49
##   8      31  89  37
##   9     136 140  82
##   10     37 102 141
##   1      90  51  78
##   2     119  80 132
##   3     116  70  97
##   4      11  29  93
##   5      39  72  45
##   6      94 125 114
##   7      68  84  25
##   8      75 123  69
##   9     131  73  87
##   10    100 132  63
##   1      14 113  54
##   2      93  65  10
##   3      64  87  74
##   4      23  66  31
##   5     102   8  18
##   6      51 101  23
##   7      77  79 107
##   8     122 142   9
##   9     143 108 128
##   10    108  18 117
##   1     105  60 142
##   2      49  12  21
##   3     132 135  48
##   4      33  48  76
##   5      46 109 110
##   6      89  24 119
##   7      84  63  17
##   8       7  16 130
##   9      41   2 116
##   10      6  59  41
##   1      85  47 104
##   2      19  55  39
##   3      47 117  35
##   4      87 126 112
##   5      25  42  51
##   6     110  82  66
##   7      62  88  89
##   8      53  53 140
##   9     120  91  36
##   10    141  97   3
##   1      95   7  95
##   2      99   1  96
##   3     150  56  72
##   4      71   6  90
##   5      32 116 125
##   6      34  22 121
##   7      81 115 113
##   8      83  14 135
##   9      21  44   4
##   10    115  83  98
##   1     128  37 147
##   2      60  45  47
##   3      82  71  11
##   4      74  68 106
##   5      18 127   2
##   6      22 145 120
##   7      80  26  32
##   8      50  99  62
##   9      59  25  80
##   10     27  32 126
##   1      56  92 149
##   2      40 136  13
##   3     134  61  64
##   4      12  39  53
##   5     114 122  55
##   6     118  90  83
##   7      26  31  75
##   8       3 139 127
##   9     126  46  44
##   10    127  50  52
##   1     121  93  85
##   2      57  40 145
##   3      48  94  26
##   4     113  52  65
##   5     106  41 131
##   6     147  23 150
##   7      78 148  20
##   8       2  95  88
##   9      55  36  79
##   10     69  49  91
##   1     112  11  12
##   2      58  28  71
##   3     145  78  22
##   4     142  76 109
##   5      63  21  43
##   6     140 144 101
##   7      91 131 137
##   8      24 150 100
##   9     109 114  56
##   10     42 112  19
##   1     144 111  29
##   2     123 110  14
##   3      67 138  68
##   4      16  54 115
##   5      79 146 108
##   6      38  33  15
##   7      45  77  99
##   8      86   5  73
##   9      54  58  67
##   10    107  81  24
##   1       5 119  92
##   2     139 141  38
##   3      96 128  30
##   4       1  10 146
##   5     137  17  27
##   6      70   9  94
##   7     101  69 123
##   8     103  96  60
##   9      10 143  59
##   10    129 149   6
##   1      65  62 105
##   2      30  38  77
##   3     138  67  57
##   4      29  13  28
##   5      15 118 136
##   6     125  85  42
##   7      88 103  46
##   8     130 124 133
##   9      97  57   1
##   10     98  15 124
##   1     146 104  70
##   2      28  19 118
##   3      36 134   8
##   4      76 100  34
##   5     149 105 111
##   6     117  35 143
##   7     135  43 102
##   8     104  86 122
##   9      20 130 103
##   10     43  34  40
##   1       4 147 129
##   2     148 120  33
##   3      44 107  81
##   4     124 106  58
##   5      72  20  84
##   6      35 137 139
##   7      73 133  50
##   8     133  64   7
##   9     111  74 134
##   10     66  27 138

cvFolds() 실행전에 호출한 set.seed는 난수를 생성하는 초기값(seed)를 지정하기 위해 사용하였다.
cvFolds()는 난수를 사용하여 데이터를 분리한다. set.seed를 지정해주면 매번 같은 folds를 결과로 내놓는다.

결과를 보면, 각 행은 Folds를 의미하고 각 열은 매 반복을 의미한다. 1행 1열의 92는, 첫번째 반복에서 첫번째 Fold의 Validation Data에 iris의 92번째 데이터를 사용하라는 의미이다. 마찬가지로 Fold가 1인 또 다른 행을 찾아보면 90, 51, 78이 있다.

위 표의 Fold에 해당하는 부분은 cv$which에, 실제 선택할 행을 저장한 부분은 cv$subset에 저장된다.

head(cv$which, 20)
##  [1]  1  2  3  4  5  6  7  8  9 10  1  2  3  4  5  6  7  8  9 10
head(cv$subset)
##      [,1] [,2] [,3]
## [1,]   92    4   86
## [2,]   52    3  144
## [3,]   17   75    5
## [4,]   13   98   61
## [5,]   61   30   16
## [6,]    9  129  148

따라서 첫번째 반복의 첫번째 fold에서 Valiation Date로 사용해야할 행의 번호는 다음과 같이 구할 수 있다.

which(cv$which == 1)
##  [1]   1  11  21  31  41  51  61  71  81  91 101 111 121 131 141
validation_idx <- cv$subsets[which(cv$which == 1), 1]
validation_idx
##  [1]  92  90  14 105  85  95 128  56 121 112 144   5  65 146   4

which() 함수는 주어진 조건을 만족하는 행의 번호를 얻기 위해 사용하였다.

이제 iris데이터에서, 첫번째 반복의 첫번째 fold에서 Training, Validation Data를 구해보자.

train <- iris[-validation_idx, ]
validation <- iris[validation_idx, ]

이를 사용해 K fold cross Validation을 반복하는 전체 코드 모습은 다음과 같다.

library(foreach)
set.seed(719)
R = 3
K = 10
cv <- cvFolds(NROW(iris), K = K, R = R)

foreach(r = 1:R) %do% {
    foreach(k = 1:K, .combine = c) %do% {
        validation_idx <- cv$subsets[which(cv$which == k), r]
        train <- iris[-validation_idx, ]
        validation <- iris[validation_idx, ]
        # preprocessing

        # training

        # prediction

        # estimation performance used runif for demonstration purpose
        return(runif(1))
    }
}
## [[1]]
##  [1] 0.3841 0.9323 0.4759 0.9419 0.2496 0.7490 0.6130 0.6606 0.4116 0.6134
## 
## [[2]]
##  [1] 0.34082 0.11616 0.08346 0.62234 0.32939 0.97048 0.09819 0.02693
##  [9] 0.70704 0.78078
## 
## [[3]]
##  [1] 0.8645 0.8138 0.1619 0.2863 0.1375 0.4782 0.1847 0.5155 0.5843 0.0190

for 대신 foreach를 사용했다. foreach()는 값을 반환할 수 있어 모델을 평가한 결과를 모으는데 유용하다.
또, foreach() 사용시 내부의 foreach()에서는 .combine에 'c'를 지정하여 결과가 리스트가 아닌 벡터로 되게하였다. 이렇게하면 리스트의 리스트가 아니라 벡터의 리스트가 되어 조작이 용이할 것이다.

좋은 모델 성능 평가가 되려면 예측하고자하는 분류(Y), 그리고 예측에 사용하는 설명 변수(X)에 대한 고려가 필요하다.

caret패키지의 여러가지 함수들은 Y값을 고려한 훈련 데이터(training data)와 검증 데이터(validation data)의 분리를 지원하며, 이들 함수를 사용해 분리한 데이터는 Y값의 비율이 원본 데이터와 같게 유지된다.

library(caret)
(parts <- createDataPartition(iris$Species, p = 0.8))
## $Resample1
##   [1]   1   2   3   5   7  10  11  12  13  14  15  17  19  20  21  22  23
##  [18]  24  26  27  28  29  30  31  33  34  35  37  38  39  40  41  42  43
##  [35]  44  45  46  47  48  49  52  53  54  55  56  57  58  60  61  62  63
##  [52]  64  65  66  69  70  71  73  74  75  76  77  78  80  82  83  84  85
##  [69]  86  87  88  89  90  91  92  93  94  95  97  98 101 103 104 106 107
##  [86] 109 110 111 114 115 116 117 119 120 121 122 123 124 125 126 127 128
## [103] 129 130 131 132 133 134 137 138 139 140 141 142 143 144 145 146 147
## [120] 150

iris 데이터의 80%를 훈련 데이터, 나머지 20%를 검증 데이터로 분리한 예이다.

table(iris[parts$Resample1, "Species"])
## 
##     setosa versicolor  virginica 
##         40         40         40

parts$Resample1에 나온 숫자가 행의 번호가 되고, 그 행의 Species를 table()로 그린 결과이다. Species마다 40개씩을 훈련 데이터로 추출했다. 나머지 데이터는 검증데이터로 Species마다 각 10개씩 할당된다.

table(iris[-parts$Resample1, "Species"])
## 
##     setosa versicolor  virginica 
##         10         10         10