An Introduction to Statistical Learning (2nd ed.)

Chapter 07

Moving Beyond Linearity

library(tidymodels)
library(ISLR)
library(ggcorrplot)

theme_set(theme_bw())
wage <- tibble(Wage)
glimpse(wage)
Rows: 3,000
Columns: 11
$ year       <int> 2006, 2004, 2003, 2003, 2005, 2008, 2009, 2008, 2006, 2004,~
$ age        <int> 18, 24, 45, 43, 50, 54, 44, 30, 41, 52, 45, 34, 35, 39, 54,~
$ maritl     <fct> 1. Never Married, 1. Never Married, 2. Married, 2. Married,~
$ race       <fct> 1. White, 1. White, 1. White, 3. Asian, 1. White, 1. White,~
$ education  <fct> 1. < HS Grad, 4. College Grad, 3. Some College, 4. College ~
$ region     <fct> 2. Middle Atlantic, 2. Middle Atlantic, 2. Middle Atlantic,~
$ jobclass   <fct> 1. Industrial, 2. Information, 1. Industrial, 2. Informatio~
$ health     <fct> 1. <=Good, 2. >=Very Good, 1. <=Good, 2. >=Very Good, 1. <=~
$ health_ins <fct> 2. No, 2. No, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Ye~
$ logwage    <dbl> 4.318063, 4.255273, 4.875061, 5.041393, 4.318063, 4.845098,~
$ wage       <dbl> 75.04315, 70.47602, 130.98218, 154.68529, 75.04315, 127.115~
summary(wage)
      year           age                     maritl           race     
 Min.   :2003   Min.   :18.00   1. Never Married: 648   1. White:2480  
 1st Qu.:2004   1st Qu.:33.75   2. Married      :2074   2. Black: 293  
 Median :2006   Median :42.00   3. Widowed      :  19   3. Asian: 190  
 Mean   :2006   Mean   :42.41   4. Divorced     : 204   4. Other:  37  
 3rd Qu.:2008   3rd Qu.:51.00   5. Separated    :  55                  
 Max.   :2009   Max.   :80.00                                          
                                                                       
              education                     region               jobclass   
 1. < HS Grad      :268   2. Middle Atlantic   :3000   1. Industrial :1544  
 2. HS Grad        :971   1. New England       :   0   2. Information:1456  
 3. Some College   :650   3. East North Central:   0                        
 4. College Grad   :685   4. West North Central:   0                        
 5. Advanced Degree:426   5. South Atlantic    :   0                        
                          6. East South Central:   0                        
                          (Other)              :   0                        
            health      health_ins      logwage           wage       
 1. <=Good     : 858   1. Yes:2083   Min.   :3.000   Min.   : 20.09  
 2. >=Very Good:2142   2. No : 917   1st Qu.:4.447   1st Qu.: 85.38  
                                     Median :4.653   Median :104.92  
                                     Mean   :4.654   Mean   :111.70  
                                     3rd Qu.:4.857   3rd Qu.:128.68  
                                     Max.   :5.763   Max.   :318.34  
                                                                     
ggplot(wage, aes(wage))+
  geom_histogram(fill = "lightgrey", color = "black")

ggplot(wage, aes(age))+
  geom_histogram(fill = "lightgrey", color = "black")

with(wage, cor.test(wage, age))

    Pearson's product-moment correlation

data:  wage and age
t = 10.923, df = 2998, p-value < 2.2e-16
alternative hypothesis: true correlation is not equal to 0
95 percent confidence interval:
 0.1609777 0.2298147
sample estimates:
      cor 
0.1956372 
ggplot(wage, aes(x = age, y = wage)) + 
  geom_point()

ggplot(wage, aes(wage, fill= race))+
  geom_boxplot() + 
  coord_flip() + 
  theme(legend.position = "top")

ggplot(wage, aes(wage, color = jobclass)) + 
  geom_density()

wage %>% 
  group_by(jobclass) %>% 
  summarise(shapiro_test = shapiro.test(wage)$p.value)
# A tibble: 2 x 2
  jobclass       shapiro_test
  <fct>                 <dbl>
1 1. Industrial      9.91e-29
2 2. Information     2.90e-33
ggplot(wage, aes(wage, fill = race))+
  geom_histogram() + 
  facet_wrap(. ~ jobclass)

t_test(wage, wage ~ jobclass)
# A tibble: 1 x 6
  statistic  t_df  p_value alternative lower_ci upper_ci
      <dbl> <dbl>    <dbl> <chr>          <dbl>    <dbl>
1     -11.5 2715. 7.27e-30 two.sided      -20.2    -14.3

Polynomial regression and Step function

poly_rec <- 
  recipe(wage ~ age, data = wage) %>% 
  step_poly(age, degree = 4)

lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

poly_wf <- 
  workflow() %>% 
  add_recipe(poly_rec) %>% 
  add_model(lm_model)
poly_fit <- fit(poly_wf, data = wage)

