Goal: Predict classification of a Bigfoot report

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(correlationfunnel)
## ══ correlationfunnel Tip #3 ════════════════════════════════════════════════════
## Using `binarize()` with data containing many columns or many rows can increase dimensionality substantially.
## Try subsetting your data column-wise or row-wise to avoid creating too many columns.
## You can always make a big problem smaller by sampling. :)
data <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-09-13/bigfoot.csv')
## Rows: 5021 Columns: 28
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (10): observed, location_details, county, state, season, title, classif...
## dbl  (17): latitude, longitude, number, temperature_high, temperature_mid, t...
## date  (1): date
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Explore data

skimr::skim(data)
Data summary
Name data
Number of rows 5021
Number of columns 28
_______________________
Column type frequency:
character 10
Date 1
numeric 17
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
observed 38 0.99 1 30374 0 4982 0
location_details 758 0.85 1 3876 0 4196 0
county 0 1.00 10 30 0 1037 0
state 0 1.00 4 14 0 49 0
season 0 1.00 4 7 0 5 0
title 976 0.81 23 235 0 4045 0
classification 0 1.00 7 7 0 3 0
geohash 976 0.81 10 10 0 4001 0
precip_type 3298 0.34 4 4 0 2 0
summary 1655 0.67 15 103 0 321 0

Variable type: Date

skim_variable n_missing complete_rate min max median n_unique
date 976 0.81 1869-11-10 2021-11-27 2003-11-16 3111

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
latitude 976 0.81 39.36 5.68 25.14 35.35 39.30 43.93 64.89 ▂▇▆▁▁
longitude 976 0.81 -97.42 16.73 -167.13 -117.06 -91.77 -83.07 -68.23 ▁▁▆▆▇
number 0 1.00 21520.23 19259.15 60.00 4595.00 15473.00 33979.00 71997.00 ▇▃▂▂▁
temperature_high 1683 0.66 67.12 17.78 -0.62 55.14 69.97 81.10 106.51 ▁▂▅▇▃
temperature_mid 1835 0.63 57.84 16.40 -8.46 46.77 59.36 70.38 94.03 ▁▁▆▇▃
temperature_low 1832 0.64 48.64 15.94 -22.78 37.50 49.40 60.66 84.34 ▁▁▅▇▃
dew_point 1648 0.67 46.23 16.44 -11.21 34.77 46.69 59.00 77.40 ▁▂▆▇▅
humidity 1648 0.67 0.71 0.16 0.08 0.62 0.73 0.82 1.00 ▁▁▃▇▅
cloud_cover 1937 0.61 0.44 0.33 0.00 0.12 0.40 0.73 1.00 ▇▅▃▃▅
moon_phase 1625 0.68 0.50 0.29 0.00 0.25 0.49 0.75 1.00 ▇▇▇▇▇
precip_intensity 2309 0.54 0.01 0.05 0.00 0.00 0.00 0.00 2.07 ▇▁▁▁▁
precip_probability 2311 0.54 0.30 0.42 0.00 0.00 0.00 0.73 1.00 ▇▁▁▁▃
pressure 2402 0.52 1017.08 6.14 980.34 1013.42 1016.96 1020.64 1042.41 ▁▁▇▆▁
uv_index 1629 0.68 5.16 3.14 0.00 3.00 5.00 8.00 13.00 ▆▇▅▆▁
visibility 1972 0.61 8.49 2.06 0.74 7.66 9.45 10.00 10.00 ▁▁▁▂▇
wind_bearing 1634 0.67 196.57 96.38 0.00 128.00 203.00 273.00 359.00 ▅▅▇▇▆
wind_speed 1632 0.67 3.87 3.28 0.00 1.34 2.93 5.56 23.94 ▇▃▁▁▁

Issues with data:

data_clean <- data %>% 
   
    # Treat missing values
    select(-precip_type, -precip_intensity, -precip_probability) %>% 
    na.omit() %>% 
    
    # Drop date (temporary)
    select(-c(date)) %>% 
    
    # Drop rare level in target variable; drop NAs in observed variable
    filter(classification != "Class C", !is.na(observed)) %>%
    
    # Rename classification "A" and "B"
    mutate(classification = case_when(classification == "Class A" ~ "sighting",
                                      classification == "Class B" ~ "possible")) %>%
    
    # Drop String Variables and Geohash; OBSERVED VARIABLE TEMPORARY
    select(-c(location_details, title, summary, observed, geohash))


