Horror Movies: Build a regression model to predict the average movie rating (vote_average). Use the horror_movies dataset.

Import Data

horror_movies <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): original_title, title, original_language, overview, tagline, post...
## dbl   (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl   (1): adult
## date  (1): release_date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Clean data

horror_movies %>% skimr::skim()
Data summary
Name Piped data
Number of rows 32540
Number of columns 20
_______________________
Column type frequency:
character 10
Date 1
logical 1
numeric 8
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
original_title 0 1.00 1 191 0 30296 0
title 0 1.00 1 191 0 29563 0
original_language 0 1.00 2 2 0 97 0
overview 1286 0.96 1 1000 0 31020 0
tagline 19835 0.39 1 237 0 12513 0
poster_path 4474 0.86 30 32 0 28048 0
status 0 1.00 7 15 0 4 0
backdrop_path 18995 0.42 29 32 0 13536 0
genre_names 0 1.00 6 144 0 772 0
collection_name 30234 0.07 4 56 0 815 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
release_date 0 1 1950-01-01 2022-12-31 2012-12-09 10999

Variable type: logical

skim_variable n_missing complete_rate mean count
adult 0 1 0 FAL: 32540

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1.00 445910.83 305744.67 17 146494.8 426521.00 707534.00 1033095.00 ▇▆▆▅▅
popularity 0 1.00 4.01 37.51 0 0.6 0.84 2.24 5088.58 ▇▁▁▁▁
vote_count 0 1.00 62.69 420.89 0 0.0 2.00 11.00 16900.00 ▇▁▁▁▁
vote_average 0 1.00 3.34 2.88 0 0.0 4.00 5.70 10.00 ▇▂▆▃▁
budget 0 1.00 543126.59 4542667.81 0 0.0 0.00 0.00 200000000.00 ▇▁▁▁▁
revenue 0 1.00 1349746.73 14430479.15 0 0.0 0.00 0.00 701842551.00 ▇▁▁▁▁
runtime 0 1.00 62.14 41.00 0 14.0 80.00 91.00 683.00 ▇▁▁▁▁
collection 30234 0.07 481534.88 324498.16 656 155421.0 471259.00 759067.25 1033032.00 ▇▅▅▅▅
data <- horror_movies %>%
    
    # Log transform vote_average
    mutate(vote_average = log1p(vote_average)) %>% # for zeroes: log1p(x) is the same as log(x+1)
    
    # Treat missing values in overview
    filter(!is.na(overview), vote_count != 0) %>%
    
    # Treat multiple categories in genre_names
    separate_rows(genre_names, sep = ", ") %>%
    
    filter(status == "Released") %>%
    
    select(id, vote_average, genre_names, overview, runtime) 

# data <- data %>% sample_n(100)

Explore Data

Check list

data %>% glimpse()
data %>% skimr::skim()
data %>% select(id) %>% explore()
data %>% describe_all()
data %>% describe_cat(genre_names)
data %>% select(-id) %>% explore_all(target = vote_average)
data %>% 
    ggplot(aes(vote_average)) +
    geom_histogram() 
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

data %>% count(genre_names, sort = T)
## # A tibble: 19 × 2
##    genre_names         n
##    <chr>           <int>
##  1 Horror          20493
##  2 Thriller         5899
##  3 Comedy           3118
##  4 Drama            2835
##  5 Mystery          2279
##  6 Science Fiction  2189
##  7 Action           1539
##  8 Fantasy          1485
##  9 Crime             895
## 10 TV Movie          533
## 11 Adventure         504
## 12 Animation         500
## 13 Romance           417
## 14 Documentary       224
## 15 Music             164
## 16 Family            124
## 17 Western           102
## 18 History            76
## 19 War                68
data %>%
    
    group_by(genre_names) %>%
    summarise(
        n = n(),
        avg_vote_average = mean(vote_average) 
    ) %>%
    ungroup() %>%
    
    ggplot(aes(n, avg_vote_average)) +
    # geom_point() +
    geom_text(aes(label = genre_names), check_overlap = TRUE) +
    geom_hline(yintercept = mean(data$vote_average), 
               linewidth = 2, linetype = "dotted", color = "darkgray") +
    
    scale_x_log10()

spacy_initialize(model = "en_core_web_sm")

# process documents and obtain a data.table
tidy_data <- data %>%
    
    # Parse overview
    mutate(overview_parsed = map(.x = .$overview, .f = ~spacy_parse(.x))) %>%
    unnest(overview_parsed) %>% 
    
    # Select nouns and adjectives 
    filter(pos %in% c("ADJ", "NOUN"))

data_filtered <- tidy_data %>%
    filter(str_detect(lemma, regex("[a-z]", ignore_case = TRUE))) %>%
    group_by(lemma) %>%
    summarise(
        n = n(),
        avg_vote_average = mean(vote_average)
    ) %>%
    filter(n > 150)