tidy(poly_fit)
# A tibble: 5 x 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)    112.      0.729    153.   0       
2 age_poly_1     447.     39.9       11.2  1.48e-28
3 age_poly_2    -478.     39.9      -12.0  2.36e-32
4 age_poly_3     126.     39.9        3.14 1.68e- 3
5 age_poly_4     -77.9    39.9       -1.95 5.10e- 2
# this is a toy example
wage %>% 
  ggplot(aes(age, wage)) + 
  geom_point(alpha = .2) + 
  geom_smooth()

Let us take that one step further and see what happens to the regression line once we go past the domain it was trained on. the previous plot showed individuals within the age range 18-80. Let us see what happens once we push this to 18-100. This is not an impossible range but an unrealistic range.

wide_age_range <- tibble(age = seq(18, 100))

regression_lines <- bind_cols(
  augment(poly_fit, new_data = wide_age_range),
  predict(poly_fit, new_data = wide_age_range, type = "conf_int")
)

wage %>%
  ggplot(aes(age, wage)) +
  geom_point(alpha = 0.2) +
  geom_line(aes(y = .pred), color = "darkgreen",
            data = regression_lines) +
  geom_line(aes(y = .pred_lower), data = regression_lines, 
            linetype = "dashed", color = "blue") +
  geom_line(aes(y = .pred_upper), data = regression_lines, 
            linetype = "dashed", color = "blue")

And we see that the curve starts diverging once we get to 93 the predicted wage is negative. The confidence bands also get wider and wider as we get farther away from the data.

As a classification problem

wage <- wage %>%
  mutate(salary = factor(wage > 250, 
                       levels = c(TRUE, FALSE), 
                       labels = c("High", "Low")))

glimpse(wage)
Rows: 3,000
Columns: 12
$ year       <int> 2006, 2004, 2003, 2003, 2005, 2008, 2009, 2008, 2006, 2004,~
$ age        <int> 18, 24, 45, 43, 50, 54, 44, 30, 41, 52, 45, 34, 35, 39, 54,~
$ maritl     <fct> 1. Never Married, 1. Never Married, 2. Married, 2. Married,~
$ race       <fct> 1. White, 1. White, 1. White, 3. Asian, 1. White, 1. White,~
$ education  <fct> 1. < HS Grad, 4. College Grad, 3. Some College, 4. College ~
$ region     <fct> 2. Middle Atlantic, 2. Middle Atlantic, 2. Middle Atlantic,~
$ jobclass   <fct> 1. Industrial, 2. Information, 1. Industrial, 2. Informatio~
$ health     <fct> 1. <=Good, 2. >=Very Good, 1. <=Good, 2. >=Very Good, 1. <=~
$ health_ins <fct> 2. No, 2. No, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Ye~
$ logwage    <dbl> 4.318063, 4.255273, 4.875061, 5.041393, 4.318063, 4.845098,~
$ wage       <dbl> 75.04315, 70.47602, 130.98218, 154.68529, 75.04315, 127.115~
$ salary     <fct> Low, Low, Low, Low, Low, Low, Low, Low, Low, Low, Low, Low,~
# Class imbalance!
wage %>% 
  group_by(salary) %>% 
  count() 
# A tibble: 2 x 2
# Groups:   salary [2]
  salary     n
  <fct>  <int>
1 High      79
2 Low     2921
glm_rec <- 
  recipe(salary ~ age, data = wage) %>% 
  step_poly(age, degree = 4)

glm_model <- 
  logistic_reg() %>% 
  set_engine("glm")

glm_wf <- 
  workflow() %>% 
  add_recipe(glm_rec) %>% 
  add_model(glm_model)
poly_glm_fit <- fit(glm_wf, wage)
tidy(poly_glm_fit)
# A tibble: 5 x 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)     4.30     0.345     12.5  1.16e-35
2 age_poly_1    -72.0     26.1       -2.76 5.86e- 3
3 age_poly_2     85.8     35.9        2.39 1.69e- 2
4 age_poly_3    -34.2     19.7       -1.74 8.27e- 2
5 age_poly_4     47.4     24.1        1.97 4.91e- 2
predict(poly_glm_fit, wage)
# A tibble: 3,000 x 1
   .pred_class
   <fct>      
 1 Low        
 2 Low        
 3 Low        
 4 Low        
 5 Low        
 6 Low        
 7 Low        
 8 Low        
 9 Low        
10 Low        
# ... with 2,990 more rows
predict(poly_glm_fit, wage, type= "prob")
# A tibble: 3,000 x 2
      .pred_High .pred_Low
           <dbl>     <dbl>
 1 0.00000000983     1.00 
 2 0.000120          1.00 
 3 0.0307            0.969
 4 0.0320            0.968
 5 0.0305            0.970
 6 0.0352            0.965
 7 0.0313            0.969
 8 0.00820           0.992
 9 0.0334            0.967
10 0.0323            0.968
# ... with 2,990 more rows

Next, let us take a look at the step function and how to fit a model using it as a preprocessor. You can create step functions in a couple of different ways. step_discretize() will convert a numeric variable into a factor variable with n bins, n here is specified with num_breaks. These will have approximately the same number of points in them according to the training data set.

discret_rec <- 
  recipe(salary ~ age, data = wage) %>% 
  step_discretize(age, num_breaks = 4)

