Horror Movies: Build a regression model to predict the average movie rating (vote_average). Use the horror_movies dataset.
horror_movies <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (10): original_title, title, original_language, overview, tagline, post...
## dbl (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl (1): adult
## date (1): release_date
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
horror_movies %>% skimr::skim()
Name | Piped data |
Number of rows | 32540 |
Number of columns | 20 |
_______________________ | |
Column type frequency: | |
character | 10 |
Date | 1 |
logical | 1 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
original_title | 0 | 1.00 | 1 | 191 | 0 | 30296 | 0 |
title | 0 | 1.00 | 1 | 191 | 0 | 29563 | 0 |
original_language | 0 | 1.00 | 2 | 2 | 0 | 97 | 0 |
overview | 1286 | 0.96 | 1 | 1000 | 0 | 31020 | 0 |
tagline | 19835 | 0.39 | 1 | 237 | 0 | 12513 | 0 |
poster_path | 4474 | 0.86 | 30 | 32 | 0 | 28048 | 0 |
status | 0 | 1.00 | 7 | 15 | 0 | 4 | 0 |
backdrop_path | 18995 | 0.42 | 29 | 32 | 0 | 13536 | 0 |
genre_names | 0 | 1.00 | 6 | 144 | 0 | 772 | 0 |
collection_name | 30234 | 0.07 | 4 | 56 | 0 | 815 | 0 |
Variable type: Date
skim_variable | n_missing | complete_rate | min | max | median | n_unique |
---|---|---|---|---|---|---|
release_date | 0 | 1 | 1950-01-01 | 2022-12-31 | 2012-12-09 | 10999 |
Variable type: logical
skim_variable | n_missing | complete_rate | mean | count |
---|---|---|---|---|
adult | 0 | 1 | 0 | FAL: 32540 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
id | 0 | 1.00 | 445910.83 | 305744.67 | 17 | 146494.8 | 426521.00 | 707534.00 | 1033095.00 | ▇▆▆▅▅ |
popularity | 0 | 1.00 | 4.01 | 37.51 | 0 | 0.6 | 0.84 | 2.24 | 5088.58 | ▇▁▁▁▁ |
vote_count | 0 | 1.00 | 62.69 | 420.89 | 0 | 0.0 | 2.00 | 11.00 | 16900.00 | ▇▁▁▁▁ |
vote_average | 0 | 1.00 | 3.34 | 2.88 | 0 | 0.0 | 4.00 | 5.70 | 10.00 | ▇▂▆▃▁ |
budget | 0 | 1.00 | 543126.59 | 4542667.81 | 0 | 0.0 | 0.00 | 0.00 | 200000000.00 | ▇▁▁▁▁ |
revenue | 0 | 1.00 | 1349746.73 | 14430479.15 | 0 | 0.0 | 0.00 | 0.00 | 701842551.00 | ▇▁▁▁▁ |
runtime | 0 | 1.00 | 62.14 | 41.00 | 0 | 14.0 | 80.00 | 91.00 | 683.00 | ▇▁▁▁▁ |
collection | 30234 | 0.07 | 481534.88 | 324498.16 | 656 | 155421.0 | 471259.00 | 759067.25 | 1033032.00 | ▇▅▅▅▅ |
data <- horror_movies %>%
# Log transform vote_average
mutate(vote_average = log1p(vote_average)) %>% # for zeroes: log1p(x) is the same as log(x+1)
# Treat missing values in overview
filter(!is.na(overview), vote_count != 0) %>%
# Treat multiple categories in genre_names
separate_rows(genre_names, sep = ", ") %>%
filter(status == "Released") %>%
select(id, vote_average, genre_names, overview, runtime)
# data <- data %>% sample_n(100)
Check list
data %>% glimpse()
data %>% skimr::skim()
data %>% select(id) %>% explore()
data %>% describe_all()
data %>% describe_cat(genre_names)
data %>% select(-id) %>% explore_all(target = vote_average)
data %>%
ggplot(aes(vote_average)) +
geom_histogram()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
data %>% count(genre_names, sort = T)
## # A tibble: 19 × 2
## genre_names n
## <chr> <int>
## 1 Horror 20493
## 2 Thriller 5899
## 3 Comedy 3118
## 4 Drama 2835
## 5 Mystery 2279
## 6 Science Fiction 2189
## 7 Action 1539
## 8 Fantasy 1485
## 9 Crime 895
## 10 TV Movie 533
## 11 Adventure 504
## 12 Animation 500
## 13 Romance 417
## 14 Documentary 224
## 15 Music 164
## 16 Family 124
## 17 Western 102
## 18 History 76
## 19 War 68
data %>%
group_by(genre_names) %>%
summarise(
n = n(),
avg_vote_average = mean(vote_average)
) %>%
ungroup() %>%
ggplot(aes(n, avg_vote_average)) +
# geom_point() +
geom_text(aes(label = genre_names), check_overlap = TRUE) +
geom_hline(yintercept = mean(data$vote_average),
linewidth = 2, linetype = "dotted", color = "darkgray") +
scale_x_log10()
spacy_initialize(model = "en_core_web_sm")
# process documents and obtain a data.table
tidy_data <- data %>%
# Parse overview
mutate(overview_parsed = map(.x = .$overview, .f = ~spacy_parse(.x))) %>%
unnest(overview_parsed) %>%
# Select nouns and adjectives
filter(pos %in% c("ADJ", "NOUN"))
data_filtered <- tidy_data %>%
filter(str_detect(lemma, regex("[a-z]", ignore_case = TRUE))) %>%
group_by(lemma) %>%
summarise(
n = n(),
avg_vote_average = mean(vote_average)
) %>%
filter(n > 150)
data_filtered %>%
ggplot(aes(n, avg_vote_average)) +
# geom_point() +
geom_text(aes(label = lemma), check_overlap = TRUE) +
geom_hline(yintercept = mean(data_filtered$avg_vote_average),
linetype = "dotted", linewidth = 2, color = "darkgray") +
scale_x_log10()
data %>%
ggplot(aes(runtime, vote_average)) +
geom_jitter(alpha = 0.3)
set.seed(123)
data_split <- initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)
set.seed(234)
data_folds <- rsample::vfold_cv(data_train)
data_folds
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [29324/3259]> Fold01
## 2 <split [29324/3259]> Fold02
## 3 <split [29324/3259]> Fold03
## 4 <split [29325/3258]> Fold04
## 5 <split [29325/3258]> Fold05
## 6 <split [29325/3258]> Fold06
## 7 <split [29325/3258]> Fold07
## 8 <split [29325/3258]> Fold08
## 9 <split [29325/3258]> Fold09
## 10 <split [29325/3258]> Fold10
library(usemodels)
use_xgboost(vote_average ~ ., data = data_train)
## xgboost_recipe <-
## recipe(formula = vote_average ~ ., data = data_train) %>%
## step_zv(all_predictors())
##
## xgboost_spec <-
## boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
## loss_reduction = tune(), sample_size = tune()) %>%
## set_mode("classification") %>%
## set_engine("xgboost")
##
## xgboost_workflow <-
## workflow() %>%
## add_recipe(xgboost_recipe) %>%
## add_model(xgboost_spec)
##
## set.seed(63112)
## xgboost_tune <-
## tune_grid(xgboost_workflow, resamples = stop("add your rsample object"), grid = stop("add number of candidate points"))
# xgboost_recipe <-
# recipe(formula = vote_average ~ ., data = data_train) %>%
# recipes::update_role(post_id, new_role = "id") %>%
# step_tokenize(overview) %>%
# step_tokenfilter(overview, max_tokens = 100) %>%
# step_tfidf(overview) %>%
# step_other(nhood) %>%
# step_dummy(nhood) %>%
# step_log(vote_average, sqft, baths) # To transform variables with skewed distribution
xgboost_recipe <-
recipe(formula = vote_average ~ ., data = data_train) %>%
recipes::update_role(id, new_role = "id") %>%
step_tokenize(overview, engine = "spacyr") %>%
step_lemma(overview) %>%
step_pos_filter(overview, keep_tags = c("NOUN", "ADJ"))%>%
step_tokenfilter(overview, max_tokens = 100) %>%
step_tfidf(overview) %>%
step_dummy(genre_names) %>%
step_YeoJohnson(runtime) # for log-transformation for a variable with zeroes
xgboost_recipe %>% prep() %>% bake(new_data = NULL) %>% glimpse()
## Found 'spacy_condaenv'. spacyr will use this environment
## successfully initialized (spaCy Version: 3.1.3, language model: en_core_web_sm)
## (python options: type = "condaenv", value = "spacy_condaenv")
## Rows: 32,583
## Columns: 121
## $ id <dbl> 314405, 147061, 166752, 309079, 27297, 427…
## $ runtime <dbl> 141.646911, 74.415520, 81.854221, 86.10912…
## $ vote_average <dbl> 2.001480, 2.197225, 1.064711, 1.791759, 1.…
## $ `tfidf_overview_-` <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_accident <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_alien <dbl> 0.0000000, 0.4878952, 0.0000000, 0.0000000…
## $ tfidf_overview_ancient <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_bad <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_beautiful <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_blood <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4825536…
## $ tfidf_overview_body <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_boy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_brother <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_child <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_city <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_couple <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_creature <dbl> 0.0000000, 0.4255323, 0.0000000, 0.0000000…
## $ tfidf_overview_dark <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_daughter <dbl> 0.6651315, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_day <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_dead <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_deadly <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_death <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_demon <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_dream <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_event <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_evil <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_family <dbl> 0.0000000, 0.0000000, 0.5243188, 0.0000000…
## $ tfidf_overview_father <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_film <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_first <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_force <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_friend <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_game <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_ghost <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_girl <dbl> 0.5300025, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_good <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_group <dbl> 0.0000000, 0.3215255, 0.0000000, 0.0000000…
## $ tfidf_overview_help <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_high <dbl> 0.0000000, 0.4636848, 0.0000000, 0.0000000…
## $ tfidf_overview_home <dbl> 0.00000, 0.00000, 0.00000, 0.00000, 0.0000…
## $ tfidf_overview_horror <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4010642…
## $ tfidf_overview_house <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_human <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_husband <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_killer <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_life <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_little <dbl> 0.000000, 0.465336, 0.000000, 0.000000, 0.…
## $ tfidf_overview_local <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_love <dbl> 0.6734058, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_man <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_monster <dbl> 0.0000000, 0.0000000, 0.6765847, 0.0000000…
## $ tfidf_overview_more <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mother <dbl> 0.6504896, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_movie <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_murder <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_mysterious <dbl> 0.5218673, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_new <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_night <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_nightmare <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4943259…
## $ tfidf_overview_old <dbl> 0.0000000, 0.3375260, 0.0000000, 0.0000000…
## $ tfidf_overview_only <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_order <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_other <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_own <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4472877…
## $ tfidf_overview_past <dbl> 0.0000000, 0.0000000, 0.0000000, 0.4853397…
## $ tfidf_overview_people <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_place <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_police <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_power <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_real <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_remote <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_revenge <dbl> 0.0000000, 0.0000000, 0.7012580, 0.0000000…
## $ tfidf_overview_school <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_scientist <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_secret <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_serial <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_series <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_short <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_sister <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_small <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_son <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_spirit <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_story <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_strange <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_student <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.…
## $ tfidf_overview_supernatural <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_tale <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_team <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_thing <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_time <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_town <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_vampire <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_victim <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_village <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_way <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_wife <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ tfidf_overview_woman <dbl> 0.0000000, 0.0000000, 0.4501321, 0.6430458…
## $ tfidf_overview_wood <dbl> 0.000000, 0.938040, 0.000000, 0.000000, 0.…
## $ tfidf_overview_world <dbl> 0.0000000, 0.0000000, 0.5532448, 0.0000000…
## $ tfidf_overview_year <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_young <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000…
## $ tfidf_overview_zombie <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Adventure <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Animation <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Comedy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Crime <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Documentary <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Drama <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Family <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Fantasy <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_History <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Horror <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, …
## $ genre_names_Music <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Mystery <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, …
## $ genre_names_Romance <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Science.Fiction <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, …
## $ genre_names_Thriller <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ genre_names_TV.Movie <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_War <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ genre_names_Western <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune()) %>%
set_mode("regression") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(xgboost_spec)
set.seed(15793)
doParallel::registerDoParallel()
xgboost_tune <-
tune_grid(xgboost_workflow,
resamples = data_folds,
grid = 5)
show_best(xgboost_tune, metric = "rmse")
## # A tibble: 5 × 8
## trees min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1593 4 rmse standard 0.246 10 0.00186 Preprocessor1_Model1
## 2 1637 20 rmse standard 0.262 10 0.00151 Preprocessor1_Model3
## 3 923 38 rmse standard 0.277 10 0.00149 Preprocessor1_Model5
## 4 458 31 rmse standard 0.283 10 0.00155 Preprocessor1_Model4
## 5 83 14 rmse standard 0.296 10 0.00144 Preprocessor1_Model2
# How did all the possible parameter combinations do?
autoplot(xgboost_tune)
We can finalize our random forest workflow with the best performing parameters.
final_rf <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, "rmse"))
The function last_fit() fits this finalized random forest one last time to the training data and evaluates one last time on the testing data.
data_fit <- last_fit(final_rf, data_split)
data_fit
## # Resampling results
## # Manual resampling
## # A tibble: 1 × 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [32583/10861]> train/test sp… <tibble> <tibble> <tibble> <workflow>
collect_metrics(data_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 0.229 Preprocessor1_Model1
## 2 rsq standard 0.483 Preprocessor1_Model1
collect_predictions(data_fit)
## # A tibble: 10,861 × 5
## id .pred .row vote_average .config
## <chr> <dbl> <int> <dbl> <chr>
## 1 train/test split 1.95 1 2.07 Preprocessor1_Model1
## 2 train/test split 2.08 4 2.09 Preprocessor1_Model1
## 3 train/test split 1.88 12 1.92 Preprocessor1_Model1
## 4 train/test split 1.80 15 2.08 Preprocessor1_Model1
## 5 train/test split 1.97 23 2.01 Preprocessor1_Model1
## 6 train/test split 1.99 27 2.05 Preprocessor1_Model1
## 7 train/test split 2.02 36 2.03 Preprocessor1_Model1
## 8 train/test split 1.79 39 2.07 Preprocessor1_Model1
## 9 train/test split 1.53 42 2.07 Preprocessor1_Model1
## 10 train/test split 1.94 44 2.07 Preprocessor1_Model1
## # ℹ 10,851 more rows
collect_predictions(data_fit) %>%
ggplot(aes(vote_average, .pred)) +
geom_point(alpha = 0.5, fill = "midnightblue") +
geom_abline(lty = 2, color = "gray50") +
coord_fixed()
data_fit %>%
extract_workflow() %>%
predict(data_test[1,])
## # A tibble: 1 × 1
## .pred
## <dbl>
## 1 1.95
library(vip)
imp_spec <- xgboost_spec %>%
tune::finalize_model(tune::select_best(xgboost_tune)) %>%
parsnip::set_engine("xgboost", importance = "permutation")
workflows::workflow() %>%
add_recipe(xgboost_recipe) %>%
add_model(imp_spec) %>%
fit(data_train) %>%
workflows::pull_workflow_fit() %>%
vip()
## [17:10:42] WARNING: amalgamation/../src/learner.cc:627:
## Parameters: { "importance" } might not be used.
##
## This could be a false alarm, with some parameters getting used by language bindings but
## then being mistakenly passed down to XGBoost core, or some parameter actually being used
## but getting flagged wrongly here. Please open an issue if you find any such cases.