#data_clean <- data %>% 
   
    # Address factors imported as numeric
    # none
    
    # Drop zero-variance variables
    # none

Explore data

data_clean %>% count(classification)
## # A tibble: 2 × 2
##   classification     n
##   <chr>          <int>
## 1 possible        1053
## 2 sighting        1019
data_clean %>%
    ggplot(aes(classification)) +
    geom_bar()

classification vs. Temperature_High

data_clean %>% 
    ggplot(aes(classification, temperature_high)) + 
    geom_boxplot()

Correlation Plot

# Step 1: Binarize
data_binarized <- data_clean %>% 
    select(-number) %>%
    binarize()

data_binarized %>% glimpse()
## Rows: 2,072
## Columns: 97
## $ county__Jackson_County              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ county__Jefferson_County            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ county__King_County                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ county__Pierce_County               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ county__Snohomish_County            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ county__Washington_County           <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `county__-OTHER`                    <dbl> 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ state__Alabama                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Arkansas                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__California                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Colorado                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Florida                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Georgia                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Idaho                        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Illinois                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Indiana                      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Iowa                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Kansas                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Kentucky                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Michigan                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Missouri                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__New_Jersey                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1…
## $ state__New_York                     <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__North_Carolina               <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0…
## $ state__Ohio                         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Oklahoma                     <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Oregon                       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Pennsylvania                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Tennessee                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Texas                        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Virginia                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Washington                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__West_Virginia                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ state__Wisconsin                    <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ `state__-OTHER`                     <dbl> 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0…
## $ season__Fall                        <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ season__Spring                      <dbl> 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0…
## $ season__Summer                      <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1…
## $ season__Unknown                     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ season__Winter                      <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0…
## $ `latitude__-Inf_35.298325`          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ latitude__35.298325_39.642495       <dbl> 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0…
## $ latitude__39.642495_43.46018        <dbl> 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1…
## $ latitude__43.46018_Inf              <dbl> 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0…
## $ `longitude__-Inf_-112.1051`         <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0…
## $ `longitude__-112.1051_-88.748825`   <dbl> 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0…
## $ `longitude__-88.748825_-82.1174575` <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0…
## $ `longitude__-82.1174575_Inf`        <dbl> 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1…
## $ classification__possible            <dbl> 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1…
## $ classification__sighting            <dbl> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0…
## $ `temperature_high__-Inf_54.65`      <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1…
## $ temperature_high__54.65_69.905      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ temperature_high__69.905_81.2625    <dbl> 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0…
## $ temperature_high__81.2625_Inf       <dbl> 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0…
## $ `temperature_mid__-Inf_46.7925`     <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1…
## $ temperature_mid__46.7925_59.7775    <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
## $ temperature_mid__59.7775_70.86125   <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0…
## $ temperature_mid__70.86125_Inf       <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ `temperature_low__-Inf_38.04`       <dbl> 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0…
## $ temperature_low__38.04_49.94        <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1…
## $ temperature_low__49.94_61.4425      <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ temperature_low__61.4425_Inf        <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0…
## $ `dew_point__-Inf_35.5475`           <dbl> 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0…
## $ dew_point__35.5475_47.51            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ dew_point__47.51_59.6225            <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0…
## $ dew_point__59.6225_Inf              <dbl> 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0…
## $ `humidity__-Inf_0.64`               <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0…
## $ humidity__0.64_0.74                 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ humidity__0.74_0.82                 <dbl> 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0…
## $ humidity__0.82_Inf                  <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1…
## $ `cloud_cover__-Inf_0.13`            <dbl> 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0…
## $ cloud_cover__0.13_0.41              <dbl> 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0…
## $ cloud_cover__0.41_0.74              <dbl> 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0…
## $ cloud_cover__0.74_Inf               <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ `moon_phase__-Inf_0.25`             <dbl> 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0…
## $ moon_phase__0.25_0.51               <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ moon_phase__0.51_0.75               <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1…
## $ moon_phase__0.75_Inf                <dbl> 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0…
## $ `pressure__-Inf_1013.32`            <dbl> 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0…
## $ pressure__1013.32_1016.935          <dbl> 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0…
## $ pressure__1016.935_1020.65          <dbl> 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ pressure__1020.65_Inf               <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1…
## $ `uv_index__-Inf_3`                  <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1…
## $ uv_index__3_5                       <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ uv_index__5_8                       <dbl> 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0…
## $ uv_index__8_Inf                     <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0…
## $ `visibility__-Inf_7.63`             <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0…
## $ visibility__7.63_9.4105             <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1…
## $ visibility__9.4105_Inf              <dbl> 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0…
## $ `wind_bearing__-Inf_127`            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1…
## $ wind_bearing__127_202               <dbl> 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0…
## $ wind_bearing__202_268               <dbl> 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0…
## $ wind_bearing__268_Inf               <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0…
## $ `wind_speed__-Inf_1.42`             <dbl> 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0…
## $ wind_speed__1.42_2.97               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1…
## $ wind_speed__2.97_5.4925             <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ wind_speed__5.4925_Inf              <dbl> 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0…
# Step 2: Correlation 
data_correlation <- data_binarized %>% 
    correlate(classification__sighting)

