library(ISLR)
library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-8
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.2     ✔ tibble    3.2.1
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ tidyr::expand() masks Matrix::expand()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ✖ tidyr::pack()   masks Matrix::pack()
## ✖ tidyr::unpack() masks Matrix::unpack()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom        1.0.8     ✔ rsample      1.3.0
## ✔ dials        1.4.0     ✔ tune         1.3.0
## ✔ infer        1.0.8     ✔ workflows    1.2.0
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.3.1     ✔ yardstick    1.3.2
## ✔ recipes      1.3.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ tidyr::expand()   masks Matrix::expand()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ tidyr::pack()     masks Matrix::pack()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
## ✖ tidyr::unpack()   masks Matrix::unpack()
## ✖ recipes::update() masks Matrix::update(), stats::update()

2. Data preparation

data(Hitters)
Hitters <- Hitters %>% drop_na()

set.seed(123)
data_split <- initial_split(Hitters, prop = 0.75)
train_data <- training(data_split)
test_data <- testing(data_split)

hitters_recipe <- recipe(Salary ~ ., data = train_data) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_predictors())

3. Ridge Regression (Section 6.4)

ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>%
  set_engine("glmnet")

ridge_workflow <- workflow() %>%
  add_model(ridge_spec) %>%
  add_recipe(hitters_recipe)

set.seed(234)
folds <- vfold_cv(train_data, v = 10)
lambda_grid <- grid_regular(penalty(), levels = 50)

ridge_results <- tune_grid(
  ridge_workflow,
  resamples = folds,
  grid = lambda_grid
)

autoplot(ridge_results) + ggtitle("Ridge Regression - RMSE vs Penalty")

best_ridge <- select_best(x = ridge_results, metric = "rmse")
final_ridge <- finalize_workflow(ridge_workflow, best_ridge)
ridge_fit <- fit(final_ridge, data = train_data)

Ridge Coefficients

ridge_fit %>% 
  pull_workflow_fit() %>% 
  tidy() %>% 
  filter(term != "(Intercept)") %>%
  ggplot(aes(x = reorder(term, estimate), y = estimate)) +
  geom_col() +
  coord_flip() +
  ggtitle("Ridge Coefficients")
## Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
## ℹ Please use `extract_fit_parsnip()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

4. Lasso Regression (Section 6.5)

lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")

lasso_workflow <- workflow() %>%
  add_model(lasso_spec) %>%
  add_recipe(hitters_recipe)

lasso_results <- tune_grid(
  lasso_workflow,
  resamples = folds,
  grid = lambda_grid
)

autoplot(lasso_results) + ggtitle("Lasso Regression - RMSE vs Penalty")

best_lasso <- select_best(x = lasso_results, metric = "rmse")
final_lasso <- finalize_workflow(lasso_workflow, best_lasso)
lasso_fit <- fit(final_lasso, data = train_data)

Lasso Coefficients

lasso_fit %>% 
  pull_workflow_fit() %>% 
  tidy() %>% 
  filter(term != "(Intercept)") %>%
  ggplot(aes(x = reorder(term, estimate), y = estimate)) +
  geom_col() +
  coord_flip() +
  ggtitle("Lasso Coefficients")