랜덤포레스트 & K-fold CV
패키지 불러오기
# echo=TRUE : 코드를 평가하고 실행결과를 포함한다
# eval=TRUE : 실행결과와 함게 코드를 출력한다
# message=FALSE : 메시지를 출력한다
# warning=TRUE : 경고메시지를 출력한다
# error=FALSE : 오류메시지를 출력한다
# tidy=FALSE : 깔끔한 방식으로 코드 형태를 변형한다
library(readxl)
library(randomForest)
데이터 불러오기
setwd("C:/Users/user/Desktop/krivet")
edu <- read_excel("data_merge_3_na.xlsx")
# 종속변수= satisfaction55, 설명변수 21개
str(edu)
## Classes 'tbl_df', 'tbl' and 'data.frame': 312 obs. of 22 variables:
## $ satisfaction55: num 4 4 4 5 4 4 4 4 2 3 ...
## $ facilities79_1: num 4 4 4 4 3 1 4 3 1 3 ...
## $ facilities80_2: num 4 4 4 2 3 2 3 4 1 3 ...
## $ facilities81_3: num 4 4 5 3 3 3 3 4 1 3 ...
## $ facilities82_4: num 4 4 5 3 3 2 3 4 1 3 ...
## $ ralation84_1 : num 4 4 4 3 3 1 3 4 1 3 ...
## $ pride85_1 : num 4 4 4 3 4 3 4 4 1 3 ...
## $ pride86_2 : num 4 4 5 3 4 3 4 4 1 3 ...
## $ schoollife87_1: num 4 4 4 3 4 4 2 4 1 3 ...
## $ schoollife88_2: num 4 4 4 3 3 1 4 4 1 3 ...
## $ schoollife89_3: num 3 4 4 3 3 2 4 4 1 3 ...
## $ schoollife90_4: num 4 4 3 3 4 3 4 4 1 3 ...
## $ schoollife91_5: num 4 4 4 3 4 4 4 4 1 3 ...
## $ schoollife92_6: num 4 4 4 3 4 4 4 4 3 4 ...
## $ schoollife93_7: num 4 4 4 1 3 3 3 4 1 3 ...
## $ class94_1 : num 4 4 4 3 4 4 4 4 3 3 ...
## $ class95_2 : num 4 4 4 3 4 4 4 4 2 3 ...
## $ class96_3 : num 4 4 4 3 4 3 4 4 2 3 ...
## $ class97_4 : num 4 4 4 3 4 3 4 4 2 4 ...
## $ class98_5 : num 4 4 4 3 4 5 4 4 3 4 ...
## $ class99_6 : num 4 4 4 3 4 5 3 4 3 4 ...
## $ prof102_1 : num 1 1 2 2 1 1 1 1 1 1 ...
summary(edu)
## satisfaction55 facilities79_1 facilities80_2 facilities81_3
## Min. :1.000 Min. :1.000 Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000
## Median :4.000 Median :4.000 Median :4.000 Median :4.000
## Mean :3.872 Mean :3.583 Mean :3.506 Mean :3.609
## 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :5.000 Max. :5.000 Max. :5.000 Max. :5.000
## facilities82_4 ralation84_1 pride85_1 pride86_2
## Min. :1.000 Min. :1.000 Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000
## Median :4.000 Median :4.000 Median :3.000 Median :4.000
## Mean :3.622 Mean :3.712 Mean :3.465 Mean :3.679
## 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :5.000 Max. :5.000 Max. :5.000 Max. :5.000
## schoollife87_1 schoollife88_2 schoollife89_3 schoollife90_4
## Min. :1.000 Min. :1.000 Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000
## Median :3.000 Median :3.000 Median :3.000 Median :4.000
## Mean :3.359 Mean :3.442 Mean :3.279 Mean :3.583
## 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :5.000 Max. :5.000 Max. :5.000 Max. :5.000
## schoollife91_5 schoollife92_6 schoollife93_7 class94_1
## Min. :1.000 Min. :1.000 Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:2.000 1st Qu.:3.000
## Median :3.000 Median :4.000 Median :3.000 Median :4.000
## Mean :3.388 Mean :3.619 Mean :2.955 Mean :3.766
## 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :5.000 Max. :5.000 Max. :5.000 Max. :5.000
## class95_2 class96_3 class97_4 class98_5
## Min. :1.000 Min. :1.000 Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000 1st Qu.:3.000
## Median :4.000 Median :4.000 Median :4.000 Median :4.000
## Mean :3.756 Mean :3.846 Mean :3.798 Mean :3.715
## 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000 3rd Qu.:4.000
## Max. :5.000 Max. :5.000 Max. :5.000 Max. :5.000
## class99_6 prof102_1
## Min. :1.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:1.000
## Median :4.000 Median :2.000
## Mean :3.827 Mean :1.721
## 3rd Qu.:4.000 3rd Qu.:2.000
## Max. :5.000 Max. :5.000
edu$satisfaction55 <- as.factor(edu$satisfaction55)
데이터 분할
t_index <- sample(1:nrow(edu), size=nrow(edu)*0.7, replace = FALSE)
train <- edu[t_index, ]
test <- edu[-t_index, ]
nrow(train); nrow(test)
## [1] 218
## [1] 94
분석 및 예측
set.seed(1)
edu_randomForest <- randomForest(factor(satisfaction55) ~ facilities80_2+facilities81_3+facilities82_4+
pride85_1+pride86_2+schoollife87_1+schoollife88_2+schoollife89_3+
class94_1+class95_2+class96_3+class97_4+class98_5+class99_6, data=train, proximity=TRUE, importance=TRUE)
plot(edu_randomForest)

