library(readr)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
# loading the data
hotels <- 
  read_csv('https://tidymodels.org/start/case-study/hotels.csv') %>%
  mutate(across(where(is.character), as.factor))
## Rows: 50000 Columns: 23
## -- Column specification --------------------------------------------------------
## Delimiter: ","
## chr  (11): hotel, children, meal, country, market_segment, distribution_chan...
## dbl  (11): lead_time, stays_in_weekend_nights, stays_in_week_nights, adults,...
## date  (1): arrival_date
## 
## i Use `spec()` to retrieve the full column specification for this data.
## i Specify the column types or set `show_col_types = FALSE` to quiet this message.
dim(hotels)
## [1] 50000    23
# level of children and proportion
hotels%>% count(children)%>% mutate(prop = n/sum(n))
## # A tibble: 2 x 3
##   children     n   prop
##   <fct>    <int>  <dbl>
## 1 children  4038 0.0808
## 2 none     45962 0.919
prop.table(table(hotels$children))# another approach but no additional column
## 
## children     none 
##  0.08076  0.91924
# Create an initial split of our data to create a training and a testing set.
#Call the split `splits`.
#Because of the class imbalance, stratify the split by our outcome variable, `children`.
#Extract the training data as `hotel_other` and the testing data as `hotel_test`.
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.5     v purrr   0.3.4
## v tibble  3.1.3     v stringr 1.4.0
## v tidyr   1.1.3     v forcats 0.5.1
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(tidymodels)
## Registered S3 method overwritten by 'tune':
##   method                   from   
##   required_pkgs.model_spec parsnip
## -- Attaching packages -------------------------------------- tidymodels 0.1.3 --
## v broom        0.7.9      v rsample      0.1.0 
## v dials        0.0.9      v tune         0.1.6 
## v infer        1.0.0      v workflows    0.2.3 
## v modeldata    0.1.1      v workflowsets 0.1.0 
## v parsnip      0.1.7      v yardstick    0.0.8 
## v recipes      0.1.16
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
## * Use tidymodels_prefer() to resolve common conflicts.
set.seed(123)
splits <- initial_split(hotels, strata = children)
splits
## <Analysis/Assess/Total>
## <37500/12500/50000>
hotel_other <- training(splits)
hotel_test <- testing(splits)
# training set proportions by children
hotel_other %>% 
  count(children) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 x 3
##   children     n   prop
##   <fct>    <int>  <dbl>
## 1 children  3027 0.0807
## 2 none     34473 0.919
# test set proportions by children
hotel_test  %>% 
  count(children) %>% 
  mutate(prop = n/sum(n))
## # A tibble: 2 x 3
##   children     n   prop
##   <fct>    <int>  <dbl>
## 1 children  1011 0.0809
## 2 none     11489 0.919
set.seed(234)
val_set <- validation_split(hotel_other, 
                            strata = children, 
                            prop = 0.80)
val_set
## # Validation Set Split (0.8/0.2)  using stratification 
## # A tibble: 1 x 2
##   splits               id        
##   <list>               <chr>     
## 1 <split [30000/7500]> validation
library(doParallel)
## Loading required package: foreach
## 
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
## 
##     accumulate, when
## Loading required package: iterators
## Loading required package: parallel
cores <- parallel::detectCores()
cores
## [1] 8
rf_mod <- 
  rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>% 
  set_engine("ranger", num.threads = cores) %>% 
  set_mode("classification")
rf_recipe <- 
  recipe(children ~ ., data = hotel_other) %>% 
  step_date(arrival_date) %>% 
  step_holiday(arrival_date) %>% 
  step_rm(arrival_date) 
rf_workflow <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(rf_recipe)
rf_mod
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
## 
## Engine-Specific Arguments:
##   num.threads = cores
## 
## Computational engine: ranger
rf_mod %>%    
  parameters() 
## Collection of 2 parameters for tuning
## 
##  identifier  type    object
##        mtry  mtry nparam[?]
##       min_n min_n nparam[+]
## 
## Model parameters needing finalization:
##    # Randomly Selected Predictors ('mtry')
## 
## See `?dials::finalize` or `?dials::update.parameters` for more information.
set.seed(345)
rf_res <- 
  rf_workflow %>% 
  tune_grid(val_set,
            grid = 25,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))
