How much camping gear will one store sell each month in a year? To the uninitiated, calculating sales at this level may seem as difficult as predicting the weather. Both types of forecasting rely on science and historical data. While a wrong weather forecast may result in you carrying around an umbrella on a sunny day, inaccurate business forecasts could result in actual or opportunity losses. In this project, in addition to traditional forecasting methods we also challenge to use machine learning to improve forecast accuracy.
We use hierarchical sales data from Walmart, the world’s largest company by revenue, to forecast daily sales for the next 28 days. The data, covers stores in three US States (California, Texas, and Wisconsin) and includes item level, department, product categories, and store details. In addition, it has explanatory variables such as price, promotions, day of the week, and special events. Together, this robust dataset can be used to improve forecasting accuracy.
The data: We are working with 42,840 hierarchical time series. The data were obtained in the 3 US states of California (CA), Texas (TX), and Wisconsin (WI). “Hierarchical” here means that data can be aggregated on different levels: item level, department level, product category level, and state level. The sales information reaches back from Jan 2011 to June 2016. In addition to the sales numbers, we are also given corresponding data on prices, promotions, and holidays. Note, that we have been warned that most of the time series contain zero values.
The data comprises 3049 individual products from 3 categories and 7 departments, sold in 10 stores in 3 states. The hierachical aggregation captures the combinations of these factors. For instance, we can create 1 time series for all sales, 3 time series for all sales per state, and so on. The largest category is sales of all individual 3049 products per 10 stores for 30490 time series.
The training data comes in the shape of 3 separate files:
sales_train.csv: this is our main training data. It has 1 column for each of the 1941 days from 2011-01-29 and 2016-05-22; not including the validation period of 28 days until 2016-06-19. It also includes the IDs for item, department, category, store, and state. The number of rows is 30490 for all combinations of 30490 items and 10 stores.
sell_prices.csv: the store and item IDs together with the sales price of the item as a weekly average.
calendar.csv: dates together with related features like day-of-the week, month, year, and an 3 binary flags for whether the stores in each state allowed purchases with SNAP food stamps at this date (1) or not (0).
library('dplyr') # data manipulation##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library('vroom') # input/output
library('readr') # input/output## Registered S3 methods overwritten by 'readr':
## method from
## format.col_spec vroom
## print.col_spec vroom
## print.collector vroom
## print.date_names vroom
## print.locale vroom
## str.col_spec vroom
##
## Attaching package: 'readr'
## The following objects are masked from 'package:vroom':
##
## as.col_spec, col_character, col_date, col_datetime, col_double,
## col_factor, col_guess, col_integer, col_logical, col_number,
## col_skip, col_time, cols, cols_condense, cols_only, date_names,
## date_names_lang, date_names_langs, default_locale, fwf_cols,
## fwf_empty, fwf_positions, fwf_widths, locale, output_column,
## problems, spec
library('stringr') # string manipulation
library('tidyr') # data wrangling
library('purrr') # data wrangling
library('kableExtra') # display##
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
##
## group_rows
library('ggplot2') # visualisation
library('ggthemes') # visualisation
path <- '/Users/Suen/Downloads/input/m5-forecasting-accuracy/'
train <- vroom(str_c(path,'sales_train_validation.csv'), delim = ",", col_types = cols())
prices <- vroom(str_c(path,'sell_prices.csv'), delim = ",", col_types = cols())
calendar <- read_csv(str_c(path,'calendar.csv'), col_types = cols())
sample_submit <- vroom(str_c(path,'sample_submission.csv'), delim = ",", col_types = cols())As a first step let’s have a quick look of the data sets using the head, summary, and glimpse tools where appropriate.
Here are the first 10 columns and rows of the our training sales data:
train %>%
select(seq(1,10,1)) %>%
head(10) %>%
kable() %>%
kable_styling()| id | item_id | dept_id | cat_id | store_id | state_id | d_1 | d_2 | d_3 | d_4 |
|---|---|---|---|---|---|---|---|---|---|
| HOBBIES_1_001_CA_1_validation | HOBBIES_1_001 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_002_CA_1_validation | HOBBIES_1_002 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_003_CA_1_validation | HOBBIES_1_003 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_004_CA_1_validation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_005_CA_1_validation | HOBBIES_1_005 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_006_CA_1_validation | HOBBIES_1_006 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_007_CA_1_validation | HOBBIES_1_007 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 0 | 0 |
| HOBBIES_1_008_CA_1_validation | HOBBIES_1_008 | HOBBIES_1 | HOBBIES | CA_1 | CA | 12 | 15 | 0 | 0 |
| HOBBIES_1_009_CA_1_validation | HOBBIES_1_009 | HOBBIES_1 | HOBBIES | CA_1 | CA | 2 | 0 | 7 | 3 |
| HOBBIES_1_010_CA_1_validation | HOBBIES_1_010 | HOBBIES_1 | HOBBIES | CA_1 | CA | 0 | 0 | 1 | 0 |
We find:
There is one column each for the IDs of item, department, category, store, and state; plus a general ID that is a combination of the other IDs plus a flag for validation.
The sales per date are encoded as columns starting with the prefix d_. Those are the number of units sold per day (not the total amount of dollars).
We already see that there are quite a lot of zero values.
This data set has too many columns and rows to display them all:
c(ncol(train),nrow(train))## [1] 1919 30490
All IDs are marked as “validation”. This refers to the intial validation testing period of the competition, before we ultimately predict a different 28-days period.
train %>%
mutate(dset = if_else(str_detect(id, "validation"), "validation", "training")) %>%
count(dset)## # A tibble: 1 x 2
## dset n
## <chr> <int>
## 1 validation 30490
This data set gives us the weekly price changes per item:
prices %>%
head(10) %>%
kable() %>%
kable_styling()| store_id | item_id | wm_yr_wk | sell_price |
|---|---|---|---|
| CA_1 | HOBBIES_1_001 | 11325 | 9.58 |
| CA_1 | HOBBIES_1_001 | 11326 | 9.58 |
| CA_1 | HOBBIES_1_001 | 11327 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11328 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11329 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11330 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11331 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11332 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11333 | 8.26 |
| CA_1 | HOBBIES_1_001 | 11334 | 8.26 |
summary(prices)## store_id item_id wm_yr_wk sell_price
## Length:6841121 Length:6841121 Min. :11101 Min. : 0.010
## Class :character Class :character 1st Qu.:11247 1st Qu.: 2.180
## Mode :character Mode :character Median :11411 Median : 3.470
## Mean :11383 Mean : 4.411
## 3rd Qu.:11517 3rd Qu.: 5.840
## Max. :11621 Max. :107.320
We find:
We have the store_id and item_id to link this data to our training and validation data.
Prices range from $0.10 to a bit more than 100 dollars.
The calendar data gives us date features such as weekday, month, or year; alongside 2 different event features and a SNAP food stamps flag:
calendar %>%
head(8) %>%
kable() %>%
kable_styling()| date | wm_yr_wk | weekday | wday | month | year | d | event_name_1 | event_type_1 | event_name_2 | event_type_2 | snap_CA | snap_TX | snap_WI |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2011-01-29 | 11101 | Saturday | 1 | 1 | 2011 | d_1 | NA | NA | NA | NA | 0 | 0 | 0 |
| 2011-01-30 | 11101 | Sunday | 2 | 1 | 2011 | d_2 | NA | NA | NA | NA | 0 | 0 | 0 |
| 2011-01-31 | 11101 | Monday | 3 | 1 | 2011 | d_3 | NA | NA | NA | NA | 0 | 0 | 0 |
| 2011-02-01 | 11101 | Tuesday | 4 | 2 | 2011 | d_4 | NA | NA | NA | NA | 1 | 1 | 0 |
| 2011-02-02 | 11101 | Wednesday | 5 | 2 | 2011 | d_5 | NA | NA | NA | NA | 1 | 0 | 1 |
| 2011-02-03 | 11101 | Thursday | 6 | 2 | 2011 | d_6 | NA | NA | NA | NA | 1 | 1 | 1 |
| 2011-02-04 | 11101 | Friday | 7 | 2 | 2011 | d_7 | NA | NA | NA | NA | 1 | 0 | 0 |
| 2011-02-05 | 11102 | Saturday | 1 | 2 | 2011 | d_8 | NA | NA | NA | NA | 1 | 1 | 1 |
glimpse(calendar)## Rows: 1,969
## Columns: 14
## $ date <date> 2011-01-29, 2011-01-30, 2011-01-31, 2011-02-01, 2011-02-…
## $ wm_yr_wk <dbl> 11101, 11101, 11101, 11101, 11101, 11101, 11101, 11102, 1…
## $ weekday <chr> "Saturday", "Sunday", "Monday", "Tuesday", "Wednesday", "…
## $ wday <dbl> 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, …
## $ month <dbl> 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
## $ year <dbl> 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 201…
## $ d <chr> "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8", "…
## $ event_name_1 <chr> NA, NA, NA, NA, NA, NA, NA, NA, "SuperBowl", NA, NA, NA, …
## $ event_type_1 <chr> NA, NA, NA, NA, NA, NA, NA, NA, "Sporting", NA, NA, NA, N…
## $ event_name_2 <chr> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ event_type_2 <chr> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ snap_CA <dbl> 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ snap_TX <dbl> 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, …
## $ snap_WI <dbl> 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, …
summary(calendar)## date wm_yr_wk weekday wday
## Min. :2011-01-29 Min. :11101 Length:1969 Min. :1.000
## 1st Qu.:2012-06-04 1st Qu.:11219 Class :character 1st Qu.:2.000
## Median :2013-10-09 Median :11337 Mode :character Median :4.000
## Mean :2013-10-09 Mean :11347 Mean :3.997
## 3rd Qu.:2015-02-13 3rd Qu.:11502 3rd Qu.:6.000
## Max. :2016-06-19 Max. :11621 Max. :7.000
## month year d event_name_1
## Min. : 1.000 Min. :2011 Length:1969 Length:1969
## 1st Qu.: 3.000 1st Qu.:2012 Class :character Class :character
## Median : 6.000 Median :2013 Mode :character Mode :character
## Mean : 6.326 Mean :2013
## 3rd Qu.: 9.000 3rd Qu.:2015
## Max. :12.000 Max. :2016
## event_type_1 event_name_2 event_type_2 snap_CA
## Length:1969 Length:1969 Length:1969 Min. :0.0000
## Class :character Class :character Class :character 1st Qu.:0.0000
## Mode :character Mode :character Mode :character Median :0.0000
## Mean :0.3301
## 3rd Qu.:1.0000
## Max. :1.0000
## snap_TX snap_WI
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000
## Mean :0.3301 Mean :0.3301
## 3rd Qu.:1.0000 3rd Qu.:1.0000
## Max. :1.0000 Max. :1.0000
We find:
The calendar has all the relevant dates, weekdays, months plus snap binary flags and logical event columns.
There are only 5 non-NA rows in the event_name_2 column; i.e. only 5 (out of 1969) instances where there is more than 1 event on a particular day.
There are no missing values in our sales training data:
sum(is.na(train))## [1] 0
However, there are a lot of zero values, here we plot the distribution of zero percentages among all time series:
bar <- train %>%
select(-contains("id")) %>%
na_if(0) %>%
is.na() %>%
as_tibble() %>%
mutate(sum = pmap_dbl(select(., everything()), sum)) %>%
mutate(mean = sum/(ncol(train) - 1)) %>%
select(sum, mean)
bar %>%
ggplot(aes(mean)) +
geom_density(fill = "blue") +
scale_x_continuous(labels = scales::percent) +
coord_cartesian(xlim = c(0, 1)) +
theme_hc() +
theme(axis.text.y = element_blank()) +
labs(x = "", y = "", title = "Density for percentage of zero values - all time series")Fig. 1
This means that only a minority of time series have less than 50% of zero values. The peak is rather close to 100%
This part contains EDA of Walmart stores that located in different states. It aims to answers following questions:
1. Which state has the highest sales?
2. Which department has the highest sales?
3. The number of products in different categories.
4. The distribution of sales on weekdays.
5. The sales trend on holiday and events.
#Library
library(data.table)##
## Attaching package: 'data.table'
## The following object is masked from 'package:purrr':
##
## transpose
## The following objects are masked from 'package:dplyr':
##
## between, first, last
library(dplyr)
library(reshape2)##
## Attaching package: 'reshape2'
## The following objects are masked from 'package:data.table':
##
## dcast, melt
## The following object is masked from 'package:tidyr':
##
## smiths
library(lubridate)##
## Attaching package: 'lubridate'
## The following objects are masked from 'package:data.table':
##
## hour, isoweek, mday, minute, month, quarter, second, wday, week,
## yday, year
## The following objects are masked from 'package:base':
##
## date, intersect, setdiff, union
library(sqldf)## Loading required package: gsubfn
## Loading required package: proto
## Warning in doTryCatch(return(expr), name, parentenv, handler): unable to load shared object '/Library/Frameworks/R.framework/Resources/modules//R_X11.so':
## dlopen(/Library/Frameworks/R.framework/Resources/modules//R_X11.so, 6): Library not loaded: /opt/X11/lib/libSM.6.dylib
## Referenced from: /Library/Frameworks/R.framework/Resources/modules//R_X11.so
## Reason: image not found
## Could not load tcltk. Will use slower R code instead.
## Loading required package: RSQLite
library(splitstackshape)
library(ggplot2)
library(ggthemes)
library(ggpubr)
library(timetk)##
## Attaching package: 'timetk'
## The following object is masked from 'package:data.table':
##
## :=
#Read dataset
train <- fread(str_c(path,"sales_train_evaluation.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
calendar <- fread(str_c(path,"calendar.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
prices <- fread(str_c(path,"sell_prices.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
submission = fread(str_c(path,'sample_submission.csv'))prices[, c('sell_price')] <- sapply(prices[, c('sell_price')], as.numeric)
calendar$date <- as.Date(calendar$date)
calendar$date<-as.Date(calendar$date)
calendar$event_type_1<- as.character(calendar$event_type_1)
calendar$event_type_2<- as.character(calendar$event_type_2)
calendar$event_name_1<- as.character(calendar$event_name_1)
calendar$event_name_2<- as.character(calendar$event_name_2)
## Replacing null values of event_type_1 into "NONE"
calendar <- calendar %>%
mutate(event_type_1 = replace(event_type_1, is.na(event_type_1), "None"))%>%
mutate(event_type_2 = replace(event_type_2, is.na(event_type_2), "None"))%>%
mutate(event_name_1 = replace(event_name_1, is.na(event_name_1), "None"))%>%
mutate(event_name_2 = replace(event_name_2, is.na(event_name_2), "None"))
## Reverting class of the event_type_1 into factor
calendar$event_type_1 <- as.factor(calendar$event_type_1)
calendar$event_type_2 <- as.factor(calendar$event_type_2)
calendar$event_name_1 <- as.factor(calendar$event_name_1)
calendar$event_name_2 <- as.factor(calendar$event_name_2)
newdf <- sqldf("SELECT DISTINCT event_name_1,event_type_1,event_type_2,event_name_2 FROM calendar")
train <- reshape2::melt(train,id.vars = c("id", "item_id", "dept_id", "cat_id", "store_id", "state_id"),
variable.name = "day",
value.name = "Unit_Sales")
data <- left_join(train, calendar,by = c("day" = "d"))
data <- stratified(data, c("item_id", "month"), .10)
data <- data %>%
left_join(prices,
by = c("store_id" = "store_id",
"item_id" = "item_id",
"wm_yr_wk" = "wm_yr_wk"))suppressMessages({
g1<-data %>% select(state_id,cat_id,Unit_Sales)%>%
group_by(state_id,cat_id) %>%
summarise(sales = sum(Unit_Sales)) %>% ggplot(aes(x = state_id,y=sales, fill = cat_id))+
geom_bar(stat='identity',position=position_dodge()) +
geom_label(aes(label=sales), size=4,position=position_dodge(width=0.9),vjust=-0.1)+
labs(x="States",y="Total Unit Sales",title="Total Sales Per State",
subtitle='Food department is the highest selling',fill = "Product Category") +theme_bw() +
theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.x=element_text(face="bold", color="#993333", size=10),
axis.title.x = element_text(size = 16),
axis.title.y = element_text(size = 16))})
g1
From the chart above, we can see that California (CA) has the highest sales. The food department achieved the highest sales.
suppressMessages({ g3<-data %>% group_by(dept_id) %>% dplyr::count(item_id) %>% summarise(n_items = length(item_id))%>%
ggplot(aes(x = dept_id, y = n_items, fill = dept_id))+
geom_bar(stat = "identity", alpha = 0.7)+
geom_label(aes(label = n_items), vjust = -0.1, show.legend = F)+theme_bw()+
labs(x='Departments', y = "Number of Items", title = "Total available Products in different departments",
subtitle = "FOODS_3 has the largest number of products.")+
theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.x=element_text(face="bold", color="#993333", size=10),
axis.title.x = element_text(size = 16),
axis.title.y = element_text(size = 16),legend.position = "none")})
g3suppressMessages({
g13<-data %>%
mutate(wday = wday(date, label = TRUE, week_start = 1),
month = month(date, label = TRUE),
year = year(date)) %>%
group_by(wday, month, year) %>%
summarise(sales = sum(Unit_Sales))%>%
ggplot(aes(month, wday, fill = sales)) +
geom_tile()+scale_fill_distiller(palette='Spectral')+theme_tufte()+
theme(legend.position = "top",aspect.ratio = 1/2,
legend.key.width = unit(6, "cm"),legend.title.align = 2.5,
axis.text.x=element_text(face="bold", color="#993333",
size=10),axis.text.y=element_text(face="bold", color="#993333",
size=10),panel.border = element_rect(colour = "grey", fill=NA, size=1),
plot.title = element_text(hjust = 0.5, size = 21, face = "bold",
margin = margin(0,0,0.5,0, unit = "cm")),
axis.title.x = element_text(size = 16), axis.title.y = element_text(size = 16))+
labs(x='Month',y='Weekday',title='Sales HeatMap')
})
g13
As we expected, the sales distributed more on weekend, compared to weekdays.
events_name1<-dplyr::filter(data,(event_name_1!='None'))
events_type1<-dplyr::filter(data,(event_type_1!='None'))
suppressMessages({g14<-events_name1%>%select(event_name_1, Unit_Sales,cat_id,state_id)%>%
group_by(event_name_1,cat_id,state_id)%>%
summarise(mean_event_sales = sum(Unit_Sales))%>%ggplot()+
geom_bar(aes(x=reorder(event_name_1,mean_event_sales),y=mean_event_sales,fill=cat_id),stat='identity',position='stack')+
coord_flip()+theme_bw()+
labs(x = "", y = "Sales", title = "Total Sales for different events")+
theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.y=element_text(face="bold", color="#993333", size=10),
axis.text.x=element_text(face="bold", color="#993333", size=10))
})
g14
The sales were highest on SuperBowl sporting events.During Thanksgiving, the sales were relatively low due to reduced open hours. On the Christmas, the sales were zero maybe because the stores were all closed.
suppressMessages({event_holiday <- data %>%
select(year, month, date, event_name_1, event_name_2, Unit_Sales) %>%
mutate(day = day(date)) %>%
group_by(year, month, day, date, event_name_1, event_name_2) %>%
summarise(Total_Sales = sum(Unit_Sales))
})suppressMessages({
options(repr.plot.width = 18, repr.plot.height = 12)
jan<-dplyr::filter(event_holiday,(year==2015)& (month<=6))
event_jan<-dplyr::filter(event_holiday,event_name_1!="None"& (year==2015)& (month<=6))
g17<-ggplot(data=jan)+geom_line(aes(x=day,y=Total_Sales),alpha = 0.8)+geom_point(aes(x = day, y = Total_Sales), size = 1)+
facet_grid(month~.)+geom_point(data = event_jan,aes(x = day, y = Total_Sales, col = event_name_1), size = 4)+
geom_text(data = event_jan,aes(x = day, y = Total_Sales, label = event_name_1))+xlab("Day") + ylab("Total Sales") +
ggtitle("Event and Holiday Analysis for First 6 months") +
theme(plot.title = element_text(hjust = 0.5, face = "bold"),
axis.text.x=element_text(face="bold", color="#993333", size=10, angle=45),legend.position = "none")
#g17
#ggsave("g17.png")
options(repr.plot.width = 18, repr.plot.height = 12)
jan<-dplyr::filter(event_holiday,(year==2015)& (month>6))
event_jan<-dplyr::filter(event_holiday,event_name_1!="None"& (year==2015)& (month>6))
g18<-ggplot(data=jan)+geom_line(aes(x=day,y=Total_Sales),alpha = 0.8)+geom_point(aes(x = day, y = Total_Sales), size = 1)+
facet_grid(month~.)+geom_point(data = event_jan,aes(x = day, y = Total_Sales, col = event_name_1), size = 4)+
geom_text(data = event_jan,aes(x = day, y = Total_Sales, label = event_name_1))+xlab("Day") + ylab("Total Sales") +
ggtitle("Event and Holiday Analysis for last 6 months") +
theme(plot.title = element_text(hjust = 0.5, face = "bold"),
axis.text.x=element_text(face="bold", color="#993333", size=8, angle=45),legend.position = "none")
#g18
#ggsave("g18.png",width=15,height=25)
g19<-ggarrange(g17,g18)
#ggsave('g19.png')
})
g19
The Sporting events like NBA Final shows an interesting insight, the sales were high the day before the event, and sales dropped on event days. The sales are dropped on special days like father day, mother’s day.
LightGBM is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed and efficient with the following advantages:
At present, decision tree based machine learning algorithms dominate Kaggle competitions. The winning solutions in these competitions have adopted an alogorithm called XGBoost.
A couple of years ago, Microsoft announced its gradient boosting framework LightGBM. Nowadays, it steals the spotlight in gradient boosting machines. Kagglers start to use LightGBM more than XGBoost. LightGBM is 6 times faster than XGBoost.
Light GBM is a relatively new algorithm and have long list of parameters given in the LightGBM documentation,
The size of dataset is increasing rapidly. It is become very difficult for traditional data science algorithms to give accurate results. Light GBM is prefixed as Light because of its high speed. Light GBM can handle the large size of data and takes lower memory to run.
Another reason why Light GBM is so popular is because it focuses on accuracy of results. LGBM also supports GPU learning and thus data scientists are widely using LGBM for data science application development.
It is not advisable to use LGBM on small datasets. Light GBM is sensitive to overfitting and can easily overfit small data.
LightGBM is a gradient boosting framework that uses tree based learning algorithm.
LightGBM documentation states that -
LightGBM grows tree vertically while other tree based learning algorithms grow trees horizontally. It means that LightGBM grows tree leaf-wise while other algorithms grow level-wise. It will choose the leaf with max delta loss to grow. When growing the same leaf, leaf-wise algorithm can reduce more loss than a level-wise algorithm.
Most decision tree learning algorithms grow tree by level (depth)-wise.
Level-wise tree growth can best be explained with the following visual -
If we grow the full tree, best-first (leaf-wise) and depth-first (level-wise) will result in the same tree. The difference is in the order in which the tree is expanded. Since we don’t normally grow trees to their full depth, order matters.
Application of early stopping criteria and pruning methods can result in very different trees. Because leaf-wise chooses splits based on their contribution to the global loss and not just the loss along a particular branch, it often (not always) will learn lower-error trees “faster” than level-wise.
For a small number of nodes, leaf-wise will probably out-perform level-wise. As we add more nodes, without stopping or pruning they will converge to the same performance because they will literally build the same tree eventually.
Type Package Title Light Gradient Boosting Machine Version 3.2.1 Date 2021-04-12 Description Tree based algorithms can be improved by introducing boosting frameworks. ‘LightGBM’ is one such framework, based on Ke, Guolin et al. (2017) https://papers.nips.cc/paper/6907-lightgbm-a-highlyefficient-gradient-boosting-decision. This package offers an R interface to work with it. It is designed to be distributed and efficient with the following advantages: 1. Faster training speed and higher efficiency. 2. Lower memory usage. 3. Better accuracy. 4. Parallel learning supported. 5. Capable of handling large-scale data. In recognition of these advantages, ‘LightGBM’ has been widelyused in many winning solutions of machine learning competitions. Comparison experiments on public datasets suggest that ‘LightGBM’ can outperform existing boosting frameworks on both efficiency and accuracy, with significantly lower memory consumption. In addition, parallel experiments suggest that in certain circumstances, ‘LightGBM’ can achieve a linear speed-up in training time by using multiple machines. Encoding UTF-8 License MIT + file LICENSE URL https://github.com/Microsoft/LightGBM BugReports https://github.com/Microsoft/LightGBM/issues NeedsCompilation yes Biarch true Suggests testthat Depends R (>= 3.5), R6 (>= 2.0) Imports data.table (>= 1.9.6), graphics, jsonlite (>= 1.0), Matrix (>= 1.1-0), methods, utils
Let’s load the packages and provide some basic parameters for our future model:
library(data.table)
library(lightgbm)
library(ggplot2)
set.seed(0)
h <- 28 # forecast horizon
max_lags <- 420 # number of observations to shift by
tr_last <- 1913 # last training day
fday <- as.IDate("2016-04-25") # first day to forecast
nrows <- Inf
path <- '/Users/Suen/Downloads/input2/m5-forecasting-accuracy/'Also we need auxilary functions:
free <- function() invisible(gc()) create_dt <- function(is_train = TRUE, nrows = Inf) {
if (is_train) { # create train set
dt <- fread(str_c(path,"sales_train_validation.csv"), nrows = nrows)
cols <- dt[, names(.SD), .SDcols = patterns("^d_")]
#dt <- as.data.table(dt)
dt[, (cols) := transpose(lapply(transpose(.SD),
function(x) {
i <- min(which(x > 0))
x[1:i-1] <- NA
x})), .SDcols = cols]
free()
} else { # create test set
dt <- fread(str_c(path,"sales_train_validation.csv"), nrows = nrows,
drop = paste0("d_", 1:(tr_last-max_lags))) # keep only max_lags days from the train set
dt[, paste0("d_", (tr_last+1):(tr_last+2*h)) := 0] # add empty columns for forecasting
}
dt <- na.omit(data.table::melt(dt,
measure.vars = patterns("^d_"),
variable.name = "d",
value.name = "sales"))
cal <- fread(str_c(path,"calendar.csv"))
dt <- dt[cal, `:=`(date = as.IDate(i.date, format="%Y-%m-%d"), # merge tables by reference
wm_yr_wk = i.wm_yr_wk,
event_name_1 = i.event_name_1,
snap_CA = i.snap_CA,
snap_TX = i.snap_TX,
snap_WI = i.snap_WI), on = "d"]
prices <- fread(str_c(path,"sell_prices.csv"))
dt[prices, sell_price := i.sell_price, on = c("store_id", "item_id", "wm_yr_wk")] # merge again
}create_fea <- function(dt) {
dt[, `:=`(d = NULL, # remove useless columns
wm_yr_wk = NULL)]
cols <- c("item_id", "store_id", "state_id", "dept_id", "cat_id", "event_name_1")
dt[, (cols) := lapply(.SD, function(x) as.integer(factor(x))), .SDcols = cols] # convert character columns to integer
free()
lag <- c(7, 28)
lag_cols <- paste0("lag_", lag) # lag columns names
dt[, (lag_cols) := shift(.SD, lag), by = id, .SDcols = "sales"] # add lag vectors
win <- c(7, 28) # rolling window size
roll_cols <- paste0("rmean_", t(outer(lag, win, paste, sep="_"))) # rolling features columns names
dt[, (roll_cols) := frollmean(.SD, win, na.rm = TRUE), by = id, .SDcols = lag_cols] # rolling features on lag_cols
dt[, `:=`(wday = wday(date), # time features
mday = mday(date),
week = week(date),
month = month(date),
year = year(date))]
}The next step is to prepare data for training. The data set itself is quite large, so we constantly need to collect garbage. Here I use the last 28 days for validation:
tr <- create_dt()
free()Just to get the general idea let’s plot grouped sales across all items:
tr[, .(sales = unlist(lapply(.SD, sum))), by = "date", .SDcols = "sales"
][, ggplot(.SD, aes(x = date, y = sales)) +
geom_line(size = 0.3, color = "steelblue", alpha = 0.8) +
geom_smooth(method='lm', formula= y~x, se = FALSE, linetype = 2, size = 0.5, color = "gray20") +
labs(x = "", y = "total sales") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position="none") +
scale_x_date(labels=scales::date_format ("%b %y"), breaks=scales::date_breaks("3 months"))]free()We can see a trend and peaks at the end of each year. If only our model could model seasonality and trend… Ok, let’s proceed to the next step and prepare a dataset for lightgbm:
create_fea(tr)
free()
tr <- na.omit(tr) # remove rows with NA to save memory
free()
idx <- tr[date <= max(date)-h, which = TRUE] # indices for training
y <- tr$sales
tr[, c("id", "sales", "date") := NULL]
free()tr <- data.matrix(tr)
free()cats <- c("item_id", "store_id", "state_id", "dept_id", "cat_id",
"wday", "mday", "week", "month", "year",
"snap_CA", "snap_TX", "snap_WI") # list of categorical features
xtr <- lgb.Dataset(tr[idx, ], label = y[idx], categorical_feature = cats) # construct lgb dataset
xval <- lgb.Dataset(tr[-idx, ], label = y[-idx], categorical_feature = cats)
rm(tr, y, cats, idx)
free()p <- list(objective = "poisson",
metric ="rmse",
force_row_wise = TRUE,
learning_rate = 0.075,
num_leaves = 128,
min_data = 100,
sub_feature = 0.8,
sub_row = 0.75,
bagging_freq = 1,
lambda_l2 = 0.1,
nthread = 4)
m_lgb <- lgb.train(params = p,
data = xtr,
#nrounds = 4000,
nrounds = 400,
valids = list(val = xval),
#early_stopping_rounds = 400,
early_stopping_rounds = 40,
#eval_freq = 400
eval_freq = 40)## [LightGBM] [Warning] Met categorical feature which contains sparse values. Consider renumbering to consecutive integers started from zero
## [LightGBM] [Info] Total Bins 475
## [LightGBM] [Info] Number of data points in the train set: 6489, number of used features: 18
## [LightGBM] [Info] Start training from score -0.126664
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[1]: val's rmse:1.29372"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[41]: val's rmse:1.10335"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[81]: val's rmse:1.08946"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
cat("Best score:", m_lgb$best_score, "at", m_lgb$best_iter, "iteration") ## Best score: 1.08622 at 78 iteration
imp <- lgb.importance(m_lgb)
rm(xtr, xval, p)
free()imp[order(-Gain)
][1:15, ggplot(.SD, aes(reorder(Feature, Gain), Gain)) +
geom_col(fill = "steelblue") +
xlab("Feature") +
coord_flip() +
theme_minimal()]We can see that rolling features are very important. I have to admit that chaining with data.table is very convenient, especially when we want to use j for side effects.
And now the hard part. As we are using 7-day lag features we have to forecast day by day in order to use the latest predictions for the current day. This slows down the forecasting process tremendously. Also, tree models are unable to extrapolate that’s why here we use some kind of “magic” multiplier which slightly inflates predictions.
te <- create_dt(FALSE, nrows)
for (day in as.list(seq(fday, length.out = 2*h, by = "day"))){
cat(as.character(day), " ")
tst <- te[date >= day - max_lags & date <= day]
create_fea(tst)
tst <- data.matrix(tst[date == day][, c("id", "sales", "date") := NULL])
te[date == day, sales := 1.03*predict(m_lgb, tst)]
}## 2016-04-25 2016-04-26 2016-04-27 2016-04-28 2016-04-29 2016-04-30 2016-05-01 2016-05-02 2016-05-03 2016-05-04 2016-05-05 2016-05-06 2016-05-07 2016-05-08 2016-05-09 2016-05-10 2016-05-11 2016-05-12 2016-05-13 2016-05-14 2016-05-15 2016-05-16 2016-05-17 2016-05-18 2016-05-19 2016-05-20 2016-05-21 2016-05-22 2016-05-23 2016-05-24 2016-05-25 2016-05-26 2016-05-27 2016-05-28 2016-05-29 2016-05-30 2016-05-31 2016-06-01 2016-06-02 2016-06-03 2016-06-04 2016-06-05 2016-06-06 2016-06-07 2016-06-08 2016-06-09 2016-06-10 2016-06-11 2016-06-12 2016-06-13 2016-06-14 2016-06-15 2016-06-16 2016-06-17 2016-06-18 2016-06-19
Let’s plot our predictions along with given values:
te[, .(sales = unlist(lapply(.SD, sum))), by = "date", .SDcols = "sales"
][, ggplot(.SD, aes(x = date, y = sales, colour = (date < fday))) +
geom_line() +
geom_smooth(method='lm', formula= y~x, se = FALSE, linetype = 2, size = 0.3, color = "gray20") +
labs(x = "", y = "total sales") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position="none") +
scale_x_date(labels=scales::date_format ("%b %d"), breaks=scales::date_breaks("14 day"))]