data_filtered %>%
    ggplot(aes(n, avg_vote_average)) +
    # geom_point() +
    geom_text(aes(label = lemma), check_overlap = TRUE) +
    geom_hline(yintercept = mean(data_filtered$avg_vote_average),
               linetype = "dotted", linewidth = 2, color = "darkgray") +
    scale_x_log10()
data %>%
    ggplot(aes(runtime, vote_average)) +
    geom_jitter(alpha = 0.3)

Build a Model

set.seed(123)
data_split <- initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)

set.seed(234)
data_folds <- rsample::vfold_cv(data_train)
data_folds
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [29324/3259]> Fold01
##  2 <split [29324/3259]> Fold02
##  3 <split [29324/3259]> Fold03
##  4 <split [29325/3258]> Fold04
##  5 <split [29325/3258]> Fold05
##  6 <split [29325/3258]> Fold06
##  7 <split [29325/3258]> Fold07
##  8 <split [29325/3258]> Fold08
##  9 <split [29325/3258]> Fold09
## 10 <split [29325/3258]> Fold10
library(usemodels)
use_xgboost(vote_average ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = vote_average ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(63112)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# xgboost_recipe <- 
#   recipe(formula = vote_average ~ ., data = data_train) %>%
#     recipes::update_role(post_id, new_role = "id") %>%
#     step_tokenize(overview) %>%
#     step_tokenfilter(overview, max_tokens = 100) %>%
#     step_tfidf(overview) %>%
#     step_other(nhood) %>%
#     step_dummy(nhood) %>%
#     step_log(vote_average, sqft, baths) # To transform variables with skewed distribution

xgboost_recipe <- 
  recipe(formula = vote_average ~ ., data = data_train) %>%
    recipes::update_role(id, new_role = "id") %>%
    step_tokenize(overview, engine = "spacyr") %>%
    step_lemma(overview) %>%
    step_pos_filter(overview, keep_tags = c("NOUN", "ADJ"))%>%
    step_tokenfilter(overview, max_tokens = 100) %>%
    step_tfidf(overview) %>%
    step_dummy(genre_names) %>%
    step_YeoJohnson(runtime)  # for log-transformation for a variable with zeroes

xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Found 'spacy_condaenv'. spacyr will use this environment
## successfully initialized (spaCy Version: 3.1.3, language model: en_core_web_sm)
## (python options: type = "condaenv", value = "spacy_condaenv")
## Rows: 32,583
## Columns: 121
## $ id                          <dbl> 314405, 147061, 166752, 309079, 27297, 427…
## $ runtime                     <dbl> 141.646911, 74.415520, 81.854221, 86.10912…
## $ vote_average                <dbl> 2.001480, 2.197225, 1.064711, 1.791759, 1.…
## $ `tfidf_overview_-`          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_accident     <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_alien        <dbl> 0.0000000, 0.4878952, 0.0000000, 0.0000000…
## $ tfidf_overview_ancient      <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_bad          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_beautiful    <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_blood        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4825536…
## $ tfidf_overview_body         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_boy          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_brother      <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_child        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_city         <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_couple       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_creature     <dbl> 0.0000000, 0.4255323, 0.0000000, 0.0000000…
## $ tfidf_overview_dark         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_daughter     <dbl> 0.6651315, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_day          <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_dead         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_deadly       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_death        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_demon        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_dream        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_event        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_evil         <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_family       <dbl> 0.0000000, 0.0000000, 0.5243188, 0.0000000…
## $ tfidf_overview_father       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_film         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_first        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_force        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_friend       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_game         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_ghost        <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_girl         <dbl> 0.5300025, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_good         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_group        <dbl> 0.0000000, 0.3215255, 0.0000000, 0.0000000…
## $ tfidf_overview_help         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_high         <dbl> 0.0000000, 0.4636848, 0.0000000, 0.0000000…
## $ tfidf_overview_home         <dbl> 0.00000, 0.00000, 0.00000, 0.00000, 0.0000…
## $ tfidf_overview_horror       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4010642…
## $ tfidf_overview_house        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_human        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_husband      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_killer       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_life         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_little       <dbl> 0.000000, 0.465336, 0.000000, 0.000000, 0.…
## $ tfidf_overview_local        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_love         <dbl> 0.6734058, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_man          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_monster      <dbl> 0.0000000, 0.0000000, 0.6765847, 0.0000000…
## $ tfidf_overview_more         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mother       <dbl> 0.6504896, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_movie        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_murder       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mysterious   <dbl> 0.5218673, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_new          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_night        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_nightmare    <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4943259…
## $ tfidf_overview_old          <dbl> 0.0000000, 0.3375260, 0.0000000, 0.0000000…
## $ tfidf_overview_only         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_order        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_other        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_own          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4472877…
## $ tfidf_overview_past         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4853397…
## $ tfidf_overview_people       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_place        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_police       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_power        <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_real         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_remote       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_revenge      <dbl> 0.0000000, 0.0000000, 0.7012580, 0.0000000…
## $ tfidf_overview_school       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_scientist    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_secret       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_serial       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_series       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_short        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_sister       <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_small        <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_son          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_spirit       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_story        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_strange      <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_student      <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_supernatural <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_tale         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_team         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_thing        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_time         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_town         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_vampire      <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_victim       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_village      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_way          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_wife         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_woman        <dbl> 0.0000000, 0.0000000, 0.4501321, 0.6430458…
## $ tfidf_overview_wood         <dbl> 0.000000, 0.938040, 0.000000, 0.000000, 0.…
## $ tfidf_overview_world        <dbl> 0.0000000, 0.0000000, 0.5532448, 0.0000000…
## $ tfidf_overview_year         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_young        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_zombie       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Adventure       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Animation       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Comedy          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Crime           <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Documentary     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Drama           <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Family          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Fantasy         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_History         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Horror          <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, …
## $ genre_names_Music           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Mystery         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, …
## $ genre_names_Romance         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Science.Fiction <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ genre_names_Thriller        <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ genre_names_TV.Movie        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_War             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Western         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

set.seed(15793)
doParallel::registerDoParallel()
xgboost_tune <-
  tune_grid(xgboost_workflow, 
            resamples = data_folds, 
            grid = 5)

Explore Results

show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 8
##   trees min_n .metric .estimator  mean     n std_err .config             
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1  1593     4 rmse    standard   0.246    10 0.00186 Preprocessor1_Model1
## 2  1637    20 rmse    standard   0.262    10 0.00151 Preprocessor1_Model3
## 3   923    38 rmse    standard   0.277    10 0.00149 Preprocessor1_Model5
## 4   458    31 rmse    standard   0.283    10 0.00155 Preprocessor1_Model4
## 5    83    14 rmse    standard   0.296    10 0.00144 Preprocessor1_Model2
# How did all the possible parameter combinations do?
autoplot(xgboost_tune)

We can finalize our random forest workflow with the best performing parameters.

final_rf <- xgboost_workflow %>% 
    finalize_workflow(select_best(xgboost_tune, "rmse"))

The function last_fit() fits this finalized random forest one last time to the training data and evaluates one last time on the testing data.

data_fit <- last_fit(final_rf, data_split)
data_fit
## # Resampling results
## # Manual resampling 
## # A tibble: 1 × 6
##   splits                id             .metrics .notes   .predictions .workflow 
##   <list>                <chr>          <list>   <list>   <list>       <list>    
## 1 <split [32583/10861]> train/test sp… <tibble> <tibble> <tibble>     <workflow>

Evaluate model

collect_metrics(data_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       0.229 Preprocessor1_Model1
## 2 rsq     standard       0.483 Preprocessor1_Model1
collect_predictions(data_fit)
## # A tibble: 10,861 × 5
##    id               .pred  .row vote_average .config             
##    <chr>            <dbl> <int>        <dbl> <chr>               
##  1 train/test split  1.95     1         2.07 Preprocessor1_Model1
##  2 train/test split  2.08     4         2.09 Preprocessor1_Model1
##  3 train/test split  1.88    12         1.92 Preprocessor1_Model1
##  4 train/test split  1.80    15         2.08 Preprocessor1_Model1
##  5 train/test split  1.97    23         2.01 Preprocessor1_Model1
##  6 train/test split  1.99    27         2.05 Preprocessor1_Model1
##  7 train/test split  2.02    36         2.03 Preprocessor1_Model1
##  8 train/test split  1.79    39         2.07 Preprocessor1_Model1
##  9 train/test split  1.53    42         2.07 Preprocessor1_Model1
## 10 train/test split  1.94    44         2.07 Preprocessor1_Model1
## # ℹ 10,851 more rows
collect_predictions(data_fit) %>%
    ggplot(aes(vote_average, .pred)) +
    geom_point(alpha = 0.5, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()

  1. Predict
data_fit %>%
    extract_workflow() %>%
    predict(data_test[1,])
## # A tibble: 1 × 1
##   .pred
##   <dbl>
## 1  1.95
  1. View important variables
library(vip)

imp_spec <- xgboost_spec %>%
    tune::finalize_model(tune::select_best(xgboost_tune)) %>%
    parsnip::set_engine("xgboost", importance = "permutation")

workflows::workflow() %>%
    add_recipe(xgboost_recipe) %>%
    add_model(imp_spec) %>%
    fit(data_train) %>%
    workflows::pull_workflow_fit() %>%
    vip()
## [17:10:42] WARNING: amalgamation/../src/learner.cc:627: 
## Parameters: { "importance" } might not be used.
## 
##   This could be a false alarm, with some parameters getting used by language bindings but
##   then being mistakenly passed down to XGBoost core, or some parameter actually being used
##   but getting flagged wrongly here. Please open an issue if you find any such cases.