Goal: to predict the average rating of horror movies. Click here for the data
horror_movies <- read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): original_title, title, original_language, overview, tagline, post...
## dbl (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl (1): adult
## date (1): release_date
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
skim(horror_movies)
Name | horror_movies |
Number of rows | 32540 |
Number of columns | 20 |
_______________________ | |
Column type frequency: | |
character | 10 |
Date | 1 |
logical | 1 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
original_title | 0 | 1.00 | 1 | 191 | 0 | 30296 | 0 |
title | 0 | 1.00 | 1 | 191 | 0 | 29563 | 0 |
original_language | 0 | 1.00 | 2 | 2 | 0 | 97 | 0 |
overview | 1286 | 0.96 | 1 | 1000 | 0 | 31020 | 0 |
tagline | 19835 | 0.39 | 1 | 237 | 0 | 12513 | 0 |
poster_path | 4474 | 0.86 | 30 | 32 | 0 | 28048 | 0 |
status | 0 | 1.00 | 7 | 15 | 0 | 4 | 0 |
backdrop_path | 18995 | 0.42 | 29 | 32 | 0 | 13536 | 0 |
genre_names | 0 | 1.00 | 6 | 144 | 0 | 772 | 0 |
collection_name | 30234 | 0.07 | 4 | 56 | 0 | 815 | 0 |
Variable type: Date
skim_variable | n_missing | complete_rate | min | max | median | n_unique |
---|---|---|---|---|---|---|
release_date | 0 | 1 | 1950-01-01 | 2022-12-31 | 2012-12-09 | 10999 |
Variable type: logical
skim_variable | n_missing | complete_rate | mean | count |
---|---|---|---|---|
adult | 0 | 1 | 0 | FAL: 32540 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
id | 0 | 1.00 | 445910.83 | 305744.67 | 17 | 146494.8 | 426521.00 | 707534.00 | 1033095.00 | ▇▆▆▅▅ |
popularity | 0 | 1.00 | 4.01 | 37.51 | 0 | 0.6 | 0.84 | 2.24 | 5088.58 | ▇▁▁▁▁ |
vote_count | 0 | 1.00 | 62.69 | 420.89 | 0 | 0.0 | 2.00 | 11.00 | 16900.00 | ▇▁▁▁▁ |
vote_average | 0 | 1.00 | 3.34 | 2.88 | 0 | 0.0 | 4.00 | 5.70 | 10.00 | ▇▂▆▃▁ |
budget | 0 | 1.00 | 543126.59 | 4542667.81 | 0 | 0.0 | 0.00 | 0.00 | 200000000.00 | ▇▁▁▁▁ |
revenue | 0 | 1.00 | 1349746.73 | 14430479.15 | 0 | 0.0 | 0.00 | 0.00 | 701842551.00 | ▇▁▁▁▁ |
runtime | 0 | 1.00 | 62.14 | 41.00 | 0 | 14.0 | 80.00 | 91.00 | 683.00 | ▇▁▁▁▁ |
collection | 30234 | 0.07 | 481534.88 | 324498.16 | 656 | 155421.0 | 471259.00 | 759067.25 | 1033032.00 | ▇▅▅▅▅ |
data <- horror_movies %>%
mutate(vote_average = log1p(vote_average)) %>%
filter(!is.na(overview), vote_count != 0) %>%
separate_rows(genre_names, sep = ", ") %>%
filter(status == "Released") %>%
select(id, vote_average, genre_names, overview, runtime, budget) %>%
na.omit()
identify good predictors.
budget
data %>%
ggplot(aes(vote_average, budget)) +
geom_point()
runtime
data %>%
ggplot(aes(vote_average, runtime)) +
geom_point()
title
data %>%
#tokenize title
unnest_tokens(output = word, input = overview) %>%
# calculate avg rent per word
group_by(word) %>%
summarise(vote_average = mean(vote_average),
n = n()) %>%
ungroup() %>%
filter(n > 10, !str_detect(word, "\\d")) %>%
slice_max(order_by = vote_average, n = 20) %>%
# plot
ggplot(aes(vote_average, fct_reorder(word, vote_average))) +
geom_point() +
labs(y = "Words in Overview")
EDA shortcut
# Step 1: Prepare data
data_binarized_tbl <- data %>%
select(-overview, -id) %>%
binarize()
data_binarized_tbl %>% glimpse()
## Rows: 43,444
## Columns: 23
## $ `vote_average__-Inf_1.66770682055808` <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.66770682055808_1.84054963339749 <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.84054963339749_1.97408102602201 <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ vote_average__1.97408102602201_Inf <dbl> 1, 1, 1, 1, 1, 1, 1, 1…
## $ genre_names__Action <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Adventure <dbl> 0, 0, 1, 0, 0, 0, 0, 0…
## $ genre_names__Animation <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Comedy <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Crime <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Drama <dbl> 0, 0, 0, 1, 0, 0, 0, 0…
## $ genre_names__Fantasy <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Horror <dbl> 1, 0, 0, 0, 1, 1, 0, 0…
## $ genre_names__Mystery <dbl> 0, 0, 0, 0, 0, 0, 1, 0…
## $ genre_names__Science_Fiction <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ genre_names__Thriller <dbl> 0, 1, 0, 0, 0, 0, 0, 1…
## $ genre_names__TV_Movie <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ `genre_names__-OTHER` <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ `runtime__-Inf_75` <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ runtime__75_87 <dbl> 0, 0, 0, 0, 0, 0, 0, 0…
## $ runtime__87_95 <dbl> 0, 0, 1, 1, 1, 0, 0, 0…
## $ runtime__95_Inf <dbl> 1, 1, 0, 0, 0, 1, 1, 1…
## $ budget__0 <dbl> 1, 1, 1, 1, 1, 0, 0, 0…
## $ `budget__-OTHER` <dbl> 0, 0, 0, 0, 0, 1, 1, 1…
# Step 2: Correlate
data_corr_tbl <- data_binarized_tbl %>%
correlate(vote_average__1.97408102602201_Inf)
data_corr_tbl
## # A tibble: 23 × 3
## feature bin correlation
## <fct> <chr> <dbl>
## 1 vote_average 1.97408102602201_Inf 1
## 2 vote_average -Inf_1.66770682055808 -0.343
## 3 vote_average 1.84054963339749_1.97408102602201 -0.328
## 4 vote_average 1.66770682055808_1.84054963339749 -0.326
## 5 runtime -Inf_75 0.192
## 6 runtime 87_95 -0.128
## 7 runtime 75_87 -0.125
## 8 genre_names Animation 0.0769
## 9 runtime 95_Inf 0.0593
## 10 budget 0 -0.0497
## # ℹ 13 more rows
# Step 3: Plot
data_corr_tbl %>%
plot_correlation_funnel()
## Warning: ggrepel: 13 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
Split Data
#data <- sample_n(data, 100)
#Split into train and test dataset
set.seed(1234)
data_split <- rsample::initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)
#Further split training dataset for cross-validation
set.seed(4321)
data_cv <- rsample::vfold_cv(data_train)
data_cv
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [29324/3259]> Fold01
## 2 <split [29324/3259]> Fold02
## 3 <split [29324/3259]> Fold03
## 4 <split [29325/3258]> Fold04
## 5 <split [29325/3258]> Fold05
## 6 <split [29325/3258]> Fold06
## 7 <split [29325/3258]> Fold07
## 8 <split [29325/3258]> Fold08
## 9 <split [29325/3258]> Fold09
## 10 <split [29325/3258]> Fold10
library(usemodels)
## Warning: package 'usemodels' was built under R version 4.3.3
usemodels::use_xgboost(vote_average ~ ., data = data_train)
## xgboost_recipe <-
## recipe(formula = vote_average ~ ., data = data_train) %>%
## step_zv(all_predictors())
##
## xgboost_spec <-
## boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
## loss_reduction = tune(), sample_size = tune()) %>%
## set_mode("classification") %>%
## set_engine("xgboost")
##
## xgboost_workflow <-
## workflow() %>%
## add_recipe(xgboost_recipe) %>%
## add_model(xgboost_spec)
##
## set.seed(48629)
## xgboost_tune <-
## tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
xgboost_recipe <-
recipe(formula = vote_average ~ ., data = data_train) %>%
recipes::update_role(id, new_role = "id") %>%
step_tokenize(overview) %>%
step_tokenfilter(overview, max_tokens = 100) %>%
step_tfidf(overview) %>%
step_dummy(genre_names) %>%
step_YeoJohnson(runtime)
xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Rows: 32,583
## Columns: 122
## $ id <dbl> 611067, 619754, 752440, 113128, 665779, 86…
## $ runtime <dbl> 16.123839, 88.509661, 17.232365, 110.91148…
## $ budget <dbl> 25058, 25000, 0, 0, 0, 0, 0, 0, 0, 1000000…
## $ vote_average <dbl> 1.902108, 1.098612, 2.041220, 1.722767, 1.…
## $ tfidf_overview_a <dbl> 0.06252428, 0.00000000, 0.13641660, 0.1500…
## $ tfidf_overview_about <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_after <dbl> 0.05662186, 0.05509154, 0.00000000, 0.0815…
## $ tfidf_overview_all <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_an <dbl> 0.00000000, 0.03849377, 0.12947905, 0.0569…
## $ tfidf_overview_and <dbl> 0.10336027, 0.02514169, 0.00000000, 0.0744…
## $ tfidf_overview_are <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_as <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_at <dbl> 0.00000000, 0.00000000, 0.19032860, 0.0000…
## $ tfidf_overview_back <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_be <dbl> 0.06162920, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_becomes <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_been <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_before <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_begins <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_being <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_but <dbl> 0.05173111, 0.00000000, 0.16930183, 0.0000…
## $ tfidf_overview_by <dbl> 0.04129338, 0.04017734, 0.00000000, 0.0000…
## $ tfidf_overview_can <dbl> 0.08205548, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_dark <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_dead <dbl> 0.0000000, 0.0000000, 0.2699028, 0.0000000…
## $ tfidf_overview_death <dbl> 0.0791073, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_evil <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_family <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_film <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_find <dbl> 0.00000000, 0.06754842, 0.00000000, 0.0000…
## $ tfidf_overview_finds <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_for <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_friends <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_from <dbl> 0.00000000, 0.09480331, 0.00000000, 0.0000…
## $ tfidf_overview_get <dbl> 0.08410588, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_girl <dbl> 0.08224722, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_group <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_has <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_have <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_he <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_help <dbl> 0.0000000, 0.0000000, 0.0000000, 0.1251635…
## $ tfidf_overview_her <dbl> 0.08878968, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_him <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0959…
## $ tfidf_overview_his <dbl> 0.00000000, 0.00000000, 0.00000000, 0.1716…
## $ tfidf_overview_home <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_horror <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_house <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_in <dbl> 0.02894298, 0.08448221, 0.09472248, 0.0000…
## $ tfidf_overview_into <dbl> 0.00000000, 0.15624116, 0.00000000, 0.0000…
## $ tfidf_overview_is <dbl> 0.03314201, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_it <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0867…
## $ tfidf_overview_killer <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_life <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_man <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_more <dbl> 0.08169083, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_must <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mysterious <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_new <dbl> 0.00000000, 0.00000000, 0.00000000, 0.1009…
## $ tfidf_overview_night <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_not <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_now <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_of <dbl> 0.05058798, 0.07383111, 0.08278033, 0.0364…
## $ tfidf_overview_old <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_on <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0612…
## $ tfidf_overview_one <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_only <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_or <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_out <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_own <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_people <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_she <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_soon <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_story <dbl> 0.00000000, 0.08102412, 0.00000000, 0.0000…
## $ tfidf_overview_strange <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_take <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_that <dbl> 0.04097793, 0.03987042, 0.00000000, 0.0590…
## $ tfidf_overview_the <dbl> 0.15732697, 0.17494273, 0.14711093, 0.0970…
## $ tfidf_overview_their <dbl> 0.00000000, 0.09460723, 0.00000000, 0.0000…
## $ tfidf_overview_them <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_there <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_they <dbl> 0.00000000, 0.09887806, 0.00000000, 0.0000…
## $ tfidf_overview_this <dbl> 0.0000000, 0.1230015, 0.0000000, 0.0000000…
## $ tfidf_overview_three <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_through <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_time <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_to <dbl> 0.07395075, 0.07195208, 0.00000000, 0.0709…
## $ tfidf_overview_town <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_two <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_up <dbl> 0.06444458, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_was <dbl> 0.00000000, 0.07773037, 0.00000000, 0.0000…
## $ tfidf_overview_way <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_what <dbl> 0.00000000, 0.00000000, 0.25600187, 0.0000…
## $ tfidf_overview_when <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0691…
## $ tfidf_overview_where <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_which <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_while <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_who <dbl> 0.04781363, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_wife <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_will <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_with <dbl> 0.00000000, 0.03859542, 0.00000000, 0.0000…
## $ tfidf_overview_woman <dbl> 0.07006285, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_world <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ tfidf_overview_years <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_young <dbl> 0.00000000, 0.00000000, 0.00000000, 0.0000…
## $ genre_names_Adventure <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Animation <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Comedy <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, …
## $ genre_names_Crime <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Documentary <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Drama <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, …
## $ genre_names_Family <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Fantasy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, …
## $ genre_names_History <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Horror <dbl> 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, …
## $ genre_names_Music <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Mystery <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Romance <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Science.Fiction <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Thriller <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ genre_names_TV.Movie <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_War <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Western <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
set.seed(1027)
xgboost_tune <-
tune_grid(xgboost_workflow, resamples = data_cv, grid = 5)
## Warning: package 'xgboost' was built under R version 4.3.3
tune::show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 8
## trees min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1872 20 rmse standard 0.218 10 0.00201 Preprocessor1_Model3
## 2 1425 16 rmse standard 0.218 10 0.00170 Preprocessor1_Model2
## 3 757 2 rmse standard 0.218 10 0.00234 Preprocessor1_Model1
## 4 807 28 rmse standard 0.234 10 0.00146 Preprocessor1_Model4
## 5 155 35 rmse standard 0.277 10 0.00127 Preprocessor1_Model5
# Update model by selecting best hyperparameters
xgboost_fw <- tune::finalize_workflow(xgboost_workflow,
tune::select_best(xgboost_tune, metric = "rmse"))
# Fit the model on the entire training data and test it on the test data.
data_fit <- tune::last_fit(xgboost_fw, data_split)
tune::collect_metrics(data_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.206 Preprocessor1_Model1
## 2 rsq standard 0.587 Preprocessor1_Model1
tune::collect_predictions(data_fit) %>%
ggplot(aes(vote_average, .pred)) +
geom_point(alpha = 0.3, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()