# %IncMSE = 정확도, IncNodePurity = 중요도
# MeanDecreaseAccuracy: 정확도 개선에 중요한 변수
# MeanDecreaseGini: 노드 불순도 개선에 중요한 변수
importance(edu_randomForest)
## 1 2 3 4 5
## facilities80_2 0 -0.9464267 0.1517930 8.6273066 11.7048013
## facilities81_3 0 -3.3046838 2.9852777 0.9437173 0.4121795
## facilities82_4 0 -0.4248770 0.5661825 5.0905662 6.4906925
## pride85_1 0 -1.3778181 3.4785249 4.2502497 6.1264505
## pride86_2 0 -3.8168220 3.1070652 7.6336569 12.6631886
## schoollife87_1 0 2.3992904 1.1511860 8.0344269 4.0874725
## schoollife88_2 0 4.3266546 0.5859650 1.5150991 5.1889757
## schoollife89_3 0 -1.4814880 7.7631961 5.5224931 -0.4930810
## class94_1 0 1.8467407 -2.4637416 7.0674265 11.2717868
## class95_2 0 3.6840805 -2.3225191 2.7444230 10.2886263
## class96_3 0 4.6506770 -9.3187080 13.4996125 3.4900580
## class97_4 0 1.3720509 -4.9852370 5.8693690 7.0003686
## class98_5 0 1.8145727 2.7735706 12.5986453 -2.2467331
## class99_6 0 3.2241147 7.2696713 5.7779856 7.8288561
## MeanDecreaseAccuracy MeanDecreaseGini
## facilities80_2 12.376340 10.278381
## facilities81_3 2.248871 7.282547
## facilities82_4 7.294985 8.326825
## pride85_1 8.417995 8.912493
## pride86_2 13.940546 10.599365
## schoollife87_1 8.870375 9.281399
## schoollife88_2 4.220253 8.103095
## schoollife89_3 8.307398 8.523642
## class94_1 10.621491 7.107757
## class95_2 7.106756 5.943027
## class96_3 9.110116 6.473564
## class97_4 6.059553 6.239225
## class98_5 11.626670 7.426254
## class99_6 12.866490 8.639981
varImpPlot(edu_randomForest)

test_pred <- predict(edu_randomForest, test)
예측 결과
# accuray 체크
table <- table(real=test$satisfaction55, predict=test_pred)
table
## predict
## real 1 2 3 4 5
## 1 0 0 0 0 0
## 2 0 0 2 1 0
## 3 0 0 5 10 0
## 4 0 1 11 45 5
## 5 0 0 1 9 4
# library(caret)
# confusionMatrix(test_pred, test$satis)
#정확도
(table[1,1]+table[2,2]+table[3,3]+table[4,4]+table[5,5])/sum(table)
## [1] 0.5744681
randomforest & 10-fold CV
# 10-fold CV 데이터 나누기
t_index <- sample(1:nrow(edu), size=nrow(edu))
split_index <- split(t_index, 1:10)
## Warning in split.default(t_index, 1:10): data length is not a multiple of
## split variable
class(split_index)
## [1] "list"
split_index[[1]]
## [1] 172 199 159 50 312 311 54 162 307 132 272 129 76 126 86 268 198
## [18] 40 60 295 88 204 103 80 183 61 211 196 3 39 301 15
# k-fold CV 수행
accuracy_3 <- c() # 데이터를 받을 빈 벡터
for(i in 1:10){
test <- edu[split_index[[i]],]
train <- edu[-split_index[[i]],]
set.seed(1000)
edu_randomForest <- randomForest(factor(satisfaction55) ~ facilities80_2+facilities81_3+facilities82_4+
pride85_1+pride86_2+schoollife87_1+schoollife88_2+schoollife89_3+
class94_1+class95_2+class96_3+class97_4+class98_5+class99_6,
data=train, proximity=TRUE, importance=TRUE)
plot(edu_randomForest)
importance(edu_randomForest)
varImpPlot(edu_randomForest)
test_pred <- predict(edu_randomForest, test)
table <- table(real=test$satisfaction55, predict=test_pred)
#정확도
accuracy_3[i] <- sum(diag(table))/sum(table)
}




















accuracy_3
## [1] 0.6250000 0.5000000 0.1935484 0.6451613 0.4516129 0.4838710 0.6129032
## [8] 0.5161290 0.5161290 0.7096774
mean(accuracy_3) # 10번 평균값
## [1] 0.5254032