1 Introduction

워싱턴 D.C 소재의 자전거 대여 스타트업 Capital Bikeshare의 데이터를 활용하여, 특정 시간대에 얼마나 많은 사람들이 자전거를 대여하는지 예측하는 것이 목표이다. 사람들이 자전거를 대여하는데는 많은 요소가 관여되어 있을 것이다. 가령 시간(새벽보다 낮에 많이 빌릴 것이다), 날씨(비가 오면 자전거를 대여하지 않을 것이다), 근무일(근무 시간에는 자전거를 대여하지 않을 것이다) 등. 이런 모든 요소를 조합하여 워싱턴 D.C의 자전거 교통량(시간당 자전거 대여량)을 예측하는 것이다.

1.1 Data description

  • datetime - 시간. 연-월-일 시:분:초 로 표현한다. (가령 2011-01-01 00:00:00은 2011년 1월 1일 0시 0분 0초)
  • season - 계절. 봄(1), 여름(2), 가을(3), 겨울(4) 순으로 표현한다.
  • holiday - 공휴일. 1이면 공휴일이며, 0이면 공휴일이 아니다.
  • workingday - 근무일. 1이면 근무일이며, 0이면 근무일이 아니다.
  • weather - 날씨. 1 ~ 4 사이의 값을 가지며, 구체적으로는 다음과 같다.
    • 1: 아주 깨끗한 날씨입니다. 또는 아주 약간의 구름이 끼어있는 경우.
    • 2: 약간의 안개와 구름이 끼어있는 날씨.
    • 3: 약간의 눈, 비가 오거나 천둥이 치는 경우.
    • 4: 아주 많은 비가 오거나 우박이 내리는 경우.
  • temp - 온도. 섭씨(Celsius)로 적혀있음.
  • atemp - 체감 온도. 마찬가지로 섭씨(Celsius)로 적혀있음.
  • humidity - 습도.
  • windspeed - 풍속.
  • casual - 비회원(non-registered)의 자전거 대여량.
  • registered - 회원(registered)의 자전거 대여량.
  • count - 총 자전거 대여랑. 비회원(casual) + 회원(registered)과 동일.

2 Collect the data

2.1 Import libraries & Load data

- Libraries:

suppressMessages(library(ggplot2))
suppressMessages(library(lubridate))
suppressMessages(library(scales))
suppressMessages(library(dplyr))
suppressMessages(library(readr))
suppressMessages(library(sqldf))
suppressMessages(library(randomForest))
suppressMessages(library(kableExtra))

suppressMessages(library(ggthemes))

- Load & Overview data:

data <- read_csv("./input/train.csv")
#str(data)
original data table
datetime season holiday workingday weather temp atemp humidity windspeed casual registered count
2011-01-01 00:00:00 1 0 0 1 9.84 14.395 81 0.0000 3 13 16
2011-01-01 01:00:00 1 0 0 1 9.02 13.635 80 0.0000 8 32 40
2011-01-01 02:00:00 1 0 0 1 9.02 13.635 80 0.0000 5 27 32
2011-01-01 03:00:00 1 0 0 1 9.84 14.395 75 0.0000 3 10 13
2011-01-01 04:00:00 1 0 0 1 9.84 14.395 75 0.0000 0 1 1
2011-01-01 05:00:00 1 0 0 2 9.84 12.880 75 6.0032 0 1 1



- Missing values:

sapply(data, function(x) sum(is.na(x)))
##   datetime     season    holiday workingday    weather       temp 
##          0          0          0          0          0          0 
##      atemp   humidity  windspeed     casual registered      count 
##          0          0          0          0          0          0

3 Preprocessing

3.1 Reshape data

3.1.1 Weather, Times, Hour, Year, Month, Weekday, Quarter, Year_month.

data$weather <- factor(data$weather, labels = c("Good", "Normal", "Bad", "Very Bad"))

data$hour    <- factor(hour(ymd_hms(data$datetime)))
data$times   <- as.POSIXct(strftime(ymd_hms(data$datetime), format="%H:%M:%S"), format="%H:%M:%S")

