randomforest 알고리즘을 이용하여 ggplot2 패키지에 있는 diamonds dataset의 diamond price 예측

  1. 목적: diamond price 예측
  2. 데이터셋 : ggplot2에 포함된 diamonds
  3. 머신러닝 알고리즘 : randomforest
  4. 검증: Metrics 패키지를 이용하여 검증
  5. randomforest의 특징: 정규화도 필요없고, 범주화도 필요없음.
library(ggplot2) # To get diamonds dataset
library(tidyverse) # To use %>% 연산자
summary(diamonds)
     carat               cut        color        clarity          depth      
 Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065   Min.   :43.00  
 1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258   1st Qu.:61.00  
 Median :0.7000   Very Good:12082   F: 9542   SI2    : 9194   Median :61.80  
 Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171   Mean   :61.75  
 3rd Qu.:1.0400   Ideal    :21551   H: 8304   VVS2   : 5066   3rd Qu.:62.50  
 Max.   :5.0100                     I: 5422   VVS1   : 3655   Max.   :79.00  
                                    J: 2808   (Other): 2531                  
     table           price             x                y         
 Min.   :43.00   Min.   :  326   Min.   : 0.000   Min.   : 0.000  
 1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710   1st Qu.: 4.720  
 Median :57.00   Median : 2401   Median : 5.700   Median : 5.710  
 Mean   :57.46   Mean   : 3933   Mean   : 5.731   Mean   : 5.735  
 3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540   3rd Qu.: 6.540  
 Max.   :95.00   Max.   :18823   Max.   :10.740   Max.   :58.900  
                                                                  
       z         
 Min.   : 0.000  
 1st Qu.: 2.910  
 Median : 3.530  
 Mean   : 3.539  
 3rd Qu.: 4.040  
 Max.   :31.800  
                 
str(diamonds)
tibble [53,940 x 10] (S3: tbl_df/tbl/data.frame)
 $ carat  : num [1:53940] 0.23 0.21 0.23 0.29 0.31 0.24 0.24 0.26 0.22 0.23 ...
 $ cut    : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 2 3 3 3 1 3 ...
 $ color  : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 7 6 5 2 5 ...
 $ clarity: Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 2 6 7 3 4 5 ...
 $ depth  : num [1:53940] 61.5 59.8 56.9 62.4 63.3 62.8 62.3 61.9 65.1 59.4 ...
 $ table  : num [1:53940] 55 61 65 58 58 57 57 55 61 61 ...
 $ price  : int [1:53940] 326 326 327 334 335 336 336 337 337 338 ...
 $ x      : num [1:53940] 3.95 3.89 4.05 4.2 4.34 3.94 3.95 4.07 3.87 4 ...
 $ y      : num [1:53940] 3.98 3.84 4.07 4.23 4.35 3.96 3.98 4.11 3.78 4.05 ...
 $ z      : num [1:53940] 2.43 2.31 2.31 2.63 2.75 2.48 2.47 2.53 2.49 2.39 ...

훈련 및 테스트(prediction용) 데이터 분리

idx <- sample(x = nrow(diamonds), size = nrow(diamonds)*0.7, replace = F)
train.df <- diamonds[idx,]
test.df <- diamonds[-idx,]
xtest <- test.df[,-7]
ytest <- test.df[,7]

훈련

library(randomForest)
s_time <- Sys.time()
rf.fit <- randomForest(price ~ .,
                       data = train.df,
                       importance = T)
end_time <- Sys.time();end_time-s_time
Time difference of 10.89859 mins

예측(prediction)

pred <- predict(rf.fit, xtest)

성능 검증

library(Metrics)
# 실제값과 예측값의 절대값 차이(오차) 요약
ae(as.matrix(ytest), as.matrix(pred)) %>% summary()
     price          
 Min.   :    0.019  
 1st Qu.:   38.876  
 Median :  100.838  
 Mean   :  279.214  
 3rd Qu.:  300.378  
 Max.   :10993.738  
# 실제값과 예측값의 절대값 차이 (오차) 비율 요약
ape(as.matrix(ytest), as.matrix(pred)) %>% summary()
     price          
 Min.   :0.0000093  
 1st Qu.:0.0228207  
 Median :0.0492940  
 Mean   :0.0684743  
 3rd Qu.:0.0916389  
 Max.   :2.4151402  
# 실제값과 예측값의 절대값 차이(오차) 평균 
mae(as.matrix(ytest), as.matrix(pred))
[1] 279.2141
# 실제값과 예측값의 절대값 차이(오차) 비율 평균 
mape(as.matrix(ytest), as.matrix(pred)) 
[1] 0.06847432

결론

실제값과 예측값의 절대값 차이가 평균 약 $279 이고, 비율은 약 6.9%이다.
약 6.9%의 오차, 즉, 93.1%의 정확도를 보여준다고 하겠다.