Goal: Predict the average movie rating

Import data

horror_movies <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): original_title, title, original_language, overview, tagline, post...
## dbl   (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl   (1): adult
## date  (1): release_date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(horror_movies)
Data summary
Name horror_movies
Number of rows 32540
Number of columns 20
_______________________
Column type frequency:
character 10
Date 1
logical 1
numeric 8
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
original_title 0 1.00 1 191 0 30296 0
title 0 1.00 1 191 0 29563 0
original_language 0 1.00 2 2 0 97 0
overview 1286 0.96 1 1000 0 31020 0
tagline 19835 0.39 1 237 0 12513 0
poster_path 4474 0.86 30 32 0 28048 0
status 0 1.00 7 15 0 4 0
backdrop_path 18995 0.42 29 32 0 13536 0
genre_names 0 1.00 6 144 0 772 0
collection_name 30234 0.07 4 56 0 815 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
release_date 0 1 1950-01-01 2022-12-31 2012-12-09 10999

Variable type: logical

skim_variable n_missing complete_rate mean count
adult 0 1 0 FAL: 32540

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1.00 445910.83 305744.67 17 146494.8 426521.00 707534.00 1033095.00 ▇▆▆▅▅
popularity 0 1.00 4.01 37.51 0 0.6 0.84 2.24 5088.58 ▇▁▁▁▁
vote_count 0 1.00 62.69 420.89 0 0.0 2.00 11.00 16900.00 ▇▁▁▁▁
vote_average 0 1.00 3.34 2.88 0 0.0 4.00 5.70 10.00 ▇▂▆▃▁
budget 0 1.00 543126.59 4542667.81 0 0.0 0.00 0.00 200000000.00 ▇▁▁▁▁
revenue 0 1.00 1349746.73 14430479.15 0 0.0 0.00 0.00 701842551.00 ▇▁▁▁▁
runtime 0 1.00 62.14 41.00 0 14.0 80.00 91.00 683.00 ▇▁▁▁▁
collection 30234 0.07 481534.88 324498.16 656 155421.0 471259.00 759067.25 1033032.00 ▇▅▅▅▅
data <- horror_movies %>%
    
    # Treat missing values
    select(-tagline, -backdrop_path, -collection_name, -collection, -release_date, -original_title, -overview, -poster_path) %>% 
    na.omit() %>% 
    
    # Log transform variables with pos-skewed distribution
    mutate(popularity = log(popularity))

Explore data

Identify good predictors

Popularity

data %>% 
    ggplot(aes(vote_average, popularity)) +
    scale_x_log10() +
    geom_point()
## Warning in scale_x_log10(): log-10 transformation introduced infinite values.

Revenue

data %>% 
    ggplot(aes(vote_average, revenue)) +
    geom_point()

Runtime

data %>% 
    ggplot(aes(vote_average, runtime)) +
    geom_point()

Title

data %>%
    
    # Tokenize title
    unnest_tokens(output = word, input = title) %>%
    
    # Calculate avg rating per word
    group_by(word) %>%
    summarise(vote_average = mean(vote_average),
              n = n()) %>%
    ungroup() %>%
    
    filter(n > 5, !str_detect(word, "\\d")) %>%
    slice_max(order_by = vote_average, n = 25) %>%
    
    # Plot
    ggplot(aes(vote_average, fct_reorder(word, vote_average))) +
    geom_point() +
    
    labs(y = "Words in Title")

EDA Shortcut

