library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom        1.0.7     ✔ rsample      1.2.1
## ✔ dials        1.4.0     ✔ tune         1.3.0
## ✔ infer        1.0.7     ✔ workflows    1.2.0
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.3.1     ✔ yardstick    1.3.2
## ✔ recipes      1.2.1     
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
library(ISLR)
library(glmnet)
## Loading required package: Matrix
## 
## Attaching package: 'Matrix'
## 
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## 
## Loaded glmnet 4.1-8
data(Hitters)
Hitters <- Hitters %>% drop_na()

set.seed(123)
data_split <- initial_split(Hitters, prop = 0.75)
train_data <- training(data_split)
test_data <- testing(data_split)

hitters_recipe <- recipe(Salary ~ ., data = train_data) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_predictors())
ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>%
  set_engine("glmnet")

ridge_workflow <- workflow() %>%
  add_model(ridge_spec) %>%
  add_recipe(hitters_recipe)

set.seed(234)
folds <- vfold_cv(train_data, v = 10)
lambda_grid <- grid_regular(penalty(), levels = 50)

ridge_results <- tune_grid(
  ridge_workflow,
  resamples = folds,
  grid = lambda_grid
)

autoplot(ridge_results) + ggtitle("Ridge Regression - RMSE vs Penalty")

best_ridge <- select_best(x = ridge_results, metric = "rmse")
final_ridge <- finalize_workflow(ridge_workflow, best_ridge)
ridge_fit <- fit(final_ridge, data = train_data)
ridge_fit %>% 
  pull_workflow_fit() %>% 
  tidy() %>% 
  filter(term != "(Intercept)") %>%
  ggplot(aes(x = reorder(term, estimate), y = estimate)) +
  geom_col() +
  coord_flip() +
  ggtitle("Ridge Coefficients")
## Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
## ℹ Please use `extract_fit_parsnip()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")

lasso_workflow <- workflow() %>%
  add_model(lasso_spec) %>%
  add_recipe(hitters_recipe)

lasso_results <- tune_grid(
  lasso_workflow,
  resamples = folds,
  grid = lambda_grid
)

autoplot(lasso_results) + ggtitle("Lasso Regression - RMSE vs Penalty")

best_lasso <- select_best(x = lasso_results, metric = "rmse")
final_lasso <- finalize_workflow(lasso_workflow, best_lasso)
lasso_fit <- fit(final_lasso, data = train_data)
lasso_fit %>% 
  pull_workflow_fit() %>% 
  tidy() %>% 
  filter(term != "(Intercept)") %>%
  ggplot(aes(x = reorder(term, estimate), y = estimate)) +
  geom_col() +
  coord_flip() +
  ggtitle("Lasso Coefficients")