library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.3 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.4.3 ✔ tibble 3.2.1
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
ratings <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-25/ratings.csv")
## Rows: 21831 Columns: 10
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (3): name, url, thumbnail
## dbl (7): num, id, year, rank, average, bayes_average, users_rated
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
details <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-25/details.csv")
## Rows: 21631 Columns: 23
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): primary, description, boardgamecategory, boardgamemechanic, boardg...
## dbl (13): num, id, yearpublished, minplayers, maxplayers, playingtime, minpl...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
ratings_joined <-
ratings %>%
left_join(details, by = "id")
ggplot(ratings_joined, aes(average)) +
geom_histogram(alpha = 0.8)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ratings_joined %>%
filter(!is.na(minage)) %>%
mutate(minage = cut_number(minage, 4)) %>%
ggplot(aes(minage, average, fill = minage)) +
geom_boxplot(alpha = 0.2, show.legend = FALSE)

Tune an xgboost model
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom 1.0.5 ✔ rsample 1.2.0
## ✔ dials 1.2.0 ✔ tune 1.1.2
## ✔ infer 1.0.5 ✔ workflows 1.1.3
## ✔ modeldata 1.2.0 ✔ workflowsets 1.0.1
## ✔ parsnip 1.1.1 ✔ yardstick 1.2.0
## ✔ recipes 1.0.8
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Use tidymodels_prefer() to resolve common conflicts.
set.seed(123)
game_split <-
ratings_joined %>%
select(name, average, matches("min|max"), boardgamecategory) %>%
na.omit() %>%
initial_split(strata = average)
game_train <- training(game_split)
game_test <- testing(game_split)
set.seed(234)
game_folds <- vfold_cv(game_train, strata = average)
game_folds
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [14407/1602]> Fold01
## 2 <split [14407/1602]> Fold02
## 3 <split [14407/1602]> Fold03
## 4 <split [14408/1601]> Fold04
## 5 <split [14408/1601]> Fold05
## 6 <split [14408/1601]> Fold06
## 7 <split [14408/1601]> Fold07
## 8 <split [14408/1601]> Fold08
## 9 <split [14410/1599]> Fold09
## 10 <split [14410/1599]> Fold10
library(textrecipes)
split_category <- function(x) {
x %>%
str_split(", ") %>%
map(str_remove_all, "[:punct:]") %>%
map(str_squish) %>%
map(str_to_lower) %>%
map(str_replace_all, " ", "_")
}
game_rec <-
recipe(average ~ ., data = game_train) %>%
update_role(name, new_role = "id") %>%
step_tokenize(boardgamecategory, custom_token = split_category) %>%
step_tokenfilter(boardgamecategory, max_tokens = 30) %>%
step_tf(boardgamecategory)
game_prep <- prep(game_rec)
bake(game_prep, new_data = NULL) %>% str()
## tibble [16,009 × 37] (S3: tbl_df/tbl/data.frame)
## $ name : Factor w/ 15781 levels "¡Adiós Calavera!",..: 10857 8587 14642 858 15729 6819 13313 1490 3143 9933 ...
## $ minplayers : num [1:16009] 2 2 2 4 2 1 2 2 4 2 ...
## $ maxplayers : num [1:16009] 6 8 10 10 6 8 6 2 16 6 ...
## $ minplaytime : num [1:16009] 120 60 30 30 60 20 60 30 60 45 ...
## $ maxplaytime : num [1:16009] 120 180 30 30 90 20 60 30 60 45 ...
## $ minage : num [1:16009] 10 8 6 12 15 6 8 8 13 8 ...
## $ average : num [1:16009] 5.59 4.37 5.41 5.79 5.8 5.62 4.31 4.66 5.68 5.14 ...
## $ tf_boardgamecategory_abstract_strategy : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_action_dexterity : int [1:16009] 0 0 0 0 0 1 0 0 0 0 ...
## $ tf_boardgamecategory_adventure : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_ancient : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_animals : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_bluffing : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_card_game : int [1:16009] 0 0 1 1 0 0 0 0 0 1 ...
## $ tf_boardgamecategory_childrens_game : int [1:16009] 0 0 0 0 0 0 1 1 0 0 ...
## $ tf_boardgamecategory_deduction : int [1:16009] 0 0 0 0 0 0 0 1 0 0 ...
## $ tf_boardgamecategory_dice : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_economic : int [1:16009] 0 1 0 0 0 0 1 0 0 0 ...
## $ tf_boardgamecategory_exploration : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
## $ tf_boardgamecategory_fantasy : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_fighting : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
## $ tf_boardgamecategory_horror : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
## $ tf_boardgamecategory_humor : int [1:16009] 0 0 0 1 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_medieval : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_miniatures : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
## $ tf_boardgamecategory_movies_tv_radio_theme: int [1:16009] 0 0 1 0 1 0 0 0 0 0 ...
## $ tf_boardgamecategory_nautical : int [1:16009] 0 0 0 0 0 0 0 1 0 0 ...
## $ tf_boardgamecategory_negotiation : int [1:16009] 0 1 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_party_game : int [1:16009] 0 0 0 1 0 1 0 0 1 0 ...
## $ tf_boardgamecategory_print_play : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_puzzle : int [1:16009] 0 0 0 0 0 0 0 0 1 0 ...
## $ tf_boardgamecategory_racing : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_realtime : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_science_fiction : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
## $ tf_boardgamecategory_trivia : int [1:16009] 0 0 0 0 0 0 0 0 1 0 ...
## $ tf_boardgamecategory_wargame : int [1:16009] 1 0 0 0 0 0 0 1 0 0 ...
## $ tf_boardgamecategory_world_war_ii : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
xgb_spec <-
boost_tree(
trees = tune(),
mtry = tune(),
min_n = tune(),
learn_rate = 0.01
) %>%
set_engine("xgboost") %>%
set_mode("regression")
xgb_wf <- workflow(game_rec, xgb_spec)
xgb_wf
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 3 Recipe Steps
##
## • step_tokenize()
## • step_tokenfilter()
## • step_tf()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Boosted Tree Model Specification (regression)
##
## Main Arguments:
## mtry = tune()
## trees = tune()
## min_n = tune()
## learn_rate = 0.01
##
## Computational engine: xgboost
library(finetune)
doParallel::registerDoParallel()
set.seed(234)
xgb_game_rs <-
tune_race_anova(
xgb_wf,
game_folds,
grid = 20,
control = control_race(verbose_elim = TRUE)
)
## i Creating pre-processing data to finalize unknown parameter: mtry
## ℹ Racing will minimize the rmse metric.
## ℹ Resamples are analyzed in a random order.
## ℹ Fold10: 6 eliminated; 14 candidates remain.
##
## ℹ Fold06: 8 eliminated; 6 candidates remain.
##
## ℹ Fold08: 2 eliminated; 4 candidates remain.
##
## ℹ Fold01: 1 eliminated; 3 candidates remain.
##
## ℹ Fold04: All but one parameter combination were eliminated.
xgb_game_rs
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 5
## splits id .order .metrics .notes
## <list> <chr> <int> <list> <list>
## 1 <split [14407/1602]> Fold03 1 <tibble [40 × 7]> <tibble [0 × 3]>
## 2 <split [14408/1601]> Fold05 2 <tibble [40 × 7]> <tibble [0 × 3]>
## 3 <split [14410/1599]> Fold10 3 <tibble [40 × 7]> <tibble [0 × 3]>
## 4 <split [14408/1601]> Fold06 4 <tibble [28 × 7]> <tibble [0 × 3]>
## 5 <split [14408/1601]> Fold08 5 <tibble [12 × 7]> <tibble [0 × 3]>
## 6 <split [14407/1602]> Fold01 6 <tibble [8 × 7]> <tibble [0 × 3]>
## 7 <split [14408/1601]> Fold04 7 <tibble [6 × 7]> <tibble [0 × 3]>
## 8 <split [14407/1602]> Fold02 8 <tibble [2 × 7]> <tibble [0 × 3]>
## 9 <split [14408/1601]> Fold07 10 <tibble [2 × 7]> <tibble [0 × 3]>
## 10 <split [14410/1599]> Fold09 9 <tibble [2 × 7]> <tibble [0 × 3]>
plot_race(xgb_game_rs)

show_best(xgb_game_rs)
## Warning: No value of `metric` was given; metric 'rmse' will be used.
## # A tibble: 1 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 14 1709 17 rmse standard 0.735 10 0.00550 Preprocessor1_Model08
xgb_last <-
xgb_wf %>%
finalize_workflow(select_best(xgb_game_rs, "rmse")) %>%
last_fit(game_split)
xgb_last
## # Resampling results
## # Manual resampling
## # A tibble: 1 × 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [16009/5339]> train/test spl… <tibble> <tibble> <tibble> <workflow>
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
xgb_fit <- extract_fit_parsnip(xgb_last)
vip(xgb_fit, geom = "point", num_features = 12)

library(SHAPforxgboost)
game_shap <-
shap.prep(
xgb_model = extract_fit_engine(xgb_fit),
X_train = bake(game_prep,
has_role("predictor"),
new_data = NULL,
composition = "matrix"
)
)
shap.plot.summary(game_shap)

shap.plot.dependence(
game_shap,
x = "minage",
color_feature = "minplayers",
size0 = 1.2,
smooth = FALSE, add_hist = TRUE
)