## i Creating pre-processing data to finalize unknown parameter: mtry
rf_res %>% 
  show_best(metric = "roc_auc")
## # A tibble: 5 x 8
##    mtry min_n .metric .estimator  mean     n std_err .config              
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1     8     7 roc_auc binary     0.926     1      NA Preprocessor1_Model13
## 2    12     7 roc_auc binary     0.925     1      NA Preprocessor1_Model01
## 3     9    12 roc_auc binary     0.925     1      NA Preprocessor1_Model19
## 4     7    25 roc_auc binary     0.924     1      NA Preprocessor1_Model03
## 5    13     4 roc_auc binary     0.924     1      NA Preprocessor1_Model05
autoplot(rf_res)

rf_best <- 
  rf_res %>% 
  select_best(metric = "roc_auc")
rf_best
## # A tibble: 1 x 3
##    mtry min_n .config              
##   <int> <int> <chr>                
## 1     8     7 Preprocessor1_Model13
rf_res %>% 
  collect_predictions()
## # A tibble: 187,500 x 8
##    id         .pred_children .pred_none  .row  mtry min_n children .config      
##    <chr>               <dbl>      <dbl> <int> <int> <int> <fct>    <chr>        
##  1 validation       0.148         0.852    13    12     7 none     Preprocessor~
##  2 validation       0.0262        0.974    20    12     7 none     Preprocessor~
##  3 validation       0.472         0.528    22    12     7 children Preprocessor~
##  4 validation       0.00918       0.991    23    12     7 none     Preprocessor~
##  5 validation       0.0110        0.989    31    12     7 none     Preprocessor~
##  6 validation       0.000643      0.999    38    12     7 none     Preprocessor~
##  7 validation       0             1        39    12     7 none     Preprocessor~
##  8 validation       0.00235       0.998    50    12     7 none     Preprocessor~
##  9 validation       0.0217        0.978    54    12     7 none     Preprocessor~
## 10 validation       0.0389        0.961    57    12     7 children Preprocessor~
## # ... with 187,490 more rows
rf_auc <- 
  rf_res %>% 
  collect_predictions(parameters = rf_best) %>% 
  roc_curve(children, .pred_children) %>% 
  mutate(model = "Random Forest")
# the last model
last_rf_mod <- 
  rand_forest(mtry = 4, min_n = 5, trees = 1000) %>% 
  set_engine("ranger", num.threads = cores, importance = "impurity") %>% 
  set_mode("classification")
# the last workflow
last_rf_workflow <- 
  rf_workflow %>% 
  update_model(last_rf_mod)
# the last fit
set.seed(345)
last_rf_fit <- 
  last_rf_workflow %>% 
  last_fit(splits)
last_rf_fit
## # Resampling results
## # Manual resampling 
## # A tibble: 1 x 6
##   splits                id               .metrics  .notes .predictions .workflow
##   <list>                <chr>            <list>    <list> <list>       <list>   
## 1 <split [37500/12500]> train/test split <tibble ~ <tibb~ <tibble [12~ <workflo~
last_rf_fit %>% 
  collect_metrics()
## # A tibble: 2 x 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.944 Preprocessor1_Model1
## 2 roc_auc  binary         0.925 Preprocessor1_Model1
library(vip)
## 
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
## 
##     vi
last_rf_fit %>% 
  pluck(".workflow", 1) %>%   
  pull_workflow_fit() %>% # now we have to use extract_fit_parsnip() instead 
  vip(num_features = 20)
## Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
## Please use `extract_fit_parsnip()` instead.