# Step 1: Prepare data
data_binarized_tbl <- data %>% 
    select(-id, -title) %>% 
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 32,540
## Columns: 44
## $ original_language__cn                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__de                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__en                               <dbl> 1, 1, 1, 1, 0, 0, …
## $ original_language__es                               <dbl> 0, 0, 0, 0, 1, 1, …
## $ original_language__fr                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__id                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__it                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__ja                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__ko                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__pt                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__th                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__zh                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `original_language__-OTHER`                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-Inf_-0.510825623765991`               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-0.510825623765991_-0.174353387144778` <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-0.174353387144778_0.807925688568666`  <dbl> 0, 0, 0, 0, 0, 0, …
## $ popularity__0.807925688568666_Inf                   <dbl> 1, 1, 1, 1, 1, 1, …
## $ `vote_count__-Inf_2`                                <dbl> 0, 0, 0, 0, 0, 1, …
## $ vote_count__2_11                                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ vote_count__11_Inf                                  <dbl> 1, 1, 1, 1, 1, 0, …
## $ `vote_average__-Inf_4`                              <dbl> 0, 0, 0, 0, 0, 1, …
## $ vote_average__4_5.7                                 <dbl> 0, 0, 0, 0, 0, 0, …
## $ vote_average__5.7_Inf                               <dbl> 1, 1, 1, 1, 1, 0, …
## $ budget__0                                           <dbl> 1, 1, 0, 0, 1, 1, …
## $ `budget__-OTHER`                                    <dbl> 0, 0, 1, 1, 0, 0, …
## $ revenue__0                                          <dbl> 0, 0, 0, 0, 1, 1, …
## $ `revenue__-OTHER`                                   <dbl> 1, 1, 1, 1, 0, 0, …
## $ `runtime__-Inf_14`                                  <dbl> 0, 0, 0, 0, 1, 1, …
## $ runtime__14_80                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ runtime__80_91                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ runtime__91_Inf                                     <dbl> 1, 1, 1, 1, 0, 0, …
## $ status__Released                                    <dbl> 1, 1, 1, 1, 1, 1, …
## $ `status__-OTHER`                                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Animation,_Horror`                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Comedy,_Horror`                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Drama,_Horror`                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Drama,_Horror,_Thriller`              <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Fantasy,_Horror`                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ genre_names__Horror                                 <dbl> 0, 0, 0, 0, 1, 0, …
## $ `genre_names__Horror,_Mystery`                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Horror,_Mystery,_Thriller`            <dbl> 0, 0, 1, 0, 0, 0, …
## $ `genre_names__Horror,_Science_Fiction`              <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Horror,_Thriller`                     <dbl> 1, 0, 0, 1, 0, 1, …
## $ `genre_names__-OTHER`                               <dbl> 0, 1, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>% 
    correlate(popularity__0.807925688568666_Inf)

data_corr_tbl
## # A tibble: 44 × 3
##    feature      bin                                  correlation
##    <fct>        <chr>                                      <dbl>
##  1 popularity   0.807925688568666_Inf                      1    
##  2 vote_count   11_Inf                                     0.775
##  3 vote_count   -Inf_2                                    -0.588
##  4 popularity   -Inf_-0.510825623765991                   -0.494
##  5 vote_average -Inf_4                                    -0.416
##  6 popularity   -0.174353387144778_0.807925688568666      -0.328
##  7 vote_average 4_5.7                                      0.314
##  8 revenue      0                                         -0.303
##  9 revenue      -OTHER                                     0.303
## 10 runtime      -Inf_14                                   -0.297
## # ℹ 34 more rows
# Step 3: Plot
data_corr_tbl %>% 
    plot_correlation_funnel()
## Warning: ggrepel: 21 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Build models

Split Data

data <- sample_n(data, 100)

# Split into train and test data-set
set.seed(1234)
data_split <- rsample::initial_split(data)
data_train <- training(data_split)
data_test  <- testing(data_split)

