library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.2     ✔ tibble    3.2.1
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom        1.0.8     ✔ rsample      1.3.0
## ✔ dials        1.4.0     ✔ tune         1.3.0
## ✔ infer        1.0.8     ✔ workflows    1.2.0
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.3.1     ✔ yardstick    1.3.2
## ✔ recipes      1.3.0     
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter()   masks stats::filter()
## ✖ recipes::fixed()  masks stringr::fixed()
## ✖ dplyr::lag()      masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step()   masks stats::step()
library(ISLR)
data(Hitters)
Hitters <- Hitters %>% drop_na()
set.seed(123)
hitters_split <- initial_split(Hitters, prop = 0.8)
hitters_train <- training(hitters_split)
hitters_test <- testing(hitters_split)
hitters_recipe <- recipe(Salary ~ ., data = hitters_train) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_predictors())
set.seed(123)
hitters_folds <- vfold_cv(hitters_train, v = 10)
ridge_grid <- grid_regular(penalty(), levels = 50)
ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>%
  set_engine("glmnet")
ridge_workflow <- workflow() %>%
  add_model(ridge_spec) %>%
  add_recipe(hitters_recipe)

ridge_results <- tune_grid(
  ridge_workflow,
  resamples = hitters_folds,
  grid = ridge_grid
)

ridge_results %>%
  collect_metrics() %>%
  filter(.metric == "rmse") %>%
  ggplot(aes(x = penalty, y = mean)) +
  geom_line() +
  scale_x_log10() +
  labs(title = "Ridge Regression", x = "Penalty", y = "RMSE")

lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")

lasso_workflow <- workflow() %>%
  add_model(lasso_spec) %>%
  add_recipe(hitters_recipe)

lasso_results <- tune_grid(
  lasso_workflow,
  resamples = hitters_folds,
  grid = ridge_grid
)

lasso_results %>%
  collect_metrics() %>%
  filter(.metric == "rmse") %>%
  ggplot(aes(x = penalty, y = mean)) +
  geom_line() +
  scale_x_log10() +
  labs(title = "Lasso Regression", x = "Penalty", y = "RMSE")

# Manually find the best penalty value based on RMSE
best_lasso <- lasso_results %>%
  collect_metrics() %>%
  filter(.metric == "rmse") %>%
  arrange(mean) %>%
  slice(1)

# Check if penalty value was extracted correctly
best_lasso
## # A tibble: 1 × 7
##   penalty .metric .estimator  mean     n std_err .config              
##     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1       1 rmse    standard    319.    10    41.2 Preprocessor1_Model50
# Pull best penalty value
best_lasso_penalty <- best_lasso %>% pull(penalty)

# Finalize the workflow with best penalty
final_lasso <- finalize_workflow(
  lasso_workflow,
  parameters = tibble(penalty = best_lasso_penalty)
)

# Fit the finalized model to training data
fit_lasso <- fit(final_lasso, data = hitters_train)

# Output the model
fit_lasso
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 3 Recipe Steps
## 
## • step_dummy()
## • step_zv()
## • step_normalize()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian",      alpha = ~1) 
## 
##    Df  %Dev  Lambda
## 1   0  0.00 295.000
## 2   1  6.70 268.800
## 3   1 12.26 244.900
## 4   2 16.98 223.200
## 5   2 20.90 203.300
## 6   3 24.61 185.300
## 7   3 28.60 168.800
## 8   4 31.98 153.800
## 9   4 35.02 140.100
## 10  4 37.55 127.700
## 11  4 39.64 116.400
## 12  4 41.38 106.000
## 13  5 42.83  96.600
## 14  5 44.05  88.020
## 15  6 45.30  80.200
## 16  6 46.49  73.070
## 17  6 47.47  66.580
## 18  6 48.29  60.670
## 19  7 49.08  55.280
## 20  7 49.73  50.370
## 21  7 50.28  45.890
## 22  7 50.73  41.810
## 23  7 51.11  38.100
## 24  7 51.42  34.710
## 25  7 51.68  31.630
## 26  8 51.91  28.820
## 27  8 52.16  26.260
## 28  8 52.38  23.930
## 29  8 52.55  21.800
## 30  8 52.70  19.870
## 31  8 52.82  18.100
## 32  9 52.94  16.490
## 33  9 53.06  15.030
## 34 10 53.36  13.690
## 35 10 53.72  12.480
## 36 10 54.00  11.370
## 37 10 54.23  10.360
## 38 10 54.42   9.438
## 39 13 54.66   8.599
## 40 14 55.04   7.835
## 41 15 55.75   7.139
## 42 15 56.36   6.505
## 43 15 56.88   5.927
## 44 14 57.30   5.401
## 45 14 57.65   4.921
## 46 14 57.95   4.484
## 
## ...
## and 45 more lines.