library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.3     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.3     ✔ tibble    3.2.1
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
horror_movies <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-11-01/horror_movies.csv')
## Rows: 32540 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): original_title, title, original_language, overview, tagline, post...
## dbl   (8): id, popularity, vote_count, vote_average, budget, revenue, runtim...
## lgl   (1): adult
## date  (1): release_date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
horror_movies %>%
  ggplot(aes(vote_average)) +
  geom_histogram(bins = 15)

Explore Data

library(tidytext)

tidy_horror <- horror_movies %>%
    unnest_tokens(word, genre_names)

tidy_horror %>% 
    count(word, sort = TRUE)
## # A tibble: 21 × 2
##    word         n
##    <chr>    <int>
##  1 horror   32543
##  2 thriller  7680
##  3 comedy    4963
##  4 drama     4271
##  5 mystery   3138
##  6 fiction   2714
##  7 science   2714
##  8 fantasy   2195
##  9 action    1966
## 10 crime     1153
## # ℹ 11 more rows
tidy_horror %>%
  group_by(word) %>%
  summarise(
    n = n(),
    vote_average = mean(vote_average)
  ) %>%
  ggplot(aes(n, vote_average)) +
  geom_hline(
    yintercept = mean(horror_movies$vote_average), lty = 2,
    color = "gray50", size = 1.5
  ) +
  geom_jitter(color = "midnightblue", alpha = 0.7) +
  geom_text(aes(label = word),
    check_overlap = TRUE, family = "sans",
    vjust = "top", hjust = "left"
  ) +
  scale_x_log10()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

##. Build Models

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom        1.0.5     ✔ rsample      1.2.0
## ✔ dials        1.2.0     ✔ tune         1.1.2
## ✔ infer        1.0.5     ✔ workflows    1.1.3
## ✔ modeldata    1.2.0     ✔ workflowsets 1.0.1
## ✔ parsnip      1.1.1     ✔ yardstick    1.2.0
## ✔ recipes      1.0.8
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Use suppressPackageStartupMessages() to eliminate package startup messages
set.seed(123)
horror_split <- initial_split(horror_movies, strata = vote_average)
horror_train <- training(horror_split)
horror_test <- testing(horror_split)

set.seed(234)
horror_folds <- vfold_cv(horror_train, strata = vote_average)
horror_folds
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [21962/2442]> Fold01
##  2 <split [21962/2442]> Fold02
##  3 <split [21963/2441]> Fold03
##  4 <split [21963/2441]> Fold04
##  5 <split [21964/2440]> Fold05
##  6 <split [21964/2440]> Fold06
##  7 <split [21964/2440]> Fold07
##  8 <split [21964/2440]> Fold08
##  9 <split [21965/2439]> Fold09
## 10 <split [21965/2439]> Fold10

Preprocessing

library(textrecipes)

horror_rec <- 
    recipe(vote_average ~ genre_names, data = horror_train) %>%
    step_tokenize(genre_names) %>%
    step_tokenfilter(genre_names, max_tokens = 100) %>%
    step_tf(genre_names) 

Model Spec

ranger_spec <-
  rand_forest(trees = 500) %>%
  set_mode("regression")

ranger_spec
## Random Forest Model Specification (regression)
## 
## Main Arguments:
##   trees = 500
## 
## Computational engine: ranger
svm_spec <- 
    svm_linear() %>%
    set_mode("regression")
svm_spec
## Linear Support Vector Machine Model Specification (regression)
## 
## Computational engine: LiblineaR

Workflow

ranger_wf <- workflow(horror_rec, ranger_spec)
svm_wf <- workflow(horror_rec, svm_spec)

Evaluate Models

doParallel::registerDoParallel()
contrl_preds <- control_resamples(save_pred = TRUE)

svm_rs <- fit_resamples(
  svm_wf,
  resamples = horror_folds,
  control = contrl_preds
)

ranger_rs <- fit_resamples(
  ranger_wf,
  resamples = horror_folds,
  control = contrl_preds
)
collect_metrics(svm_rs)
## # A tibble: 2 × 6
##   .metric .estimator   mean     n std_err .config             
##   <chr>   <chr>       <dbl> <int>   <dbl> <chr>               
## 1 rmse    standard   2.81      10 0.00992 Preprocessor1_Model1
## 2 rsq     standard   0.0453    10 0.00300 Preprocessor1_Model1
collect_metrics(ranger_rs)
## # A tibble: 2 × 6
##   .metric .estimator   mean     n std_err .config             
##   <chr>   <chr>       <dbl> <int>   <dbl> <chr>               
## 1 rmse    standard   2.81      10 0.00939 Preprocessor1_Model1
## 2 rsq     standard   0.0500    10 0.00278 Preprocessor1_Model1
bind_rows(
  collect_predictions(svm_rs) %>%
    mutate(mod = "SVM"),
  collect_predictions(ranger_rs) %>%
    mutate(mod = "ranger")
) %>%
    ggplot(aes(vote_average, .pred, color = id)) +
    geom_abline(lty = 2, color = "gray50", size = 1.2)  +
    geom_jitter(width = 0.5, alpha = 0.5) +
    facet_wrap(vars(mod)) +
    coord_fixed()

final_fitted <- last_fit(svm_wf, horror_split)
## → A | warning: max_tokens was set to '100', but only 21 was available and selected.
## 
There were issues with some computations   A: x1

There were issues with some computations   A: x1
collect_metrics(final_fitted) 
## # A tibble: 2 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard      2.80   Preprocessor1_Model1
## 2 rsq     standard      0.0468 Preprocessor1_Model1
final_wf <- extract_workflow(final_fitted)
predict(final_wf, horror_test[55,])
## # A tibble: 1 × 1
##   .pred
##   <dbl>
## 1  2.77
extract_workflow(final_fitted) %>%
    tidy() %>% 
    group_by(estimate > 0) %>%
    slice_max(abs(estimate), n = 10) %>%
    ungroup() %>%
    mutate(term = str_remove(term, "tf_genre_names")) %>%
    ggplot(aes(estimate,fct_reorder(term, estimate),
               fill = estimate > 0)) +
    geom_col(alpha = 0.8)