Summary
I’ve been noticing some healthy competition between the tidymodels set of packages and the DALEX set of packages. Each of frameworks try to provide end to end machine learning capabilities while lowering the entry bar with intuitive function names and defaults for most hyper-parameters.
In this paper I will walk through A Gentle Introduction to tidymodels by Edgar Ruiz using one of my favorite datasets, The World Happiness Report.
I will pre-process the data, apply three machine learning algorithms, then evaluate their performance against the test set. You’ll also find out which countries are happier than they should be.
The tidymodels’ process
Import data
The data is compiled by the United Nations (UN) who conduct annual surverys in over 160 countries. This analysis is limited to only countries that are members of the Organization for Economic Cooperation and Development(OECD). One of the requirements of being a member is to perform these surveys every year. See listing of countries below. OECD countries are generally industrialized and have democratic governments.
#options("max.print" = 100)
library(tidymodels)
library(tidyverse)
library(openxlsx)
#library(yardstick)
# Modelling introduces randomness. Random seed is set for reproducibility.
set.seed(317)
# Import World Happiness Report (whr) data (includes data through)
whr <- read.xlsx("WHR2019 and stress 2018.xlsx",
sheet = "WHR_Clean") %>%
filter(OECD == TRUE) %>%
mutate(confidence_in_govt = as.numeric(confidence_in_govt)) %>%
na.omit() %>%
mutate(country = as.factor(country)) %>%
select(-OECD)
levels(whr$country)## [1] "Australia" "Austria" "Belgium" "Canada"
## [5] "Chile" "Czech Republic" "Denmark" "Estonia"
## [9] "Finland" "France" "Germany" "Greece"
## [13] "Hungary" "Iceland" "Ireland" "Israel"
## [17] "Italy" "Japan" "Latvia" "Lithuania"
## [21] "Luxembourg" "Mexico" "Netherlands" "New Zealand"
## [25] "Norway" "Poland" "Portugal" "Slovakia"
## [29] "Slovenia" "South Korea" "Spain" "Sweden"
## [33] "Switzerland" "Turkey" "United Kingdom" "United States"
Split data into training and test sets
Include the response variable. Function prints number of records in each set and the total.
whr_split <- initial_split(whr, prop = 0.8)
train_countries <- training(whr_split)[, 1:2]
test_countries <- testing(whr_split)[, 1:2]
whr_split## <236/58/294>
Take a look at the training set
A full description of the variables can be found here
whr_split %>%
training() %>%
glimpse(width = 80)## Observations: 236
## Variables: 15
## $ country <fct> Australia, Australia, Australia, Australia, Aus…
## $ year <dbl> 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2010,…
## $ happiness <dbl> 7.405616, 7.195586, 7.364169, 7.288550, 7.30906…
## $ gdp_per_capita <dbl> 10.64290, 10.66323, 10.67170, 10.68160, 10.6902…
## $ social_support <dbl> 0.9670292, 0.9445990, 0.9282052, 0.9237987, 0.9…
## $ life_expectancy <dbl> 72.30, 72.40, 72.50, 72.60, 72.70, 73.00, 73.30…
## $ life_choices <dbl> 0.9445865, 0.9351463, 0.9333792, 0.9229323, 0.9…
## $ generosity <dbl> 0.360889763, 0.265168518, 0.260279536, 0.310035…
## $ corruption_perception <dbl> 0.3817717, 0.3682517, 0.4315390, 0.4420214, 0.3…
## $ positive_feelings <dbl> 0.8158603, 0.8107418, 0.8188353, 0.7752100, 0.7…
## $ negative_feelings <dbl> 0.1953237, 0.2143969, 0.1771423, 0.2453043, 0.2…
## $ confidence_in_govt <dbl> 0.5307867, 0.4204192, 0.4558715, 0.4646762, 0.4…
## $ democratic_quality <dbl> 1.1947155, 1.2485938, 1.2337434, 1.1969540, 1.1…
## $ delivery_quality <dbl> 1.835444, 1.790157, 1.751174, 1.811844, 1.76520…
## $ gini_household_income <dbl> 0.4311813, 0.4060564, 0.3812075, 0.4515168, 0.4…
Pre-processing the data
Our data is clean with no NAs and all numeric predictors like the source example.
Apply pre-processing to test set
In tidymodels this is referred to as baking and ensures that the same parameters are applied to any dataset. Perfect for “what if” scenarios later.
whr_testing <- whr_recipe %>%
bake(testing(whr_split))
glimpse(whr_testing, width = 80)## Observations: 58
## Variables: 13
## $ happiness <dbl> 7.450047, 6.948936, 7.415144, 7.304258, 6.63565…
## $ gdp_per_capita <dbl> 0.53848227, 0.58296762, 0.56229379, 0.65308934,…
## $ social_support <dbl> 0.94696267, 0.43101279, 0.81792767, 0.20634986,…
## $ life_expectancy <dbl> 0.60935009, 0.37318871, 0.73530302, 0.86125596,…
## $ life_choices <dbl> 0.97574408, 0.50790107, 0.87626537, 1.02400376,…
## $ generosity <dbl> 1.53729840, -0.48999264, 1.39498804, 1.28323134…
## $ corruption_perception <dbl> -1.18352946, -0.63819595, -0.76794657, -0.86765…
## $ positive_feelings <dbl> 0.9390345, 0.1913895, 1.1802202, 0.9300989, 0.6…
## $ negative_feelings <dbl> -0.39956570, 0.25414428, -0.24664515, 0.2367862…
## $ confidence_in_govt <dbl> 1.31180067, 0.06503883, 0.74718718, 0.70469903,…
## $ democratic_quality <dbl> 0.54904783, 0.08979786, 0.77559674, 0.81056988,…
## $ delivery_quality <dbl> 0.9612146, 0.2951689, 0.9218724, 0.9905740, 0.2…
## $ gini_household_income <dbl> 0.16355229, -0.46560573, 4.77554716, 2.76625414…
Summarized test set
The only big outlier is in gini_household_income.
# Summary
summary(whr_testing)## happiness gdp_per_capita social_support life_expectancy
## Min. :4.669 Min. :-2.30821 Min. :-3.12362 Min. :-2.5552
## 1st Qu.:5.959 1st Qu.:-1.00864 1st Qu.:-0.59181 1st Qu.:-0.9454
## Median :6.730 Median : 0.09842 Median : 0.32963 Median : 0.3141
## Mean :6.559 Mean :-0.03228 Mean :-0.02473 Mean :-0.1108
## 3rd Qu.:7.407 3rd Qu.: 0.69644 3rd Qu.: 0.67730 3rd Qu.: 0.7589
## Max. :7.678 Max. : 3.02246 Max. : 1.56113 Max. : 1.6327
## life_choices generosity corruption_perception
## Min. :-2.51718 Min. :-1.86087 Min. :-1.91511
## 1st Qu.:-0.52165 1st Qu.:-0.72834 1st Qu.:-0.84273
## Median : 0.41692 Median :-0.01175 Median : 0.20133
## Mean : 0.07525 Mean : 0.05407 Mean : 0.03367
## 3rd Qu.: 0.90520 3rd Qu.: 0.82046 3rd Qu.: 0.92609
## Max. : 1.13306 Max. : 1.93239 Max. : 1.37609
## positive_feelings negative_feelings confidence_in_govt
## Min. :-2.3805 Min. :-1.8817 Min. :-2.023414
## 1st Qu.:-0.4129 1st Qu.:-0.7944 1st Qu.:-0.653645
## Median : 0.3379 Median :-0.2613 Median :-0.158614
## Mean : 0.1633 Mean :-0.1600 Mean :-0.002656
## 3rd Qu.: 0.8780 3rd Qu.: 0.3650 3rd Qu.: 0.740290
## Max. : 1.3872 Max. : 2.0888 Max. : 2.546842
## democratic_quality delivery_quality gini_household_income
## Min. :-2.4317 Min. :-2.401118 Min. :-1.06123
## 1st Qu.:-0.2810 1st Qu.:-0.692741 1st Qu.:-0.54746
## Median : 0.1656 Median : 0.301956 Median :-0.08329
## Mean : 0.1259 Mean : 0.001291 Mean : 0.21470
## 3rd Qu.: 0.8018 3rd Qu.: 0.917473 3rd Qu.: 0.65965
## Max. : 1.2487 Max. : 1.344256 Max. : 4.77555
Test set averages
Notice that the mean is NOT equal to zero. The mean of each predictor is an indication of how skewed the test data is from the training data.
# Averages
sapply(whr_testing, mean)## happiness gdp_per_capita social_support
## 6.559258329 -0.032276696 -0.024733466
## life_expectancy life_choices generosity
## -0.110807044 0.075253904 0.054070771
## corruption_perception positive_feelings negative_feelings
## 0.033667859 0.163339413 -0.159984273
## confidence_in_govt democratic_quality delivery_quality
## -0.002656145 0.125947427 0.001291523
## gini_household_income
## 0.214695903
To further illustrate, here are the means of the training set. All are essentially zero which is what you expect after centering.
# Averages
sapply(juice(whr_recipe), mean)## gdp_per_capita social_support life_expectancy
## 1.443334e-15 1.101309e-15 4.761637e-16
## life_choices generosity corruption_perception
## -4.270309e-17 -7.781907e-18 -1.076022e-16
## positive_feelings negative_feelings confidence_in_govt
## -1.657116e-16 1.662387e-16 1.630346e-16
## democratic_quality delivery_quality gini_household_income
## -6.066226e-17 4.536881e-18 1.574353e-16
## happiness
## 6.574871e+00
Pull pre-processed training set from the recipe
The pre-processing of the training set was done in the recipe process but you may want it in a separate object. Sticking with the culinary terms you juice the recipe to get the prepared training set.
whr_training <- juice(whr_recipe)Model fitting
Model 1 - Random Forest
The blog that I’m copying was solving a classification problem and this is a regression problem. Luckily Random Forest works on both, and I’ll fit a couple more traditional models after RF. The default Random Forest engine is Ranger so we’ll use that.
library(ranger)
fit_rf <- rand_forest(trees = 500, mode = "regression") %>%
set_engine("ranger") %>%
fit(happiness ~ ., data = whr_training)
pred_test_rf <- fit_rf %>%
predict(whr_testing) %>%
bind_cols(whr_testing) %>%
rename(actual = happiness,
prediction = .pred) %>%
mutate(difference = actual - prediction,
model = "RF") %>%
select(model, actual, prediction, difference)
glimpse(pred_test_rf, width = 80)## Observations: 58
## Variables: 4
## $ model <chr> "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF"…
## $ actual <dbl> 7.450047, 6.948936, 7.415144, 7.304258, 6.635656, 6.599129…
## $ prediction <dbl> 7.319408, 6.812941, 7.414979, 7.372932, 6.649756, 6.719880…
## $ difference <dbl> 0.130639213, 0.135995920, 0.000165059, -0.068673873, -0.01…
Model Validation
Random Forest is notorious for overfitting so don’t get too excited if there is a high r-squared on the training set.
I am going to get the metrics from the test set which the real measure of accuracy.
pred_test_rf_metrics <- pred_test_rf %>%
metrics(truth = actual, estimate = prediction) %>%
mutate(model = "RF") %>%
select(model, everything())
pred_test_rf_metrics## # A tibble: 3 x 4
## model .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 RF rmse standard 0.273
## 2 RF rsq standard 0.908
## 3 RF mae standard 0.203
An r-squared of 0.91 on the test set is very good. Let’s see if the more regression-oriented models do better.
Model 2 - Linear Model
I always fit a linear model for regression. Not because it’s usually competitive but because it provides a good baseline to measure other models against.
fit_lm <- linear_reg(mode = "regression") %>%
set_engine("lm") %>%
fit(happiness ~ ., data = whr_training)
pred_test_lm <- fit_lm %>%
predict(whr_testing) %>%
bind_cols(whr_testing) %>%
rename(actual = happiness,
prediction = .pred) %>%
mutate(difference = actual - prediction,
model = "RF") %>%
select(model, actual, prediction, difference)
glimpse(pred_test_lm, width = 80)## Observations: 58
## Variables: 4
## $ model <chr> "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF"…
## $ actual <dbl> 7.450047, 6.948936, 7.415144, 7.304258, 6.635656, 6.599129…
## $ prediction <dbl> 7.557909, 6.735188, 7.168694, 7.137612, 6.426501, 6.553792…
## $ difference <dbl> -0.107862351, 0.213748025, 0.246450022, 0.166646235, 0.209…
Model Validation
pred_test_lm_metrics <- pred_test_lm %>%
metrics(truth = actual, estimate = prediction) %>%
mutate(model = "LM") %>%
select(model, everything())
pred_test_lm_metrics## # A tibble: 3 x 4
## model .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 LM rmse standard 0.419
## 2 LM rsq standard 0.758
## 3 LM mae standard 0.314
Model 3 - MARS
Multivariate Adaptive Regression Splines (MARS) is one of my favorite regression models because of it’s ability to mold to the data better than lm.
library(earth)
fit_mars <- mars(mode = "regression") %>%
set_engine("earth") %>%
fit(happiness ~ ., data = whr_training)
pred_test_mars <- fit_mars %>%
predict(whr_testing) %>%
bind_cols(whr_testing) %>%
rename(actual = happiness,
prediction = .pred) %>%
mutate(difference = actual - prediction,
model = "RF") %>%
select(model, actual, prediction, difference)
glimpse(pred_test_mars, width = 80)## Observations: 58
## Variables: 4
## $ model <chr> "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF", "RF"…
## $ actual <dbl> 7.450047, 6.948936, 7.415144, 7.304258, 6.635656, 6.599129…
## $ prediction <dbl> 7.393590, 6.756081, 7.445648, 7.463295, 6.271285, 6.259817…
## $ difference <dbl> 0.056456842, 0.192855698, -0.030503948, -0.159036719, 0.36…
Model Validation
pred_test_mars_metrics <- pred_test_mars %>%
metrics(truth = actual, estimate = prediction) %>%
mutate(model = "MARS") %>%
select(model, everything())
pred_test_mars_metrics## # A tibble: 3 x 4
## model .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 MARS rmse standard 0.310
## 2 MARS rsq standard 0.868
## 3 MARS mae standard 0.241
Evaluation
Random Forest leads on all three metrics so that will be the model chosen for future use.
For r-squared a higher number is better with 1 being perfect; rmse and mae should be as low as possible because they measure error.
metrics <- bind_rows(pred_test_rf_metrics,
pred_test_lm_metrics,
pred_test_mars_metrics) %>%
arrange(.metric)
metrics## # A tibble: 9 x 4
## model .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 RF mae standard 0.203
## 2 LM mae standard 0.314
## 3 MARS mae standard 0.241
## 4 RF rmse standard 0.273
## 5 LM rmse standard 0.419
## 6 MARS rmse standard 0.310
## 7 RF rsq standard 0.908
## 8 LM rsq standard 0.758
## 9 MARS rsq standard 0.868
Obligatory Insight
Even though the source blog didn’t include any visualizations the model that we have can help us to understand which countries are more or less happy than the model predicted.
This looks only at the latest year in the data which is 2017. Countries on the positive side report happiness higher than predicted. The maximum was around 0.25 in either direction which isn’t much variance.
pred_train <- fit_rf %>%
predict(whr_training) %>%
bind_cols(whr_training) %>%
rename(actual = happiness,
prediction = .pred) %>%
mutate(difference = actual - prediction,
model = "RF") %>%
bind_cols(train_countries) %>%
select(model, country, year, actual, prediction, difference)
pred_test <- pred_test_rf %>%
bind_cols(test_countries) %>%
select(model, country, year, actual, prediction, difference)
pred_all <- bind_rows(pred_train, pred_test)
pred_sample <- pred_all %>%
filter(year == max(year)) %>%
select(-year) %>%
arrange(-difference)
pred_sample$country2 <- factor(pred_sample$country) %>%
fct_reorder(pred_sample$difference)
ggplot(pred_sample, aes(country2, difference)) +
geom_col(fill = "darkseagreen3") +
scale_y_continuous(limits = c(-0.25, 0.25)) +
coord_flip() +
ggtitle("2017 difference between actual happiness prediction") +
labs(y = "Difference", x = "OECD Member") Conclusion
The tidymodels framework consists of the rsample and recipes package for pre-processing, the parsnip package for training, and the yardstick package for validation. All packages work well with each other and cover many scenarios that you will encounter in machine learning.
People that are already familiar with the tidyverse will immediately feel comfortable with this approach to modelling. The designers go to great lengths to put everything in dataframes as opposed to lists. The only lists produced(yes I know dataframes are lists, but you know what I mean) were for the fit objects which is impressive.
This walkthrough was purely interested in prediction which is why there is very little in the way of how to explain the model to a layperson.
I just scratched the surface of the tidymodels framework and look forward to stress-testing it in the future.