Goal: Predict the success of movies

Import data

horror_movies <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): original_title, title, original_language, overview, tagline, post...
## dbl   (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl   (1): adult
## date  (1): release_date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skimr::skim(horror_movies)
Data summary
Name horror_movies
Number of rows 32540
Number of columns 20
_______________________
Column type frequency:
character 10
Date 1
logical 1
numeric 8
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
original_title 0 1.00 1 191 0 30296 0
title 0 1.00 1 191 0 29563 0
original_language 0 1.00 2 2 0 97 0
overview 1286 0.96 1 1000 0 31020 0
tagline 19835 0.39 1 237 0 12513 0
poster_path 4474 0.86 30 32 0 28048 0
status 0 1.00 7 15 0 4 0
backdrop_path 18995 0.42 29 32 0 13536 0
genre_names 0 1.00 6 144 0 772 0
collection_name 30234 0.07 4 56 0 815 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
release_date 0 1 1950-01-01 2022-12-31 2012-12-09 10999

Variable type: logical

skim_variable n_missing complete_rate mean count
adult 0 1 0 FAL: 32540

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1.00 445910.83 305744.67 17 146494.8 426521.00 707534.00 1033095.00 ▇▆▆▅▅
popularity 0 1.00 4.01 37.51 0 0.6 0.84 2.24 5088.58 ▇▁▁▁▁
vote_count 0 1.00 62.69 420.89 0 0.0 2.00 11.00 16900.00 ▇▁▁▁▁
vote_average 0 1.00 3.34 2.88 0 0.0 4.00 5.70 10.00 ▇▂▆▃▁
budget 0 1.00 543126.59 4542667.81 0 0.0 0.00 0.00 200000000.00 ▇▁▁▁▁
revenue 0 1.00 1349746.73 14430479.15 0 0.0 0.00 0.00 701842551.00 ▇▁▁▁▁
runtime 0 1.00 62.14 41.00 0 14.0 80.00 91.00 683.00 ▇▁▁▁▁
collection 30234 0.07 481534.88 324498.16 656 155421.0 471259.00 759067.25 1033032.00 ▇▅▅▅▅
data <- horror_movies %>%
    
    # Treat missing values
    select(-tagline, -backdrop_path, -collection_name, -collection, -release_date, -original_title, -overview, -poster_path) %>% 
    na.omit() %>% 
    
        # Log transform variables with pos-skewed distribution
    mutate(popularity = log(popularity))

Explore data

Identify good predictors

Popularity

data %>% 
    ggplot(aes(budget, popularity)) +
    scale_x_log10() +
    geom_point()
## Warning in scale_x_log10(): log-10 transformation introduced infinite values.

Revenue

data %>% 
    ggplot(aes(budget, revenue)) +
    geom_point()

Runtime

data %>% 
    ggplot(aes(budget, runtime)) +
    geom_point()

Title

data %>%
    
    # Tokenize title
    unnest_tokens(output = word, input = title) %>%
    
    # Calculate avg rating per word
    group_by(word) %>%
    summarise(budget = mean(budget),
              n = n()) %>%
    ungroup() %>%
    
    filter(n > 5, !str_detect(word, "\\d")) %>%
    slice_max(order_by = budget, n = 25) %>%
    
    # Plot
    ggplot(aes(budget, fct_reorder(word, budget))) +
    geom_point() +
    
    labs(y = "Words in Title")

EDA Shortcut

# Step 1: Prepare data
data_binarized_tbl <- data %>% 
    select(-id, -title) %>% 
    binarize()

data_binarized_tbl %>% glimpse()
## Rows: 32,540
## Columns: 44
## $ original_language__cn                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__de                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__en                               <dbl> 1, 1, 1, 1, 0, 0, …
## $ original_language__es                               <dbl> 0, 0, 0, 0, 1, 1, …
## $ original_language__fr                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__id                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__it                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__ja                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__ko                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__pt                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__th                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ original_language__zh                               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `original_language__-OTHER`                         <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-Inf_-0.510825623765991`               <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-0.510825623765991_-0.174353387144778` <dbl> 0, 0, 0, 0, 0, 0, …
## $ `popularity__-0.174353387144778_0.807925688568666`  <dbl> 0, 0, 0, 0, 0, 0, …
## $ popularity__0.807925688568666_Inf                   <dbl> 1, 1, 1, 1, 1, 1, …
## $ `vote_count__-Inf_2`                                <dbl> 0, 0, 0, 0, 0, 1, …
## $ vote_count__2_11                                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ vote_count__11_Inf                                  <dbl> 1, 1, 1, 1, 1, 0, …
## $ `vote_average__-Inf_4`                              <dbl> 0, 0, 0, 0, 0, 1, …
## $ vote_average__4_5.7                                 <dbl> 0, 0, 0, 0, 0, 0, …
## $ vote_average__5.7_Inf                               <dbl> 1, 1, 1, 1, 1, 0, …
## $ budget__0                                           <dbl> 1, 1, 0, 0, 1, 1, …
## $ `budget__-OTHER`                                    <dbl> 0, 0, 1, 1, 0, 0, …
## $ revenue__0                                          <dbl> 0, 0, 0, 0, 1, 1, …
## $ `revenue__-OTHER`                                   <dbl> 1, 1, 1, 1, 0, 0, …
## $ `runtime__-Inf_14`                                  <dbl> 0, 0, 0, 0, 1, 1, …
## $ runtime__14_80                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ runtime__80_91                                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ runtime__91_Inf                                     <dbl> 1, 1, 1, 1, 0, 0, …
## $ status__Released                                    <dbl> 1, 1, 1, 1, 1, 1, …
## $ `status__-OTHER`                                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Animation,_Horror`                    <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Comedy,_Horror`                       <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Drama,_Horror`                        <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Drama,_Horror,_Thriller`              <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Fantasy,_Horror`                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ genre_names__Horror                                 <dbl> 0, 0, 0, 0, 1, 0, …
## $ `genre_names__Horror,_Mystery`                      <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Horror,_Mystery,_Thriller`            <dbl> 0, 0, 1, 0, 0, 0, …
## $ `genre_names__Horror,_Science_Fiction`              <dbl> 0, 0, 0, 0, 0, 0, …
## $ `genre_names__Horror,_Thriller`                     <dbl> 1, 0, 0, 1, 0, 1, …
## $ `genre_names__-OTHER`                               <dbl> 0, 1, 0, 0, 0, 0, …
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>% 
    correlate('budget__-OTHER')

