library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.1     ✔ stringr   1.5.2
## ✔ ggplot2   4.0.0     ✔ tibble    3.3.0
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.1.0     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(readr)
library(skimr)
## Warning: package 'skimr' was built under R version 4.5.2
library(correlationfunnel)
## Warning: package 'correlationfunnel' was built under R version 4.5.2
## ══ Using correlationfunnel? ════════════════════════════════════════════════════
## You might also be interested in applied data science training for business.
## </> Learn more at - www.business-science.io </>
library(tidymodels)
## Warning: package 'tidymodels' was built under R version 4.5.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom        1.0.10     ✔ rsample      1.3.2 
## ✔ dials        1.4.2      ✔ tailor       0.1.0 
## ✔ infer        1.1.0      ✔ tune         2.0.1 
## ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
## ✔ parsnip      1.4.1      ✔ workflowsets 1.1.1 
## ✔ recipes      1.3.1      ✔ yardstick    1.3.2
## Warning: package 'dials' was built under R version 4.5.2
## Warning: package 'infer' was built under R version 4.5.2
## Warning: package 'modeldata' was built under R version 4.5.2
## Warning: package 'parsnip' was built under R version 4.5.2
## Warning: package 'recipes' was built under R version 4.5.2
## Warning: package 'rsample' was built under R version 4.5.2
## Warning: package 'tailor' was built under R version 4.5.2
## Warning: package 'tune' was built under R version 4.5.2
## Warning: package 'workflows' was built under R version 4.5.2
## Warning: package 'workflowsets' was built under R version 4.5.2
## Warning: package 'yardstick' was built under R version 4.5.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
# Load Data
youtube <- readr::read_csv(
  "https://raw.githubusercontent.com/rfordatascience/tidytuesday/main/data/2021/2021-03-02/youtube.csv"
)
## Rows: 247 Columns: 25
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): brand, superbowl_ads_dot_com_url, youtube_url, id, kind, etag, ti...
## dbl   (7): year, view_count, like_count, dislike_count, favorite_count, comm...
## lgl   (7): funny, show_product_quickly, patriotic, celebrity, danger, animal...
## dttm  (1): published_at
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# Clean Data
data <- youtube %>%
  select(-superbowl_ads_dot_com_url,
         -youtube_url,
         -thumbnail,
         -title,
         -description,
         -etag,
         -id,
         -kind,
         -channel_title) %>%
  na.omit() %>%
  filter(view_count <= quantile(view_count, 0.99, na.rm = TRUE)) %>%
  mutate(
    view_count    = log(view_count),
    like_count    = log(like_count + 1),
    dislike_count = log(dislike_count + 1),
    comment_count = log(comment_count + 1),
    high_views    = view_count > log(1000000)
  )

# Explore
ggplot(data, aes(year, view_count)) + geom_point()

ggplot(data, aes(like_count, year)) + geom_point()

# Correlation Funnel
data_binarized_tbl <- data %>%
  select(-brand, -published_at) %>%
  binarize()

target_col <- names(data_binarized_tbl)[
  stringr::str_detect(names(data_binarized_tbl), "^high_views") &
  stringr::str_detect(names(data_binarized_tbl), "__1$|_1$")
][1]

data_corr_tbl <- data_binarized_tbl %>%
  correlate(!!rlang::sym(target_col))

plot_correlation_funnel(data_corr_tbl)
## Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
##   Please report the issue at
##   <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: The `size` argument of `element_rect()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## ℹ The deprecated feature was likely used in the correlationfunnel package.
##   Please report the issue at
##   <https://github.com/business-science/correlationfunnel/issues>.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

# Train/Test Split
set.seed(1234)
data_split <- initial_split(data)
data_train <- training(data_split)
data_test  <- testing(data_split)

# Recipe
xgboost_recipe <- recipe(view_count ~ ., data = data_train) %>%
  step_rm(brand, published_at, high_views) %>%
  step_mutate_at(all_logical_predictors(), fn = as.numeric) %>%
  step_zv(all_predictors())

# Model
xgboost_spec <- boost_tree(
  trees = 500,
  tree_depth = 3,
  min_n = 5,
  learn_rate = 0.05,
  sample_size = 0.8
) %>%
  set_mode("regression") %>%
  set_engine("xgboost")

# Workflow + Fit
xgboost_workflow <- workflow() %>%
  add_recipe(xgboost_recipe) %>%
  add_model(xgboost_spec)

xgboost_fit <- fit(xgboost_workflow, data = data_train)

# Evaluate
test_preds <- predict(xgboost_fit, data_test) %>%
  bind_cols(data_test)

rmse(test_preds, truth = view_count, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       0.720
rsq(test_preds, truth = view_count, estimate = .pred)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.922
ggplot(test_preds, aes(view_count, .pred)) +
  geom_point() +
  geom_abline()