HYPERPARAMETER TUNING AND TIDYTUESDAY FOOD CONSUMPTION

Last week I published a screencast demonstrating how to use the tidymodels framwork and specifically the recipes package. Today, I’m using this week’s TidyTuesday dataset on food consumption around the world to show hyperparameter tuning!

Explore the data

Our modeling goal here is to predict which countries are Asian countries and which countries are not, based on their patterns of food consumption in the eleven categories from the TidyTuesday dataset. The original data is in a long, tidy format, and includes ibnformation on the carbon emission with each category of food consumption

library(tidyverse)
## -- Attaching packages --------
## v ggplot2 3.3.2     v purrr   0.3.4
## v tibble  3.0.3     v dplyr   1.0.0
## v tidyr   1.1.0     v stringr 1.4.0
## v readr   1.3.1     v forcats 0.5.0
## -- Conflicts -----------------
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
food_consumption <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-02-18/food_consumption.csv")
## Parsed with column specification:
## cols(
##   country = col_character(),
##   food_category = col_character(),
##   consumption = col_double(),
##   co2_emmission = col_double()
## )
food_consumption
## # A tibble: 1,430 x 4
##    country   food_category            consumption co2_emmission
##    <chr>     <chr>                          <dbl>         <dbl>
##  1 Argentina Pork                           10.5          37.2 
##  2 Argentina Poultry                        38.7          41.5 
##  3 Argentina Beef                           55.5        1712   
##  4 Argentina Lamb & Goat                     1.56         54.6 
##  5 Argentina Fish                            4.36          6.96
##  6 Argentina Eggs                           11.4          10.5 
##  7 Argentina Milk - inc. cheese            195.          278.  
##  8 Argentina Wheat and Wheat Products      103.           19.7 
##  9 Argentina Rice                            8.77         11.2 
## 10 Argentina Soybeans                        0             0   
## # ... with 1,420 more rows

Let’s build a dataset for modeling that is wide instead of long using pivot_wider() from tidyr. We can use country package to find which continent each country is in, and create a new variable for prediction asia that tells us whether a country is in Asia or not

library(countrycode)
library(janitor)
## 
## Attaching package: 'janitor'
## The following objects are masked from 'package:stats':
## 
##     chisq.test, fisher.test
food<- food_consumption %>% 
               select(-co2_emmission) %>% 
               pivot_wider(
                              names_from = food_category,
                              values_from = consumption
               ) %>% 
               clean_names() %>% 
               mutate(continent = countrycode(
                              country,
                              origin = 'country.name',
                              destination = 'continent'
               )) %>% 
               mutate(asia = case_when(
                              continent == 'Asia' ~'Asia', 
                              TRUE ~ 'Other'
               )) %>% 
               select(-country, -continent) %>%  mutate_if(is.character, factor)

food
## # A tibble: 130 x 12
##     pork poultry  beef lamb_goat  fish  eggs milk_inc_cheese wheat_and_wheat~
##    <dbl>   <dbl> <dbl>     <dbl> <dbl> <dbl>           <dbl>            <dbl>
##  1  10.5    38.7  55.5      1.56  4.36 11.4             195.            103. 
##  2  24.1    46.1  33.9      9.87 17.7   8.51            234.             70.5
##  3  10.9    13.2  22.5     15.3   3.85 12.5             304.            139. 
##  4  21.7    26.9  13.4     21.1  74.4   8.24            226.             72.9
##  5  22.3    35.0  22.5     18.9  20.4   9.91            137.             76.9
##  6  27.6    50.0  36.2      0.43 12.4  14.6             255.             80.4
##  7  16.8    27.4  29.1      8.23  6.53 13.1             211.            109. 
##  8  43.6    21.4  29.9      1.67 23.1  14.6             255.            103. 
##  9  12.6    45    39.2      0.62 10.0   8.98            149.             53  
## 10  10.4    18.4  23.4      9.56  5.21  8.29            288.             92.3
## # ... with 120 more rows, and 4 more variables: rice <dbl>, soybeans <dbl>,
## #   nuts_inc_peanut_butter <dbl>, asia <fct>

This is not a big dataset, but it will be good for demonstrating how to tune hyperparameters. before we get started on that, how are the categories of food consumption related? Since these are all numeric variables, we can use ggscatmat() for a quick visualization

library(GGally)
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
ggscatmat(food, columns = 1:11, color = 'asia', alpha = 0.7)

Notice how important rice is! Also see how the relationships between different food categories is different for Asian and non_Asian countries; a tree-based model like a random forest is good as learning interactions like this.

Tune hyperparameters

Now it’s time to tune the hyperparameters for a random forest model. First, let’s create a set of bootstrap resamples to use for tuning, and then let’s create a model specification for a random forest where we will tune mtry (the number of predictors to sample at each split) and min_n (the number of observations needed to keep spliting nodes). These are hyperparamaters that can’t be learned from data when training the model

library(tidymodels)
## -- Attaching packages --------
## v broom     0.7.0      v recipes   0.1.13
## v dials     0.0.8      v rsample   0.0.7 
## v infer     0.5.3      v tune      0.1.1 
## v modeldata 0.0.2      v workflows 0.1.2 
## v parsnip   0.1.2      v yardstick 0.0.7
## -- Conflicts -----------------
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
set.seed(1234)

food_boot <- bootstraps(food, times = 30)

