Goal: to predict the average rating of horror movies. Click here for the data

Import Data

horror_movies <- read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): original_title, title, original_language, overview, tagline, post...
## dbl   (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl   (1): adult
## date  (1): release_date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skim(horror_movies)
Data summary
Name horror_movies
Number of rows 32540
Number of columns 20
_______________________
Column type frequency:
character 10
Date 1
logical 1
numeric 8
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
original_title 0 1.00 1 191 0 30296 0
title 0 1.00 1 191 0 29563 0
original_language 0 1.00 2 2 0 97 0
overview 1286 0.96 1 1000 0 31020 0
tagline 19835 0.39 1 237 0 12513 0
poster_path 4474 0.86 30 32 0 28048 0
status 0 1.00 7 15 0 4 0
backdrop_path 18995 0.42 29 32 0 13536 0
genre_names 0 1.00 6 144 0 772 0
collection_name 30234 0.07 4 56 0 815 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
release_date 0 1 1950-01-01 2022-12-31 2012-12-09 10999

Variable type: logical

skim_variable n_missing complete_rate mean count
adult 0 1 0 FAL: 32540

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1.00 445910.83 305744.67 17 146494.8 426521.00 707534.00 1033095.00 ▇▆▆▅▅
popularity 0 1.00 4.01 37.51 0 0.6 0.84 2.24 5088.58 ▇▁▁▁▁
vote_count 0 1.00 62.69 420.89 0 0.0 2.00 11.00 16900.00 ▇▁▁▁▁
vote_average 0 1.00 3.34 2.88 0 0.0 4.00 5.70 10.00 ▇▂▆▃▁
budget 0 1.00 543126.59 4542667.81 0 0.0 0.00 0.00 200000000.00 ▇▁▁▁▁
revenue 0 1.00 1349746.73 14430479.15 0 0.0 0.00 0.00 701842551.00 ▇▁▁▁▁
runtime 0 1.00 62.14 41.00 0 14.0 80.00 91.00 683.00 ▇▁▁▁▁
collection 30234 0.07 481534.88 324498.16 656 155421.0 471259.00 759067.25 1033032.00 ▇▅▅▅▅
data <- horror_movies %>%
    
    mutate(vote_average = log1p(vote_average)) %>%
    
    filter(!is.na(overview), vote_count != 0) %>%
    
    separate_rows(genre_names, sep = ", ") %>%
    
    filter(status == "Released") %>%
    
    select(id, vote_average, genre_names, overview, runtime, budget) %>%
    
    na.omit()

Explore Data

identify good predictors.

budget

data %>%
    ggplot(aes(vote_average, budget)) +
    geom_point()

runtime

data %>%
    ggplot(aes(vote_average, runtime)) +
    geom_point()

title

data %>%
    
    #tokenize title
    unnest_tokens(output = word, input = overview) %>%
    
    # calculate avg rent per word
    group_by(word) %>%
    summarise(vote_average = mean(vote_average), 
              n    = n()) %>%
    ungroup() %>%
    
    filter(n > 10, !str_detect(word, "\\d")) %>%
    slice_max(order_by = vote_average, n = 20) %>%
    
    # plot
    ggplot(aes(vote_average, fct_reorder(word, vote_average))) +
    geom_point() +
    
    labs(y = "Words in Overview")

EDA shortcut

# Step 1: Prepare data
data_binarized_tbl <- data %>%
    select(-overview, -id) %>%
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 43,444
## Columns: 23
## $ `vote_average__-Inf_1.66770682055808`           <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.66770682055808_1.84054963339749 <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.84054963339749_1.97408102602201 <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.97408102602201_Inf              <dbl> 1, 1, 1, 1, 1, 1, 1, 1…
## $ genre_names__Action                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Adventure                          <dbl> 0, 0, 1, 0, 0, 0, 0, 0…
## $ genre_names__Animation                          <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Comedy                             <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Crime                              <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Drama                              <dbl> 0, 0, 0, 1, 0, 0, 0, 0…
## $ genre_names__Fantasy                            <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Horror                             <dbl> 1, 0, 0, 0, 1, 1, 0, 0…
## $ genre_names__Mystery                            <dbl> 0, 0, 0, 0, 0, 0, 1, 0…
## $ genre_names__Science_Fiction                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Thriller                           <dbl> 0, 1, 0, 0, 0, 0, 0, 1…
## $ genre_names__TV_Movie                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ `genre_names__-OTHER`                           <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ `runtime__-Inf_75`                              <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ runtime__75_87                                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ runtime__87_95                                  <dbl> 0, 0, 1, 1, 1, 0, 0, 0…
## $ runtime__95_Inf                                 <dbl> 1, 1, 0, 0, 0, 1, 1, 1…
## $ budget__0                                       <dbl> 1, 1, 1, 1, 1, 0, 0, 0…
## $ `budget__-OTHER`                                <dbl> 0, 0, 0, 0, 0, 1, 1, 1…
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
    correlate(vote_average__1.97408102602201_Inf)

