library(ISLR)
library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-8
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.2 ✔ tibble 3.2.1
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ tidyr::expand() masks Matrix::expand()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ tidyr::pack() masks Matrix::pack()
## ✖ tidyr::unpack() masks Matrix::unpack()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom 1.0.8 ✔ rsample 1.3.0
## ✔ dials 1.4.0 ✔ tune 1.3.0
## ✔ infer 1.0.8 ✔ workflows 1.2.0
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.3.1 ✔ yardstick 1.3.2
## ✔ recipes 1.3.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ tidyr::expand() masks Matrix::expand()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ tidyr::pack() masks Matrix::pack()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## ✖ tidyr::unpack() masks Matrix::unpack()
## ✖ recipes::update() masks Matrix::update(), stats::update()
2. Data preparation
data(Hitters)
Hitters <- Hitters %>% drop_na()
set.seed(123)
data_split <- initial_split(Hitters, prop = 0.75)
train_data <- training(data_split)
test_data <- testing(data_split)
hitters_recipe <- recipe(Salary ~ ., data = train_data) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
3. Ridge Regression (Section 6.4)
ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>%
set_engine("glmnet")
ridge_workflow <- workflow() %>%
add_model(ridge_spec) %>%
add_recipe(hitters_recipe)
set.seed(234)
folds <- vfold_cv(train_data, v = 10)
lambda_grid <- grid_regular(penalty(), levels = 50)
ridge_results <- tune_grid(
ridge_workflow,
resamples = folds,
grid = lambda_grid
)
autoplot(ridge_results) + ggtitle("Ridge Regression - RMSE vs Penalty")

best_ridge <- select_best(x = ridge_results, metric = "rmse")
final_ridge <- finalize_workflow(ridge_workflow, best_ridge)
ridge_fit <- fit(final_ridge, data = train_data)
Ridge Coefficients
ridge_fit %>%
pull_workflow_fit() %>%
tidy() %>%
filter(term != "(Intercept)") %>%
ggplot(aes(x = reorder(term, estimate), y = estimate)) +
geom_col() +
coord_flip() +
ggtitle("Ridge Coefficients")
## Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
## ℹ Please use `extract_fit_parsnip()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

4. Lasso Regression (Section 6.5)
lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
lasso_workflow <- workflow() %>%
add_model(lasso_spec) %>%
add_recipe(hitters_recipe)
lasso_results <- tune_grid(
lasso_workflow,
resamples = folds,
grid = lambda_grid
)
autoplot(lasso_results) + ggtitle("Lasso Regression - RMSE vs Penalty")

best_lasso <- select_best(x = lasso_results, metric = "rmse")
final_lasso <- finalize_workflow(lasso_workflow, best_lasso)
lasso_fit <- fit(final_lasso, data = train_data)
Lasso Coefficients
lasso_fit %>%
pull_workflow_fit() %>%
tidy() %>%
filter(term != "(Intercept)") %>%
ggplot(aes(x = reorder(term, estimate), y = estimate)) +
geom_col() +
coord_flip() +
ggtitle("Lasso Coefficients")
