library(tidymodels)
library(ISLR)
library(ggcorrplot)
theme_set(theme_bw())
wage <- tibble(Wage)
glimpse(wage)
Rows: 3,000
Columns: 11
$ year <int> 2006, 2004, 2003, 2003, 2005, 2008, 2009, 2008, 2006, 2004,~
$ age <int> 18, 24, 45, 43, 50, 54, 44, 30, 41, 52, 45, 34, 35, 39, 54,~
$ maritl <fct> 1. Never Married, 1. Never Married, 2. Married, 2. Married,~
$ race <fct> 1. White, 1. White, 1. White, 3. Asian, 1. White, 1. White,~
$ education <fct> 1. < HS Grad, 4. College Grad, 3. Some College, 4. College ~
$ region <fct> 2. Middle Atlantic, 2. Middle Atlantic, 2. Middle Atlantic,~
$ jobclass <fct> 1. Industrial, 2. Information, 1. Industrial, 2. Informatio~
$ health <fct> 1. <=Good, 2. >=Very Good, 1. <=Good, 2. >=Very Good, 1. <=~
$ health_ins <fct> 2. No, 2. No, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Ye~
$ logwage <dbl> 4.318063, 4.255273, 4.875061, 5.041393, 4.318063, 4.845098,~
$ wage <dbl> 75.04315, 70.47602, 130.98218, 154.68529, 75.04315, 127.115~
summary(wage)
year age maritl race
Min. :2003 Min. :18.00 1. Never Married: 648 1. White:2480
1st Qu.:2004 1st Qu.:33.75 2. Married :2074 2. Black: 293
Median :2006 Median :42.00 3. Widowed : 19 3. Asian: 190
Mean :2006 Mean :42.41 4. Divorced : 204 4. Other: 37
3rd Qu.:2008 3rd Qu.:51.00 5. Separated : 55
Max. :2009 Max. :80.00
education region jobclass
1. < HS Grad :268 2. Middle Atlantic :3000 1. Industrial :1544
2. HS Grad :971 1. New England : 0 2. Information:1456
3. Some College :650 3. East North Central: 0
4. College Grad :685 4. West North Central: 0
5. Advanced Degree:426 5. South Atlantic : 0
6. East South Central: 0
(Other) : 0
health health_ins logwage wage
1. <=Good : 858 1. Yes:2083 Min. :3.000 Min. : 20.09
2. >=Very Good:2142 2. No : 917 1st Qu.:4.447 1st Qu.: 85.38
Median :4.653 Median :104.92
Mean :4.654 Mean :111.70
3rd Qu.:4.857 3rd Qu.:128.68
Max. :5.763 Max. :318.34
ggplot(wage, aes(wage))+
geom_histogram(fill = "lightgrey", color = "black")
ggplot(wage, aes(age))+
geom_histogram(fill = "lightgrey", color = "black")
with(wage, cor.test(wage, age))
Pearson's product-moment correlation
data: wage and age
t = 10.923, df = 2998, p-value < 2.2e-16
alternative hypothesis: true correlation is not equal to 0
95 percent confidence interval:
0.1609777 0.2298147
sample estimates:
cor
0.1956372
ggplot(wage, aes(x = age, y = wage)) +
geom_point()
ggplot(wage, aes(wage, fill= race))+
geom_boxplot() +
coord_flip() +
theme(legend.position = "top")
ggplot(wage, aes(wage, color = jobclass)) +
geom_density()
wage %>%
group_by(jobclass) %>%
summarise(shapiro_test = shapiro.test(wage)$p.value)
# A tibble: 2 x 2
jobclass shapiro_test
<fct> <dbl>
1 1. Industrial 9.91e-29
2 2. Information 2.90e-33
ggplot(wage, aes(wage, fill = race))+
geom_histogram() +
facet_wrap(. ~ jobclass)
t_test(wage, wage ~ jobclass)
# A tibble: 1 x 6
statistic t_df p_value alternative lower_ci upper_ci
<dbl> <dbl> <dbl> <chr> <dbl> <dbl>
1 -11.5 2715. 7.27e-30 two.sided -20.2 -14.3
poly_rec <-
recipe(wage ~ age, data = wage) %>%
step_poly(age, degree = 4)
lm_model <-
linear_reg() %>%
set_engine("lm")
poly_wf <-
workflow() %>%
add_recipe(poly_rec) %>%
add_model(lm_model)
poly_fit <- fit(poly_wf, data = wage)
tidy(poly_fit)
# A tibble: 5 x 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 112. 0.729 153. 0
2 age_poly_1 447. 39.9 11.2 1.48e-28
3 age_poly_2 -478. 39.9 -12.0 2.36e-32
4 age_poly_3 126. 39.9 3.14 1.68e- 3
5 age_poly_4 -77.9 39.9 -1.95 5.10e- 2
# this is a toy example
wage %>%
ggplot(aes(age, wage)) +
geom_point(alpha = .2) +
geom_smooth()
Let us take that one step further and see what happens to the regression line once we go past the domain it was trained on. the previous plot showed individuals within the age range 18-80. Let us see what happens once we push this to 18-100. This is not an impossible range but an unrealistic range.
wide_age_range <- tibble(age = seq(18, 100))
regression_lines <- bind_cols(
augment(poly_fit, new_data = wide_age_range),
predict(poly_fit, new_data = wide_age_range, type = "conf_int")
)
wage %>%
ggplot(aes(age, wage)) +
geom_point(alpha = 0.2) +
geom_line(aes(y = .pred), color = "darkgreen",
data = regression_lines) +
geom_line(aes(y = .pred_lower), data = regression_lines,
linetype = "dashed", color = "blue") +
geom_line(aes(y = .pred_upper), data = regression_lines,
linetype = "dashed", color = "blue")
And we see that the curve starts diverging once we get to 93 the predicted wage is negative. The confidence bands also get wider and wider as we get farther away from the data.
wage <- wage %>%
mutate(salary = factor(wage > 250,
levels = c(TRUE, FALSE),
labels = c("High", "Low")))
glimpse(wage)
Rows: 3,000
Columns: 12
$ year <int> 2006, 2004, 2003, 2003, 2005, 2008, 2009, 2008, 2006, 2004,~
$ age <int> 18, 24, 45, 43, 50, 54, 44, 30, 41, 52, 45, 34, 35, 39, 54,~
$ maritl <fct> 1. Never Married, 1. Never Married, 2. Married, 2. Married,~
$ race <fct> 1. White, 1. White, 1. White, 3. Asian, 1. White, 1. White,~
$ education <fct> 1. < HS Grad, 4. College Grad, 3. Some College, 4. College ~
$ region <fct> 2. Middle Atlantic, 2. Middle Atlantic, 2. Middle Atlantic,~
$ jobclass <fct> 1. Industrial, 2. Information, 1. Industrial, 2. Informatio~
$ health <fct> 1. <=Good, 2. >=Very Good, 1. <=Good, 2. >=Very Good, 1. <=~
$ health_ins <fct> 2. No, 2. No, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Yes, 1. Ye~
$ logwage <dbl> 4.318063, 4.255273, 4.875061, 5.041393, 4.318063, 4.845098,~
$ wage <dbl> 75.04315, 70.47602, 130.98218, 154.68529, 75.04315, 127.115~
$ salary <fct> Low, Low, Low, Low, Low, Low, Low, Low, Low, Low, Low, Low,~
# Class imbalance!
wage %>%
group_by(salary) %>%
count()
# A tibble: 2 x 2
# Groups: salary [2]
salary n
<fct> <int>
1 High 79
2 Low 2921
glm_rec <-
recipe(salary ~ age, data = wage) %>%
step_poly(age, degree = 4)
glm_model <-
logistic_reg() %>%
set_engine("glm")
glm_wf <-
workflow() %>%
add_recipe(glm_rec) %>%
add_model(glm_model)
poly_glm_fit <- fit(glm_wf, wage)
tidy(poly_glm_fit)
# A tibble: 5 x 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 4.30 0.345 12.5 1.16e-35
2 age_poly_1 -72.0 26.1 -2.76 5.86e- 3
3 age_poly_2 85.8 35.9 2.39 1.69e- 2
4 age_poly_3 -34.2 19.7 -1.74 8.27e- 2
5 age_poly_4 47.4 24.1 1.97 4.91e- 2
predict(poly_glm_fit, wage)
# A tibble: 3,000 x 1
.pred_class
<fct>
1 Low
2 Low
3 Low
4 Low
5 Low
6 Low
7 Low
8 Low
9 Low
10 Low
# ... with 2,990 more rows
predict(poly_glm_fit, wage, type= "prob")
# A tibble: 3,000 x 2
.pred_High .pred_Low
<dbl> <dbl>
1 0.00000000983 1.00
2 0.000120 1.00
3 0.0307 0.969
4 0.0320 0.968
5 0.0305 0.970
6 0.0352 0.965
7 0.0313 0.969
8 0.00820 0.992
9 0.0334 0.967
10 0.0323 0.968
# ... with 2,990 more rows
Next, let us take a look at the step function and how to fit a model using it as a preprocessor. You can create step functions in a couple of different ways. step_discretize() will convert a numeric variable into a factor variable with n bins, n here is specified with num_breaks. These will have approximately the same number of points in them according to the training data set.
discret_rec <-
recipe(salary ~ age, data = wage) %>%
step_discretize(age, num_breaks = 4)
discret_wf <-
workflow() %>%
add_recipe(discret_rec) %>%
add_model(glm_model)
discret_fit <- fit(discret_wf, wage)
tidy(discret_fit)
# A tibble: 4 x 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 5.00 0.449 11.2 6.88e-29
2 agebin2 -1.49 0.498 -3.00 2.73e- 3
3 agebin3 -1.67 0.488 -3.41 6.45e- 4
4 agebin4 -1.71 0.494 -3.45 5.64e- 4
If you already know where you want the step function to break then you can use step_cut() and supply the breaks manually.
cut_rec <- recipe(salary ~ age, data = wage) %>%
step_cut(age, breaks = c(30, 50, 70))
cut_wf <- workflow() %>%
add_model(glm_model) %>%
add_recipe(cut_rec)
cut_fit <- fit(cut_wf, data = wage)
tidy(cut_fit)
# A tibble: 4 x 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 6.26 1.00 6.25 4.11e-10
2 age(30,50] -2.75 1.01 -2.72 6.62e- 3
3 age(50,70] -3.04 1.02 -2.98 2.88e- 3
4 age(70,80] 10.3 446. 0.0231 9.82e- 1
In order to fit regression splines, or in other words, use splines as preprocessors when fitting a linear model, we use step_bs() to construct the matrices of basis functions. The bs() function is used and arguments such as knots can be passed to bs() by using passing a named list to options.
spline_rec <- recipe(wage ~ age, data = wage) %>%
step_bs(age, options = list(knots = 25, 40, 60))
spline_wf <- workflow() %>%
add_model(lm_model) %>%
add_recipe(spline_rec)
spline_fit <- fit(spline_wf, data = wage)
predict(spline_fit, new_data = wage)
# A tibble: 3,000 x 1
.pred
<dbl>
1 58.7
2 84.3
3 120.
4 120.
5 120.
6 119.
7 120.
8 102.
9 119.
10 120.
# ... with 2,990 more rows
An Introduction to Statistcial Learning
– END
sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)
Matrix products: default
locale:
[1] LC_COLLATE=Spanish_Mexico.1252 LC_CTYPE=Spanish_Mexico.1252
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C
[5] LC_TIME=Spanish_Mexico.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] ggcorrplot_0.1.3 ISLR_1.2 yardstick_0.0.8 workflowsets_0.0.2
[5] workflows_0.2.3 tune_0.1.6 tidyr_1.1.3 tibble_3.1.3
[9] rsample_0.1.0 recipes_0.1.16 purrr_0.3.4 parsnip_0.1.7
[13] modeldata_0.1.1 infer_0.5.4 ggplot2_3.3.5 dplyr_1.0.7
[17] dials_0.0.9 scales_1.1.1 broom_0.7.8 tidymodels_0.1.3
loaded via a namespace (and not attached):
[1] sass_0.4.0 jsonlite_1.7.2 splines_4.1.0 foreach_1.5.1
[5] prodlim_2019.11.13 bslib_0.2.5.1 assertthat_0.2.1 highr_0.9
[9] GPfit_1.0-8 yaml_2.2.1 globals_0.14.0 ipred_0.9-11
[13] pillar_1.6.2 backports_1.2.1 lattice_0.20-44 glue_1.4.2
[17] pROC_1.17.0.1 digest_0.6.27 hardhat_0.1.6 colorspace_2.0-2
[21] plyr_1.8.6 htmltools_0.5.1.1 Matrix_1.3-3 timeDate_3043.102
[25] pkgconfig_2.0.3 lhs_1.1.1 DiceDesign_1.9 listenv_0.8.0
[29] gower_0.2.2 lava_1.6.9 mgcv_1.8-35 farver_2.1.0
[33] generics_0.1.0 ellipsis_0.3.2 withr_2.4.2 furrr_0.2.3
[37] nnet_7.3-16 cli_3.0.0 survival_3.2-11 magrittr_2.0.1
[41] crayon_1.4.1 evaluate_0.14 future_1.21.0 fansi_0.5.0
[45] parallelly_1.26.1 nlme_3.1-152 MASS_7.3-54 class_7.3-19
[49] tools_4.1.0 lifecycle_1.0.0 stringr_1.4.0 munsell_0.5.0
[53] compiler_4.1.0 jquerylib_0.1.4 rlang_0.4.11 grid_4.1.0
[57] rstudioapi_0.13 iterators_1.0.13 labeling_0.4.2 rmarkdown_2.9
[61] gtable_0.3.0 codetools_0.2-18 DBI_1.1.1 R6_2.5.0
[65] lubridate_1.7.10 knitr_1.33 utf8_1.2.2 stringi_1.6.2
[69] parallel_4.1.0 Rcpp_1.0.7 vctrs_0.3.8 rpart_4.1-15
[73] tidyselect_1.1.1 xfun_0.24