data_correlation
## # A tibble: 97 × 3
##    feature        bin                  correlation
##    <fct>          <chr>                      <dbl>
##  1 classification possible                 -1     
##  2 classification sighting                  1     
##  3 wind_speed     -Inf_1.42                -0.0917
##  4 longitude      -112.1051_-88.748825      0.0741
##  5 wind_speed     5.4925_Inf                0.0697
##  6 longitude      -Inf_-112.1051           -0.0686
##  7 state          California               -0.0677
##  8 wind_bearing   -Inf_127                  0.0640
##  9 state          Alabama                   0.0598
## 10 dew_point      35.5475_47.51            -0.0573
## # ℹ 87 more rows
# Step 3: Plot
data_correlation %>% 
    correlationfunnel::plot_correlation_funnel()
## Warning: ggrepel: 35 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps

Model Building

Splitting Data

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.5      ✔ rsample      1.2.1 
## ✔ dials        1.2.1      ✔ tune         1.2.1 
## ✔ infer        1.0.7      ✔ workflows    1.1.4 
## ✔ modeldata    1.4.0      ✔ workflowsets 1.1.0 
## ✔ parsnip      1.2.1      ✔ yardstick    1.3.1 
## ✔ recipes      1.0.10
## Warning: package 'modeldata' was built under R version 4.3.3
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## • Use suppressPackageStartupMessages() to eliminate package startup messages
set.seed(1234)
data_clean <- data_clean %>% sample_n(100)

data_split <- initial_split(data_clean)
data_train <- training(data_split)
data_test <- testing(data_split)

data_cv <- rsample::vfold_cv(data_train)
data_cv
## #  10-fold cross-validation 
## # A tibble: 10 × 2
##    splits         id    
##    <list>         <chr> 
##  1 <split [67/8]> Fold01
##  2 <split [67/8]> Fold02
##  3 <split [67/8]> Fold03
##  4 <split [67/8]> Fold04
##  5 <split [67/8]> Fold05
##  6 <split [68/7]> Fold06
##  7 <split [68/7]> Fold07
##  8 <split [68/7]> Fold08
##  9 <split [68/7]> Fold09
## 10 <split [68/7]> Fold10

Preprocessing Data

library(themis)
xgboost_rec <- recipes::recipe(classification ~ ., data = data_train) %>% 
    update_role(number, new_role = "ID") %>%
    step_dummy(all_nominal_predictors()) %>%
    step_YeoJohnson(number, cloud_cover, wind_speed) %>%
    step_normalize(all_numeric_predictors()) #%>%
    #step_pca(all_numeric_predictors(), threshold = 0.75)