# Further split training data-set for cross-validation
set.seed(4321)
data_cv <- rsample::vfold_cv(data_train)
data_cv
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits         id    
##    <list>         <chr> 
##  1 <split [67/8]> Fold01
##  2 <split [67/8]> Fold02
##  3 <split [67/8]> Fold03
##  4 <split [67/8]> Fold04
##  5 <split [67/8]> Fold05
##  6 <split [68/7]> Fold06
##  7 <split [68/7]> Fold07
##  8 <split [68/7]> Fold08
##  9 <split [68/7]> Fold09
## 10 <split [68/7]> Fold10
library(usemodels)
usemodels::use_xgboost(vote_average ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = vote_average ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(4796)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_recipe <- 
  recipe(formula = vote_average ~ ., data = data_train) %>% 
  step_zv(all_predictors()) %>%
    update_role(id, new_role = "id variable") %>%
    step_tokenize(title) %>% 
    step_tokenfilter(title, max_tokens = 100) %>%
    step_tfidf(title) %>% 
    step_other(original_language, genre_names) %>%
    step_dummy(original_language, genre_names, one_hot = TRUE) %>%
    step_YeoJohnson(popularity, vote_count, budget, revenue, runtime)

xgboost_recipe %>% prep() %>% juice() %>% glimpse()
## Rows: 75
## Columns: 115
## $ id                                  <dbl> 526052, 81384, 13094, 876269, 5328…
## $ popularity                          <dbl> 0.7677197, 0.3245809, 0.7198756, -…
## $ vote_count                          <dbl> 1.7202966, 1.1047995, 1.8427917, 0…
## $ budget                              <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ revenue                             <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ runtime                             <dbl> 23.134628, 22.988422, 21.490318, 2…
## $ vote_average                        <dbl> 5.5, 4.6, 5.9, 10.0, 0.0, 4.3, 1.0…
## $ tfidf_title_068                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_1984                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_2                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_50                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_7                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_a                       <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_agency                  <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_alive                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_alone                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_american                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_amerikan                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_and                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_apartment               <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_apparition              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_assassin                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_atta                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_attack                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_auopssessed             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_aylesbury               <dbl> 0.000000, 0.000000, 0.000000, 2.16…
## $ tfidf_title_battle                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_becoming                <dbl> 4.330733, 0.000000, 0.000000, 0.00…
## $ tfidf_title_bitch                   <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_black                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_blues                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_blurred                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_bone                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_boo                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_breathe                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_capps                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_cat                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_cobre                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_cookie                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_cried                   <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_crossing                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_cut                     <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_dale                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_dance                   <dbl> 0.000000, 0.000000, 1.082683, 0.00…
## $ tfidf_title_dangerous               <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_dans                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_darkness                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_day                     <dbl> 0.0000000, 0.0000000, 0.0000000, 0…
## $ tfidf_title_de                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_dead                    <dbl> 0.0000000, 0.0000000, 0.8145241, 1…
## $ tfidf_title_death                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_dementia                <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_demons                  <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ `tfidf_title_devil's`               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_doctor                  <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ `tfidf_title_don't`                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_du                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_eizou                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_eta                     <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_evil                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_experimenting           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_expira                  <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_express                 <dbl> 0.000000, 2.165367, 0.000000, 0.00…
## $ tfidf_title_eyes                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_face                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_familiar                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_fantasies               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_first                   <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_foot                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_forbidden               <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_game                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_ghost                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_ghosting                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_girl                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_glass                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_gojusan                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_gotas                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_heart                   <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_holocaust               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_hontou                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_howling                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_in                      <dbl> 0.0000000, 0.0000000, 0.0000000, 0…
## $ tfidf_title_is                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_joaquim                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_kalong                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_kill                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_killer                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_kore                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_la                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_lost                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_love                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_lurking                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_mala                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_man                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_mari                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_market                  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_mask                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_memo                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_monk                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_monster                 <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_moon                    <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_movie                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_new                     <dbl> 0.000000, 0.000000, 0.000000, 0.00…
## $ tfidf_title_of                      <dbl> 0.0000000, 0.0000000, 0.5848498, 0…
## $ tfidf_title_sex                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ tfidf_title_the                     <dbl> 0.0000000, 0.6785120, 0.3392560, 0…
## $ tfidf_title_vampire                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ original_language_en                <dbl> 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1…
## $ original_language_ja                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ original_language_other             <dbl> 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0…
## $ genre_names_Comedy..Horror          <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1…
## $ genre_names_Drama..Horror..Thriller <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names_Horror                  <dbl> 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0…
## $ genre_names_Horror..Thriller        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0…
## $ genre_names_other                   <dbl> 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

set.seed(4796)
xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry
## Warning: package 'xgboost' was built under R version 4.3.3
## → A | warning: A correlation computation is required, but `estimate` is constant and has 0
##                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## 
There were issues with some computations   A: x1

There were issues with some computations   A: x2

There were issues with some computations   A: x3

There were issues with some computations   A: x4

There were issues with some computations   A: x5

There were issues with some computations   A: x6

There were issues with some computations   A: x7

There were issues with some computations   A: x8

There were issues with some computations   A: x9

There were issues with some computations   A: x10

There were issues with some computations   A: x10

Evaluate Models

tune::show_best(xgboost_tune, metric = "rsq")
## # A tibble: 4 × 10
##    mtry trees min_n learn_rate .metric .estimator  mean     n std_err .config   
##   <int> <int> <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>     
## 1    91  1985    22    0.0183  rsq     standard   0.672    10  0.0722 Preproces…
## 2    50   527     2    0.00349 rsq     standard   0.663    10  0.0736 Preproces…
## 3    32  1073    15    0.0434  rsq     standard   0.611    10  0.0716 Preproces…
## 4    88   337    31    0.00111 rsq     standard   0.152    10  0.0512 Preproces…
# Update the model by selecting the best hyperparameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
                                      tune::select_best(xgboost_tune, metric = "rmse"))

# Fit the model on the entire training data and test it on the test data
data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       1.42  Preprocessor1_Model1
## 2 rsq     standard       0.728 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
    ggplot(aes(vote_average, .pred)) +
    geom_point(alpha = 0.3, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()