Introduction

How much camping gear will one store sell each month in a year? To the uninitiated, calculating sales at this level may seem as difficult as predicting the weather. Both types of forecasting rely on science and historical data. While a wrong weather forecast may result in you carrying around an umbrella on a sunny day, inaccurate business forecasts could result in actual or opportunity losses. In this project, in addition to traditional forecasting methods we also challenge to use machine learning to improve forecast accuracy.

We use hierarchical sales data from Walmart, the world’s largest company by revenue, to forecast daily sales for the next 28 days. The data, covers stores in three US States (California, Texas, and Wisconsin) and includes item level, department, product categories, and store details. In addition, it has explanatory variables such as price, promotions, day of the week, and special events. Together, this robust dataset can be used to improve forecasting accuracy.

Data Preparation

The data: We are working with 42,840 hierarchical time series. The data were obtained in the 3 US states of California (CA), Texas (TX), and Wisconsin (WI). “Hierarchical” here means that data can be aggregated on different levels: item level, department level, product category level, and state level. The sales information reaches back from Jan 2011 to June 2016. In addition to the sales numbers, we are also given corresponding data on prices, promotions, and holidays. Note, that we have been warned that most of the time series contain zero values.

The data comprises 3049 individual products from 3 categories and 7 departments, sold in 10 stores in 3 states. The hierachical aggregation captures the combinations of these factors. For instance, we can create 1 time series for all sales, 3 time series for all sales per state, and so on. The largest category is sales of all individual 3049 products per 10 stores for 30490 time series.

The training data comes in the shape of 3 separate files:

library('dplyr') # data manipulation
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library('vroom') # input/output
library('readr') # input/output
## Registered S3 methods overwritten by 'readr':
##   method           from 
##   format.col_spec  vroom
##   print.col_spec   vroom
##   print.collector  vroom
##   print.date_names vroom
##   print.locale     vroom
##   str.col_spec     vroom
## 
## Attaching package: 'readr'
## The following objects are masked from 'package:vroom':
## 
##     as.col_spec, col_character, col_date, col_datetime, col_double,
##     col_factor, col_guess, col_integer, col_logical, col_number,
##     col_skip, col_time, cols, cols_condense, cols_only, date_names,
##     date_names_lang, date_names_langs, default_locale, fwf_cols,
##     fwf_empty, fwf_positions, fwf_widths, locale, output_column,
##     problems, spec
library('stringr') # string manipulation
library('tidyr') # data wrangling
library('purrr') # data wrangling

library('kableExtra') # display
## 
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
## 
##     group_rows
library('ggplot2') # visualisation
library('ggthemes') # visualisation

path <- '/Users/Suen/Downloads/input/m5-forecasting-accuracy/'
train <- vroom(str_c(path,'sales_train_validation.csv'), delim = ",", col_types = cols())
prices <- vroom(str_c(path,'sell_prices.csv'), delim = ",", col_types = cols())
calendar <- read_csv(str_c(path,'calendar.csv'), col_types = cols())

sample_submit <- vroom(str_c(path,'sample_submission.csv'), delim = ",", col_types = cols())

Quick Look: File structure and content

As a first step let’s have a quick look of the data sets using the head, summary, and glimpse tools where appropriate.

Training sales data

Here are the first 10 columns and rows of the our training sales data:

train %>% 
  select(seq(1,10,1)) %>% 
  head(10) %>% 
  kable() %>% 
  kable_styling()
id item_id dept_id cat_id store_id state_id d_1 d_2 d_3 d_4
HOBBIES_1_001_CA_1_validation HOBBIES_1_001 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_002_CA_1_validation HOBBIES_1_002 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_003_CA_1_validation HOBBIES_1_003 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_004_CA_1_validation HOBBIES_1_004 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_005_CA_1_validation HOBBIES_1_005 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_006_CA_1_validation HOBBIES_1_006 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_007_CA_1_validation HOBBIES_1_007 HOBBIES_1 HOBBIES CA_1 CA 0 0 0 0
HOBBIES_1_008_CA_1_validation HOBBIES_1_008 HOBBIES_1 HOBBIES CA_1 CA 12 15 0 0
HOBBIES_1_009_CA_1_validation HOBBIES_1_009 HOBBIES_1 HOBBIES CA_1 CA 2 0 7 3
HOBBIES_1_010_CA_1_validation HOBBIES_1_010 HOBBIES_1 HOBBIES CA_1 CA 0 0 1 0