discret_wf <- 
  workflow() %>% 
  add_recipe(discret_rec) %>% 
  add_model(glm_model)

discret_fit <- fit(discret_wf, wage)
tidy(discret_fit)
# A tibble: 4 x 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)     5.00     0.449     11.2  6.88e-29
2 agebin2        -1.49     0.498     -3.00 2.73e- 3
3 agebin3        -1.67     0.488     -3.41 6.45e- 4
4 agebin4        -1.71     0.494     -3.45 5.64e- 4

If you already know where you want the step function to break then you can use step_cut() and supply the breaks manually.

cut_rec <- recipe(salary ~ age, data = wage) %>%
  step_cut(age, breaks = c(30, 50, 70))

cut_wf <- workflow() %>%
  add_model(glm_model) %>%
  add_recipe(cut_rec)

cut_fit <- fit(cut_wf, data = wage)
tidy(cut_fit)
# A tibble: 4 x 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)     6.26      1.00    6.25   4.11e-10
2 age(30,50]     -2.75      1.01   -2.72   6.62e- 3
3 age(50,70]     -3.04      1.02   -2.98   2.88e- 3
4 age(70,80]     10.3     446.      0.0231 9.82e- 1

Splines

In order to fit regression splines, or in other words, use splines as preprocessors when fitting a linear model, we use step_bs() to construct the matrices of basis functions. The bs() function is used and arguments such as knots can be passed to bs() by using passing a named list to options.

spline_rec <- recipe(wage ~ age, data = wage) %>%
  step_bs(age, options = list(knots = 25, 40, 60))

spline_wf <- workflow() %>%
  add_model(lm_model) %>%
  add_recipe(spline_rec)

spline_fit <- fit(spline_wf, data = wage)

predict(spline_fit, new_data = wage)
# A tibble: 3,000 x 1
   .pred
   <dbl>
 1  58.7
 2  84.3
 3 120. 
 4 120. 
 5 120. 
 6 119. 
 7 120. 
 8 102. 
 9 119. 
10 120. 
# ... with 2,990 more rows

An Introduction to Statistcial Learning

ISLR tidymodels Labs

– END

sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)

Matrix products: default

locale:
[1] LC_COLLATE=Spanish_Mexico.1252  LC_CTYPE=Spanish_Mexico.1252   
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C                   
[5] LC_TIME=Spanish_Mexico.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] ggcorrplot_0.1.3   ISLR_1.2           yardstick_0.0.8    workflowsets_0.0.2
 [5] workflows_0.2.3    tune_0.1.6         tidyr_1.1.3        tibble_3.1.3      
 [9] rsample_0.1.0      recipes_0.1.16     purrr_0.3.4        parsnip_0.1.7     
[13] modeldata_0.1.1    infer_0.5.4        ggplot2_3.3.5      dplyr_1.0.7       
[17] dials_0.0.9        scales_1.1.1       broom_0.7.8        tidymodels_0.1.3  

loaded via a namespace (and not attached):
 [1] sass_0.4.0         jsonlite_1.7.2     splines_4.1.0      foreach_1.5.1     
 [5] prodlim_2019.11.13 bslib_0.2.5.1      assertthat_0.2.1   highr_0.9         
 [9] GPfit_1.0-8        yaml_2.2.1         globals_0.14.0     ipred_0.9-11      
[13] pillar_1.6.2       backports_1.2.1    lattice_0.20-44    glue_1.4.2        
[17] pROC_1.17.0.1      digest_0.6.27      hardhat_0.1.6      colorspace_2.0-2  
[21] plyr_1.8.6         htmltools_0.5.1.1  Matrix_1.3-3       timeDate_3043.102 
[25] pkgconfig_2.0.3    lhs_1.1.1          DiceDesign_1.9     listenv_0.8.0     
[29] gower_0.2.2        lava_1.6.9         mgcv_1.8-35        farver_2.1.0      
[33] generics_0.1.0     ellipsis_0.3.2     withr_2.4.2        furrr_0.2.3       
[37] nnet_7.3-16        cli_3.0.0          survival_3.2-11    magrittr_2.0.1    
[41] crayon_1.4.1       evaluate_0.14      future_1.21.0      fansi_0.5.0       
[45] parallelly_1.26.1  nlme_3.1-152       MASS_7.3-54        class_7.3-19      
[49] tools_4.1.0        lifecycle_1.0.0    stringr_1.4.0      munsell_0.5.0     
[53] compiler_4.1.0     jquerylib_0.1.4    rlang_0.4.11       grid_4.1.0        
[57] rstudioapi_0.13    iterators_1.0.13   labeling_0.4.2     rmarkdown_2.9     
[61] gtable_0.3.0       codetools_0.2-18   DBI_1.1.1          R6_2.5.0          
[65] lubridate_1.7.10   knitr_1.33         utf8_1.2.2         stringi_1.6.2     
[69] parallel_4.1.0     Rcpp_1.0.7         vctrs_0.3.8        rpart_4.1-15      
[73] tidyselect_1.1.1   xfun_0.24