xgboost_rec %>% prep() %>% juice() %>% glimpse()
## Rows: 75
## Columns: 106
## $ latitude                   <dbl> 0.94939161, 0.08727068, 0.30195605, -0.7299…
## $ longitude                  <dbl> -1.80788951, 0.56429841, 0.66402872, 0.1727…
## $ number                     <dbl> 12.26835, 46.55484, 165.35253, 109.44224, 1…
## $ temperature_high           <dbl> -0.40428414, -1.18286042, -2.41142465, 1.52…
## $ temperature_mid            <dbl> -0.2508239, -1.0283477, -2.4596984, 1.72485…
## $ temperature_low            <dbl> -0.06908514, -0.80471026, -2.38384428, 1.84…
## $ dew_point                  <dbl> 0.11705362, -1.82597460, -2.24033269, 1.003…
## $ humidity                   <dbl> 1.02989171, -1.07807378, 0.72875379, -1.605…
## $ cloud_cover                <dbl> 1.4786909, 0.2408948, 0.4100535, -1.3042633…
## $ moon_phase                 <dbl> -1.020882991, 1.544208507, 0.525199556, -1.…
## $ pressure                   <dbl> -0.7652005, 1.3213607, 1.0454806, -0.512019…
## $ uv_index                   <dbl> -0.1033537, -1.2107147, -1.2107147, 1.74224…
## $ visibility                 <dbl> 6.672057e-01, 6.672057e-01, -2.616650e+00, …
## $ wind_bearing               <dbl> 0.664607974, -1.234881740, 0.269770224, 0.1…
## $ wind_speed                 <dbl> 0.4536867, 0.7147669, -0.2942243, 1.7601305…
## $ classification             <fct> possible, sighting, possible, sighting, pos…
## $ county_Anne.Arundel.County <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Baldwin.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Bayfield.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Belmont.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Bennington.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Brantley.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Bristol.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Buncombe.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Butte.County        <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Caddo.Parish        <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Cass.County         <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Clackamas.County    <dbl> 8.5447840, -0.1154701, -0.1154701, -0.11547…
## $ county_Clark.County        <dbl> -0.2027588, 4.8662100, -0.2027588, -0.20275…
## $ county_Columbia.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Columbiana.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Crittenden.County   <dbl> -0.1154701, -0.1154701, -0.1154701, 8.54478…
## $ county_Effingham.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Erie.County         <dbl> -0.1154701, -0.1154701, 8.5447840, -0.11547…
## $ county_Florence.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Franklin.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Fulton.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Geauga.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Grays.Harbor.County <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Guernsey.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Hardin.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Harrison.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Holmes.County       <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_Hood.River.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Jackson.County      <dbl> -0.2027588, -0.2027588, -0.2027588, -0.2027…
## $ county_Jefferson.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Josephine.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Kinney.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Kitsap.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Kittitas.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Lawrence.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Lee.County          <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Manistee.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Mason.County        <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_McDonough.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Miller.County       <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_Monroe.County       <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_Montgomery.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Morgan.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Okaloosa.County     <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_Oneida.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Oscoda.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Pierce.County       <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ county_Portage.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Pulaski.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Rapides.Parish      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Rockland.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_San.Miguel.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Seminole.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Snohomish.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Tarrant.County      <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Tattnall.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Tuscarawas.County   <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Umatilla.County     <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Warren.County       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Washtenaw.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Winnebago.County    <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Wythe.County        <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ county_Yakima.County       <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ state_Arkansas             <dbl> -0.2027588, -0.2027588, -0.2027588, 4.86621…
## $ state_California           <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Florida              <dbl> -0.2929114, -0.2929114, -0.2929114, -0.2929…
## $ state_Georgia              <dbl> -0.2027588, -0.2027588, -0.2027588, -0.2027…
## $ state_Illinois             <dbl> -0.2929114, -0.2929114, -0.2929114, -0.2929…
## $ state_Kentucky             <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Louisiana            <dbl> -0.164414, -0.164414, -0.164414, -0.164414,…
## $ state_Maryland             <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Massachusetts        <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Michigan             <dbl> -0.2654735, -0.2654735, -0.2654735, -0.2654…
## $ state_Mississippi          <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Missouri             <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_New.Mexico           <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_New.York             <dbl> -0.2357686, -0.2357686, -0.2357686, -0.2357…
## $ state_North.Carolina       <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Ohio                 <dbl> -0.411805, 2.395956, 2.395956, -0.411805, -…
## $ state_Oregon               <dbl> 3.7166293, -0.2654735, -0.2654735, -0.26547…
## $ state_Pennsylvania         <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Tennessee            <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Texas                <dbl> -0.2027588, -0.2027588, -0.2027588, -0.2027…
## $ state_Vermont              <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Virginia             <dbl> -0.1154701, -0.1154701, -0.1154701, -0.1154…
## $ state_Washington           <dbl> -0.3896086, -0.3896086, -0.3896086, -0.3896…
## $ state_Wisconsin            <dbl> -0.2027588, -0.2027588, -0.2027588, -0.2027…
## $ season_Spring              <dbl> 2.2759613, -0.4335164, -0.4335164, -0.43351…
## $ season_Summer              <dbl> -0.7886881, -0.7886881, -0.7886881, 1.25102…
## $ season_Winter              <dbl> -0.3668044, -0.3668044, 2.6898988, -0.36680…

