data(iris)
d <- subset(iris, Species == "setosa" | Species == "versicolor")
str(d)
## 'data.frame': 100 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
아직 Factor에는 3수준의 요인이 남아있으므로 2요인으로 바꾸어야 한다.
d$Species <- factor(d$Species)
str(d)
## 'data.frame': 100 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 2 levels "setosa","versicolor": 1 1 1 1 1 1 1 1 1 1 ...
모델은 glm()함수에서 family=binomial을 지정해 회귀분석하듯이 만들면 된다.
(m <- glm(Species ~ ., data = d, family = binomial))
## Warning: glm.fit: algorithm did not converge
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
##
## Call: glm(formula = Species ~ ., family = binomial, data = d)
##
## Coefficients:
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## 6.56 -9.88 -7.42 19.05 25.03
##
## Degrees of Freedom: 99 Total (i.e. Null); 95 Residual
## Null Deviance: 139
## Residual Deviance: 1.32e-09 AIC: 10
로지스틱 회귀 모형은 0 또는 1의 값을 예측해서, 어느 그룹에 속할 확률이 큰지를 나타내는 모델이므로, 적합값 fitted()를 보면 setosa에 해당하는 적합값은 0, versicolor는 1로 예측된다.
fitted(m)[c(1:5, 51:55)]
## 1 2 3 4 5 51 52
## 2.220e-16 2.220e-16 2.220e-16 5.152e-13 2.220e-16 1.000e+00 1.000e+00
## 53 54 55
## 1.000e+00 1.000e+00 1.000e+00
예측이 얼마나 잘 되었는지 알아보자.
f <- fitted(m)
as.numeric(d$Species)
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [36] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [71] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
as.numeric()은 요인을 숫자로 저장한 벡터를 변환한다. 근데 1부터 값을 부여하므로, 1을 빼주어야 로지스틱 회귀분석처럼 0 또는 1의 값을 갖게된다.
ifelse(f > 0.5, 1, 0) == as.numeric(d$Species) - 1
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## 91 92 93 94 95 96 97 98 99 100
## TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
TRUE 갯수를 코드를 사용해 세보고 예측의 정확도를 알아보자.
is_correct <- (ifelse(f > 0.5, 1, 0) == as.numeric(d$Species) - 1)
sum(is_correct)
## [1] 100
sum(is_correct)/NROW(is_correct)
## [1] 1
sum()함수는 TRUE를 1, FALSE를 0으로 취급하고, NROW()는 데이터 갯수를 반환한다.
새로운 데이터에 대한 예측은 predict() 함수를 사용한다. type을 response로 지정하고 예측을 수행하면 0에서 1사이의 결과값을 구해준다.
predict(m, newdata = d[c(1, 10, 55), ], type = "response")
## 1 10 55
## 2.22e-16 2.22e-16 1.00e+00
library(nnet)
## Warning: package 'nnet' was built under R version 3.0.3
(m <- multinom(Species ~ ., data = iris))
## # weights: 18 (10 variable)
## initial value 164.791843
## iter 10 value 16.177348
## iter 20 value 7.111438
## iter 30 value 6.182999
## iter 40 value 5.984028
## iter 50 value 5.961278
## iter 60 value 5.954900
## iter 70 value 5.951851
## iter 80 value 5.950343
## iter 90 value 5.949904
## iter 100 value 5.949867
## final value 5.949867
## stopped after 100 iterations
## Call:
## multinom(formula = Species ~ ., data = iris)
##
## Coefficients:
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## versicolor 18.69 -5.458 -8.707 14.24 -3.098
## virginica -23.84 -7.924 -15.371 23.66 15.135
##
## Residual Deviance: 11.9
## AIC: 31.9
작성한 모델이 주어진 훈련데이터를 어떻게 분류하고 있는지는 fitted()를 사용해 구할 수 있다.
head(fitted(m))
## setosa versicolor virginica
## 1 1 1.526e-09 2.716e-36
## 2 1 3.536e-07 2.884e-32
## 3 1 4.444e-08 6.103e-34
## 4 1 3.164e-06 7.117e-31
## 5 1 1.103e-09 1.290e-36
## 6 1 3.522e-10 1.345e-35
fitted()의 결과는 각 행의 데이터가 각 분류에 속할 확률을 뜻한다.
어떤 분류로 예측되었는지를 알아보기 위해 각 행에서 가장 큰 값이 속하는 열을 뽑을 수도 있겠지만, 더 간단하게 predict()를 사용해도 된다. 특히, predict()에는 newdata에 새로운 데이터를 지정할 수 있다.
apply(fitted(m), 1, max)
## 1 2 3 4 5 6 7 8 9 10
## 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000
## 11 12 13 14 15 16 17 18 19 20
## 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000
## 21 22 23 24 25 26 27 28 29 30
## 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000
## 31 32 33 34 35 36 37 38 39 40
## 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000
## 41 42 43 44 45 46 47 48 49 50
## 1.0000 0.9998 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000
## 51 52 53 54 55 56 57 58 59 60
## 1.0000 1.0000 0.9988 1.0000 0.9986 0.9999 0.9987 1.0000 1.0000 1.0000
## 61 62 63 64 65 66 67 68 69 70
## 1.0000 1.0000 1.0000 0.9992 1.0000 1.0000 0.9986 1.0000 0.9401 1.0000
## 71 72 73 74 75 76 77 78 79 80
## 0.5945 1.0000 0.7743 1.0000 1.0000 1.0000 0.9993 0.7236 0.9990 1.0000
## 81 82 83 84 85 86 87 88 89 90
## 1.0000 1.0000 1.0000 0.8676 0.9978 0.9998 0.9997 0.9997 1.0000 1.0000
## 91 92 93 94 95 96 97 98 99 100
## 1.0000 0.9998 1.0000 1.0000 1.0000 1.0000 1.0000 1.0000 0.9998 1.0000
## 101 102 103 104 105 106 107 108 109 110
## 1.0000 0.9996 1.0000 0.9997 1.0000 1.0000 0.8908 1.0000 1.0000 1.0000
## 111 112 113 114 115 116 117 118 119 120
## 0.9901 0.9997 1.0000 1.0000 1.0000 1.0000 0.9977 1.0000 1.0000 0.9204
## 121 122 123 124 125 126 127 128 129 130
## 1.0000 0.9995 1.0000 0.9481 1.0000 0.9996 0.8239 0.8019 1.0000 0.9711
## 131 132 133 134 135 136 137 138 139 140
## 1.0000 0.9999 1.0000 0.7939 0.9665 1.0000 1.0000 0.9965 0.6689 0.9999
## 141 142 143 144 145 146 147 148 149 150
## 1.0000 0.9999 0.9996 1.0000 1.0000 1.0000 0.9991 0.9990 1.0000 0.9776
a <- apply(fitted(m), 1, max)
ifelse(a == 1, "setosa", ifelse(a == 2, "versicolor", "virginica"))
## 1 2 3 4 5 6
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 7 8 9 10 11 12
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 13 14 15 16 17 18
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 19 20 21 22 23 24
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 25 26 27 28 29 30
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 31 32 33 34 35 36
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 37 38 39 40 41 42
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 43 44 45 46 47 48
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 49 50 51 52 53 54
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 55 56 57 58 59 60
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 61 62 63 64 65 66
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 67 68 69 70 71 72
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 73 74 75 76 77 78
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 79 80 81 82 83 84
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 85 86 87 88 89 90
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 91 92 93 94 95 96
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 97 98 99 100 101 102
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 103 104 105 106 107 108
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 109 110 111 112 113 114
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 115 116 117 118 119 120
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 121 122 123 124 125 126
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 127 128 129 130 131 132
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 133 134 135 136 137 138
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 139 140 141 142 143 144
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
## 145 146 147 148 149 150
## "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
predict(m)
## [1] setosa setosa setosa setosa setosa setosa
## [7] setosa setosa setosa setosa setosa setosa
## [13] setosa setosa setosa setosa setosa setosa
## [19] setosa setosa setosa setosa setosa setosa
## [25] setosa setosa setosa setosa setosa setosa
## [31] setosa setosa setosa setosa setosa setosa
## [37] setosa setosa setosa setosa setosa setosa
## [43] setosa setosa setosa setosa setosa setosa
## [49] setosa setosa versicolor versicolor versicolor versicolor
## [55] versicolor versicolor versicolor versicolor versicolor versicolor
## [61] versicolor versicolor versicolor versicolor versicolor versicolor
## [67] versicolor versicolor versicolor versicolor versicolor versicolor
## [73] versicolor versicolor versicolor versicolor versicolor versicolor
## [79] versicolor versicolor versicolor versicolor versicolor virginica
## [85] versicolor versicolor versicolor versicolor versicolor versicolor
## [91] versicolor versicolor versicolor versicolor versicolor versicolor
## [97] versicolor versicolor versicolor versicolor virginica virginica
## [103] virginica virginica virginica virginica virginica virginica
## [109] virginica virginica virginica virginica virginica virginica
## [115] virginica virginica virginica virginica virginica virginica
## [121] virginica virginica virginica virginica virginica virginica
## [127] virginica virginica virginica virginica virginica virginica
## [133] virginica versicolor virginica virginica virginica virginica
## [139] virginica virginica virginica virginica virginica virginica
## [145] virginica virginica virginica virginica virginica virginica
## Levels: setosa versicolor virginica
predict(m, newdata = iris[c(1, 51, 101), ], type = "class")
## [1] setosa versicolor virginica
## Levels: setosa versicolor virginica
분류를 얻을때는 type=“class"를 지정해야하지만, 기본 값이 class이므로 생략해도 된다.
각 분류에 속할 확률을 예측하고자한다면 type="probs"를 지정한다.
predict(m, newdata = iris[c(1, 51, 101), ], type = "probs")
## setosa versicolor virginica
## 1 1.000e+00 1.526e-09 2.716e-36
## 51 2.427e-07 1.000e+00 1.202e-05
## 101 9.454e-25 2.718e-10 1.000e+00
모델의 정확도는 예측된 Species와 실제 Species를 비교한다.
predicted <- predict(m, newdata = iris)
sum(predicted == iris$Species)/NROW(iris)
## [1] 0.9867
분류 대상이 2개 이상인 경우 분할표를 그린다.
xtabs(~predicted + iris$Species)
## iris$Species
## predicted setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 49 1
## virginica 0 1 49
분할표는 xtabs(도수를 나타내는 칼럼 ~ 변수 + 변수 + …)의 형식이다.