food_boot
## # Bootstrap sampling 
## # A tibble: 30 x 2
##    splits           id         
##    <list>           <chr>      
##  1 <split [130/48]> Bootstrap01
##  2 <split [130/49]> Bootstrap02
##  3 <split [130/49]> Bootstrap03
##  4 <split [130/51]> Bootstrap04
##  5 <split [130/47]> Bootstrap05
##  6 <split [130/51]> Bootstrap06
##  7 <split [130/57]> Bootstrap07
##  8 <split [130/51]> Bootstrap08
##  9 <split [130/44]> Bootstrap09
## 10 <split [130/53]> Bootstrap10
## # ... with 20 more rows
rf_spec <- rand_forest(
               mode = 'classification',
               mtry = tune(),
               trees = 1000,
               min_n = tune()
) %>%
               set_engine('ranger')

rf_spec
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
## 
## Computational engine: ranger

We can’t learn the right values when training a single model, but we can train a whole bunch of models and see which ones turn out best. We can use parallel processing to make this go faster, since the different parts of the grid are independent

doParallel::registerDoParallel()

rf_grid <- tune_grid(
               asia ~.,
               model = rf_spec,
               resamples = food_boot
)
## Warning: `tune_grid.formula()` is deprecated as of lifecycle 0.1.0.
## The first argument to `tune_grid()` should be either a model or a workflow. In the future, you can use:
## tune_grid(rf_spec, asia ~ ., resamples = food_boot)
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## i Creating pre-processing data to finalize unknown parameter: mtry
rf_grid
## # Tuning results
## # Bootstrap sampling 
## # A tibble: 30 x 4
##    splits           id          .metrics          .notes          
##    <list>           <chr>       <list>            <list>          
##  1 <split [130/48]> Bootstrap01 <tibble [20 x 6]> <tibble [0 x 1]>
##  2 <split [130/49]> Bootstrap02 <tibble [20 x 6]> <tibble [0 x 1]>
##  3 <split [130/49]> Bootstrap03 <tibble [20 x 6]> <tibble [0 x 1]>
##  4 <split [130/51]> Bootstrap04 <tibble [20 x 6]> <tibble [0 x 1]>
##  5 <split [130/47]> Bootstrap05 <tibble [20 x 6]> <tibble [0 x 1]>
##  6 <split [130/51]> Bootstrap06 <tibble [20 x 6]> <tibble [0 x 1]>
##  7 <split [130/57]> Bootstrap07 <tibble [20 x 6]> <tibble [0 x 1]>
##  8 <split [130/51]> Bootstrap08 <tibble [20 x 6]> <tibble [0 x 1]>
##  9 <split [130/44]> Bootstrap09 <tibble [20 x 6]> <tibble [0 x 1]>
## 10 <split [130/53]> Bootstrap10 <tibble [20 x 6]> <tibble [0 x 1]>
## # ... with 20 more rows

Once we have our tuning results, we can check them out

rf_grid %>% 
               collect_metrics()
## # A tibble: 20 x 8
##     mtry min_n .metric  .estimator  mean     n std_err .config
##    <int> <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>  
##  1    11    15 accuracy binary     0.813    30 0.0111  Model01
##  2    11    15 roc_auc  binary     0.824    30 0.0108  Model01
##  3     4    33 accuracy binary     0.817    30 0.00873 Model02
##  4     4    33 roc_auc  binary     0.820    30 0.0103  Model02
##  5     5    31 accuracy binary     0.816    30 0.00901 Model03
##  6     5    31 roc_auc  binary     0.821    30 0.0105  Model03
##  7     4    37 accuracy binary     0.815    30 0.00846 Model04
##  8     4    37 roc_auc  binary     0.820    30 0.0102  Model04
##  9     6     9 accuracy binary     0.822    30 0.00956 Model05
## 10     6     9 roc_auc  binary     0.831    30 0.00943 Model05
## 11     2     4 accuracy binary     0.834    30 0.00783 Model06
## 12     2     4 roc_auc  binary     0.844    30 0.00944 Model06
## 13     2    12 accuracy binary     0.827    30 0.00806 Model07
## 14     2    12 roc_auc  binary     0.836    30 0.00929 Model07
## 15     7    21 accuracy binary     0.820    30 0.00956 Model08
## 16     7    21 roc_auc  binary     0.826    30 0.0102  Model08
## 17     8    18 accuracy binary     0.818    30 0.00887 Model09
## 18     8    18 roc_auc  binary     0.826    30 0.00996 Model09
## 19     9    26 accuracy binary     0.814    30 0.0102  Model10
## 20     9    26 roc_auc  binary     0.821    30 0.0106  Model10

And we can see which models performed the best, in term of some given metric

rf_grid %>% 
               show_best('roc_auc')
## # A tibble: 5 x 8
##    mtry min_n .metric .estimator  mean     n std_err .config
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>  
## 1     2     4 roc_auc binary     0.844    30 0.00944 Model06
## 2     2    12 roc_auc binary     0.836    30 0.00929 Model07
## 3     6     9 roc_auc binary     0.831    30 0.00943 Model05
## 4     8    18 roc_auc binary     0.826    30 0.00996 Model09
## 5     7    21 roc_auc binary     0.826    30 0.0102  Model08

If you would like to specific the grid for tuning yourself, check out the dials package