options(rpubs.upload.method = "internal")This is an R Markdown Notebook. When you execute code within the notebook, the results appear beneath the code. Execute chunks by clicking the Run button within the chunk or by placing your cursor inside it and pressing Ctrl+Shift+Enter.
This RMD file gives a walkthrough of my solution for Grab’s SEA Traffic Management Challenge.
Economies in Southeast Asia are turning to AI to solve traffic congestion, which hinders mobility and economic growth. The first step in the push towards alleviating traffic congestion is to understand travel demand and travel patterns within the city. In this challenge, I forecast travel demand based on historical Grab bookings to predict areas and times with high travel demand
## Installing Required libraries and initializing custom functions
list.of.packages <- c("dplyr","lubridate","caret","ggplot2", "RcppRoll","xgboost","Matrix")
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
if(length(new.packages)) install.packages(new.packages)
invisible(lapply(list.of.packages, library, character.only = TRUE))
source("geohash.R")
# source("day_max.R")
source("xgb_data_prep.R")The dataset given comprises of the following features: - geohash6: public domain geocoding system which encodes a geographic location into a short string - day: Sequential order of days - timestamp: start time of 15-minute intervals, in the following format:
traffic_data = read.csv("training.csv",stringsAsFactors = F)
str(traffic_data)## 'data.frame': 4206321 obs. of 4 variables:
## $ geohash6 : chr "qp03wc" "qp03pn" "qp09sw" "qp0991" ...
## $ day : int 18 10 9 32 15 1 25 51 48 4 ...
## $ timestamp: chr "20:0" "14:30" "6:15" "5:0" ...
## $ demand : num 0.0201 0.0247 0.1028 0.0888 0.0745 ...
any(is.na(traffic_data))## [1] FALSE
no_unique_locs = length(unique(traffic_data$geohash6))
no_days = length(unique(traffic_data$day))
expected_records = no_unique_locs*no_days*24*4 #
actual_records = nrow(traffic_data)
missing_number = expected_records-actual_records
data.frame(no_unique_locs,no_days,expected_records,actual_records,missing_number)## no_unique_locs no_days expected_records actual_records missing_number
## 1 1329 61 7782624 4206321 3576303
The number of missing day-hour-bracket pairs for diifferent locations is high. To fix this we can create a commbination of all locations, days, hours and minutes and join with traffic_data sets as and when required
min_day = min(traffic_data$day)
max_day = max(traffic_data$day)
all_comb = data.frame(expand.grid(unique(traffic_data$geohash6),c(min_day:max_day),c(0:23),c(0,15,30,45),stringsAsFactors = F))
colnames(all_comb) = c("geohash6","day","hour","minute")
head(all_comb %>% arrange(geohash6,day,hour,minute),8)## geohash6 day hour minute
## 1 qp02yc 1 0 0
## 2 qp02yc 1 0 15
## 3 qp02yc 1 0 30
## 4 qp02yc 1 0 45
## 5 qp02yc 1 1 0
## 6 qp02yc 1 1 15
## 7 qp02yc 1 1 30
## 8 qp02yc 1 1 45
traffic_data= traffic_data %>% mutate(timestamp_list = hm(timestamp),hour= timestamp_list@hour,minute = timestamp_list@minute)
head(traffic_data,5)## geohash6 day timestamp demand timestamp_list hour minute
## 1 qp03wc 18 20:0 0.02007179 20H 0M 0S 20 0
## 2 qp03pn 10 14:30 0.02472097 14H 30M 0S 14 30
## 3 qp09sw 9 6:15 0.10282096 6H 15M 0S 6 15
## 4 qp0991 32 5:0 0.08875480 5H 0M 0S 5 0
## 5 qp090q 15 4:0 0.07446839 4H 0M 0S 4 0
# Remove timestamp and timestamp_list as it's redundant now
traffic_data = traffic_data %>% select(-timestamp,-timestamp_list)Now, we join all_comb with traffic_data and replace NA demands with 0
cat("\n join all combinations of time-demand pairs for each location")##
## join all combinations of time-demand pairs for each location
traffic_all = all_comb %>% left_join(traffic_data,by=c("geohash6","day","hour","minute"))
traffic_all$demand[which(is.na(traffic_all$demand))] = 0Okay, so with our data prepared, we can now look at different features that we can engineer.
(assumption: Seq days start from 1 year before today (i.e. Sys.Date()))
unique_days = data.frame(day=unique(traffic_data$day))
#Assuming days started 1 year back from now
start_day = Sys.Date()-365
unique_days = unique_days %>% mutate(Date = start_day+day,day_label=wday(Date)) %>% select(-Date)
traffic_all = traffic_all %>% left_join(unique_days,by=c("day")) cat("\n Feature Engineering")##
## Feature Engineering
traffic_all <- traffic_all %>%
group_by(geohash6) %>% arrange(day,hour,minute) %>%
mutate(lag_1 = lag(demand, 1),lag_4 = lag(demand,4), lag_day = lag(demand,4*24), lag_week = lag(demand,4*24*7), roll_avg_week = lag(roll_meanr(demand, 7*24*4), 1), roll_avg_day = lag(roll_meanr(demand, 24*4), 1),roll_avg_hour= lag(roll_meanr(demand, 4,na.rm = T), 1)
) %>% ungroup()
cat("\n Get complete data and recency of purchases")##
## Get complete data and recency of purchases
traffic_all = traffic_all %>% mutate(time_min = ((day*24*60)+(hour*60)+minute)) %>% group_by(geohash6) %>% arrange(geohash6,time_min) %>% mutate(recency = c(0,diff(time_min))) %>% ungroup()
traffic_all = traffic_all[complete.cases(traffic_all),]
cat("\n Fix timestamp")##
## Fix timestamp
# traffic_all = traffic_all %>% left_join(unique_days%>% select(day,day_label),by=c("day"))
traffic_all = traffic_all %>% mutate( timestamp = (Sys.Date()-365)+day)
hour(traffic_all$timestamp) = traffic_all$hour
minute(traffic_all$timestamp) = traffic_all$minute
traffic_all$timestamp=as.POSIXct(traffic_all$timestamp)
traffic_all = traffic_all %>% group_by(geohash6) %>% arrange(geohash6,timestamp) %>% ungroup()
head(traffic_all,3)## # A tibble: 3 x 16
## geohash6 day hour minute demand day_label lag_1 lag_4 lag_day lag_week
## <chr> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 qp02yc 8 0 0 0 3 0 0 0 0
## 2 qp02yc 8 0 15 0 3 0 0 0 0
## 3 qp02yc 8 0 30 0 3 0 0 0 0
## # … with 6 more variables: roll_avg_week <dbl>, roll_avg_day <dbl>,
## # roll_avg_hour <dbl>, time_min <dbl>, recency <dbl>, timestamp <dttm>
Now that we have our data prepared, we can split it into train and test data and start modeling. We take the first 48 days for training and validation and the rest for test (60-20-20% split)
train_ind = round(0.6*length(unique(traffic_data$day)))
val_ind = round(0.8*length(unique(traffic_data$day)))
train_all = traffic_all %>% filter(day<=train_ind)
val_all = traffic_all %>% filter(day>train_ind & day<=val_ind)
test_all = traffic_all %>% filter(day>val_ind)
rm(traffic_data,traffic_all)We have already calculated rolling_average for the week,day and hour while preparing the data. However, this is just a regression problem now. On checking RMSE now
print(paste0("The RMSE of demand vs roll_avg_week in train set = ",RMSE(train_all$demand,train_all$roll_avg_week)))## [1] "The RMSE of demand vs roll_avg_week in train set = 0.0714265656242179"
print(paste0("The RMSE of demand vs roll_avg_day in train set = ",RMSE(train_all$demand,train_all$roll_avg_day)))## [1] "The RMSE of demand vs roll_avg_day in train set = 0.069551493702308"
print(paste0("The RMSE of demand vs roll_avg_hour in train set = ",RMSE(train_all$demand,train_all$roll_avg_hour)))## [1] "The RMSE of demand vs roll_avg_hour in train set = 0.0285096210741496"
print(paste0("The RMSE of demand vs roll_avg_week in val set = ",RMSE(val_all$demand,val_all$roll_avg_week)))## [1] "The RMSE of demand vs roll_avg_week in val set = 0.0761336828001388"
print(paste0("The RMSE of demand vs roll_avg_day in val set = ",RMSE(val_all$demand,val_all$roll_avg_day)))## [1] "The RMSE of demand vs roll_avg_day in val set = 0.074389368218852"
print(paste0("The RMSE of demand vs roll_avg_hour in val set = ",RMSE(val_all$demand,val_all$roll_avg_hour)))## [1] "The RMSE of demand vs roll_avg_hour in val set = 0.0298792366635775"
print(paste0("The RMSE of demand vs roll_avg_week in test set = ",RMSE(test_all$demand,test_all$roll_avg_week)))## [1] "The RMSE of demand vs roll_avg_week in test set = 0.0814582597811666"
print(paste0("The RMSE of demand vs roll_avg_day in test set = ",RMSE(test_all$demand,test_all$roll_avg_day)))## [1] "The RMSE of demand vs roll_avg_day in test set = 0.0793905303080019"
print(paste0("The RMSE of demand vs roll_avg_hour in test set = ",RMSE(test_all$demand,test_all$roll_avg_hour)))## [1] "The RMSE of demand vs roll_avg_hour in test set = 0.0312413495269609"
We see that the RMSE between demand and roll_avg_hour is the lowest. Now, we need to experiment with a model that can do better
label <- train_all$demand
#Returns object unchanged if there are NA values
previous_na_action<- options('na.action')
options(na.action='na.pass')
#Build matrix input for the model
trainMatrix <- sparse.model.matrix(~ geohash6+day+hour+minute+lag_1+lag_4+lag_day+lag_week+roll_avg_hour+roll_avg_day+roll_avg_week+day_label+recency
, data = train_all
, contrasts.arg = c('geohash6', 'day','hour','minute','day_label')
, sparse = FALSE, sci = FALSE)
options(na.action = previous_na_action$na.action)
trainDMatrix <- xgb.DMatrix(data = trainMatrix, label = label)3-fold CV
params <- list(booster = "gbtree"
, objective = "reg:linear"
, eta=0.6
, gamma=0
,random_state = 2
)
xgb.tab <- xgb.cv(data = trainDMatrix
, param = params
, maximize = FALSE, evaluation = "rmse", nrounds = 100
, nthreads = 10, nfold = 3, early_stopping_round = 10)## [1] train-rmse:0.186594+0.000003 test-rmse:0.186696+0.000007
## Multiple eval metrics are present. Will use test_rmse for early stopping.
## Will train until test_rmse hasn't improved in 10 rounds.
##
## [2] train-rmse:0.077724+0.000016 test-rmse:0.077750+0.000011
## [3] train-rmse:0.037539+0.000019 test-rmse:0.037627+0.000028
## [4] train-rmse:0.025688+0.000017 test-rmse:0.025827+0.000049
## [5] train-rmse:0.023138+0.000031 test-rmse:0.023297+0.000040
## [6] train-rmse:0.022615+0.000031 test-rmse:0.022796+0.000053
## [7] train-rmse:0.022472+0.000045 test-rmse:0.022664+0.000040
## [8] train-rmse:0.022400+0.000044 test-rmse:0.022603+0.000042
## [9] train-rmse:0.022346+0.000036 test-rmse:0.022565+0.000047
## [10] train-rmse:0.022288+0.000037 test-rmse:0.022527+0.000049
## [11] train-rmse:0.022248+0.000039 test-rmse:0.022498+0.000045
## [12] train-rmse:0.022206+0.000034 test-rmse:0.022470+0.000051
## [13] train-rmse:0.022166+0.000035 test-rmse:0.022439+0.000048
## [14] train-rmse:0.022131+0.000035 test-rmse:0.022417+0.000047
## [15] train-rmse:0.022105+0.000038 test-rmse:0.022402+0.000047
## [16] train-rmse:0.022068+0.000039 test-rmse:0.022377+0.000049
## [17] train-rmse:0.022027+0.000055 test-rmse:0.022349+0.000031
## [18] train-rmse:0.022006+0.000059 test-rmse:0.022340+0.000028
## [19] train-rmse:0.021974+0.000061 test-rmse:0.022330+0.000032
## [20] train-rmse:0.021949+0.000055 test-rmse:0.022315+0.000034
## [21] train-rmse:0.021925+0.000051 test-rmse:0.022300+0.000035
## [22] train-rmse:0.021899+0.000051 test-rmse:0.022283+0.000036
## [23] train-rmse:0.021876+0.000049 test-rmse:0.022273+0.000040
## [24] train-rmse:0.021852+0.000050 test-rmse:0.022260+0.000034
## [25] train-rmse:0.021826+0.000052 test-rmse:0.022249+0.000030
## [26] train-rmse:0.021804+0.000051 test-rmse:0.022239+0.000028
## [27] train-rmse:0.021780+0.000055 test-rmse:0.022232+0.000033
## [28] train-rmse:0.021766+0.000057 test-rmse:0.022227+0.000034
## [29] train-rmse:0.021742+0.000061 test-rmse:0.022225+0.000034
## [30] train-rmse:0.021720+0.000061 test-rmse:0.022219+0.000035
## [31] train-rmse:0.021702+0.000061 test-rmse:0.022211+0.000037
## [32] train-rmse:0.021686+0.000063 test-rmse:0.022204+0.000037
## [33] train-rmse:0.021664+0.000068 test-rmse:0.022194+0.000034
## [34] train-rmse:0.021648+0.000063 test-rmse:0.022191+0.000034
## [35] train-rmse:0.021633+0.000057 test-rmse:0.022186+0.000040
## [36] train-rmse:0.021597+0.000041 test-rmse:0.022178+0.000040
## [37] train-rmse:0.021580+0.000043 test-rmse:0.022173+0.000038
## [38] train-rmse:0.021561+0.000040 test-rmse:0.022167+0.000043
## [39] train-rmse:0.021548+0.000037 test-rmse:0.022160+0.000046
## [40] train-rmse:0.021533+0.000039 test-rmse:0.022156+0.000050
## [41] train-rmse:0.021500+0.000025 test-rmse:0.022131+0.000065
## [42] train-rmse:0.021484+0.000025 test-rmse:0.022124+0.000067
## [43] train-rmse:0.021469+0.000023 test-rmse:0.022124+0.000067
## [44] train-rmse:0.021453+0.000022 test-rmse:0.022119+0.000065
## [45] train-rmse:0.021424+0.000013 test-rmse:0.022096+0.000075
## [46] train-rmse:0.021408+0.000008 test-rmse:0.022090+0.000077
## [47] train-rmse:0.021393+0.000009 test-rmse:0.022088+0.000078
## [48] train-rmse:0.021379+0.000012 test-rmse:0.022082+0.000079
## [49] train-rmse:0.021366+0.000013 test-rmse:0.022080+0.000079
## [50] train-rmse:0.021354+0.000007 test-rmse:0.022074+0.000078
## [51] train-rmse:0.021340+0.000006 test-rmse:0.022068+0.000077
## [52] train-rmse:0.021327+0.000007 test-rmse:0.022063+0.000078
## [53] train-rmse:0.021315+0.000005 test-rmse:0.022060+0.000080
## [54] train-rmse:0.021298+0.000006 test-rmse:0.022058+0.000082
## [55] train-rmse:0.021282+0.000003 test-rmse:0.022057+0.000080
## [56] train-rmse:0.021272+0.000003 test-rmse:0.022055+0.000081
## [57] train-rmse:0.021258+0.000003 test-rmse:0.022055+0.000082
## [58] train-rmse:0.021249+0.000006 test-rmse:0.022055+0.000082
## [59] train-rmse:0.021239+0.000009 test-rmse:0.022052+0.000083
## [60] train-rmse:0.021229+0.000009 test-rmse:0.022050+0.000082
## [61] train-rmse:0.021219+0.000005 test-rmse:0.022047+0.000083
## [62] train-rmse:0.021207+0.000004 test-rmse:0.022046+0.000085
## [63] train-rmse:0.021196+0.000005 test-rmse:0.022047+0.000087
## [64] train-rmse:0.021180+0.000011 test-rmse:0.022049+0.000085
## [65] train-rmse:0.021173+0.000013 test-rmse:0.022050+0.000085
## [66] train-rmse:0.021161+0.000015 test-rmse:0.022046+0.000082
## [67] train-rmse:0.021149+0.000015 test-rmse:0.022045+0.000083
## [68] train-rmse:0.021136+0.000016 test-rmse:0.022045+0.000079
## [69] train-rmse:0.021126+0.000014 test-rmse:0.022044+0.000078
## [70] train-rmse:0.021115+0.000013 test-rmse:0.022043+0.000077
## [71] train-rmse:0.021105+0.000011 test-rmse:0.022040+0.000076
## [72] train-rmse:0.021097+0.000012 test-rmse:0.022041+0.000077
## [73] train-rmse:0.021082+0.000013 test-rmse:0.022039+0.000078
## [74] train-rmse:0.021073+0.000012 test-rmse:0.022039+0.000078
## [75] train-rmse:0.021064+0.000011 test-rmse:0.022040+0.000079
## [76] train-rmse:0.021054+0.000016 test-rmse:0.022037+0.000078
## [77] train-rmse:0.021044+0.000015 test-rmse:0.022036+0.000077
## [78] train-rmse:0.021033+0.000014 test-rmse:0.022032+0.000077
## [79] train-rmse:0.021025+0.000012 test-rmse:0.022029+0.000078
## [80] train-rmse:0.021014+0.000013 test-rmse:0.022029+0.000075
## [81] train-rmse:0.021005+0.000013 test-rmse:0.022028+0.000076
## [82] train-rmse:0.020996+0.000016 test-rmse:0.022028+0.000074
## [83] train-rmse:0.020982+0.000015 test-rmse:0.022030+0.000073
## [84] train-rmse:0.020973+0.000018 test-rmse:0.022029+0.000071
## [85] train-rmse:0.020962+0.000019 test-rmse:0.022025+0.000071
## [86] train-rmse:0.020952+0.000015 test-rmse:0.022025+0.000071
## [87] train-rmse:0.020943+0.000016 test-rmse:0.022025+0.000070
## [88] train-rmse:0.020936+0.000017 test-rmse:0.022024+0.000070
## [89] train-rmse:0.020928+0.000018 test-rmse:0.022024+0.000072
## [90] train-rmse:0.020922+0.000018 test-rmse:0.022022+0.000073
## [91] train-rmse:0.020913+0.000021 test-rmse:0.022022+0.000074
## [92] train-rmse:0.020904+0.000024 test-rmse:0.022023+0.000076
## [93] train-rmse:0.020896+0.000025 test-rmse:0.022023+0.000077
## [94] train-rmse:0.020887+0.000026 test-rmse:0.022023+0.000076
## [95] train-rmse:0.020878+0.000026 test-rmse:0.022023+0.000075
## [96] train-rmse:0.020869+0.000024 test-rmse:0.022022+0.000075
## [97] train-rmse:0.020859+0.000025 test-rmse:0.022020+0.000075
## [98] train-rmse:0.020850+0.000023 test-rmse:0.022020+0.000076
## [99] train-rmse:0.020843+0.000024 test-rmse:0.022021+0.000076
## [100] train-rmse:0.020833+0.000024 test-rmse:0.022021+0.000076
num_iterations = xgb.tab$best_iteration
model <- xgb.train(data = trainDMatrix
, param = params
, maximize = FALSE, evaluation = 'rmse', nrounds = num_iterations)
xgb.save(model,"xgb_model")## [1] TRUE
importance <- xgb.importance(feature_names = colnames(trainMatrix), model = model)
xgb.ggplot.importance(importance_matrix = importance[1:10,])pred <- predict(model, trainDMatrix)
pred[pred < 0] <- 0
train_all$pred = pred
print(paste0("RMSE of XGBoost Regression wrt demand = ",RMSE(train_all$pred,train_all$demand)))## [1] "RMSE of XGBoost Regression wrt demand = 0.0210711390059585"
ggplot(data = train_all %>% filter(geohash6==train_all$geohash6[2]), aes(x=timestamp,y=demand) )+geom_line()+geom_line(aes(y=roll_avg_week),color="green") + geom_line(aes(y=roll_avg_day),color="blue") + geom_line(aes(y=pred),color="dark red")+geom_line(aes(y=roll_avg_hour),color="cyan") valMatrix <- sparse.model.matrix(~ geohash6+day+hour+minute+lag_1+lag_4+lag_day+lag_week+roll_avg_hour+roll_avg_day+roll_avg_week+day_label+recency
, data = val_all
, contrasts.arg = c('geohash6', 'day','hour','minute','day_label')
, sparse = FALSE, sci = FALSE)
options(na.action = previous_na_action$na.action)
# valDMatrix <- xgb.DMatrix(data = valMatrix, label = val_all$demand)
pred <- predict(model, valMatrix)
pred[pred < 0] <- 0
val_all$pred = pred
RMSE(val_all$pred,val_all$demand)## [1] 0.02482054
print(paste0("RMSE of XGBoost Regression wrt demand in val_set = ",RMSE(val_all$pred,val_all$demand)))## [1] "RMSE of XGBoost Regression wrt demand in val_set = 0.024820537581574"
ggplot(data = val_all %>% filter(geohash6==val_all$geohash6[2]), aes(x=timestamp,y=demand) )+geom_line()+geom_line(aes(y=roll_avg_week),color="green") + geom_line(aes(y=roll_avg_day),color="blue") + geom_line(aes(y=pred),color="dark red") +geom_line(aes(y=roll_avg_hour),color="cyan") testMatrix <- sparse.model.matrix(~ geohash6+day+hour+minute+lag_1+lag_4+lag_day+lag_week+roll_avg_hour+roll_avg_day+roll_avg_week+day_label+recency
, data = test_all
, contrasts.arg = c('geohash6', 'day','hour','minute','day_label')
, sparse = FALSE, sci = FALSE)
options(na.action = previous_na_action$na.action)
# testDMatrix <- xgb.DMatrix(data = testMatrix, label = test_all$demand)
model2 = xgb.load("xgb_model")
pred <- predict(model2, testMatrix)
pred[pred < 0] <- 0
test_all$pred = pred
print(paste0("RMSE of XGBoost Regression wrt demand in test_set = ",RMSE(test_all$pred,test_all$demand)))## [1] "RMSE of XGBoost Regression wrt demand in test_set = 0.0256797150841778"
ggplot(data = test_all %>% filter(geohash6==test_all$geohash6[2]), aes(x=timestamp,y=demand) )+geom_point()+geom_line(aes(y=roll_avg_week),color="green") + geom_line(aes(y=roll_avg_day),color="blue") + geom_line(aes(y=pred),color="dark red") +geom_line(aes(y=roll_avg_hour),color="cyan") To forecast future values, we need to forecast reccursively using a for loop. One of the difficulties in this problem is we must predict sales for the next 5 time points. The current data set only has the lag and rolling mean values for T. To predict for the entire test data set, we must recursively add the predictions to the test data set then use the updated data set to predict for the next day. One issue is that errors will propagate through the model which can decrease the accuracy of longer-term predictions. But we only need to predict for T+5.
Here we define a function forecast_n that takes in data and predicts n_future steps using the defined modeling method.
forecast_n = function(train_set, method="xg",n_pred = 5){
train_set= train_set %>% mutate(timestamp_list = hm(timestamp),hour= timestamp_list@hour,minute = timestamp_list@minute)
# Remove timestamp and timestamp_list as it's redundant now
train_set = train_set %>% select(-timestamp,-timestamp_list)
df_test = train_set %>% mutate( timestamp = (Sys.Date()-365)+day) # change back to 365 after test
hour(df_test$timestamp) = df_test$hour
minute(df_test$timestamp) = df_test$minute
df_test$timestamp = as.POSIXct(df_test$timestamp)
max_date = max(df_test$timestamp)
start_day
dates = seq(from = max_date+15*60,to=(max_date+15*60*n_pred+1),by="15 min") #start from max_time+15 mins till max time + n_pred*15 minutes
# df_test = rbind(df_test,all_dats,fill=T)
for(i in 1:length(dates)){
cat(paste0("Predicting for ", dates[i]))
all_dats = data.frame(expand.grid(geohash6=unique(df_test$geohash6),timestamp=dates[i],stringsAsFactors = F))
all_dats = all_dats %>% mutate(day = as.numeric(round(as.Date(timestamp)-(start_day))),hour = hour(timestamp),minute = minute(timestamp))
all_dats$demand = NA
# hour(all_dats$timestamp) = all_dats$hour
# minute(all_dats$timestamp) = all_dats$minute
df_test = df_test %>% select(geohash6,timestamp,day,hour,minute,demand) %>% ungroup()
df_bind = rbind(df_test,all_dats)
#df_test = data.frame(rbind(df_test %>% select(geohash6,timestamp,day,hour,minute,demand)%>% ungroup(),all_dats)) %>% unique()
print(paste0("No. of rows in dftest = ",nrow(df_bind)))
df_bind$demand[which(is.na(df_bind$demand))] = 0
df_test_c <- df_bind %>%
group_by(geohash6) %>% arrange(day,hour,minute) %>%
mutate(lag_1 = lag(demand, 1),lag_4 = lag(demand,4), lag_day = lag(demand,4*24), lag_week = lag(demand,4*24*7), roll_avg_week = lag(roll_meanr(demand, 7*24*4), 1), roll_avg_day = lag(roll_meanr(demand, 24*4), 1),roll_avg_hour= lag(roll_meanr(demand, 4,na.rm = T), 1)
) %>% ungroup() %>% group_by(geohash6,day) %>% #mutate(mean_day = mean(demand,na.rm = T), freq_day=n()) %>%
ungroup() %>%
group_by(geohash6,hour) %>%# mutate(mean_hour = mean(demand,na.rm = T),freq_hour=n()) %>% ungroup() %>%
mutate(day_label = wday(timestamp),time_min = ((day*24*60)+(hour*60)+minute)) %>% group_by(geohash6) %>% arrange(geohash6,time_min) %>% mutate(recency = c(0,diff(time_min))) %>% ungroup() %>%
arrange(timestamp)
df_test_try = df_test_c#[complete.cases(df_test_c),]
cat(paste0("data prep for " , dates[i] , " complete"))
check = df_test_try %>% filter(day == 31) %>% arrange(timestamp)
pred_data = df_test_try %>% filter(timestamp>=dates[i])
# pred_data = pred_data %>% mutate(time_min = ((day*24*60)+(hour*60)+minute)) %>% group_by(geohash6) %>% arrange(geohash6,time_min) %>% mutate(recency = c(0,diff(time_min))) %>% ungroup()
pred_data = pred_data %>% select(geohash6,day,hour,minute,demand,lag_1,lag_4,lag_day,lag_week,roll_avg_week,roll_avg_day,roll_avg_hour,time_min,recency,day_label,timestamp)
pred_data[is.na(pred_data)] <- 0
# check = pred_data %>% filter(day == 31) %>% arrange(geohash6) %>% arrange(geohash6)
if(method == "xg"){
predMatrix <- sparse.model.matrix(~ geohash6+day+hour+minute+lag_1+lag_4+lag_day+lag_week+roll_avg_hour+roll_avg_day+roll_avg_week+day_label+recency
, data = pred_data
, contrasts.arg = c('geohash6', 'day','hour','minute','day_label')
, sparse = FALSE, sci = FALSE)
options(na.action = previous_na_action$na.action)
model = xgb.load("xgb_model")
pred <- predict(model, predMatrix)
pred[pred < 0] <- 0
pred_data$pred = pred
rm(pred)
}
if(method=="roll_hour"){
pred_data$pred = pred_data$roll_avg_hour
}
if(method=="roll_day"){
pred_data$pred = pred_data$roll_avg_day
}
if(method=="roll_week"){
pred_data$pred = pred_data$roll_avg_week
}
sols = data.frame(pred_data %>% select(geohash6,timestamp,day,hour,minute,roll_avg_hour,roll_avg_day,roll_avg_week,pred),stringsAsFactors = F)
df_preds = df_bind %>% group_by(geohash6,timestamp,day,hour,minute) %>% left_join(sols,by=c("geohash6","timestamp","day","hour","minute"))
df_test = df_preds %>% mutate(demand=ifelse(is.na(pred),demand,pred)) %>% select(-roll_avg_hour,-roll_avg_day,-roll_avg_week,-pred)
}
df_test = df_test %>% mutate(future_val = ifelse(timestamp>=dates[1],1,0))
return(df_test)
}Calling the function on validation set
filename = paste0("test/",list.files("test/")) # Test data, replace test.csv with your test file
if(!filename=="test/"){
test_data = read.csv(filename,stringsAsFactors = F)
xgpred_test = forecast_n(test_data) %>% rename(xgb_pred = demand)
# write.csv(xgpred_test,"solution/xgb_solution.csv")
xgpred_hour = forecast_n(test_data,method = "roll_hour") %>% rename(roll_hour_pred = demand, future_val_hour = future_val)
xgpred_day = forecast_n(test_data,method = "roll_day")%>% rename(roll_day_pred = demand, future_val_day = future_val)
xgpred_week = forecast_n(test_data,method = "roll_week")%>% rename(roll_week_pred = demand, future_val_week = future_val)
#Since this is from val set, we can check against first few values of test set
df_all_pred = xgpred_test %>% left_join(xgpred_hour,by=c("geohash6","timestamp","day","hour","minute"))%>% left_join(xgpred_day,by=c("geohash6","timestamp","day","hour","minute"))%>% left_join(xgpred_week,by=c("geohash6","timestamp","day","hour","minute"))
test_few = test_all %>% select(geohash6,timestamp,day,hour,minute,demand)
df_all_pred = df_all_pred %>% left_join(test_few,by=c("geohash6","timestamp","day","hour","minute"))
write.csv(df_all_pred,"all_pred.csv")
p = ggplot(df_all_pred %>% filter(geohash6==df_all_pred$geohash6[1]),aes(x = timestamp,y = xgb_pred ),colour = "red") + geom_line() + geom_line(aes(x = timestamp,y = roll_hour_pred ),colour = "blue")+geom_line(aes(x = timestamp,y = roll_day_pred ),colour = "green")+ geom_line(aes(x = timestamp,y = roll_week_pred ),colour = "cyan") + geom_line(aes(x = timestamp,y = demand ),colour = "black")
}else{p=NULL
print("Enter file in test folder")
}## Predicting for 2018-08-07[1] "No. of rows in dftest = 865955"
## data prep for 2018-08-07 completePredicting for 2018-08-07 00:15:00[1] "No. of rows in dftest = 867271"
## data prep for 2018-08-07 00:15:00 completePredicting for 2018-08-07 00:30:00[1] "No. of rows in dftest = 868587"
## data prep for 2018-08-07 00:30:00 completePredicting for 2018-08-07 00:45:00[1] "No. of rows in dftest = 869903"
## data prep for 2018-08-07 00:45:00 completePredicting for 2018-08-07 01:00:00[1] "No. of rows in dftest = 871219"
## data prep for 2018-08-07 01:00:00 completePredicting for 2018-08-07[1] "No. of rows in dftest = 865955"
## data prep for 2018-08-07 completePredicting for 2018-08-07 00:15:00[1] "No. of rows in dftest = 867271"
## data prep for 2018-08-07 00:15:00 completePredicting for 2018-08-07 00:30:00[1] "No. of rows in dftest = 868587"
## data prep for 2018-08-07 00:30:00 completePredicting for 2018-08-07 00:45:00[1] "No. of rows in dftest = 869903"
## data prep for 2018-08-07 00:45:00 completePredicting for 2018-08-07 01:00:00[1] "No. of rows in dftest = 871219"
## data prep for 2018-08-07 01:00:00 completePredicting for 2018-08-07[1] "No. of rows in dftest = 865955"
## data prep for 2018-08-07 completePredicting for 2018-08-07 00:15:00[1] "No. of rows in dftest = 867271"
## data prep for 2018-08-07 00:15:00 completePredicting for 2018-08-07 00:30:00[1] "No. of rows in dftest = 868587"
## data prep for 2018-08-07 00:30:00 completePredicting for 2018-08-07 00:45:00[1] "No. of rows in dftest = 869903"
## data prep for 2018-08-07 00:45:00 completePredicting for 2018-08-07 01:00:00[1] "No. of rows in dftest = 871219"
## data prep for 2018-08-07 01:00:00 completePredicting for 2018-08-07[1] "No. of rows in dftest = 865955"
## data prep for 2018-08-07 completePredicting for 2018-08-07 00:15:00[1] "No. of rows in dftest = 867271"
## data prep for 2018-08-07 00:15:00 completePredicting for 2018-08-07 00:30:00[1] "No. of rows in dftest = 868587"
## data prep for 2018-08-07 00:30:00 completePredicting for 2018-08-07 00:45:00[1] "No. of rows in dftest = 869903"
## data prep for 2018-08-07 00:45:00 completePredicting for 2018-08-07 01:00:00[1] "No. of rows in dftest = 871219"
## data prep for 2018-08-07 01:00:00 complete
p## Warning: Removed 121 rows containing missing values (geom_path).