랜덤포레스트 & K-fold CV



패키지 불러오기
# echo=TRUE : 코드를 평가하고 실행결과를 포함한다
# eval=TRUE : 실행결과와 함게 코드를 출력한다
# message=FALSE : 메시지를 출력한다
# warning=TRUE : 경고메시지를 출력한다 
# error=FALSE : 오류메시지를 출력한다 
# tidy=FALSE : 깔끔한 방식으로 코드 형태를 변형한다 

library(readxl)
library(randomForest)

데이터 불러오기
setwd("C:/Users/user/Desktop/krivet")
edu <- read_excel("data_merge_3_na.xlsx")

# 종속변수= satisfaction55, 설명변수 21개
str(edu)
## Classes 'tbl_df', 'tbl' and 'data.frame':    312 obs. of  22 variables:
##  $ satisfaction55: num  4 4 4 5 4 4 4 4 2 3 ...
##  $ facilities79_1: num  4 4 4 4 3 1 4 3 1 3 ...
##  $ facilities80_2: num  4 4 4 2 3 2 3 4 1 3 ...
##  $ facilities81_3: num  4 4 5 3 3 3 3 4 1 3 ...
##  $ facilities82_4: num  4 4 5 3 3 2 3 4 1 3 ...
##  $ ralation84_1  : num  4 4 4 3 3 1 3 4 1 3 ...
##  $ pride85_1     : num  4 4 4 3 4 3 4 4 1 3 ...
##  $ pride86_2     : num  4 4 5 3 4 3 4 4 1 3 ...
##  $ schoollife87_1: num  4 4 4 3 4 4 2 4 1 3 ...
##  $ schoollife88_2: num  4 4 4 3 3 1 4 4 1 3 ...
##  $ schoollife89_3: num  3 4 4 3 3 2 4 4 1 3 ...
##  $ schoollife90_4: num  4 4 3 3 4 3 4 4 1 3 ...
##  $ schoollife91_5: num  4 4 4 3 4 4 4 4 1 3 ...
##  $ schoollife92_6: num  4 4 4 3 4 4 4 4 3 4 ...
##  $ schoollife93_7: num  4 4 4 1 3 3 3 4 1 3 ...
##  $ class94_1     : num  4 4 4 3 4 4 4 4 3 3 ...
##  $ class95_2     : num  4 4 4 3 4 4 4 4 2 3 ...
##  $ class96_3     : num  4 4 4 3 4 3 4 4 2 3 ...
##  $ class97_4     : num  4 4 4 3 4 3 4 4 2 4 ...
##  $ class98_5     : num  4 4 4 3 4 5 4 4 3 4 ...
##  $ class99_6     : num  4 4 4 3 4 5 3 4 3 4 ...
##  $ prof102_1     : num  1 1 2 2 1 1 1 1 1 1 ...
summary(edu)
##  satisfaction55  facilities79_1  facilities80_2  facilities81_3 
##  Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000  
##  Median :4.000   Median :4.000   Median :4.000   Median :4.000  
##  Mean   :3.872   Mean   :3.583   Mean   :3.506   Mean   :3.609  
##  3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :5.000   Max.   :5.000   Max.   :5.000   Max.   :5.000  
##  facilities82_4   ralation84_1     pride85_1       pride86_2    
##  Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000  
##  Median :4.000   Median :4.000   Median :3.000   Median :4.000  
##  Mean   :3.622   Mean   :3.712   Mean   :3.465   Mean   :3.679  
##  3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :5.000   Max.   :5.000   Max.   :5.000   Max.   :5.000  
##  schoollife87_1  schoollife88_2  schoollife89_3  schoollife90_4 
##  Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000  
##  Median :3.000   Median :3.000   Median :3.000   Median :4.000  
##  Mean   :3.359   Mean   :3.442   Mean   :3.279   Mean   :3.583  
##  3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :5.000   Max.   :5.000   Max.   :5.000   Max.   :5.000  
##  schoollife91_5  schoollife92_6  schoollife93_7    class94_1    
##  Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:3.000   1st Qu.:2.000   1st Qu.:3.000  
##  Median :3.000   Median :4.000   Median :3.000   Median :4.000  
##  Mean   :3.388   Mean   :3.619   Mean   :2.955   Mean   :3.766  
##  3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :5.000   Max.   :5.000   Max.   :5.000   Max.   :5.000  
##    class95_2       class96_3       class97_4       class98_5    
##  Min.   :1.000   Min.   :1.000   Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000   1st Qu.:3.000  
##  Median :4.000   Median :4.000   Median :4.000   Median :4.000  
##  Mean   :3.756   Mean   :3.846   Mean   :3.798   Mean   :3.715  
##  3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000   3rd Qu.:4.000  
##  Max.   :5.000   Max.   :5.000   Max.   :5.000   Max.   :5.000  
##    class99_6       prof102_1    
##  Min.   :1.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:1.000  
##  Median :4.000   Median :2.000  
##  Mean   :3.827   Mean   :1.721  
##  3rd Qu.:4.000   3rd Qu.:2.000  
##  Max.   :5.000   Max.   :5.000
edu$satisfaction55 <- as.factor(edu$satisfaction55)