data_corr_tbl
## # A tibble: 44 × 3
##    feature           bin                     correlation
##    <fct>             <chr>                         <dbl>
##  1 budget            0                           -1     
##  2 budget            -OTHER                       1     
##  3 revenue           0                           -0.331 
##  4 revenue           -OTHER                       0.331 
##  5 original_language en                           0.163 
##  6 vote_count        11_Inf                       0.161 
##  7 popularity        0.807925688568666_Inf        0.160 
##  8 vote_count        -Inf_2                      -0.0898
##  9 popularity        -Inf_-0.510825623765991     -0.0852
## 10 original_language ja                          -0.0774
## # ℹ 34 more rows
# Step 3: Plot
data_corr_tbl %>% 
    plot_correlation_funnel()
## Warning: ggrepel: 23 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Build models

Split Data

data <- sample_n(data, 100)

# Split into train and test data-set
set.seed(1234)
data_split <- rsample::initial_split(data)
data_train <- training(data_split)
data_test  <- testing(data_split)

# Further split training data-set for cross-validation
set.seed(4321)
data_cv <- rsample::vfold_cv(data_train)
data_cv
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits         id    
##    <list>         <chr> 
##  1 <split [67/8]> Fold01
##  2 <split [67/8]> Fold02
##  3 <split [67/8]> Fold03
##  4 <split [67/8]> Fold04
##  5 <split [67/8]> Fold05
##  6 <split [68/7]> Fold06
##  7 <split [68/7]> Fold07
##  8 <split [68/7]> Fold08
##  9 <split [68/7]> Fold09
## 10 <split [68/7]> Fold10
library(usemodels)
usemodels::use_xgboost(budget ~ ., data = data_train)
## xgboost_recipe <- 
##   recipe(formula = budget ~ ., data = data_train) %>% 
##   step_zv(all_predictors()) 
## 
## xgboost_spec <- 
##   boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
##     loss_reduction = tune(), sample_size = tune()) %>% 
##   set_mode("classification") %>% 
##   set_engine("xgboost") 
## 
## xgboost_workflow <- 
##   workflow() %>% 
##   add_recipe(xgboost_recipe) %>% 
##   add_model(xgboost_spec) 
## 
## set.seed(4796)
## xgboost_tune <-
##   tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_recipe <- 
  recipe(formula = budget ~ ., data = data_train) %>% 
  step_zv(all_predictors()) %>%
    update_role(id, new_role = "id variable") %>%
    step_tokenize(title) %>% 
    step_tokenfilter(title, max_tokens = 100) %>%
    step_tfidf(title) %>% 
    step_other(original_language, genre_names) %>%
    step_dummy(original_language, genre_names, one_hot = TRUE) %>%
    step_YeoJohnson(popularity, vote_count, budget, revenue, runtime)