data$year <- year(ymd_hms(data$datetime))
data$month <- month(ymd_hms(data$datetime))
data$day <- day(ymd_hms(data$datetime))
data$weekday <- wday(ymd_hms(data$datetime), label=TRUE)
data$quarter <- quarter(ymd_hms(data$datetime))
data$year_month <- format(as.Date(data$datetime), "%Y-%m")
modified data table
datetime season holiday workingday weather temp atemp humidity windspeed casual registered count hour times year month day weekday quarter year_month
2011-01-01 00:00:00 1 0 0 Good 9.84 14.395 81 0.0000 3 13 16 0 2018-10-19 09:00:00 2011 1 1 Sat 1 2011-01
2011-01-01 01:00:00 1 0 0 Good 9.02 13.635 80 0.0000 8 32 40 1 2018-10-19 10:00:00 2011 1 1 Sat 1 2011-01
2011-01-01 02:00:00 1 0 0 Good 9.02 13.635 80 0.0000 5 27 32 2 2018-10-19 11:00:00 2011 1 1 Sat 1 2011-01
2011-01-01 03:00:00 1 0 0 Good 9.84 14.395 75 0.0000 3 10 13 3 2018-10-19 12:00:00 2011 1 1 Sat 1 2011-01
2011-01-01 04:00:00 1 0 0 Good 9.84 14.395 75 0.0000 0 1 1 4 2018-10-19 13:00:00 2011 1 1 Sat 1 2011-01
2011-01-01 05:00:00 1 0 0 Normal 9.84 12.880 75 6.0032 0 1 1 5 2018-10-19 14:00:00 2011 1 1 Sat 1 2011-01

3.1.2 Season

데이터를 보면 사계절을 단순하게 4등분-1, 2, 3월을 Spring으로하고 10, 11, 12월을 Winter로 설정-하였는데 잘못된 데이터는 잘못된 결론을 도출하기 때문에 수정할 필요가 있다.

  • Spring : 3 ~ 5월
  • Summer : 6 ~ 8월
  • Fall : 9 ~ 11월
  • Winter : 12 ~ 2월
data$season  <- factor(data$season, labels = c("Spring", "Summer", "Fall", "Winter"))
data$season_2 <- ifelse(data$month == 12 | data$month == 1 | data$month == 2, 'Winter',
                      ifelse(data$month == 3 | data$month == 4 | data$month == 5, 'Spring',
                            ifelse(data$month == 6 | data$month == 7 | data$month == 8, 'Summer', 'Fall')))

- original season

