library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom 1.0.10 ✔ recipes 1.3.1
## ✔ dials 1.4.2 ✔ rsample 1.3.1
## ✔ dplyr 1.1.4 ✔ tailor 0.1.0
## ✔ ggplot2 4.0.0 ✔ tidyr 1.3.1
## ✔ infer 1.0.9 ✔ tune 2.0.1
## ✔ modeldata 1.5.1 ✔ workflows 1.3.0
## ✔ parsnip 1.3.3 ✔ workflowsets 1.1.1
## ✔ purrr 1.2.0 ✔ yardstick 1.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
applications <- ISLR2::College
head(applications)
## Private Apps Accept Enroll Top10perc Top25perc
## Abilene Christian University Yes 1660 1232 721 23 52
## Adelphi University Yes 2186 1924 512 16 29
## Adrian College Yes 1428 1097 336 22 50
## Agnes Scott College Yes 417 349 137 60 89
## Alaska Pacific University Yes 193 146 55 16 44
## Albertson College Yes 587 479 158 38 62
## F.Undergrad P.Undergrad Outstate Room.Board Books
## Abilene Christian University 2885 537 7440 3300 450
## Adelphi University 2683 1227 12280 6450 750
## Adrian College 1036 99 11250 3750 400
## Agnes Scott College 510 63 12960 5450 450
## Alaska Pacific University 249 869 7560 4120 800
## Albertson College 678 41 13500 3335 500
## Personal PhD Terminal S.F.Ratio perc.alumni Expend
## Abilene Christian University 2200 70 78 18.1 12 7041
## Adelphi University 1500 29 30 12.2 16 10527
## Adrian College 1165 53 66 12.9 30 8735
## Agnes Scott College 875 92 97 7.7 37 19016
## Alaska Pacific University 1500 76 72 11.9 2 10922
## Albertson College 675 67 73 9.4 11 9727
## Grad.Rate
## Abilene Christian University 60
## Adelphi University 56
## Adrian College 54
## Agnes Scott College 59
## Alaska Pacific University 15
## Albertson College 55
corrplot::corrplot(cor(applications %>% select(!Private)))
lm <- parsnip::linear_reg()
lasso <- parsnip::linear_reg(mixture = 1.0, penalty = tune()) %>% set_engine(engine = "glmnet", standardize = TRUE)
ridge <- parsnip::linear_reg(mixture = 0.0, penalty = tune()) %>% set_engine(engine = "glmnet", standardize = TRUE)
recipe <- recipes::recipe(Apps ~., data=applications) %>%
recipes::step_dummy(all_nominal_predictors()) %>%
recipes::step_normalize(all_numeric_predictors())
wf_lm <- workflow() %>%
add_model(lm) %>%
add_recipe(recipe)
wf_ridge <- workflow() %>%
add_model(ridge) %>%
add_recipe(recipe)
wf_lasso <- workflow() %>%
add_model(lasso) %>%
add_recipe(recipe)
split <- initial_split(applications)
train <- training(split)
test <- testing(split)
cv <- rsample::vfold_cv(train)
cv_results_lm <- fit_resamples(
wf_lm,
resamples = cv,
metrics = yardstick::metric_set(rmse, mape, mae),
control = control_resamples(save_pred = TRUE)
)
collect_metrics(cv_results_lm)
## # A tibble: 3 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 mae standard 676. 10 33.9 pre0_mod0_post0
## 2 mape standard 40.8 10 1.80 pre0_mod0_post0
## 3 rmse standard 1177. 10 105. pre0_mod0_post0
full_lm <- wf_lm %>% fit(train)
grid_ridge <- grid_regular(
penalty(range = c(-10, 3)),
levels = 100
)
grid_lasso <- grid_regular(
penalty(range = c(-6, 3)),
levels = 100
)
cv <- rsample::vfold_cv(train)
cv_results_ridge <- wf_ridge %>%
tune_grid(
grid = grid_ridge,
resamples = cv,
metrics = yardstick::metric_set(rmse, mape, mae),
)
metrics_ridge <- collect_metrics(cv_results_ridge)
cv_results_lasso <- wf_lasso %>%
tune_grid(
grid = grid_lasso,
resamples = cv,
metrics = yardstick::metric_set(rmse, mape, mae),
)
metrics_lasso <- collect_metrics(cv_results_lasso)
show_best(cv_results_ridge, metric = "rmse")
## # A tibble: 5 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 e-10 rmse standard 1243. 10 204. pre0_mod001_post0
## 2 1.35e-10 rmse standard 1243. 10 204. pre0_mod002_post0
## 3 1.83e-10 rmse standard 1243. 10 204. pre0_mod003_post0
## 4 2.48e-10 rmse standard 1243. 10 204. pre0_mod004_post0
## 5 3.35e-10 rmse standard 1243. 10 204. pre0_mod005_post0
show_best(cv_results_lasso, metric = "rmse")
## # A tibble: 5 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.000001 rmse standard 1182. 10 114. pre0_mod001_post0
## 2 0.00000123 rmse standard 1182. 10 114. pre0_mod002_post0
## 3 0.00000152 rmse standard 1182. 10 114. pre0_mod003_post0
## 4 0.00000187 rmse standard 1182. 10 114. pre0_mod004_post0
## 5 0.00000231 rmse standard 1182. 10 114. pre0_mod005_post0
collect_metrics(cv_results_ridge)
## # A tibble: 300 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 e-10 mae standard 694. 10 46.7 pre0_mod001_post0
## 2 1 e-10 mape standard 41.3 10 2.06 pre0_mod001_post0
## 3 1 e-10 rmse standard 1243. 10 204. pre0_mod001_post0
## 4 1.35e-10 mae standard 694. 10 46.7 pre0_mod002_post0
## 5 1.35e-10 mape standard 41.3 10 2.06 pre0_mod002_post0
## 6 1.35e-10 rmse standard 1243. 10 204. pre0_mod002_post0
## 7 1.83e-10 mae standard 694. 10 46.7 pre0_mod003_post0
## 8 1.83e-10 mape standard 41.3 10 2.06 pre0_mod003_post0
## 9 1.83e-10 rmse standard 1243. 10 204. pre0_mod003_post0
## 10 2.48e-10 mae standard 694. 10 46.7 pre0_mod004_post0
## # ℹ 290 more rows
collect_metrics(cv_results_ridge) %>%
filter(.metric == "rmse") %>%
ggplot(aes(x=penalty, y=mean)) +
geom_point() +
geom_line()
collect_metrics(cv_results_ridge)
## # A tibble: 300 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 e-10 mae standard 694. 10 46.7 pre0_mod001_post0
## 2 1 e-10 mape standard 41.3 10 2.06 pre0_mod001_post0
## 3 1 e-10 rmse standard 1243. 10 204. pre0_mod001_post0
## 4 1.35e-10 mae standard 694. 10 46.7 pre0_mod002_post0
## 5 1.35e-10 mape standard 41.3 10 2.06 pre0_mod002_post0
## 6 1.35e-10 rmse standard 1243. 10 204. pre0_mod002_post0
## 7 1.83e-10 mae standard 694. 10 46.7 pre0_mod003_post0
## 8 1.83e-10 mape standard 41.3 10 2.06 pre0_mod003_post0
## 9 1.83e-10 rmse standard 1243. 10 204. pre0_mod003_post0
## 10 2.48e-10 mae standard 694. 10 46.7 pre0_mod004_post0
## # ℹ 290 more rows
collect_metrics(cv_results_ridge) %>%
filter(.metric == "rmse") %>%
filter(penalty < 250) %>%
ggplot(aes(x=penalty, y=mean)) +
geom_point() +
geom_line()
collect_metrics(cv_results_lasso) %>%
filter(.metric == "rmse") %>%
filter(penalty < 50) %>%
ggplot(aes(x=penalty, y=mean)) +
geom_point() +
geom_line()
select_best(cv_results_ridge, metric="rmse")
## # A tibble: 1 × 2
## penalty .config
## <dbl> <chr>
## 1 0.0000000001 pre0_mod001_post0
best_wf <- finalize_workflow(wf_ridge, select_best(cv_results_ridge, metric="rmse")) %>% fit(train)
best_wf %>% extract_fit_parsnip %>% tidy()
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-10
## # A tibble: 18 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 3109. 0.0000000001
## 2 Accept 2636. 0.0000000001
## 3 Enroll 337. 0.0000000001
## 4 Top10perc 557. 0.0000000001
## 5 Top25perc -9.16 0.0000000001
## 6 F.Undergrad 358. 0.0000000001
## 7 P.Undergrad 23.9 0.0000000001
## 8 Outstate -119. 0.0000000001
## 9 Room.Board 243. 0.0000000001
## 10 Books 3.75 0.0000000001
## 11 Personal 4.74 0.0000000001
## 12 PhD -88.1 0.0000000001
## 13 Terminal -72.0 0.0000000001
## 14 S.F.Ratio 75.1 0.0000000001
## 15 perc.alumni -79.4 0.0000000001
## 16 Expend 414. 0.0000000001
## 17 Grad.Rate 201. 0.0000000001
## 18 Private_Yes -263. 0.0000000001
best_wf <- finalize_workflow(wf_lasso, select_best(cv_results_lasso, metric="rmse")) %>% fit(train)
best_wf %>% extract_fit_parsnip %>% tidy()
## # A tibble: 18 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 3109. 0.000001
## 2 Accept 4083. 0.000001
## 3 Enroll -1080. 0.000001
## 4 Top10perc 1086. 0.000001
## 5 Top25perc -392. 0.000001
## 6 F.Undergrad 484. 0.000001
## 7 P.Undergrad 61.6 0.000001
## 8 Outstate -379. 0.000001
## 9 Room.Board 166. 0.000001
## 10 Books -17.1 0.000001
## 11 Personal 11.9 0.000001
## 12 PhD -161. 0.000001
## 13 Terminal -34.8 0.000001
## 14 S.F.Ratio 77.9 0.000001
## 15 perc.alumni 22.3 0.000001
## 16 Expend 418. 0.000001
## 17 Grad.Rate 161. 0.000001
## 18 Private_Yes -223. 0.000001
full_lm %>% extract_fit_parsnip() %>% tidy()
## # A tibble: 18 × 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) 3109. 45.1 68.9 7.11e-277
## 2 Accept 4123. 116. 35.5 1.16e-145
## 3 Enroll -1205. 217. -5.55 4.36e- 8
## 4 Top10perc 1121. 123. 9.11 1.46e- 18
## 5 Top25perc -425. 110. -3.84 1.35e- 4
## 6 F.Undergrad 573. 198. 2.90 3.89e- 3
## 7 P.Undergrad 61.9 58.3 1.06 2.89e- 1
## 8 Outstate -392. 92.3 -4.24 2.56e- 5
## 9 Room.Board 168. 65.2 2.58 1.01e- 2
## 10 Books -20.9 48.3 -0.433 6.65e- 1
## 11 Personal 13.4 52.4 0.255 7.99e- 1
## 12 PhD -165. 94.5 -1.74 8.16e- 2
## 13 Terminal -36.0 92.8 -0.388 6.98e- 1
## 14 S.F.Ratio 80.9 63.2 1.28 2.01e- 1
## 15 perc.alumni 31.3 61.9 0.505 6.13e- 1
## 16 Expend 422. 80.6 5.24 2.29e- 7
## 17 Grad.Rate 168. 61.1 2.76 6.03e- 3
## 18 Private_Yes -224. 75.9 -2.95 3.28e- 3
summary(full_lm %>% extract_fit_parsnip() %>% extract_fit_engine())
##
## Call:
## stats::lm(formula = ..y ~ ., data = data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -4901.3 -446.8 -2.0 353.4 7284.2
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 3109.14 45.11 68.917 < 2e-16 ***
## Accept 4122.67 116.29 35.453 < 2e-16 ***
## Enroll -1205.09 217.06 -5.552 4.36e-08 ***
## Top10perc 1120.57 123.04 9.107 < 2e-16 ***
## Top25perc -424.60 110.44 -3.845 0.000135 ***
## F.Undergrad 573.05 197.70 2.899 0.003893 **
## P.Undergrad 61.87 58.33 1.061 0.289295
## Outstate -391.69 92.29 -4.244 2.56e-05 ***
## Room.Board 168.30 65.22 2.580 0.010122 *
## Books -20.89 48.26 -0.433 0.665239
## Personal 13.38 52.40 0.255 0.798562
## PhD -164.81 94.46 -1.745 0.081558 .
## Terminal -36.01 92.78 -0.388 0.698038
## S.F.Ratio 80.91 63.24 1.279 0.201307
## perc.alumni 31.26 61.86 0.505 0.613472
## Expend 422.29 80.61 5.239 2.29e-07 ***
## Grad.Rate 168.31 61.06 2.756 0.006035 **
## Private_Yes -224.19 75.91 -2.953 0.003275 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1088 on 564 degrees of freedom
## Multiple R-squared: 0.9302, Adjusted R-squared: 0.9281
## F-statistic: 442.2 on 17 and 564 DF, p-value: < 2.2e-16
Lasso logra un buen rmse, descartando 3 variables Ridge: no logro que la curva del lambda tenga una forma de U; que característica del dataset deberia estar mirando