We find:

  • There is one column each for the IDs of item, department, category, store, and state; plus a general ID that is a combination of the other IDs plus a flag for validation.

  • The sales per date are encoded as columns starting with the prefix d_. Those are the number of units sold per day (not the total amount of dollars).

  • We already see that there are quite a lot of zero values.

This data set has too many columns and rows to display them all:

c(ncol(train),nrow(train))
## [1]  1919 30490

All IDs are marked as “validation”. This refers to the intial validation testing period of the competition, before we ultimately predict a different 28-days period.

train %>% 
  mutate(dset = if_else(str_detect(id, "validation"), "validation", "training")) %>% 
  count(dset)
## # A tibble: 1 x 2
##   dset           n
##   <chr>      <int>
## 1 validation 30490

Sales prices

This data set gives us the weekly price changes per item:

prices %>% 
  head(10) %>% 
  kable() %>% 
  kable_styling()
store_id item_id wm_yr_wk sell_price
CA_1 HOBBIES_1_001 11325 9.58
CA_1 HOBBIES_1_001 11326 9.58
CA_1 HOBBIES_1_001 11327 8.26
CA_1 HOBBIES_1_001 11328 8.26
CA_1 HOBBIES_1_001 11329 8.26
CA_1 HOBBIES_1_001 11330 8.26
CA_1 HOBBIES_1_001 11331 8.26
CA_1 HOBBIES_1_001 11332 8.26
CA_1 HOBBIES_1_001 11333 8.26
CA_1 HOBBIES_1_001 11334 8.26
summary(prices)
##    store_id           item_id             wm_yr_wk       sell_price     
##  Length:6841121     Length:6841121     Min.   :11101   Min.   :  0.010  
##  Class :character   Class :character   1st Qu.:11247   1st Qu.:  2.180  
##  Mode  :character   Mode  :character   Median :11411   Median :  3.470  
##                                        Mean   :11383   Mean   :  4.411  
##                                        3rd Qu.:11517   3rd Qu.:  5.840  
##                                        Max.   :11621   Max.   :107.320

We find:

  • We have the store_id and item_id to link this data to our training and validation data.

  • Prices range from $0.10 to a bit more than 100 dollars.

Calendar

The calendar data gives us date features such as weekday, month, or year; alongside 2 different event features and a SNAP food stamps flag:

calendar %>% 
  head(8) %>% 
  kable() %>% 
  kable_styling()
date wm_yr_wk weekday wday month year d event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI
2011-01-29 11101 Saturday 1 1 2011 d_1 NA NA NA NA 0 0 0
2011-01-30 11101 Sunday 2 1 2011 d_2 NA NA NA NA 0 0 0
2011-01-31 11101 Monday 3 1 2011 d_3 NA NA NA NA 0 0 0
2011-02-01 11101 Tuesday 4 2 2011 d_4 NA NA NA NA 1 1 0
2011-02-02 11101 Wednesday 5 2 2011 d_5 NA NA NA NA 1 0 1
2011-02-03 11101 Thursday 6 2 2011 d_6 NA NA NA NA 1 1 1
2011-02-04 11101 Friday 7 2 2011 d_7 NA NA NA NA 1 0 0
2011-02-05 11102 Saturday 1 2 2011 d_8 NA NA NA NA 1 1 1
glimpse(calendar)
## Rows: 1,969
## Columns: 14
## $ date         <date> 2011-01-29, 2011-01-30, 2011-01-31, 2011-02-01, 2011-02-…
## $ wm_yr_wk     <dbl> 11101, 11101, 11101, 11101, 11101, 11101, 11101, 11102, 1…
## $ weekday      <chr> "Saturday", "Sunday", "Monday", "Tuesday", "Wednesday", "…
## $ wday         <dbl> 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, …
## $ month        <dbl> 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
## $ year         <dbl> 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 201…
## $ d            <chr> "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8", "…
## $ event_name_1 <chr> NA, NA, NA, NA, NA, NA, NA, NA, "SuperBowl", NA, NA, NA, …
## $ event_type_1 <chr> NA, NA, NA, NA, NA, NA, NA, NA, "Sporting", NA, NA, NA, N…
## $ event_name_2 <chr> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ event_type_2 <chr> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N…
## $ snap_CA      <dbl> 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, …
## $ snap_TX      <dbl> 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, …
## $ snap_WI      <dbl> 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, …
summary(calendar)
##       date               wm_yr_wk       weekday               wday      
##  Min.   :2011-01-29   Min.   :11101   Length:1969        Min.   :1.000  
##  1st Qu.:2012-06-04   1st Qu.:11219   Class :character   1st Qu.:2.000  
##  Median :2013-10-09   Median :11337   Mode  :character   Median :4.000  
##  Mean   :2013-10-09   Mean   :11347                      Mean   :3.997  
##  3rd Qu.:2015-02-13   3rd Qu.:11502                      3rd Qu.:6.000  
##  Max.   :2016-06-19   Max.   :11621                      Max.   :7.000  
##      month             year           d             event_name_1      
##  Min.   : 1.000   Min.   :2011   Length:1969        Length:1969       
##  1st Qu.: 3.000   1st Qu.:2012   Class :character   Class :character  
##  Median : 6.000   Median :2013   Mode  :character   Mode  :character  
##  Mean   : 6.326   Mean   :2013                                        
##  3rd Qu.: 9.000   3rd Qu.:2015                                        
##  Max.   :12.000   Max.   :2016                                        
##  event_type_1       event_name_2       event_type_2          snap_CA      
##  Length:1969        Length:1969        Length:1969        Min.   :0.0000  
##  Class :character   Class :character   Class :character   1st Qu.:0.0000  
##  Mode  :character   Mode  :character   Mode  :character   Median :0.0000  
##                                                           Mean   :0.3301  
##                                                           3rd Qu.:1.0000  
##                                                           Max.   :1.0000  
##     snap_TX          snap_WI      
##  Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :0.0000   Median :0.0000  
##  Mean   :0.3301   Mean   :0.3301  
##  3rd Qu.:1.0000   3rd Qu.:1.0000  
##  Max.   :1.0000   Max.   :1.0000