sqldf("
select
    season
    , month
    , count(*) as freq
from data
where season = 'Spring'
group by 1, 2
")
##   season month freq
## 1 Spring     1  884
## 2 Spring     2  901
## 3 Spring     3  901
table(data$season)
## 
## Spring Summer   Fall Winter 
##   2686   2733   2733   2734



- modified season

sqldf("
select
    season_2
    , month
    , count(*) as freq
from data
where season_2 = 'Spring'
group by 1, 2
")
##   season_2 month freq
## 1   Spring     3  901
## 2   Spring     4  909
## 3   Spring     5  912
table(data$season_2)
## 
##   Fall Spring Summer Winter 
##   2731   2722   2736   2697

4 EDA (Exploratory Data Analysis)

4.1 EDA 1 : Year

#table(data$year)
#table(data$year_month)

year_summary <- data %>%
                  group_by(year) %>%
                  summarise(count = mean(count))

year_summary_2 <- data %>%
                    group_by(year, hour) %>%
                    summarise(count = mean(count))

#year_summary_2 <- sqldf("
#select
#    year
#    , hour
#    , count(*) as freq
#    , avg(count) as avg
#from data
#group by 1, 2
#")

year_summary
## # A tibble: 2 x 2
##    year count
##   <dbl> <dbl>
## 1  2011  144.
## 2  2012  239.
#year_summary_2 %>% head()

year_summary$year <- factor(year_summary$year)
year_summary_2$year <-factor(year_summary_2$year)
  
ggplot(data = year_summary, aes(x = year, y = count)) +
  geom_bar(stat = "identity") +
  theme_minimal() +
  ggtitle("Bar chart of year") + 
  theme(plot.title=element_text(size=18))

ggplot(data, aes(x = hour, y = count, colour = year)) +
  geom_point(data = year_summary_2, aes(group = year)) +
  geom_line(data = year_summary_2, aes(group = year)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count of year") +
  theme_minimal() +
  ggtitle("Line chart of year") + 
  theme(plot.title=element_text(size=18))

4.2 EDA 2 : Month

#table(data$month)
month_summary <- data %>%
                  group_by(month) %>%
                  summarise(count = mean(count))

month_summary_2 <- data %>%
                    group_by(month, hour) %>%
                    summarise(count = mean(count))
  
#month_summary
#month_summary_2 %>% head()
month_summary$month <- factor(month_summary$month)
month_summary_2$month <- factor(month_summary_2$month)

ggplot(data = month_summary, aes(x = month, y = count)) +
  geom_bar(stat = "identity") + 
  scale_x_discrete("Month") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("Bar chart of month") +
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

ggplot(data, aes(x = hour, y = count, colour = month)) +
  geom_point(data = month_summary_2, aes(group = month)) +
  geom_line(data = month_summary_2, aes(group = month)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count of month") +
  theme_minimal() +
  ggtitle("Line chart of month") + 
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

4.3 EDA 3 : Year Month

#table(data$year_month)

year_month_summary <- data %>%
                        group_by(season_2, year_month) %>%
                        summarise(count = mean(count))

#year_month_summary

ggplot(data = year_month_summary, aes(x = year_month, y = count)) +
  geom_bar(stat = "identity") + 
  scale_x_discrete("Year Month") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("Bar chart of year month") +
  theme(axis.text.x=element_text(angle=60, hjust=1), plot.title=element_text(size=18))

4.4 EDA 4 : Hour

#table(data$hour)

wrk_hour_summary <- data %>%
                    group_by(hour) %>%
                    summarise(count = mean(count))

ggplot(wrk_hour_summary, aes(x = hour, y = count)) +
  geom_bar(stat = "identity") +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("Bar chart of hour") +
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

#   geom_point(data = wrk_day_summary, aes(group = quarter)) +
#   geom_line(data = wrk_day_summary, aes(group = quarter)) +
#   scale_x_discrete("Hour") +
#   scale_y_continuous("Prob of Bad") +
#   theme_minimal() +
#   ggtitle("The probability of Bad weather is higher in Summer and Winter. \n") + 
#   theme(plot.title=element_text(size=18))

4.5 EDA 5 : Working day

#table(data$workingday)

wrk_day_summary <- data %>%
                    group_by(workingday) %>%
                    summarise(count = mean(count))
        
wrk_day_summary_2 <- data %>%
                    group_by(workingday, hour) %>%
                    summarise(count = mean(count))

wrk_day_summary$workingday <- factor(wrk_day_summary$workingday)
wrk_day_summary_2$workingday <- factor(wrk_day_summary_2$workingday)

#ggplot(data = wrk_day_summary, aes(x = workingday, y = count)) +
#  geom_bar(stat = "identity") +
#  theme_minimal() +
#  ggtitle("Bar chart of working day") + 
#  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

#ggplot(data = wrk_day_summary_2, aes(x = hour, y = count, fill = workingday)) +
  #ggtitle("Time series by working day") + xlab("Hour") + ylab("Frequency") +
#  geom_text(aes(label=count), vjust=-0.5) + theme_economist() +
#    theme(plot.title=element_text(hjust=0.5), axis.title=element_text(size=12, face="bold"), 
#        axis.text.x=element_text(size=12, angle=90), legend.position="null") + 
  #theme_minimal() +
  #geom_point(data = wrk_day_summary_2, aes(group = workingday)) +
  #geom_line(data = wrk_day_summary_2, aes(group = workingday)) +
  #theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))


ggplot(wrk_day_summary_2, aes(x = hour, y = count, colour = workingday)) +
  geom_point(data = wrk_day_summary_2, aes(group = workingday)) +
  geom_line(data = wrk_day_summary_2, aes(group = workingday)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count of working day") +
  theme_minimal() +
  ggtitle("Line chart of working day") + 
  theme(plot.title=element_text(size=18))

4.6 EDA 6 : Quarter

#table(data$quarter)

quarter_summary <- data %>%
                    group_by(quarter) %>%
                    summarise(count = mean(count))

quarter_summary_2 <- data %>%
                    group_by(quarter, hour) %>%
                    summarise(count = mean(count))

ggplot(data = quarter_summary, aes(x = quarter, y = count)) +
  geom_bar(stat = "identity") +
  theme_minimal() +
  ggtitle("Bar chart of quarter") + 
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

ggplot(quarter_summary_2, aes(x = hour, y = count, colour = factor(quarter))) +
  geom_point(data = quarter_summary_2, aes(group = quarter)) +
  geom_line(data = quarter_summary_2, aes(group = quarter)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count of quarter") +
  theme_minimal() +
  ggtitle("Line chart of quarter") + 
  theme(plot.title=element_text(size=18))

4.7 EDA 7 : Season

People rent bikes more in Summer and Fall, and much less in Winter.

#table(data$season)

season_summary <- data %>%
                    group_by(season_2) %>%
                    summarise(count = mean(count))

season_summary_1 <- data %>%
                    group_by(season, hour) %>%
                    summarise(count = mean(count))

season_summary_2 <- data %>%
                    group_by(season_2, hour) %>%
                    summarise(count = mean(count))

ggplot(data = season_summary, aes(x = season_2, y = count)) +
  geom_bar(stat = "identity") +
  theme_minimal() +
  ggtitle("Bar chart of season") + 
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

ggplot(data, aes(x = hour, y = count, colour = season)) +
  geom_point(data = season_summary_1, aes(group = season)) +
  geom_line(data = season_summary_1, aes(group = season)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("Wrong line chart of season") + 
  theme(plot.title=element_text(size=18))

ggplot(data, aes(x = hour, y = count, colour = season_2)) +
  geom_point(data = season_summary_2, aes(group = season_2)) +
  geom_line(data = season_summary_2, aes(group = season_2)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("Line chart of season")+ 
  theme(plot.title=element_text(size=18))

4.8 EDA 8 : Weather

table(data$weather)
## 
##     Good   Normal      Bad Very Bad 
##     7192     2834      859        1
weather_summary <- data %>%
                    group_by(weather) %>%
                    summarise(count = n())
weather_summary_2 <- data %>%
                    group_by(weather, hour) %>%
                    summarise(count = mean(count))

weather_summary
## # A tibble: 4 x 2
##   weather  count
##   <fct>    <int>
## 1 Good      7192
## 2 Normal    2834
## 3 Bad        859
## 4 Very Bad     1
ggplot(data = weather_summary , aes(x = weather, y = count)) +
  geom_bar(stat="identity") + 
  theme_minimal() +
  ggtitle("Bar chart of weather") + 
  theme(axis.text.x=element_text(angle=0, hjust=1), plot.title=element_text(size=18))

ggplot(data, aes(x = hour, y = count, colour = weather)) +
  geom_point(data = weather_summary_2, aes(group = weather)) +
  geom_line(data = weather_summary_2, aes(group = weather)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count") +
  theme_minimal() +
   ggtitle("People rent bikes more when the weather is Good.\n") + 
  theme(plot.title=element_text(size=18))

4.9 EDA 9 : Good weather

weather_prob <- data %>%
                  group_by(season_2, hour) %>%
                  summarise(good = mean(weather == 'Good'),
                            normal = mean(weather == 'Normal'),
                            bad = mean(weather == 'Bad'),
                            very_bad = mean(weather == 'Very Bad'))

weather_prob %>% head()
## # A tibble: 6 x 6
## # Groups:   season_2 [1]
##   season_2 hour   good normal    bad very_bad
##   <chr>    <fct> <dbl>  <dbl>  <dbl>    <dbl>
## 1 Fall     0     0.658  0.272 0.0702        0
## 2 Fall     1     0.637  0.265 0.0973        0
## 3 Fall     2     0.681  0.265 0.0531        0
## 4 Fall     3     0.658  0.279 0.0631        0
## 5 Fall     4     0.649  0.307 0.0439        0
## 6 Fall     5     0.605  0.342 0.0526        0
ggplot(data, aes(x = hour, y = good, colour = season_2)) +
  geom_point(data = weather_prob, aes(group = season_2)) +
  geom_line(data = weather_prob, aes(group = season_2)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Prob of Good") +
  theme_minimal() +
  ggtitle("The probability of Good weather is higher in all. \n") + 
  theme(plot.title=element_text(size=18))

4.10 EDA 10 : Normal weather

ggplot(data, aes(x = hour, y = normal, colour = season_2)) +
  geom_point(data = weather_prob, aes(group = season_2)) +
  geom_line(data = weather_prob, aes(group = season_2)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Prob of Normal") +
  theme_minimal() +
  ggtitle("The probability of Normal weather is higher in Spring. \n") + 
  theme(plot.title=element_text(size=18))

4.11 EDA 11 : Bad weather

ggplot(data, aes(x = hour, y = bad, colour = season_2)) +
  geom_point(data = weather_prob, aes(group = season_2)) +
  geom_line(data = weather_prob, aes(group = season_2)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Prob of Bad") +
  theme_minimal() +
  ggtitle("The probability of Bad weather is higher in Summer and Winter. \n") + 
  theme(plot.title=element_text(size=18))

4.12 EDA 12 : Weekday

#table(data$weekday)

day_summary <- data %>%
                group_by(weekday) %>%
                summarise(count = n())

day_summary_2 <- data %>%
                group_by(weekday, hour) %>%
                summarise(count = mean(count))
day_summary
## # A tibble: 7 x 2
##   weekday count
##   <ord>   <int>
## 1 Sun      1579
## 2 Mon      1551
## 3 Tue      1539
## 4 Wed      1551
## 5 Thu      1553
## 6 Fri      1529
## 7 Sat      1584
ggplot(data = day_summary, aes(x = weekday, y = count)) +
  geom_bar(stat = "identity") +
  theme_minimal() +
  ggtitle("Bar chart of weekday")

ggplot(data, aes(x = hour, y = count, colour = weekday)) +
  geom_point(data = day_summary_2, aes(group=weekday)) +
  geom_line(data = day_summary_2, aes(group=weekday)) +
  scale_x_discrete("Hour") +
  scale_y_continuous("Count") +
  theme_minimal() +
  ggtitle("People rent bikes for morning/evening commutes on weekdays, and daytime rides on weekends\n")

5 Modeling

Regression Tree

Random Forest

5.0.1 Load data & Check missing value

set.seed(1)

train <- read.csv("./input/train.csv")
test <- read.csv("./input/test.csv")

colSums(is.na(train))
##   datetime     season    holiday workingday    weather       temp 
##          0          0          0          0          0          0 
##      atemp   humidity  windspeed     casual registered      count 
##          0          0          0          0          0          0
colSums(is.na(test))
##   datetime     season    holiday workingday    weather       temp 
##          0          0          0          0          0          0 
##      atemp   humidity  windspeed 
##          0          0          0

5.0.2 Preprocessing

## Reshape train data
### Weather, Times, Hour, Year, Month, Weekday, Quarter, Year_month.
train$hour    <- factor(hour(ymd_hms(train$datetime)))
train$times   <- as.POSIXct(strftime(ymd_hms(train$datetime), format="%H:%M:%S"), format="%H:%M:%S")
train$weekday <- as.numeric(wday(ymd_hms(train$datetime), label=TRUE))
train$year_month <- format(as.Date(train$datetime), "%Y-%m")
train$year <- year(train$datetime)
train$month <- month(train$datetime)
train$date <- day(train$datetime)
train$quarter <- quarter(train$datetime)

train$season <- ifelse(train$month == 12 | train$month == 1 | train$month == 2, 4,
                      ifelse(train$month == 3 | train$month == 4 | train$month == 5, 1,
                            ifelse(train$month == 6 | train$month == 7 | train$month == 8, 2, 3)))


test$hour    <- factor(hour(ymd_hms(test$datetime)))
test$times   <- as.POSIXct(strftime(ymd_hms(test$datetime), format="%H:%M:%S"), format="%H:%M:%S")
test$weekday <- as.numeric(wday(ymd_hms(test$datetime), label=TRUE))
test$year_month <- format(as.Date(test$datetime), "%Y-%m")
test$year <- year(test$datetime)
test$month <- month(test$datetime)
test$date <- day(test$datetime)
test$quarter <- quarter(test$datetime)

test$season <- ifelse(test$month == 12 | test$month == 1 | test$month == 2, 4,
                      ifelse(test$month == 3 | test$month == 4 | test$month == 5, 1,
                            ifelse(test$month == 6 | test$month == 7 | test$month == 8, 2, 3)))

5.0.3 Feature selection

extractFeatures <- function(data) {
  features <- c("season",
                "holiday",
                "workingday",
                "weather",
                "temp",
                "atemp",
                "humidity",
                "windspeed",
                "hour",
                "weekday",
                "quarter",
                "month",
                "date"
                )
  data$hour <- hour(ymd_hms(data$datetime))
  return(data[,features])
}

trainFea <- extractFeatures(train)
testFea  <- extractFeatures(test)
#submission <- data.frame(datetime=test$datetime, count=NA)

# We only use past data to make predictions on the test set, 
# so we train a new model for each test set cutoff point
#for (i_year in unique(year(ymd_hms(test$datetime)))) {
#  for (i_month in unique(month(ymd_hms(test$datetime)))) {
#    cat("Year: ", i_year, "\tMonth: ", i_month, "\n")
#    testLocs   <- year(ymd_hms(test$datetime))==i_year & month(ymd_hms(test$datetime))==i_month
#    testSubset <- test[testLocs,]
#    trainLocs  <- ymd_hms(train$datetime) <= min(ymd_hms(testSubset$datetime))
#    rf <- randomForest(extractFeatures(train[trainLocs,]), train[trainLocs,"count"], ntree=100)
#    submission[testLocs, "count"] <- predict(rf, extractFeatures(testSubset))
#  }
#}

#write.csv(submission, file = "./1_random_forest_submission.csv", row.names=FALSE)

5.0.4 Training model

# Train a model across all the training data and plot the variable importance
rf <- randomForest(extractFeatures(train), train$count, ntree=100, importance=TRUE)
imp <- importance(rf, type=1)
featureImportance <- data.frame(Feature=row.names(imp), Importance=imp[,1])

5.0.4.1 Measure model & Feature importance

rf
## 
## Call:
##  randomForest(x = extractFeatures(train), y = train$count, ntree = 100,      importance = TRUE) 
##                Type of random forest: regression
##                      Number of trees: 100
## No. of variables tried at each split: 4
## 
##           Mean of squared residuals: 4444.307
##                     % Var explained: 86.45
featureImportance
##               Feature Importance
## season         season  12.734589
## holiday       holiday   9.180699
## workingday workingday  42.917577
## weather       weather  21.129722
## temp             temp  21.815391
## atemp           atemp  18.102083
## humidity     humidity  29.206170
## windspeed   windspeed  20.639961
## hour             hour  96.407430
## weekday       weekday  26.494560
## quarter       quarter  12.354449
## month           month  18.851967
## date             date  26.139195
varImpPlot(rf)

p <- ggplot(featureImportance, aes(x=reorder(Feature, Importance), y=Importance)) +
     geom_bar(stat="identity", fill="#53cfff") +
     coord_flip() + 
     theme_light(base_size=20) +
     xlab("Importance") +
     ylab("") + 
     ggtitle("Random Forest Feature Importance\n") +
     theme(plot.title=element_text(size=18))

p

ggsave("2_feature_importance.png", p)

XGBoost