suppressMessages(library(tidyverse))
suppressMessages(library(tidymodels))

scooby_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-07-13/scoobydoo.csv")

scooby_raw %>%
  filter(monster_amount > 0) %>%
  count(monster_real)
## # A tibble: 2 x 2
##   monster_real     n
##   <chr>        <int>
## 1 FALSE          404
## 2 TRUE           112
theme_set(theme_light())

scooby_raw %>%
  filter(monster_amount > 0) %>%
  count(
    year_aired = 10 * ((lubridate::year(date_aired) + 1) %/% 10),
    monster_real
  ) %>%
  mutate(year_aired = factor(year_aired)) %>%
  ggplot(aes(year_aired, n, fill = monster_real)) +
  geom_col(position = position_dodge(preserve = "single"), alpha = 0.8) +
  labs(x = "Date aired", y = "Monsters per decade", fill = "Real monster?")

scooby_raw %>%
  filter(monster_amount > 0) %>%
  mutate(imdb = parse_number(imdb)) %>%
  ggplot(aes(imdb, after_stat(density), fill = monster_real)) +
  geom_histogram(position = "identity", alpha = 0.5) +
  labs(x = "IMDB rating", y = "Density", fill = "Real monster?")

set.seed(123)

scooby_split <- 
               scooby_raw %>%
               mutate(
                              imdb = parse_number(imdb),
                              year_aired = lubridate::year(date_aired)
               ) %>%
               filter(monster_amount > 0, !is.na(imdb)) %>%
               mutate(
                              monster_real = case_when(
                                             monster_real == "FALSE" ~ "fake",
                                             TRUE ~ "real"
                              ), 
                              monster_real = as.factor(monster_real)
               ) %>%
               select(year_aired, imdb, monster_real, title) %>%
               initial_split(strata = monster_real)

scooby_train <- training(scooby_split)
scooby_test <- testing(scooby_split)


set.seed(234)
scooby_fold <- bootstraps(scooby_train, strata = monster_real)
scooby_fold
## # Bootstrap sampling using stratification 
## # A tibble: 25 x 2
##    splits            id         
##    <list>            <chr>      
##  1 <split [375/133]> Bootstrap01
##  2 <split [375/144]> Bootstrap02
##  3 <split [375/140]> Bootstrap03
##  4 <split [375/132]> Bootstrap04
##  5 <split [375/139]> Bootstrap05
##  6 <split [375/134]> Bootstrap06
##  7 <split [375/146]> Bootstrap07
##  8 <split [375/132]> Bootstrap08
##  9 <split [375/143]> Bootstrap09
## 10 <split [375/143]> Bootstrap10
## # ... with 15 more rows
tree_spec <-
               decision_tree(
                              cost_complexity = tune(),
                              tree_depth = tune(),
                              min_n = tune()
               ) %>%
               set_mode("classification") %>%
               set_engine("rpart")

tree_spec
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
##   min_n = tune()
## 
## Computational engine: rpart
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 4)

tree_grid
## # A tibble: 64 x 3
##    cost_complexity tree_depth min_n
##              <dbl>      <int> <int>
##  1    0.0000000001          1     2
##  2    0.0000001             1     2
##  3    0.0001                1     2
##  4    0.1                   1     2
##  5    0.0000000001          5     2
##  6    0.0000001             5     2
##  7    0.0001                5     2
##  8    0.1                   5     2
##  9    0.0000000001         10     2
## 10    0.0000001            10     2
## # ... with 54 more rows
doParallel::registerDoParallel(cores = 2)

set.seed(345)

tree_rs <- 
               tune_grid(
                              tree_spec,
                              monster_real ~ year_aired + imdb,
                              resamples = scooby_fold,
                              grid = tree_grid,
                              metrics = metric_set(accuracy, roc_auc, sensitivity, specificity)
               )

tree_rs
## # Tuning results
## # Bootstrap sampling using stratification 
## # A tibble: 25 x 4
##    splits            id          .metrics           .notes          
##    <list>            <chr>       <list>             <list>          
##  1 <split [375/133]> Bootstrap01 <tibble [256 x 7]> <tibble [0 x 1]>
##  2 <split [375/144]> Bootstrap02 <tibble [256 x 7]> <tibble [0 x 1]>
##  3 <split [375/140]> Bootstrap03 <tibble [256 x 7]> <tibble [0 x 1]>
##  4 <split [375/132]> Bootstrap04 <tibble [256 x 7]> <tibble [0 x 1]>
##  5 <split [375/139]> Bootstrap05 <tibble [256 x 7]> <tibble [0 x 1]>
##  6 <split [375/134]> Bootstrap06 <tibble [256 x 7]> <tibble [0 x 1]>
##  7 <split [375/146]> Bootstrap07 <tibble [256 x 7]> <tibble [0 x 1]>
##  8 <split [375/132]> Bootstrap08 <tibble [256 x 7]> <tibble [0 x 1]>
##  9 <split [375/143]> Bootstrap09 <tibble [256 x 7]> <tibble [0 x 1]>
## 10 <split [375/143]> Bootstrap10 <tibble [256 x 7]> <tibble [0 x 1]>
## # ... with 15 more rows
show_best(tree_rs)
## # A tibble: 5 x 9
##   cost_complexity tree_depth min_n .metric  .estimator  mean     n std_err
##             <dbl>      <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
## 1    0.0000000001         10     2 accuracy binary     0.872    25 0.00481
## 2    0.0000001            10     2 accuracy binary     0.872    25 0.00481
## 3    0.0001               10     2 accuracy binary     0.872    25 0.00481
## 4    0.0000000001         15     2 accuracy binary     0.871    25 0.00456
## 5    0.0000001            15     2 accuracy binary     0.871    25 0.00456
## # ... with 1 more variable: .config <chr>
autoplot(tree_rs) + theme_light(base_family = "IBMPlexSans")

simpler_tree <- select_by_one_std_err(tree_rs,
  -cost_complexity,
  metric = "roc_auc"
)
final_tree <- finalize_model(tree_spec, simpler_tree)

final_fit <- fit(final_tree, monster_real ~ year_aired + imdb, scooby_train)

final_rs <- last_fit(final_tree, monster_real ~ year_aired + imdb, scooby_split)
library(parttree)

scooby_train %>% 
               ggplot(aes(imdb, year_aired)) +
               geom_parttree(data = final_fit, aes(fill = monster_real), alpha = 0.2) +
               geom_jitter(alpha = 0.7, width = 0.05, height = 0.2, aes(color = monster_real))