library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.2 ✔ tibble 3.2.1
## ✔ lubridate 1.9.4 ✔ tidyr 1.3.1
## ✔ purrr 1.0.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
## ✔ broom 1.0.8 ✔ rsample 1.3.0
## ✔ dials 1.4.0 ✔ tune 1.3.0
## ✔ infer 1.0.8 ✔ workflows 1.2.0
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.3.1 ✔ yardstick 1.3.2
## ✔ recipes 1.3.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
library(ISLR)
data(Hitters)
Hitters <- Hitters %>% drop_na()
set.seed(123)
hitters_split <- initial_split(Hitters, prop = 0.8)
hitters_train <- training(hitters_split)
hitters_test <- testing(hitters_split)
hitters_recipe <- recipe(Salary ~ ., data = hitters_train) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
set.seed(123)
hitters_folds <- vfold_cv(hitters_train, v = 10)
ridge_grid <- grid_regular(penalty(), levels = 50)
ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>%
set_engine("glmnet")
ridge_workflow <- workflow() %>%
add_model(ridge_spec) %>%
add_recipe(hitters_recipe)
ridge_results <- tune_grid(
ridge_workflow,
resamples = hitters_folds,
grid = ridge_grid
)
ridge_results %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
ggplot(aes(x = penalty, y = mean)) +
geom_line() +
scale_x_log10() +
labs(title = "Ridge Regression", x = "Penalty", y = "RMSE")

lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
lasso_workflow <- workflow() %>%
add_model(lasso_spec) %>%
add_recipe(hitters_recipe)
lasso_results <- tune_grid(
lasso_workflow,
resamples = hitters_folds,
grid = ridge_grid
)
lasso_results %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
ggplot(aes(x = penalty, y = mean)) +
geom_line() +
scale_x_log10() +
labs(title = "Lasso Regression", x = "Penalty", y = "RMSE")

# Manually find the best penalty value based on RMSE
best_lasso <- lasso_results %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
arrange(mean) %>%
slice(1)
# Check if penalty value was extracted correctly
best_lasso
## # A tibble: 1 × 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 rmse standard 319. 10 41.2 Preprocessor1_Model50
# Pull best penalty value
best_lasso_penalty <- best_lasso %>% pull(penalty)
# Finalize the workflow with best penalty
final_lasso <- finalize_workflow(
lasso_workflow,
parameters = tibble(penalty = best_lasso_penalty)
)
# Fit the finalized model to training data
fit_lasso <- fit(final_lasso, data = hitters_train)
# Output the model
fit_lasso
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 3 Recipe Steps
##
## • step_dummy()
## • step_zv()
## • step_normalize()
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call: glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian", alpha = ~1)
##
## Df %Dev Lambda
## 1 0 0.00 295.000
## 2 1 6.70 268.800
## 3 1 12.26 244.900
## 4 2 16.98 223.200
## 5 2 20.90 203.300
## 6 3 24.61 185.300
## 7 3 28.60 168.800
## 8 4 31.98 153.800
## 9 4 35.02 140.100
## 10 4 37.55 127.700
## 11 4 39.64 116.400
## 12 4 41.38 106.000
## 13 5 42.83 96.600
## 14 5 44.05 88.020
## 15 6 45.30 80.200
## 16 6 46.49 73.070
## 17 6 47.47 66.580
## 18 6 48.29 60.670
## 19 7 49.08 55.280
## 20 7 49.73 50.370
## 21 7 50.28 45.890
## 22 7 50.73 41.810
## 23 7 51.11 38.100
## 24 7 51.42 34.710
## 25 7 51.68 31.630
## 26 8 51.91 28.820
## 27 8 52.16 26.260
## 28 8 52.38 23.930
## 29 8 52.55 21.800
## 30 8 52.70 19.870
## 31 8 52.82 18.100
## 32 9 52.94 16.490
## 33 9 53.06 15.030
## 34 10 53.36 13.690
## 35 10 53.72 12.480
## 36 10 54.00 11.370
## 37 10 54.23 10.360
## 38 10 54.42 9.438
## 39 13 54.66 8.599
## 40 14 55.04 7.835
## 41 15 55.75 7.139
## 42 15 56.36 6.505
## 43 15 56.88 5.927
## 44 14 57.30 5.401
## 45 14 57.65 4.921
## 46 14 57.95 4.484
##
## ...
## and 45 more lines.