About

This is a minimal demonstration of fitting some machine learning models with the new tidymodels package.

Init

options(digits = 3)
library(pacman)
p_load(tidyverse, magrittr, tidymodels, GGally)
theme_set(theme_bw())

iris dataset

Binary logistic regression

A really simple case of classification. We predict whether it is a particular species from the measurements.

#plot data to see what's up
#https://stackoverflow.com/a/12047554/3980197
ggpairs(iris, aes(colour = Species, alpha = 0.4))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

#lets make a binary version for species
iris = cbind(
  iris,
  model.matrix( ~ Species - 1, data=iris ) %>% set_colnames(colnames(.) %>% str_replace("Species", "")) %>% as_tibble() %>% map_df(as.factor)
) %>% as_tibble()

#print data
iris
#outcome split
iris$Species %>% table()
## .
##     setosa versicolor  virginica 
##         50         50         50
#make a recipe
iris_recipe <- 
  recipe(virginica ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = iris)

#make a model
iris_model <- 
  logistic_reg() %>% 
  set_engine("glm")

#resampling method
set.seed(1)
iris_folds <- vfold_cv(iris, v = 10)

#make workflow
iris_wf <- 
  workflow() %>% 
  add_model(iris_model) %>% 
  add_recipe(iris_recipe)

#fit
iris_fit <- 
  iris_wf %>% 
  fit_resamples(iris_folds)
## ! Fold01: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold02: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold03: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold04: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold05: model: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie...
## ! Fold06: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold07: model: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie...
## ! Fold08: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold09: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold10: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
#metrics
collect_metrics(iris_fit)

Multinomial regression

As above, but we want to predict each species level, not a single contrast. It was a bit difficult to figure out how to do this, but one needs to search here for the model to find the function one needs, turns out it is called multinom_reg().

#make a recipe
#same as above except that we swap the outcome to Species
iris2_recipe <- 
  recipe(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = iris)

#make a model
#swap to multinom_reg() + nnet
iris2_model <- 
  multinom_reg() %>% 
  set_engine("nnet")

#resampling method
#actually the same as above
set.seed(1)
iris2_folds <- vfold_cv(iris, v = 10)

#make workflow
iris2_wf <- 
  workflow() %>% 
  add_model(iris2_model) %>% 
  add_recipe(iris2_recipe)

#fit
iris2_fit <- 
  iris_wf %>% 
  fit_resamples(iris2_folds)
## ! Fold01: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold02: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold03: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold04: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold05: model: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie...
## ! Fold06: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold07: model: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie...
## ! Fold08: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold09: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
## ! Fold10: model: glm.fit: fitted probabilities numerically 0 or 1 occurred
#metrics
collect_metrics(iris2_fit)

So, we find that this performs slightly worse when we need to predict each species, not just whether species was virginica. This is expected. If we look at the plot, we see that the red species (setosa) is very easy to predict even by eyesight, so any proper model would have ~100% accuracy. The other two species are a bit harder to tell apart, but we do well here with about 97% accuracy versus expected by chance of ~33% (the dataset is evenly split).

MMPI items to predict intelligence in Vietnam Experience Study dataset

Let’s try something a little bit more complicated. Let’s redo this analysis I previusly did in the previous meta-modeling package, caret. In this situation, we have data from about 4.5k US veterans on the MMPI items, some 550 items asking about varius aspects of psychopathology. We also have general intelligence as measured by the common fator of ~18 diverse cognitive tests.

#read data from file
#you can get it here: https://osf.io/dbn4k/
ves_mmpi = read_rds("data/d_nomiss2.rds") %>% as_tibble()

#print data
ves_mmpi
#make a recipe
ves_recipe <- 
  recipe(g ~ ., data = ves_mmpi)

#make a model
#use glmnet
ves_model <- 
  linear_reg(
    penalty = tune(), #tune this
    mixture = 0 #use ridge
  ) %>% 
  set_engine("glmnet")

#resampling method
set.seed(1)
ves_folds <- vfold_cv(ves_mmpi, v = 10)

#make workflow
ves_wf <- 
  workflow() %>% 
  add_model(ves_model) %>% 
  add_recipe(ves_recipe)

#fit
ves_fit <- 
  ves_wf %>% 
  tune_grid(
    resamples = ves_folds,
    grid = 25,
    control = control_grid(
      save_pred = TRUE
    )
  )

#metrics
collect_metrics(ves_fit) %>% 
  filter(.metric == "rsq") %>% 
  arrange(.metric)
collect_metrics(ves_fit) %>% 
  ggplot(aes(penalty, mean)) +
  geom_line() +
  facet_wrap(".metric")

#plot out of sample predictions
ves_fit %>% 
  collect_predictions(parameters = select_best(ves_fit, metric = "rsq")) %>% 
  ggplot(aes(.pred, g)) +
  geom_point() +
  geom_smooth() +
  ggtitle("Prediction of general intelligence from MMPI items",
          str_glue("Out of sample predictions, n = {nrow(ves_mmpi)}, model = ridge regression, r = {collect_metrics(ves_fit) %>% filter(.metric == 'rsq') %>%  arrange(.metric) %>% pull(mean) %>% head(1) %>% sqrt() %>% round(3)}"))