data_corr_tbl
## # A tibble: 23 × 3
##    feature      bin                               correlation
##    <fct>        <chr>                                   <dbl>
##  1 vote_average 1.97408102602201_Inf                   1     
##  2 vote_average -Inf_1.66770682055808                 -0.343 
##  3 vote_average 1.84054963339749_1.97408102602201     -0.328 
##  4 vote_average 1.66770682055808_1.84054963339749     -0.326 
##  5 runtime      -Inf_75                                0.192 
##  6 runtime      87_95                                 -0.128 
##  7 runtime      75_87                                 -0.125 
##  8 genre_names  Animation                              0.0769
##  9 runtime      95_Inf                                 0.0593
## 10 budget       0                                     -0.0497
## # ℹ 13 more rows
# Step 3: Plot
data_corr_tbl %>%
    plot_correlation_funnel()
## Warning: ggrepel: 13 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Build Models

Split Data

#data <- sample_n(data, 100)

#Split into train and test dataset
set.seed(1234)
data_split <- rsample::initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)

#Further split training dataset for cross-validation
set.seed(4321)
data_cv <- rsample::vfold_cv(data_train)
data_cv
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [29324/3259]> Fold01
##  2 <split [29324/3259]> Fold02
##  3 <split [29324/3259]> Fold03
##  4 <split [29325/3258]> Fold04
##  5 <split [29325/3258]> Fold05
##  6 <split [29325/3258]> Fold06
##  7 <split [29325/3258]> Fold07
##  8 <split [29325/3258]> Fold08
##  9 <split [29325/3258]> Fold09
## 10 <split [29325/3258]> Fold10
library(usemodels)
## Warning: package 'usemodels' was built under R version 4.3.3
usemodels::use_xgboost(vote_average ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = vote_average ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(48629)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_recipe <- 
  recipe(formula = vote_average ~ ., data = data_train) %>% 
    recipes::update_role(id, new_role = "id") %>%
    step_tokenize(overview) %>%
    step_tokenfilter(overview, max_tokens = 100) %>%
    step_tfidf(overview) %>%
    step_dummy(genre_names) %>%
    step_YeoJohnson(runtime)

xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Rows: 32,583
## Columns: 122
## $ id                          <dbl> 611067, 619754, 752440, 113128, 665779, 86…
## $ runtime                     <dbl> 16.123839, 88.509661, 17.232365, 110.91148…
## $ budget                      <dbl> 25058, 25000, 0, 0, 0, 0, 0, 0, 0, 1000000…
## $ vote_average                <dbl> 1.902108, 1.098612, 2.041220, 1.722767, 1.…
## $ tfidf_overview_a            <dbl> 0.06252428, 0.00000000, 0.13641660, 0.1500…
## $ tfidf_overview_about        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_after        <dbl> 0.05662186, 0.05509154, 0.00000000, 0.0815…
## $ tfidf_overview_all          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_an           <dbl> 0.00000000, 0.03849377, 0.12947905, 0.0569…
## $ tfidf_overview_and          <dbl> 0.10336027, 0.02514169, 0.00000000, 0.0744…
## $ tfidf_overview_are          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_as           <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_at           <dbl> 0.00000000, 0.00000000, 0.19032860, 0.0000…
## $ tfidf_overview_back         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_be           <dbl> 0.06162920, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_becomes      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_been         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_before       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_begins       <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_being        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_but          <dbl> 0.05173111, 0.00000000, 0.16930183, 0.0000…
## $ tfidf_overview_by           <dbl> 0.04129338, 0.04017734, 0.00000000, 0.0000…
## $ tfidf_overview_can          <dbl> 0.08205548, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_dark         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_dead         <dbl> 0.0000000, 0.0000000, 0.2699028, 0.0000000…
## $ tfidf_overview_death        <dbl> 0.0791073, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_evil         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_family       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_film         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_find         <dbl> 0.00000000, 0.06754842, 0.00000000, 0.0000…
## $ tfidf_overview_finds        <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_for          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_friends      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_from         <dbl> 0.00000000, 0.09480331, 0.00000000, 0.0000…
## $ tfidf_overview_get          <dbl> 0.08410588, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_girl         <dbl> 0.08224722, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_group        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_has          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_have         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_he           <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_help         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.1251635…
## $ tfidf_overview_her          <dbl> 0.08878968, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_him          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0959…
## $ tfidf_overview_his          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.1716…
## $ tfidf_overview_home         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_horror       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_house        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_in           <dbl> 0.02894298, 0.08448221, 0.09472248, 0.0000…
## $ tfidf_overview_into         <dbl> 0.00000000, 0.15624116, 0.00000000, 0.0000…
## $ tfidf_overview_is           <dbl> 0.03314201, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_it           <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0867…
## $ tfidf_overview_killer       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_life         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_man          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_more         <dbl> 0.08169083, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_must         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mysterious   <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_new          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.1009…
## $ tfidf_overview_night        <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_not          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_now          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_of           <dbl> 0.05058798, 0.07383111, 0.08278033, 0.0364…
## $ tfidf_overview_old          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_on           <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0612…
## $ tfidf_overview_one          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_only         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_or           <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_out          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_own          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_people       <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_she          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_soon         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_story        <dbl> 0.00000000, 0.08102412, 0.00000000, 0.0000…
## $ tfidf_overview_strange      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_take         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_that         <dbl> 0.04097793, 0.03987042, 0.00000000, 0.0590…
## $ tfidf_overview_the          <dbl> 0.15732697, 0.17494273, 0.14711093, 0.0970…
## $ tfidf_overview_their        <dbl> 0.00000000, 0.09460723, 0.00000000, 0.0000…
## $ tfidf_overview_them         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_there        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_they         <dbl> 0.00000000, 0.09887806, 0.00000000, 0.0000…
## $ tfidf_overview_this         <dbl> 0.0000000, 0.1230015, 0.0000000, 0.0000000…
## $ tfidf_overview_three        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_through      <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_time         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_to           <dbl> 0.07395075, 0.07195208, 0.00000000, 0.0709…
## $ tfidf_overview_town         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_two          <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_up           <dbl> 0.06444458, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_was          <dbl> 0.00000000, 0.07773037, 0.00000000, 0.0000…
## $ tfidf_overview_way          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_what         <dbl> 0.00000000, 0.00000000, 0.25600187, 0.0000…
## $ tfidf_overview_when         <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0691…
## $ tfidf_overview_where        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_which        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_while        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_who          <dbl> 0.04781363, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_wife         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_will         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_with         <dbl> 0.00000000, 0.03859542, 0.00000000, 0.0000…
## $ tfidf_overview_woman        <dbl> 0.07006285, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_world        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_years        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_young        <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ genre_names_Adventure       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Animation       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Comedy          <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, …
## $ genre_names_Crime           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Documentary     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Drama           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ genre_names_Family          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Fantasy         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ genre_names_History         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Horror          <dbl> 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ genre_names_Music           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Mystery         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Romance         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Science.Fiction <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Thriller        <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ genre_names_TV.Movie        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_War             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Western         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

set.seed(1027)
xgboost_tune <-
  tune_grid(xgboost_workflow, resamples = data_cv, grid = 5)
## Warning: package 'xgboost' was built under R version 4.3.3

Evaluate Models

tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 8
##   trees min_n .metric .estimator  mean     n std_err .config             
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1  1872    20 rmse    standard   0.218    10 0.00201 Preprocessor1_Model3
## 2  1425    16 rmse    standard   0.218    10 0.00170 Preprocessor1_Model2
## 3   757     2 rmse    standard   0.218    10 0.00234 Preprocessor1_Model1
## 4   807    28 rmse    standard   0.234    10 0.00146 Preprocessor1_Model4
## 5   155    35 rmse    standard   0.277    10 0.00127 Preprocessor1_Model5
# Update model by selecting best hyperparameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
                        tune::select_best(xgboost_tune, metric = "rmse"))

# Fit the model on the entire training data and test it on the test data.
data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       0.206 Preprocessor1_Model1
## 2 rsq     standard       0.587 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
    ggplot(aes(vote_average, .pred)) +
    geom_point(alpha = 0.3, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()

Make Predictions