Specify Model

xgboost_spec <- 
  boost_tree(trees = tune(), tree_depth = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(xgboost_rec) %>% 
  add_model(xgboost_spec) 

Tune Hyperparameters

doParallel::registerDoParallel()

set.seed(56356)
xgboost_tune <-
  tune_grid(xgboost_workflow,
            resamples = data_cv,
            grid = 5,
            control = control_grid(save_pred = TRUE))
## Warning: package 'xgboost' was built under R version 4.3.3

Model Evaluation

Identify optimal hyperparameters

collect_metrics(xgboost_tune)
## # A tibble: 15 × 8
##    trees tree_depth .metric     .estimator  mean     n std_err .config          
##    <int>      <int> <chr>       <chr>      <dbl> <int>   <dbl> <chr>            
##  1  1599          2 accuracy    binary     0.436    10  0.0601 Preprocessor1_Mo…
##  2  1599          2 brier_class binary     0.424    10  0.0475 Preprocessor1_Mo…
##  3  1599          2 roc_auc     binary     0.398    10  0.0828 Preprocessor1_Mo…
##  4    18          5 accuracy    binary     0.371    10  0.0447 Preprocessor1_Mo…
##  5    18          5 brier_class binary     0.384    10  0.0395 Preprocessor1_Mo…
##  6    18          5 roc_auc     binary     0.397    10  0.0795 Preprocessor1_Mo…
##  7  1982          8 accuracy    binary     0.427    10  0.0700 Preprocessor1_Mo…
##  8  1982          8 brier_class binary     0.456    10  0.0537 Preprocessor1_Mo…
##  9  1982          8 roc_auc     binary     0.405    10  0.0909 Preprocessor1_Mo…
## 10   525         12 accuracy    binary     0.427    10  0.0632 Preprocessor1_Mo…
## 11   525         12 brier_class binary     0.455    10  0.0533 Preprocessor1_Mo…
## 12   525         12 roc_auc     binary     0.383    10  0.0955 Preprocessor1_Mo…
## 13  1037         15 accuracy    binary     0.427    10  0.0700 Preprocessor1_Mo…
## 14  1037         15 brier_class binary     0.454    10  0.0532 Preprocessor1_Mo…
## 15  1037         15 roc_auc     binary     0.407    10  0.0967 Preprocessor1_Mo…
collect_predictions(xgboost_tune) %>%
    group_by(id) %>%
    roc_curve(classification, .pred_possible) %>% 
    autoplot()

Fit the model for the last time

xgboost_last <- xgboost_workflow %>%
    finalize_workflow(select_best(xgboost_tune, metric = "accuracy")) %>%
    last_fit(data_split)
## → A | warning: ! There are new levels in a factor: `McKean County`, `Daggett County`, `Johnson
##                  County`, `Passaic County`, `St. Louis County`, `Davis County`, `White
##                  County`, `Ohio County`, `Cameron County`, `Muskingum County`, `York County`,
##                  `El Dorado County`, `Lewis County`, `Clarke County`, `Cook County`, `Wilson
##                  County`, `Dubuque County`, `Cherokee County`, …, `Kootenai County`, and
##                  `Sarasota County`., ! There are new levels in a factor: `Utah`, `Iowa`, `New Jersey`, `Minnesota`,
##                  `West Virginia`, `Maine`, and `Idaho`.
## 
There were issues with some computations   A: x1

There were issues with some computations   A: x1
collect_metrics(xgboost_last)
## # A tibble: 3 × 4
##   .metric     .estimator .estimate .config             
##   <chr>       <chr>          <dbl> <chr>               
## 1 accuracy    binary         0.6   Preprocessor1_Model1
## 2 roc_auc     binary         0.564 Preprocessor1_Model1
## 3 brier_class binary         0.328 Preprocessor1_Model1
collect_predictions(xgboost_last) %>%
    yardstick::conf_mat(classification, .pred_class) %>%
    autoplot()

Variable importance

library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
xgboost_last %>%
    workflows::extract_fit_engine() %>%
    vip()

Conclusion

The previous model had an accuracy of 0.56 and an AUC of 0.532.