suppressMessages(library(tidyverse))
suppressMessages(library(tidymodels))
scooby_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-07-13/scoobydoo.csv")
scooby_raw %>%
filter(monster_amount > 0) %>%
count(monster_real)
## # A tibble: 2 x 2
## monster_real n
## <chr> <int>
## 1 FALSE 404
## 2 TRUE 112
theme_set(theme_light())
scooby_raw %>%
filter(monster_amount > 0) %>%
count(
year_aired = 10 * ((lubridate::year(date_aired) + 1) %/% 10),
monster_real
) %>%
mutate(year_aired = factor(year_aired)) %>%
ggplot(aes(year_aired, n, fill = monster_real)) +
geom_col(position = position_dodge(preserve = "single"), alpha = 0.8) +
labs(x = "Date aired", y = "Monsters per decade", fill = "Real monster?")

scooby_raw %>%
filter(monster_amount > 0) %>%
mutate(imdb = parse_number(imdb)) %>%
ggplot(aes(imdb, after_stat(density), fill = monster_real)) +
geom_histogram(position = "identity", alpha = 0.5) +
labs(x = "IMDB rating", y = "Density", fill = "Real monster?")

set.seed(123)
scooby_split <-
scooby_raw %>%
mutate(
imdb = parse_number(imdb),
year_aired = lubridate::year(date_aired)
) %>%
filter(monster_amount > 0, !is.na(imdb)) %>%
mutate(
monster_real = case_when(
monster_real == "FALSE" ~ "fake",
TRUE ~ "real"
),
monster_real = as.factor(monster_real)
) %>%
select(year_aired, imdb, monster_real, title) %>%
initial_split(strata = monster_real)
scooby_train <- training(scooby_split)
scooby_test <- testing(scooby_split)
set.seed(234)
scooby_fold <- bootstraps(scooby_train, strata = monster_real)
scooby_fold
## # Bootstrap sampling using stratification
## # A tibble: 25 x 2
## splits id
## <list> <chr>
## 1 <split [375/133]> Bootstrap01
## 2 <split [375/144]> Bootstrap02
## 3 <split [375/140]> Bootstrap03
## 4 <split [375/132]> Bootstrap04
## 5 <split [375/139]> Bootstrap05
## 6 <split [375/134]> Bootstrap06
## 7 <split [375/146]> Bootstrap07
## 8 <split [375/132]> Bootstrap08
## 9 <split [375/143]> Bootstrap09
## 10 <split [375/143]> Bootstrap10
## # ... with 15 more rows
tree_spec <-
decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_mode("classification") %>%
set_engine("rpart")
tree_spec
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = tune()
## tree_depth = tune()
## min_n = tune()
##
## Computational engine: rpart
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 4)
tree_grid
## # A tibble: 64 x 3
## cost_complexity tree_depth min_n
## <dbl> <int> <int>
## 1 0.0000000001 1 2
## 2 0.0000001 1 2
## 3 0.0001 1 2
## 4 0.1 1 2
## 5 0.0000000001 5 2
## 6 0.0000001 5 2
## 7 0.0001 5 2
## 8 0.1 5 2
## 9 0.0000000001 10 2
## 10 0.0000001 10 2
## # ... with 54 more rows
doParallel::registerDoParallel(cores = 2)
set.seed(345)
tree_rs <-
tune_grid(
tree_spec,
monster_real ~ year_aired + imdb,
resamples = scooby_fold,
grid = tree_grid,
metrics = metric_set(accuracy, roc_auc, sensitivity, specificity)
)
tree_rs
## # Tuning results
## # Bootstrap sampling using stratification
## # A tibble: 25 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [375/133]> Bootstrap01 <tibble [256 x 7]> <tibble [0 x 1]>
## 2 <split [375/144]> Bootstrap02 <tibble [256 x 7]> <tibble [0 x 1]>
## 3 <split [375/140]> Bootstrap03 <tibble [256 x 7]> <tibble [0 x 1]>
## 4 <split [375/132]> Bootstrap04 <tibble [256 x 7]> <tibble [0 x 1]>
## 5 <split [375/139]> Bootstrap05 <tibble [256 x 7]> <tibble [0 x 1]>
## 6 <split [375/134]> Bootstrap06 <tibble [256 x 7]> <tibble [0 x 1]>
## 7 <split [375/146]> Bootstrap07 <tibble [256 x 7]> <tibble [0 x 1]>
## 8 <split [375/132]> Bootstrap08 <tibble [256 x 7]> <tibble [0 x 1]>
## 9 <split [375/143]> Bootstrap09 <tibble [256 x 7]> <tibble [0 x 1]>
## 10 <split [375/143]> Bootstrap10 <tibble [256 x 7]> <tibble [0 x 1]>
## # ... with 15 more rows
show_best(tree_rs)
## # A tibble: 5 x 9
## cost_complexity tree_depth min_n .metric .estimator mean n std_err
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
## 1 0.0000000001 10 2 accuracy binary 0.872 25 0.00481
## 2 0.0000001 10 2 accuracy binary 0.872 25 0.00481
## 3 0.0001 10 2 accuracy binary 0.872 25 0.00481
## 4 0.0000000001 15 2 accuracy binary 0.871 25 0.00456
## 5 0.0000001 15 2 accuracy binary 0.871 25 0.00456
## # ... with 1 more variable: .config <chr>
autoplot(tree_rs) + theme_light(base_family = "IBMPlexSans")

simpler_tree <- select_by_one_std_err(tree_rs,
-cost_complexity,
metric = "roc_auc"
)
final_tree <- finalize_model(tree_spec, simpler_tree)
final_fit <- fit(final_tree, monster_real ~ year_aired + imdb, scooby_train)
final_rs <- last_fit(final_tree, monster_real ~ year_aired + imdb, scooby_split)
library(parttree)
scooby_train %>%
ggplot(aes(imdb, year_aired)) +
geom_parttree(data = final_fit, aes(fill = monster_real), alpha = 0.2) +
geom_jitter(alpha = 0.7, width = 0.05, height = 0.2, aes(color = monster_real))