We find:

  • The calendar has all the relevant dates, weekdays, months plus snap binary flags and logical event columns.

  • There are only 5 non-NA rows in the event_name_2 column; i.e. only 5 (out of 1969) instances where there is more than 1 event on a particular day.

Missing values & zero values

There are no missing values in our sales training data:

sum(is.na(train))
## [1] 0

However, there are a lot of zero values, here we plot the distribution of zero percentages among all time series:

bar <- train %>% 
  select(-contains("id")) %>% 
  na_if(0) %>% 
  is.na() %>% 
  as_tibble() %>% 
  mutate(sum = pmap_dbl(select(., everything()), sum)) %>% 
  mutate(mean = sum/(ncol(train) - 1)) %>% 
  select(sum, mean)
  
bar %>% 
  ggplot(aes(mean)) +
  geom_density(fill = "blue") +
  scale_x_continuous(labels = scales::percent) +
  coord_cartesian(xlim = c(0, 1)) +
  theme_hc() +
  theme(axis.text.y = element_blank()) +
  labs(x = "", y = "", title = "Density for percentage of zero values - all time series")
Fig. 1

Fig. 1

This means that only a minority of time series have less than 50% of zero values. The peak is rather close to 100%

EDA - Explortary Data Analysis

This part contains EDA of Walmart stores that located in different states. It aims to answers following questions:
1. Which state has the highest sales?
2. Which department has the highest sales?
3. The number of products in different categories.
4. The distribution of sales on weekdays.
5. The sales trend on holiday and events.

Data Preprocessing