xgboost_recipe %>% prep() %>% juice() %>% glimpse()
## Rows: 75
## Columns: 115
## $ id                           <dbl> 684433, 471920, 518695, 741187, 969412, 4…
## $ popularity                   <dbl> -0.7699431, -0.7699431, -0.7699431, -0.76…
## $ vote_count                   <dbl> 0.0000000, 0.0000000, 0.6111858, 0.902255…
## $ vote_average                 <dbl> 0.0, 0.0, 6.0, 6.5, 0.0, 3.9, 10.0, 5.0, …
## $ revenue                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ runtime                      <dbl> 36.637345, 47.797968, 3.520282, 47.797968…
## $ budget                       <dbl> 0.000000, 1.575996, 0.000000, 0.000000, 0…
## $ tfidf_title_2                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_4                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_a                <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_aber             <dbl> 2.165367, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_above            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_absent           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_addiction        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_aftershocks      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_alive            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_all              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_apocalypse       <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_at               <dbl> 0.0000000, 0.0000000, 0.7217889, 0.000000…
## $ tfidf_title_avacalho         <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_bad              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_badd             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_ballad           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_ban              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_bath             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_beast            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_been             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_bender           <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_bigfoot          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.721788…
## $ tfidf_title_billy            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_bitch            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_bites            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_blackheath       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_blinder          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_blood            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_blutschrei       <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 2…
## $ tfidf_title_broadcast        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_by               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_call             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_camp             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_carol            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_caronte          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_chapter          <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_circus           <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_cross            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_curtição         <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_da               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_de               <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_death            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_der              <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 2…
## $ tfidf_title_devil            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_digging          <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_dilim            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_discarnate       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_disputes         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_do               <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ `tfidf_title_don't`          <dbl> 0.0000000, 0.0000000, 0.7217889, 0.000000…
## $ tfidf_title_dr               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_du               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_earth            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_exorcists        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_experiment       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_faeries          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_fatal            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_fender           <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_file             <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_final            <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_flavia           <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_for              <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_francisville     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_frank            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_frightmare       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_ghost            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_glücklich        <dbl> 2.165367, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_go               <dbl> 0.0000000, 0.0000000, 0.7217889, 0.000000…
## $ tfidf_title_grief            <dbl> 0.000000, 2.165367, 0.000000, 0.000000, 0…
## $ tfidf_title_halloween        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_hell             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_heretic          <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_high             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_himmel           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_horrible         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_house            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_i                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ `tfidf_title_i've`           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_iemon            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_into             <dbl> 0.0000000, 0.0000000, 0.7217889, 0.000000…
## $ tfidf_title_irina            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_is               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_isle             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_it               <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_jaakkirathai     <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_kaiki            <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_kills            <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_kolobos          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_konak            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_kowasugi         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.000000…
## $ tfidf_title_la               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_larmes           <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_last             <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0…
## $ tfidf_title_legacy           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_legend           <dbl> 0.000000, 0.000000, 0.000000, 0.608443, 0…
## $ tfidf_title_lilith           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ tfidf_title_night            <dbl> 0.000000, 0.000000, 0.608443, 0.000000, 0…
## $ tfidf_title_of               <dbl> 0.0000000, 0.0000000, 0.0000000, 0.389899…
## $ tfidf_title_on               <dbl> 0.000000, 0.000000, 0.000000, 0.608443, 0…
## $ tfidf_title_the              <dbl> 0.0000000, 0.7790723, 0.2596908, 0.519381…
## $ original_language_en         <dbl> 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0,…
## $ original_language_es         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ original_language_ja         <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,…
## $ original_language_other      <dbl> 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,…
## $ genre_names_Comedy..Horror   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,…
## $ genre_names_Horror           <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0,…
## $ genre_names_Horror..Thriller <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,…
## $ genre_names_other            <dbl> 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1,…
xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_recipe) %>% 
  add_model(xgboost_spec) 

set.seed(4796)
xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5)
## i Creating pre-processing data to finalize unknown parameter: mtry
## Warning: package 'xgboost' was built under R version 4.3.3
## → A | warning: A correlation computation is required, but `estimate` is constant and has 0
##                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## 
There were issues with some computations   A: x1

There were issues with some computations   A: x2

There were issues with some computations   A: x3

There were issues with some computations   A: x4

There were issues with some computations   A: x5

There were issues with some computations   A: x6

                                                 
→ B | warning: A correlation computation is required, but `truth` is constant and has 0
##                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
## There were issues with some computations   A: x6

There were issues with some computations   A: x6   B: x1

There were issues with some computations   A: x7   B: x1

There were issues with some computations   A: x8   B: x1

There were issues with some computations   A: x9   B: x1

There were issues with some computations   A: x9   B: x1

Evaluate Models

tune::show_best(xgboost_tune, metric = "rsq")
## # A tibble: 4 × 10
##    mtry trees min_n learn_rate .metric .estimator  mean     n std_err .config   
##   <int> <int> <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>     
## 1    50   527     2    0.00349 rsq     standard   0.374     9  0.131  Preproces…
## 2    32  1073    15    0.0434  rsq     standard   0.260     9  0.0929 Preproces…
## 3    91  1985    22    0.0183  rsq     standard   0.162     9  0.0398 Preproces…
## 4    88   337    31    0.00111 rsq     standard   0.122     9  0.0323 Preproces…
# Update the model by selecting the best hyperparameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
                                      tune::select_best(xgboost_tune, metric = "rmse"))

# Fit the model on the entire training data and test it on the test data
data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard       0.453 Preprocessor1_Model1
## 2 rsq     standard       0.158 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
    ggplot(aes(budget, .pred)) +
    geom_point(alpha = 0.3, fill = "midnightblue") +
    geom_abline(lty = 2, color = "gray50") +
    coord_fixed()