## `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

Checking with the prior results, we see that in the prior analysis, we found the out of sample correlation to be .83. Within rounding error, this is the same value we find here. We also see that we don’t need much penality, indicating that ridge regression is not helping out much here compared to just OLS.

Meta

sessionInfo()
## R version 4.0.0 (2020-04-24)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Linux Mint 19.3
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
##  [5] LC_MONETARY=de_DE.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=de_DE.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=de_DE.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] GGally_1.5.0     yardstick_0.0.6  workflows_0.1.1  tune_0.1.0      
##  [5] rsample_0.0.6    recipes_0.1.12   parsnip_0.1.1    infer_0.5.1     
##  [9] dials_0.0.6      scales_1.1.1     broom_0.5.6      tidymodels_0.1.0
## [13] magrittr_1.5     forcats_0.5.0    stringr_1.4.0    dplyr_0.8.5     
## [17] purrr_0.3.4      readr_1.3.1      tidyr_1.0.3      tibble_3.0.1    
## [21] ggplot2_3.3.0    tidyverse_1.3.0  pacman_0.5.1    
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.3.1        backports_1.1.7     tidytext_0.2.4     
##   [4] plyr_1.8.6          igraph_1.2.5        splines_4.0.0      
##   [7] crosstalk_1.1.0.1   listenv_0.8.0       SnowballC_0.7.0    
##  [10] rstantools_2.0.0    inline_0.3.15       digest_0.6.25      
##  [13] foreach_1.5.0       htmltools_0.4.0     rsconnect_0.8.16   
##  [16] fansi_0.4.1         globals_0.12.5      modelr_0.1.7       
##  [19] gower_0.2.1         matrixStats_0.56.0  xts_0.12-0         
##  [22] hardhat_0.1.2       prettyunits_1.1.1   colorspace_1.4-1   
##  [25] rvest_0.3.5         haven_2.2.0         xfun_0.13          
##  [28] callr_3.4.3         crayon_1.3.4        jsonlite_1.6.1     
##  [31] lme4_1.1-23         iterators_1.0.12    survival_3.1-12    
##  [34] zoo_1.8-8           glue_1.4.1          gtable_0.3.0       
##  [37] ipred_0.9-9         pkgbuild_1.0.8      rstan_2.19.3       
##  [40] shape_1.4.4         DBI_1.1.0           miniUI_0.1.1.1     
##  [43] Rcpp_1.0.4.6        xtable_1.8-4        GPfit_1.0-8        
##  [46] stats4_4.0.0        lava_1.6.7          StanHeaders_2.19.2 
##  [49] prodlim_2019.11.13  DT_0.13             glmnet_4.0         
##  [52] htmlwidgets_1.5.1   httr_1.4.1          threejs_0.3.3      
##  [55] RColorBrewer_1.1-2  ellipsis_0.3.1      farver_2.0.3       
##  [58] reshape_0.8.8       pkgconfig_2.0.3     loo_2.2.0          
##  [61] nnet_7.3-14         dbplyr_1.4.3        labeling_0.3       
##  [64] tidyselect_1.1.0    rlang_0.4.6         DiceDesign_1.8-1   
##  [67] reshape2_1.4.4      later_1.0.0         munsell_0.5.0      
##  [70] cellranger_1.1.0    tools_4.0.0         cli_2.0.2          
##  [73] generics_0.0.2      ggridges_0.5.2      evaluate_0.14      
##  [76] fastmap_1.0.1       yaml_2.2.1          processx_3.4.2     
##  [79] knitr_1.28          fs_1.4.1            future_1.17.0      
##  [82] nlme_3.1-147        mime_0.9            rstanarm_2.19.3    
##  [85] xml2_1.3.2          tokenizers_0.2.1    compiler_4.0.0     
##  [88] bayesplot_1.7.1     shinythemes_1.1.2   rstudioapi_0.11    
##  [91] reprex_0.3.0        tidyposterior_0.0.2 lhs_1.0.2          
##  [94] statmod_1.4.34      stringi_1.4.6       ps_1.3.3           
##  [97] lattice_0.20-41     Matrix_1.2-18       nloptr_1.2.2.1     
## [100] markdown_1.1        shinyjs_1.1         vctrs_0.3.0        
## [103] pillar_1.4.4        lifecycle_0.2.0     furrr_0.1.0        
## [106] httpuv_1.5.2        R6_2.4.1            promises_1.1.0     
## [109] gridExtra_2.3       janeaustenr_0.1.5   codetools_0.2-16   
## [112] boot_1.3-25         colourpicker_1.0    MASS_7.3-51.6      
## [115] gtools_3.8.2        assertthat_0.2.1    withr_2.2.0        
## [118] shinystan_2.5.0     mgcv_1.8-31         parallel_4.0.0     
## [121] hms_0.5.3           grid_4.0.0          rpart_4.1-15       
## [124] timeDate_3043.102   minqa_1.2.4         class_7.3-17       
## [127] rmarkdown_2.1       pROC_1.16.2         tidypredict_0.4.5  
## [130] shiny_1.4.0.2       lubridate_1.7.8     base64enc_0.1-3    
## [133] dygraphs_1.1.1.6