sessionInfo()
## R version 4.1.1 (2021-08-10)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 19043)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=English_Germany.1252  LC_CTYPE=English_Germany.1252   
## [3] LC_MONETARY=English_Germany.1252 LC_NUMERIC=C                    
## [5] LC_TIME=English_Germany.1252    
## 
## attached base packages:
## [1] parallel  stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] vip_0.3.2          ranger_0.13.1      vctrs_0.3.8        rlang_0.4.11      
##  [5] doParallel_1.0.16  iterators_1.0.13   foreach_1.5.1      yardstick_0.0.8   
##  [9] workflowsets_0.1.0 workflows_0.2.3    tune_0.1.6         rsample_0.1.0     
## [13] recipes_0.1.16     parsnip_0.1.7      modeldata_0.1.1    infer_1.0.0       
## [17] dials_0.0.9        scales_1.1.1       broom_0.7.9        tidymodels_0.1.3  
## [21] forcats_0.5.1      stringr_1.4.0      purrr_0.3.4        tidyr_1.1.3       
## [25] tibble_3.1.3       ggplot2_3.3.5      tidyverse_1.3.1    dplyr_1.0.7       
## [29] readr_2.0.0       
## 
## loaded via a namespace (and not attached):
##  [1] colorspace_2.0-2   ellipsis_0.3.2     class_7.3-19       fs_1.5.0          
##  [5] rstudioapi_0.13    listenv_0.8.0      furrr_0.2.3        farver_2.1.0      
##  [9] bit64_4.0.5        prodlim_2019.11.13 fansi_0.5.0        lubridate_1.7.10  
## [13] xml2_1.3.2         codetools_0.2-18   splines_4.1.1      knitr_1.33        
## [17] jsonlite_1.7.2     pROC_1.17.0.1      dbplyr_2.1.1       compiler_4.1.1    
## [21] httr_1.4.2         backports_1.2.1    assertthat_0.2.1   Matrix_1.3-4      
## [25] cli_3.0.1          htmltools_0.5.1.1  tools_4.1.1        gtable_0.3.0      
## [29] glue_1.4.2         Rcpp_1.0.7         cellranger_1.1.0   DiceDesign_1.9    
## [33] timeDate_3043.102  gower_0.2.2        xfun_0.25          globals_0.14.0    
## [37] rvest_1.0.1        lifecycle_1.0.0    future_1.21.0      MASS_7.3-54       
## [41] ipred_0.9-11       vroom_1.5.4        hms_1.1.0          yaml_2.2.1        
## [45] curl_4.3.2         gridExtra_2.3      rpart_4.1-15       stringi_1.7.3     
## [49] highr_0.9          lhs_1.1.1          hardhat_0.1.6      lava_1.6.9        
## [53] pkgconfig_2.0.3    evaluate_0.14      lattice_0.20-44    labeling_0.4.2    
## [57] bit_4.0.4          tidyselect_1.1.1   parallelly_1.27.0  plyr_1.8.6        
## [61] magrittr_2.0.1     R6_2.5.0           generics_0.1.0     DBI_1.1.1         
## [65] pillar_1.6.2       haven_2.4.3        withr_2.4.2        survival_3.2-11   
## [69] nnet_7.3-16        modelr_0.1.8       crayon_1.4.1       utf8_1.2.2        
## [73] tzdb_0.1.2         rmarkdown_2.10     grid_4.1.1         readxl_1.3.1      
## [77] reprex_2.0.1       digest_0.6.27      munsell_0.5.0      GPfit_1.0-8

R Markdown

This is an R Markdown document. Markdown is a simple formatting syntax for authoring HTML, PDF, and MS Word documents. For more details on using R Markdown see http://rmarkdown.rstudio.com.

When you click the Knit button a document will be generated that includes both content as well as the output of any embedded R code chunks within the document. You can embed an R code chunk like this:

summary(cars)
##      speed           dist       
##  Min.   : 4.0   Min.   :  2.00  
##  1st Qu.:12.0   1st Qu.: 26.00  
##  Median :15.0   Median : 36.00  
##  Mean   :15.4   Mean   : 42.98  
##  3rd Qu.:19.0   3rd Qu.: 56.00  
##  Max.   :25.0   Max.   :120.00

Including Plots

You can also embed plots, for example:

Note that the echo = FALSE parameter was added to the code chunk to prevent printing of the R code that generated the plot.