데이터 분할
t_index <- sample(1:nrow(edu), size=nrow(edu)*0.7, replace = FALSE)
train <- edu[t_index, ]
test <- edu[-t_index, ]

nrow(train); nrow(test)
## [1] 218
## [1] 94

분석 및 예측
set.seed(1)

edu_randomForest <- randomForest(factor(satisfaction55) ~ facilities80_2+facilities81_3+facilities82_4+
                                   pride85_1+pride86_2+schoollife87_1+schoollife88_2+schoollife89_3+
                                   class94_1+class95_2+class96_3+class97_4+class98_5+class99_6, data=train, proximity=TRUE, importance=TRUE)
plot(edu_randomForest)

# %IncMSE = 정확도, IncNodePurity = 중요도
# MeanDecreaseAccuracy: 정확도 개선에 중요한 변수   
# MeanDecreaseGini: 노드 불순도 개선에 중요한 변수
importance(edu_randomForest)    
##                1          2          3          4          5
## facilities80_2 0 -0.9464267  0.1517930  8.6273066 11.7048013
## facilities81_3 0 -3.3046838  2.9852777  0.9437173  0.4121795
## facilities82_4 0 -0.4248770  0.5661825  5.0905662  6.4906925
## pride85_1      0 -1.3778181  3.4785249  4.2502497  6.1264505
## pride86_2      0 -3.8168220  3.1070652  7.6336569 12.6631886
## schoollife87_1 0  2.3992904  1.1511860  8.0344269  4.0874725
## schoollife88_2 0  4.3266546  0.5859650  1.5150991  5.1889757
## schoollife89_3 0 -1.4814880  7.7631961  5.5224931 -0.4930810
## class94_1      0  1.8467407 -2.4637416  7.0674265 11.2717868
## class95_2      0  3.6840805 -2.3225191  2.7444230 10.2886263
## class96_3      0  4.6506770 -9.3187080 13.4996125  3.4900580
## class97_4      0  1.3720509 -4.9852370  5.8693690  7.0003686
## class98_5      0  1.8145727  2.7735706 12.5986453 -2.2467331
## class99_6      0  3.2241147  7.2696713  5.7779856  7.8288561
##                MeanDecreaseAccuracy MeanDecreaseGini
## facilities80_2            12.376340        10.278381
## facilities81_3             2.248871         7.282547
## facilities82_4             7.294985         8.326825
## pride85_1                  8.417995         8.912493
## pride86_2                 13.940546        10.599365
## schoollife87_1             8.870375         9.281399
## schoollife88_2             4.220253         8.103095
## schoollife89_3             8.307398         8.523642
## class94_1                 10.621491         7.107757
## class95_2                  7.106756         5.943027
## class96_3                  9.110116         6.473564
## class97_4                  6.059553         6.239225
## class98_5                 11.626670         7.426254
## class99_6                 12.866490         8.639981
varImpPlot(edu_randomForest)   

test_pred <- predict(edu_randomForest, test)

예측 결과
# accuray 체크
table <- table(real=test$satisfaction55, predict=test_pred)
table
##     predict
## real  1  2  3  4  5
##    1  0  0  0  0  0
##    2  0  0  2  1  0
##    3  0  0  5 10  0
##    4  0  1 11 45  5
##    5  0  0  1  9  4
# library(caret)
# confusionMatrix(test_pred, test$satis)

#정확도
(table[1,1]+table[2,2]+table[3,3]+table[4,4]+table[5,5])/sum(table)
## [1] 0.5744681

randomforest & 10-fold CV
# 10-fold CV 데이터 나누기

t_index <- sample(1:nrow(edu), size=nrow(edu))
split_index <- split(t_index, 1:10)
## Warning in split.default(t_index, 1:10): data length is not a multiple of
## split variable
class(split_index)
## [1] "list"
split_index[[1]]
##  [1] 172 199 159  50 312 311  54 162 307 132 272 129  76 126  86 268 198
## [18]  40  60 295  88 204 103  80 183  61 211 196   3  39 301  15
# k-fold CV 수행
accuracy_3 <- c()        # 데이터를 받을 빈 벡터
for(i in 1:10){
  test <- edu[split_index[[i]],]  
  train <- edu[-split_index[[i]],]  

  set.seed(1000)
  edu_randomForest <- randomForest(factor(satisfaction55) ~ facilities80_2+facilities81_3+facilities82_4+
                                     pride85_1+pride86_2+schoollife87_1+schoollife88_2+schoollife89_3+
                                     class94_1+class95_2+class96_3+class97_4+class98_5+class99_6, 
                                   data=train, proximity=TRUE, importance=TRUE)
  plot(edu_randomForest)
  importance(edu_randomForest)    
  varImpPlot(edu_randomForest)     
  
  test_pred <- predict(edu_randomForest, test)
  table <- table(real=test$satisfaction55, predict=test_pred)
  
  #정확도
  accuracy_3[i] <- sum(diag(table))/sum(table)
}

accuracy_3
##  [1] 0.6250000 0.5000000 0.1935484 0.6451613 0.4516129 0.4838710 0.6129032
##  [8] 0.5161290 0.5161290 0.7096774
mean(accuracy_3)  # 10번 평균값
## [1] 0.5254032