#Library
library(data.table)
## 
## Attaching package: 'data.table'
## The following object is masked from 'package:purrr':
## 
##     transpose
## The following objects are masked from 'package:dplyr':
## 
##     between, first, last
library(dplyr)
library(reshape2)
## 
## Attaching package: 'reshape2'
## The following objects are masked from 'package:data.table':
## 
##     dcast, melt
## The following object is masked from 'package:tidyr':
## 
##     smiths
library(lubridate)
## 
## Attaching package: 'lubridate'
## The following objects are masked from 'package:data.table':
## 
##     hour, isoweek, mday, minute, month, quarter, second, wday, week,
##     yday, year
## The following objects are masked from 'package:base':
## 
##     date, intersect, setdiff, union
library(sqldf)
## Loading required package: gsubfn
## Loading required package: proto
## Warning in doTryCatch(return(expr), name, parentenv, handler): unable to load shared object '/Library/Frameworks/R.framework/Resources/modules//R_X11.so':
##   dlopen(/Library/Frameworks/R.framework/Resources/modules//R_X11.so, 6): Library not loaded: /opt/X11/lib/libSM.6.dylib
##   Referenced from: /Library/Frameworks/R.framework/Resources/modules//R_X11.so
##   Reason: image not found
## Could not load tcltk.  Will use slower R code instead.
## Loading required package: RSQLite
library(splitstackshape)
library(ggplot2)
library(ggthemes)
library(ggpubr)
library(timetk)
## 
## Attaching package: 'timetk'
## The following object is masked from 'package:data.table':
## 
##     :=
#Read dataset
train <- fread(str_c(path,"sales_train_evaluation.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
calendar <- fread(str_c(path,"calendar.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
prices <- fread(str_c(path,"sell_prices.csv"), na.strings=c("", "NULL"), header = TRUE, stringsAsFactors = TRUE)
submission = fread(str_c(path,'sample_submission.csv'))
prices[, c('sell_price')] <- sapply(prices[, c('sell_price')], as.numeric)
calendar$date <- as.Date(calendar$date)

calendar$date<-as.Date(calendar$date)
calendar$event_type_1<- as.character(calendar$event_type_1)
calendar$event_type_2<- as.character(calendar$event_type_2)
calendar$event_name_1<- as.character(calendar$event_name_1)
calendar$event_name_2<- as.character(calendar$event_name_2)

## Replacing null values of event_type_1 into "NONE"
calendar <- calendar %>%
            mutate(event_type_1 = replace(event_type_1, is.na(event_type_1), "None"))%>%
            mutate(event_type_2 = replace(event_type_2, is.na(event_type_2), "None"))%>%
            mutate(event_name_1 = replace(event_name_1, is.na(event_name_1), "None"))%>%
            mutate(event_name_2 = replace(event_name_2, is.na(event_name_2), "None"))

## Reverting class of the event_type_1 into factor
calendar$event_type_1 <- as.factor(calendar$event_type_1)
calendar$event_type_2 <- as.factor(calendar$event_type_2)
calendar$event_name_1 <- as.factor(calendar$event_name_1)
calendar$event_name_2 <- as.factor(calendar$event_name_2)

newdf <- sqldf("SELECT DISTINCT event_name_1,event_type_1,event_type_2,event_name_2 FROM calendar")

train <- reshape2::melt(train,id.vars = c("id", "item_id", "dept_id", "cat_id", "store_id", "state_id"),
                 variable.name = "day", 
                 value.name = "Unit_Sales") 

data <- left_join(train, calendar,by = c("day" = "d"))

data <- stratified(data, c("item_id", "month"), .10)

data <- data %>%
            left_join(prices, 
                      by = c("store_id" = "store_id",
                             "item_id" = "item_id",
                             "wm_yr_wk" = "wm_yr_wk"))

Data Visualization

suppressMessages({ 
g1<-data %>% select(state_id,cat_id,Unit_Sales)%>% 
  group_by(state_id,cat_id) %>% 
  summarise(sales = sum(Unit_Sales)) %>% ggplot(aes(x = state_id,y=sales, fill = cat_id))+
        geom_bar(stat='identity',position=position_dodge())  +
          geom_label(aes(label=sales), size=4,position=position_dodge(width=0.9),vjust=-0.1)+
        labs(x="States",y="Total Unit Sales",title="Total Sales Per State",
             subtitle='Food department is the highest selling',fill = "Product Category") +theme_bw() +
        theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.x=element_text(face="bold", color="#993333", size=10),
             axis.title.x = element_text(size = 16),
             axis.title.y = element_text(size = 16))})
g1


From the chart above, we can see that California (CA) has the highest sales. The food department achieved the highest sales.

suppressMessages({ g3<-data %>% group_by(dept_id) %>% dplyr::count(item_id) %>% summarise(n_items = length(item_id))%>%
  ggplot(aes(x = dept_id, y = n_items, fill = dept_id))+
  geom_bar(stat = "identity", alpha = 0.7)+
  geom_label(aes(label = n_items), vjust = -0.1, show.legend = F)+theme_bw()+
  labs(x='Departments', y = "Number of Items", title = "Total available Products in different departments",
       subtitle = "FOODS_3 has the largest number of products.")+
        theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.x=element_text(face="bold", color="#993333", size=10),
             axis.title.x = element_text(size = 16),
             axis.title.y = element_text(size = 16),legend.position = "none")})
g3

suppressMessages({
g13<-data %>% 
  mutate(wday = wday(date, label = TRUE, week_start = 1),
         month = month(date, label = TRUE),
         year = year(date)) %>% 
  group_by(wday, month, year) %>% 
  summarise(sales = sum(Unit_Sales))%>%
  ggplot(aes(month, wday, fill = sales)) +
  geom_tile()+scale_fill_distiller(palette='Spectral')+theme_tufte()+
    theme(legend.position = "top",aspect.ratio = 1/2,
        legend.key.width = unit(6, "cm"),legend.title.align = 2.5,
     axis.text.x=element_text(face="bold", color="#993333", 
                           size=10),axis.text.y=element_text(face="bold", color="#993333", 
                           size=10),panel.border = element_rect(colour = "grey", fill=NA, size=1),
      plot.title = element_text(hjust = 0.5, size = 21, face = "bold",
                                  margin = margin(0,0,0.5,0, unit = "cm")),
             axis.title.x = element_text(size = 16), axis.title.y = element_text(size = 16))+
    labs(x='Month',y='Weekday',title='Sales HeatMap')
    })
g13


As we expected, the sales distributed more on weekend, compared to weekdays.

events_name1<-dplyr::filter(data,(event_name_1!='None'))
events_type1<-dplyr::filter(data,(event_type_1!='None'))

suppressMessages({g14<-events_name1%>%select(event_name_1, Unit_Sales,cat_id,state_id)%>%
                    group_by(event_name_1,cat_id,state_id)%>%
                    summarise(mean_event_sales = sum(Unit_Sales))%>%ggplot()+
                geom_bar(aes(x=reorder(event_name_1,mean_event_sales),y=mean_event_sales,fill=cat_id),stat='identity',position='stack')+
                coord_flip()+theme_bw()+
              labs(x = "", y = "Sales", title = "Total Sales for different events")+
        theme(plot.title = element_text(hjust = 0.5, face = "bold"),axis.text.y=element_text(face="bold", color="#993333", size=10),
             axis.text.x=element_text(face="bold", color="#993333", size=10))
})
g14


The sales were highest on SuperBowl sporting events.During Thanksgiving, the sales were relatively low due to reduced open hours. On the Christmas, the sales were zero maybe because the stores were all closed.

suppressMessages({event_holiday <- data %>%
                    select(year, month, date, event_name_1, event_name_2, Unit_Sales) %>%
                    mutate(day = day(date)) %>%
                    group_by(year, month, day, date, event_name_1, event_name_2) %>%
                    summarise(Total_Sales = sum(Unit_Sales))
                  })
suppressMessages({
options(repr.plot.width = 18, repr.plot.height = 12)
jan<-dplyr::filter(event_holiday,(year==2015)& (month<=6))
event_jan<-dplyr::filter(event_holiday,event_name_1!="None"& (year==2015)& (month<=6))
g17<-ggplot(data=jan)+geom_line(aes(x=day,y=Total_Sales),alpha = 0.8)+geom_point(aes(x = day, y = Total_Sales), size = 1)+
                facet_grid(month~.)+geom_point(data = event_jan,aes(x = day, y = Total_Sales, col = event_name_1), size = 4)+
geom_text(data = event_jan,aes(x = day, y = Total_Sales, label = event_name_1))+xlab("Day") + ylab("Total Sales") +
                ggtitle("Event and Holiday Analysis for First 6 months") + 
                theme(plot.title = element_text(hjust = 0.5, face = "bold"),
                      axis.text.x=element_text(face="bold", color="#993333", size=10, angle=45),legend.position = "none")
#g17
#ggsave("g17.png")
options(repr.plot.width = 18, repr.plot.height = 12)
jan<-dplyr::filter(event_holiday,(year==2015)& (month>6))
event_jan<-dplyr::filter(event_holiday,event_name_1!="None"& (year==2015)& (month>6))
g18<-ggplot(data=jan)+geom_line(aes(x=day,y=Total_Sales),alpha = 0.8)+geom_point(aes(x = day, y = Total_Sales), size = 1)+
                facet_grid(month~.)+geom_point(data = event_jan,aes(x = day, y = Total_Sales, col = event_name_1), size = 4)+
geom_text(data = event_jan,aes(x = day, y = Total_Sales, label = event_name_1))+xlab("Day") + ylab("Total Sales") +
                ggtitle("Event and Holiday Analysis for last 6 months") + 
                theme(plot.title = element_text(hjust = 0.5, face = "bold"),
                      axis.text.x=element_text(face="bold", color="#993333", size=8, angle=45),legend.position = "none")
#g18
#ggsave("g18.png",width=15,height=25)
g19<-ggarrange(g17,g18)
#ggsave('g19.png')
    })
g19


The Sporting events like NBA Final shows an interesting insight, the sales were high the day before the event, and sales dropped on event days. The sales are dropped on special days like father day, mother’s day.

Methodology

LightGBM intuition

  • LightGBM is a gradient boosting framework that uses tree based learning algorithm.

  • LightGBM documentation states that -

LightGBM grows tree vertically while other tree based learning algorithms grow trees horizontally. It means that LightGBM grows tree leaf-wise while other algorithms grow level-wise. It will choose the leaf with max delta loss to grow. When growing the same leaf, leaf-wise algorithm can reduce more loss than a level-wise algorithm.

  • So, we need to understand the distinction between leaf-wise tree growth and level-wise tree growth.

Leaf-wise tree growth

  • Leaf-wise tree growth can best be explained with the following visual - Leaf-wise tree growth

Level-wise tree growth

  • Most decision tree learning algorithms grow tree by level (depth)-wise.

  • Level-wise tree growth can best be explained with the following visual - Level-wise tree growth

Important points about tree-growth**

  • If we grow the full tree, best-first (leaf-wise) and depth-first (level-wise) will result in the same tree. The difference is in the order in which the tree is expanded. Since we don’t normally grow trees to their full depth, order matters.

  • Application of early stopping criteria and pruning methods can result in very different trees. Because leaf-wise chooses splits based on their contribution to the global loss and not just the loss along a particular branch, it often (not always) will learn lower-error trees “faster” than level-wise.

  • For a small number of nodes, leaf-wise will probably out-perform level-wise. As we add more nodes, without stopping or pruning they will converge to the same performance because they will literally build the same tree eventually.

LightGBM Package

Type Package Title Light Gradient Boosting Machine Version 3.2.1 Date 2021-04-12 Description Tree based algorithms can be improved by introducing boosting frameworks. ‘LightGBM’ is one such framework, based on Ke, Guolin et al. (2017) https://papers.nips.cc/paper/6907-lightgbm-a-highlyefficient-gradient-boosting-decision. This package offers an R interface to work with it. It is designed to be distributed and efficient with the following advantages: 1. Faster training speed and higher efficiency. 2. Lower memory usage. 3. Better accuracy. 4. Parallel learning supported. 5. Capable of handling large-scale data. In recognition of these advantages, ‘LightGBM’ has been widelyused in many winning solutions of machine learning competitions. Comparison experiments on public datasets suggest that ‘LightGBM’ can outperform existing boosting frameworks on both efficiency and accuracy, with significantly lower memory consumption. In addition, parallel experiments suggest that in certain circumstances, ‘LightGBM’ can achieve a linear speed-up in training time by using multiple machines. Encoding UTF-8 License MIT + file LICENSE URL https://github.com/Microsoft/LightGBM BugReports https://github.com/Microsoft/LightGBM/issues NeedsCompilation yes Biarch true Suggests testthat Depends R (>= 3.5), R6 (>= 2.0) Imports data.table (>= 1.9.6), graphics, jsonlite (>= 1.0), Matrix (>= 1.1-0), methods, utils

Implementation & Results

Preparations

Let’s load the packages and provide some basic parameters for our future model:

library(data.table)
library(lightgbm)
library(ggplot2)

set.seed(0)

h <- 28 # forecast horizon
max_lags <- 420 # number of observations to shift by
tr_last <- 1913 # last training day
fday <- as.IDate("2016-04-25") # first day to forecast
nrows <- Inf

path <- '/Users/Suen/Downloads/input2/m5-forecasting-accuracy/'

Also we need auxilary functions:

  • free() just calls a garbage collector
free <- function() invisible(gc()) 
  • create_dt() creates a training or testing data table from a wide-format file with leading zeros removed. Pay attention to the cool feature of the melt() function and data.table: we can choose columns by regex patterns. Notice that I refer to the columns of the second table using i. prefix when merging the data tables (e.g. i.event_name_1 is a column of the cal table).
create_dt <- function(is_train = TRUE, nrows = Inf) {
  
  if (is_train) { # create train set
    dt <- fread(str_c(path,"sales_train_validation.csv"), nrows = nrows)
    cols <- dt[, names(.SD), .SDcols = patterns("^d_")]
    #dt <- as.data.table(dt)
    dt[, (cols) := transpose(lapply(transpose(.SD),
                                    function(x) {
                                      i <- min(which(x > 0))
                                      x[1:i-1] <- NA
                                      x})), .SDcols = cols]
    free()
  } else { # create test set
    dt <- fread(str_c(path,"sales_train_validation.csv"), nrows = nrows,
                drop = paste0("d_", 1:(tr_last-max_lags))) # keep only max_lags days from the train set
    dt[, paste0("d_", (tr_last+1):(tr_last+2*h)) := 0] # add empty columns for forecasting
  }
  
  dt <- na.omit(data.table::melt(dt,
                     measure.vars = patterns("^d_"),
                     variable.name = "d",
                     value.name = "sales"))
  
  cal <- fread(str_c(path,"calendar.csv"))
  dt <- dt[cal, `:=`(date = as.IDate(i.date, format="%Y-%m-%d"), # merge tables by reference
                     wm_yr_wk = i.wm_yr_wk,
                     event_name_1 = i.event_name_1,
                     snap_CA = i.snap_CA,
                     snap_TX = i.snap_TX,
                     snap_WI = i.snap_WI), on = "d"]
  
  prices <- fread(str_c(path,"sell_prices.csv"))
  dt[prices, sell_price := i.sell_price, on = c("store_id", "item_id", "wm_yr_wk")] # merge again
}
  • create_fea() adds lags, rolling features and time variables to the data table. frollmean() is a fast rolling function to calculate means on sliding window. Notice how we use := operation to add new columns.
create_fea <- function(dt) {
  dt[, `:=`(d = NULL, # remove useless columns
            wm_yr_wk = NULL)]
  
  cols <- c("item_id", "store_id", "state_id", "dept_id", "cat_id", "event_name_1") 
  dt[, (cols) := lapply(.SD, function(x) as.integer(factor(x))), .SDcols = cols] # convert character columns to integer
  free()
  
  lag <- c(7, 28) 
  lag_cols <- paste0("lag_", lag) # lag columns names
  dt[, (lag_cols) := shift(.SD, lag), by = id, .SDcols = "sales"] # add lag vectors
  
  win <- c(7, 28) # rolling window size
  roll_cols <- paste0("rmean_", t(outer(lag, win, paste, sep="_"))) # rolling features columns names
  dt[, (roll_cols) := frollmean(.SD, win, na.rm = TRUE), by = id, .SDcols = lag_cols] # rolling features on lag_cols
  
  dt[, `:=`(wday = wday(date), # time features
            mday = mday(date),
            week = week(date),
            month = month(date),
            year = year(date))]
}

The next step is to prepare data for training. The data set itself is quite large, so we constantly need to collect garbage. Here I use the last 28 days for validation:

tr <- create_dt()
free()

Just to get the general idea let’s plot grouped sales across all items:

tr[, .(sales = unlist(lapply(.SD, sum))), by = "date", .SDcols = "sales"
   ][, ggplot(.SD, aes(x = date, y = sales)) +
       geom_line(size = 0.3, color = "steelblue", alpha = 0.8) + 
       geom_smooth(method='lm', formula= y~x, se = FALSE, linetype = 2, size = 0.5, color = "gray20") + 
       labs(x = "", y = "total sales") +
       theme_minimal() +
       theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position="none") +
       scale_x_date(labels=scales::date_format ("%b %y"), breaks=scales::date_breaks("3 months"))]

free()

We can see a trend and peaks at the end of each year. If only our model could model seasonality and trend… Ok, let’s proceed to the next step and prepare a dataset for lightgbm:

create_fea(tr)
free()

tr <- na.omit(tr) # remove rows with NA to save memory
free()

idx <- tr[date <= max(date)-h, which = TRUE] # indices for training
y <- tr$sales
tr[, c("id", "sales", "date") := NULL]
free()
tr <- data.matrix(tr)
free()
cats <- c("item_id", "store_id", "state_id", "dept_id", "cat_id", 
          "wday", "mday", "week", "month", "year",
          "snap_CA", "snap_TX", "snap_WI") # list of categorical features

xtr <- lgb.Dataset(tr[idx, ], label = y[idx], categorical_feature = cats) # construct lgb dataset
xval <- lgb.Dataset(tr[-idx, ], label = y[-idx], categorical_feature = cats)

rm(tr, y, cats, idx)
free()

Training model

It’s time to train our not so simple model with poisson loss, which is suitable for counts. Lately I tune tree models manually following this approach:
p <- list(objective = "poisson",
          metric ="rmse",
          force_row_wise = TRUE,
          learning_rate = 0.075,
          num_leaves = 128,
          min_data = 100,
          sub_feature = 0.8,
          sub_row = 0.75,
          bagging_freq = 1,
          lambda_l2 = 0.1,
          nthread = 4)

m_lgb <- lgb.train(params = p,
                   data = xtr,
                   #nrounds = 4000,
                   nrounds = 400,
                   valids = list(val = xval),
                   #early_stopping_rounds = 400,
                   early_stopping_rounds = 40,
                   #eval_freq = 400
                   eval_freq = 40)
## [LightGBM] [Warning] Met categorical feature which contains sparse values. Consider renumbering to consecutive integers started from zero
## [LightGBM] [Info] Total Bins 475
## [LightGBM] [Info] Number of data points in the train set: 6489, number of used features: 18
## [LightGBM] [Info] Start training from score -0.126664
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[1]:  val's rmse:1.29372"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[41]:  val's rmse:1.10335"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [1] "[81]:  val's rmse:1.08946"
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
## [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
cat("Best score:", m_lgb$best_score, "at", m_lgb$best_iter, "iteration")   
## Best score: 1.08622 at 78 iteration
imp <- lgb.importance(m_lgb)

rm(xtr, xval, p)
free()
imp[order(-Gain)
    ][1:15, ggplot(.SD, aes(reorder(Feature, Gain), Gain)) +
        geom_col(fill = "steelblue") +
        xlab("Feature") +
        coord_flip() +
        theme_minimal()]

We can see that rolling features are very important. I have to admit that chaining with data.table is very convenient, especially when we want to use j for side effects.

Forecasting

And now the hard part. As we are using 7-day lag features we have to forecast day by day in order to use the latest predictions for the current day. This slows down the forecasting process tremendously. Also, tree models are unable to extrapolate that’s why here we use some kind of “magic” multiplier which slightly inflates predictions.

te <- create_dt(FALSE, nrows)

for (day in as.list(seq(fday, length.out = 2*h, by = "day"))){
  cat(as.character(day), " ")
  tst <- te[date >= day - max_lags & date <= day]
  create_fea(tst)
  tst <- data.matrix(tst[date == day][, c("id", "sales", "date") := NULL])
  te[date == day, sales := 1.03*predict(m_lgb, tst)]
}
## 2016-04-25  2016-04-26  2016-04-27  2016-04-28  2016-04-29  2016-04-30  2016-05-01  2016-05-02  2016-05-03  2016-05-04  2016-05-05  2016-05-06  2016-05-07  2016-05-08  2016-05-09  2016-05-10  2016-05-11  2016-05-12  2016-05-13  2016-05-14  2016-05-15  2016-05-16  2016-05-17  2016-05-18  2016-05-19  2016-05-20  2016-05-21  2016-05-22  2016-05-23  2016-05-24  2016-05-25  2016-05-26  2016-05-27  2016-05-28  2016-05-29  2016-05-30  2016-05-31  2016-06-01  2016-06-02  2016-06-03  2016-06-04  2016-06-05  2016-06-06  2016-06-07  2016-06-08  2016-06-09  2016-06-10  2016-06-11  2016-06-12  2016-06-13  2016-06-14  2016-06-15  2016-06-16  2016-06-17  2016-06-18  2016-06-19

Let’s plot our predictions along with given values:

te[, .(sales = unlist(lapply(.SD, sum))), by = "date", .SDcols = "sales"
   ][, ggplot(.SD, aes(x = date, y = sales, colour = (date < fday))) +
       geom_line() + 
       geom_smooth(method='lm', formula= y~x, se = FALSE, linetype = 2, size = 0.3, color = "gray20") + 
       labs(x = "", y = "total sales") +
       theme_minimal() +
       theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position="none") +
       scale_x_date(labels=scales::date_format ("%b %d"), breaks=scales::date_breaks("14 day"))]