While the glmnet
package is the workhorse behind these methods in R
,
directly using it can sometimes be a bit cumbersome for tasks like
hyperparameter tuning and cross-validation. This is where the
caret
(Classification And REgression Training) package
shines! caret provides a
unified interface for training a vast array of machine learning models,
including Lasso and Ridge, making the entire process much more efficient
and robust.
Both Lasso (Least Absolute Shrinkage and Selection Operator) and Ridge regression are extensions of ordinary least squares (OLS) regression that add a penalty term to the loss function during model fitting. This penalty shrinks the regression coefficients towards zero.
Ridge Regression (L2 Penalty): Adds a penalty proportional to the sum of the squared magnitudes of the coefficients. The penalty term is given by: \[ \lambda \sum_{j=1}^{p} \beta_j^2 \]
Effect: Shrinks coefficients towards zero, but rarely to exactly zero. It’s good for handling multicollinearity by spreading the impact of correlated predictors.
Utility: Reduces variance, especially when predictors are highly correlated.
Lasso Regression (L1 Penalty): Adds a penalty
proportional to the sum of the absolute magnitudes of the coefficients.
The penalty term is given by:
\[
\lambda \sum_{j=1}^{p} |\beta_j|
\]
Effect: Shrinks coefficients towards zero, and can shrink some coefficients exactly to zero, effectively performing feature selection.
Utility: Useful for models with many predictors where only a subset are truly relevant. It leads to sparser models.
The strength of this penalty is controlled by a tuning parameter,
often denoted as \(\lambda\)
(lambda). In glmnet
, the alpha
parameter controls the mix between Ridge and Lasso, where
alpha = 0
is pure Ridge, alpha = 1
is pure
Lasso, and values between 0 and 1 represent Elastic Net. Finding the
optimal \(\lambda\) is crucial, and
this is typically done via cross-validation.
First, let’s load the necessary packages. If you don’t have them
installed, you’ll need to run
install.packages("package_name")
.
# Install packages if you haven't already
# install.packages("glmnet")
# install.packages("caret")
# install.packages("Matrix") # Often a dependency for glmnet
# install.packages("mlbench") # For the BostonHousing dataset
# Load necessary libraries
library(glmnet) # The core package for Lasso and Ridge
library(caret) # For streamlined model training and tuning
library(Matrix) # For sparse matrix operations, often used by glmnet
library(dplyr) # For data manipulation
library(ggplot2) # For plotting
library(mlbench) # To load the BostonHousing dataset
mlbench
We will use the BostonHousing
dataset from the mlbench package.
This dataset contains
various housing-related features and the median value of owner-occupied
homes (medv) in Boston suburbs.
# Load the BostonHousing dataset
data(BostonHousing)
# Inspect the data
head(BostonHousing)
str(BostonHousing)
'data.frame': 506 obs. of 14 variables:
$ crim : num 0.00632 0.02731 0.02729 0.03237 0.06905 ...
$ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
$ indus : num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
$ chas : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 1 1 1 ...
$ nox : num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
$ rm : num 6.58 6.42 7.18 7 7.15 ...
$ age : num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
$ dis : num 4.09 4.97 4.97 6.06 6.06 ...
$ rad : num 1 2 2 3 3 3 5 5 5 5 ...
$ tax : num 296 242 242 222 222 222 311 311 311 311 ...
$ ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
$ b : num 397 397 393 395 397 ...
$ lstat : num 4.98 9.14 4.03 2.94 5.33 ...
$ medv : num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
# The response variable is 'medv' (median value of owner-occupied homes)
# All other columns will be used as predictors.
data_df <- BostonHousing
glmnet
(Directly)Before diving into caret
, let’s briefly see how you
would fit these models directly using glmnet
. This will
help us appreciate caret
’s abstraction.
When using glmnet
, you typically need to:
Prepare your data as a matrix for predictors (x
) and
a vector for the response (y
).
Specify alpha (0 for Ridge, 1 for Lasso).
Use cv.glmnet
to perform cross-validation to find
the optimal lambda
.
# Prepare data for glmnet
# Exclude the 'medv' column from predictors and use it as the response
x_matrix <- as.matrix(data_df %>% select(-medv))
y_vector <- data_df$medv
# --- Ridge Regression with glmnet ---
# alpha = 0 for Ridge
ridge_glmnet <- cv.glmnet(x_matrix, y_vector, alpha = 0, standardize = TRUE)
# Plot the cross-validation curve
plot(ridge_glmnet)
# Use base R title function for glmnet plots as they are base R plots
title("Ridge Regression CV Curve (glmnet)", line = 2.5)
# Get the optimal lambda (lambda.min)
cat("Optimal lambda for Ridge (glmnet):", ridge_glmnet$lambda.min, "\n")
Optimal lambda for Ridge (glmnet): 0.6777654
# Get coefficients at optimal lambda
coef(ridge_glmnet, s = "lambda.min")
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 28.001475824
crim -0.087572712
zn 0.032681030
indus -0.038003639
chas 2.899781645
nox -11.913360479
rm 4.011308385
age -0.003731470
dis -1.118874607
rad 0.153730052
tax -0.005751054
ptratio -0.854984614
b 0.009073740
lstat -0.472423800
# --- Lasso Regression with glmnet ---
# alpha = 1 for Lasso
lasso_glmnet <- cv.glmnet(x_matrix, y_vector, alpha = 1, standardize = TRUE)
# Plot the cross-validation curve
plot(lasso_glmnet)
# Use base R title function for glmnet plots as they are base R plots
title("Lasso Regression CV Curve (glmnet)", line = 2.5)
# Get the optimal lambda (lambda.min)
cat("Optimal lambda for Lasso (glmnet):", lasso_glmnet$lambda.min, "\n")
Optimal lambda for Lasso (glmnet): 0.02118502
# Get coefficients at optimal lambda
coef(lasso_glmnet, s = "lambda.min")
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 34.880894915
crim -0.100714832
zn 0.042486737
indus .
chas 2.693903097
nox -16.562664331
rm 3.851646315
age .
dis -1.419168850
rad 0.263725830
tax -0.010286456
ptratio -0.933927773
b 0.009089735
lstat -0.522521473
While glmnet
is powerful, you manually manage the
alpha
parameter, interpret the CV plots, and extract
coefficients. For more complex workflows (e.g., comparing multiple
models, pre-processing, feature engineering), this can become
tedious.
caret
Now, let’s see how caret
simplifies this process.
caret
provides a unified train()
function that
handles:
Data Splitting: Creating training and testing sets.
Pre-processing: (e.g., centering, scaling).
Resampling: Cross-validation or bootstrapping.
Hyperparameter Tuning: Searching for the best lambda (and alpha for Elastic Net).
Model Training: Fitting the model.
Performance Evaluation: Calculating metrics like RMSE, R-squared.
caret
uses the glmnet
package internally
for Lasso and Ridge, but it wraps the functionality in a user-friendly
way.
First, we define our training control, which specifies the resampling method (e.g., k-fold cross-validation).
# Define training control for cross-validation
# We'll use 10-fold cross-validation, repeated 3 times
fitControl <- trainControl(
method = "repeatedcv",
number = 10,
repeats = 3,
verboseIter = TRUE # Show training progress
)
caret
For Ridge regression, caret
’s glmnet
method
requires alpha
to be fixed at 0. caret
will
then tune \(\lambda\) for us.
# --- Ridge Regression with caret ---
# The 'glmnet' method in caret can handle Ridge (alpha=0) and Lasso (alpha=1)
# For Ridge, we specify a tuneGrid where alpha is fixed at 0.
# caret will then search for the best lambda within the specified range.
# Define a grid for lambda (caret will choose from these)
# It's good practice to provide a range of lambda values, often on a log scale.
# caret will automatically select an appropriate range if tuneGrid is not provided,
# but explicit control can be useful.
ridgeGrid <- expand.grid(alpha = 0,
lambda = 10^seq(-3, 1, length = 100)) # A range of lambda values
cat("\n--- Training Ridge Regression with caret ---\n")
--- Training Ridge Regression with caret ---
ridge_caret <- train(
medv ~ ., # Formula interface: response ~ all predictors
data = data_df, # Our data frame
method = "glmnet", # Specify the glmnet model
tuneGrid = ridgeGrid, # Our custom tuning grid for alpha and lambda
trControl = fitControl, # Our defined cross-validation control
preProcess = c("center", "scale") # Center and scale predictors (important for regularization)
)
+ Fold01.Rep1: alpha=0, lambda=10
- Fold01.Rep1: alpha=0, lambda=10
+ Fold02.Rep1: alpha=0, lambda=10
- Fold02.Rep1: alpha=0, lambda=10
+ Fold03.Rep1: alpha=0, lambda=10
- Fold03.Rep1: alpha=0, lambda=10
+ Fold04.Rep1: alpha=0, lambda=10
- Fold04.Rep1: alpha=0, lambda=10
+ Fold05.Rep1: alpha=0, lambda=10
- Fold05.Rep1: alpha=0, lambda=10
+ Fold06.Rep1: alpha=0, lambda=10
- Fold06.Rep1: alpha=0, lambda=10
+ Fold07.Rep1: alpha=0, lambda=10
- Fold07.Rep1: alpha=0, lambda=10
+ Fold08.Rep1: alpha=0, lambda=10
- Fold08.Rep1: alpha=0, lambda=10
+ Fold09.Rep1: alpha=0, lambda=10
- Fold09.Rep1: alpha=0, lambda=10
+ Fold10.Rep1: alpha=0, lambda=10
- Fold10.Rep1: alpha=0, lambda=10
+ Fold01.Rep2: alpha=0, lambda=10
- Fold01.Rep2: alpha=0, lambda=10
+ Fold02.Rep2: alpha=0, lambda=10
- Fold02.Rep2: alpha=0, lambda=10
+ Fold03.Rep2: alpha=0, lambda=10
- Fold03.Rep2: alpha=0, lambda=10
+ Fold04.Rep2: alpha=0, lambda=10
- Fold04.Rep2: alpha=0, lambda=10
+ Fold05.Rep2: alpha=0, lambda=10
- Fold05.Rep2: alpha=0, lambda=10
+ Fold06.Rep2: alpha=0, lambda=10
- Fold06.Rep2: alpha=0, lambda=10
+ Fold07.Rep2: alpha=0, lambda=10
- Fold07.Rep2: alpha=0, lambda=10
+ Fold08.Rep2: alpha=0, lambda=10
- Fold08.Rep2: alpha=0, lambda=10
+ Fold09.Rep2: alpha=0, lambda=10
- Fold09.Rep2: alpha=0, lambda=10
+ Fold10.Rep2: alpha=0, lambda=10
- Fold10.Rep2: alpha=0, lambda=10
+ Fold01.Rep3: alpha=0, lambda=10
- Fold01.Rep3: alpha=0, lambda=10
+ Fold02.Rep3: alpha=0, lambda=10
- Fold02.Rep3: alpha=0, lambda=10
+ Fold03.Rep3: alpha=0, lambda=10
- Fold03.Rep3: alpha=0, lambda=10
+ Fold04.Rep3: alpha=0, lambda=10
- Fold04.Rep3: alpha=0, lambda=10
+ Fold05.Rep3: alpha=0, lambda=10
- Fold05.Rep3: alpha=0, lambda=10
+ Fold06.Rep3: alpha=0, lambda=10
- Fold06.Rep3: alpha=0, lambda=10
+ Fold07.Rep3: alpha=0, lambda=10
- Fold07.Rep3: alpha=0, lambda=10
+ Fold08.Rep3: alpha=0, lambda=10
- Fold08.Rep3: alpha=0, lambda=10
+ Fold09.Rep3: alpha=0, lambda=10
- Fold09.Rep3: alpha=0, lambda=10
+ Fold10.Rep3: alpha=0, lambda=10
- Fold10.Rep3: alpha=0, lambda=10
Aggregating results
Selecting tuning parameters
Fitting alpha = 0, lambda = 0.614 on full training set
# Print the model results
print(ridge_caret)
glmnet
506 samples
13 predictor
Pre-processing: centered (13), scaled (13)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 456, 456, 454, 455, 455, 456, ...
Resampling results across tuning parameters:
lambda RMSE Rsquared MAE
0.001000000 4.806746 0.7301647 3.317381
0.001097499 4.806746 0.7301647 3.317381
0.001204504 4.806746 0.7301647 3.317381
0.001321941 4.806746 0.7301647 3.317381
0.001450829 4.806746 0.7301647 3.317381
0.001592283 4.806746 0.7301647 3.317381
0.001747528 4.806746 0.7301647 3.317381
0.001917910 4.806746 0.7301647 3.317381
0.002104904 4.806746 0.7301647 3.317381
0.002310130 4.806746 0.7301647 3.317381
0.002535364 4.806746 0.7301647 3.317381
0.002782559 4.806746 0.7301647 3.317381
0.003053856 4.806746 0.7301647 3.317381
0.003351603 4.806746 0.7301647 3.317381
0.003678380 4.806746 0.7301647 3.317381
0.004037017 4.806746 0.7301647 3.317381
0.004430621 4.806746 0.7301647 3.317381
0.004862602 4.806746 0.7301647 3.317381
0.005336699 4.806746 0.7301647 3.317381
0.005857021 4.806746 0.7301647 3.317381
0.006428073 4.806746 0.7301647 3.317381
0.007054802 4.806746 0.7301647 3.317381
0.007742637 4.806746 0.7301647 3.317381
0.008497534 4.806746 0.7301647 3.317381
0.009326033 4.806746 0.7301647 3.317381
0.010235310 4.806746 0.7301647 3.317381
0.011233240 4.806746 0.7301647 3.317381
0.012328467 4.806746 0.7301647 3.317381
0.013530478 4.806746 0.7301647 3.317381
0.014849683 4.806746 0.7301647 3.317381
0.016297508 4.806746 0.7301647 3.317381
0.017886495 4.806746 0.7301647 3.317381
0.019630407 4.806746 0.7301647 3.317381
0.021544347 4.806746 0.7301647 3.317381
0.023644894 4.806746 0.7301647 3.317381
0.025950242 4.806746 0.7301647 3.317381
0.028480359 4.806746 0.7301647 3.317381
0.031257158 4.806746 0.7301647 3.317381
0.034304693 4.806746 0.7301647 3.317381
0.037649358 4.806746 0.7301647 3.317381
0.041320124 4.806746 0.7301647 3.317381
0.045348785 4.806746 0.7301647 3.317381
0.049770236 4.806746 0.7301647 3.317381
0.054622772 4.806746 0.7301647 3.317381
0.059948425 4.806746 0.7301647 3.317381
0.065793322 4.806746 0.7301647 3.317381
0.072208090 4.806746 0.7301647 3.317381
0.079248290 4.806746 0.7301647 3.317381
0.086974900 4.806746 0.7301647 3.317381
0.095454846 4.806746 0.7301647 3.317381
0.104761575 4.806746 0.7301647 3.317381
0.114975700 4.806746 0.7301647 3.317381
0.126185688 4.806746 0.7301647 3.317381
0.138488637 4.806746 0.7301647 3.317381
0.151991108 4.806746 0.7301647 3.317381
0.166810054 4.806746 0.7301647 3.317381
0.183073828 4.806746 0.7301647 3.317381
0.200923300 4.806746 0.7301647 3.317381
0.220513074 4.806746 0.7301647 3.317381
0.242012826 4.806746 0.7301647 3.317381
0.265608778 4.806746 0.7301647 3.317381
0.291505306 4.806746 0.7301647 3.317381
0.319926714 4.806746 0.7301647 3.317381
0.351119173 4.806746 0.7301647 3.317381
0.385352859 4.806746 0.7301647 3.317381
0.422924287 4.806746 0.7301647 3.317381
0.464158883 4.806746 0.7301647 3.317381
0.509413801 4.806746 0.7301647 3.317381
0.559081018 4.806746 0.7301647 3.317381
0.613590727 4.806746 0.7301647 3.317381
0.673415066 4.807033 0.7301517 3.317428
0.739072203 4.811518 0.7298486 3.317314
0.811130831 4.816688 0.7294945 3.317410
0.890215085 4.822560 0.7290988 3.318470
0.977009957 4.829213 0.7286581 3.320186
1.072267222 4.836750 0.7281649 3.323002
1.176811952 4.845217 0.7276188 3.326951
1.291549665 4.854696 0.7270178 3.331652
1.417474163 4.865317 0.7263531 3.337287
1.555676144 4.877231 0.7256163 3.343453
1.707352647 4.890479 0.7248094 3.350169
1.873817423 4.905252 0.7239185 3.358365
2.056512308 4.921678 0.7229364 3.368157
2.257019720 4.939879 0.7218607 3.379089
2.477076356 4.960040 0.7206732 3.391748
2.718588243 4.982301 0.7193676 3.406106
2.983647240 5.006867 0.7179301 3.422477
3.274549163 5.033853 0.7163524 3.440514
3.593813664 5.063455 0.7146199 3.460589
3.944206059 5.095833 0.7127234 3.482775
4.328761281 5.131167 0.7106439 3.507956
4.750810162 5.169606 0.7083577 3.535950
5.214008288 5.211237 0.7058604 3.566416
5.722367659 5.256181 0.7031371 3.599599
6.280291442 5.304507 0.7001766 3.636717
6.892612104 5.356264 0.6969686 3.675653
7.564633276 5.411474 0.6935052 3.716819
8.302175681 5.470122 0.6897812 3.761187
9.111627561 5.532156 0.6857950 3.808851
10.000000000 5.597482 0.6815522 3.859416
Tuning parameter 'alpha' was held constant at a value of 0
RMSE was used to select the optimal model using the
smallest value.
The final values used for the model were alpha = 0 and
lambda = 0.6135907.
# Plot the tuning results (RMSE vs lambda) using ggplot2 directly
ridge_plot_data <- ridge_caret$results
print(ggplot(ridge_plot_data, aes(x = lambda, y = RMSE)) +
geom_line() +
geom_point() +
# Add error bars if RMSESD is available in the results
{if("RMSESD" %in% names(ridge_plot_data)) geom_errorbar(aes(ymin = RMSE - RMSESD, ymax = RMSE + RMSESD), width = 0.01)} +
ggplot2::labs(title = "Ridge Regression Tuning (caret)",
x = "Lambda",
y = "RMSE") +
theme_minimal())
# Get the best lambda found by caret
cat("Optimal lambda for Ridge (caret):", ridge_caret$bestTune$lambda, "\n")
Optimal lambda for Ridge (caret): 0.6135907
# Get the coefficients from the best model
coef(ridge_caret$finalModel, s = ridge_caret$bestTune$lambda)
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 22.5328063
crim -0.7532606
zn 0.7622018
indus -0.2607184
chas1 0.7365273
nox -1.3804925
rm 2.8184140
age -0.1050366
dis -2.3560256
rad 1.3385674
tax -0.9692660
ptratio -1.8509951
b 0.8283859
lstat -3.3736074
caret
Similarly, for Lasso regression, we fix \(\alpha\) at 1.
# --- Lasso Regression with caret ---
# For Lasso, we specify a tuneGrid where alpha is fixed at 1.
# caret will then search for the best lambda.
lassoGrid <- expand.grid(alpha = 1,
lambda = 10^seq(-3, 1, length = 100)) # A range of lambda values
cat("\n--- Training Lasso Regression with caret ---\n")
--- Training Lasso Regression with caret ---
lasso_caret <- train(
medv ~ .,
data = data_df,
method = "glmnet",
tuneGrid = lassoGrid,
trControl = fitControl,
preProcess = c("center", "scale")
)
+ Fold01.Rep1: alpha=1, lambda=10
- Fold01.Rep1: alpha=1, lambda=10
+ Fold02.Rep1: alpha=1, lambda=10
- Fold02.Rep1: alpha=1, lambda=10
+ Fold03.Rep1: alpha=1, lambda=10
- Fold03.Rep1: alpha=1, lambda=10
+ Fold04.Rep1: alpha=1, lambda=10
- Fold04.Rep1: alpha=1, lambda=10
+ Fold05.Rep1: alpha=1, lambda=10
- Fold05.Rep1: alpha=1, lambda=10
+ Fold06.Rep1: alpha=1, lambda=10
- Fold06.Rep1: alpha=1, lambda=10
+ Fold07.Rep1: alpha=1, lambda=10
- Fold07.Rep1: alpha=1, lambda=10
+ Fold08.Rep1: alpha=1, lambda=10
- Fold08.Rep1: alpha=1, lambda=10
+ Fold09.Rep1: alpha=1, lambda=10
- Fold09.Rep1: alpha=1, lambda=10
+ Fold10.Rep1: alpha=1, lambda=10
- Fold10.Rep1: alpha=1, lambda=10
+ Fold01.Rep2: alpha=1, lambda=10
- Fold01.Rep2: alpha=1, lambda=10
+ Fold02.Rep2: alpha=1, lambda=10
- Fold02.Rep2: alpha=1, lambda=10
+ Fold03.Rep2: alpha=1, lambda=10
- Fold03.Rep2: alpha=1, lambda=10
+ Fold04.Rep2: alpha=1, lambda=10
- Fold04.Rep2: alpha=1, lambda=10
+ Fold05.Rep2: alpha=1, lambda=10
- Fold05.Rep2: alpha=1, lambda=10
+ Fold06.Rep2: alpha=1, lambda=10
- Fold06.Rep2: alpha=1, lambda=10
+ Fold07.Rep2: alpha=1, lambda=10
- Fold07.Rep2: alpha=1, lambda=10
+ Fold08.Rep2: alpha=1, lambda=10
- Fold08.Rep2: alpha=1, lambda=10
+ Fold09.Rep2: alpha=1, lambda=10
- Fold09.Rep2: alpha=1, lambda=10
+ Fold10.Rep2: alpha=1, lambda=10
- Fold10.Rep2: alpha=1, lambda=10
+ Fold01.Rep3: alpha=1, lambda=10
- Fold01.Rep3: alpha=1, lambda=10
+ Fold02.Rep3: alpha=1, lambda=10
- Fold02.Rep3: alpha=1, lambda=10
+ Fold03.Rep3: alpha=1, lambda=10
- Fold03.Rep3: alpha=1, lambda=10
+ Fold04.Rep3: alpha=1, lambda=10
- Fold04.Rep3: alpha=1, lambda=10
+ Fold05.Rep3: alpha=1, lambda=10
- Fold05.Rep3: alpha=1, lambda=10
+ Fold06.Rep3: alpha=1, lambda=10
- Fold06.Rep3: alpha=1, lambda=10
+ Fold07.Rep3: alpha=1, lambda=10
- Fold07.Rep3: alpha=1, lambda=10
+ Fold08.Rep3: alpha=1, lambda=10
- Fold08.Rep3: alpha=1, lambda=10
+ Fold09.Rep3: alpha=1, lambda=10
- Fold09.Rep3: alpha=1, lambda=10
+ Fold10.Rep3: alpha=1, lambda=10
- Fold10.Rep3: alpha=1, lambda=10
Warning: There were missing values in resampled performance measures.
Aggregating results
Selecting tuning parameters
Fitting alpha = 1, lambda = 0.0313 on full training set
# Print the model results
print(lasso_caret)
glmnet
506 samples
13 predictor
Pre-processing: centered (13), scaled (13)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 456, 455, 456, 455, 455, 455, ...
Resampling results across tuning parameters:
lambda RMSE Rsquared MAE
0.001000000 4.827425 0.7335316 3.393066
0.001097499 4.827425 0.7335316 3.393066
0.001204504 4.827425 0.7335316 3.393066
0.001321941 4.827425 0.7335316 3.393066
0.001450829 4.827425 0.7335316 3.393066
0.001592283 4.827425 0.7335316 3.393066
0.001747528 4.827425 0.7335316 3.393066
0.001917910 4.827425 0.7335316 3.393066
0.002104904 4.827425 0.7335316 3.393066
0.002310130 4.827425 0.7335316 3.393066
0.002535364 4.827425 0.7335316 3.393066
0.002782559 4.827425 0.7335316 3.393066
0.003053856 4.827425 0.7335316 3.393066
0.003351603 4.827425 0.7335316 3.393066
0.003678380 4.827425 0.7335316 3.393066
0.004037017 4.827425 0.7335316 3.393066
0.004430621 4.827425 0.7335316 3.393066
0.004862602 4.827425 0.7335316 3.393066
0.005336699 4.827425 0.7335316 3.393066
0.005857021 4.827406 0.7335340 3.393038
0.006428073 4.827260 0.7335475 3.392831
0.007054802 4.827098 0.7335620 3.392556
0.007742637 4.826881 0.7335818 3.392098
0.008497534 4.826607 0.7336092 3.391507
0.009326033 4.826304 0.7336394 3.390856
0.010235310 4.826006 0.7336691 3.390167
0.011233240 4.825662 0.7337037 3.389439
0.012328467 4.825308 0.7337402 3.388642
0.013530478 4.824912 0.7337794 3.387781
0.014849683 4.824455 0.7338255 3.386805
0.016297508 4.824026 0.7338671 3.385767
0.017886495 4.823597 0.7339097 3.384665
0.019630407 4.823147 0.7339534 3.383476
0.021544347 4.822767 0.7339890 3.382235
0.023644894 4.822450 0.7340148 3.380885
0.025950242 4.822233 0.7340258 3.379502
0.028480359 4.822119 0.7340269 3.378029
0.031257158 4.822114 0.7340157 3.376497
0.034304693 4.822354 0.7339813 3.374991
0.037649358 4.822854 0.7339164 3.373470
0.041320124 4.823551 0.7338280 3.371835
0.045348785 4.824612 0.7337017 3.370324
0.049770236 4.826082 0.7335323 3.368953
0.054622772 4.828061 0.7333088 3.367784
0.059948425 4.830641 0.7330209 3.366864
0.065793322 4.833864 0.7326627 3.366030
0.072208090 4.837730 0.7322322 3.365426
0.079248290 4.842752 0.7316706 3.365236
0.086974900 4.849078 0.7309704 3.365654
0.095454846 4.856825 0.7301116 3.366740
0.104761575 4.866181 0.7290690 3.368768
0.114975700 4.877414 0.7278118 3.372438
0.126185688 4.890990 0.7262832 3.377892
0.138488637 4.906864 0.7244850 3.385528
0.151991108 4.923308 0.7226186 3.394208
0.166810054 4.938754 0.7209047 3.403435
0.183073828 4.951737 0.7195315 3.412342
0.200923300 4.962284 0.7184644 3.420201
0.220513074 4.973371 0.7173569 3.428915
0.242012826 4.986318 0.7160580 3.438947
0.265608778 5.001966 0.7144936 3.451218
0.291505306 5.020676 0.7125915 3.466796
0.319926714 5.040717 0.7105349 3.484102
0.351119173 5.062932 0.7082395 3.502757
0.385352859 5.088999 0.7055459 3.523538
0.422924287 5.119854 0.7023088 3.549734
0.464158883 5.150742 0.6990987 3.577273
0.509413801 5.177733 0.6963939 3.600444
0.559081018 5.199051 0.6945347 3.618083
0.613590727 5.217934 0.6932181 3.633864
0.673415066 5.236341 0.6921782 3.649264
0.739072203 5.257714 0.6910335 3.667368
0.811130831 5.282367 0.6897834 3.688690
0.890215085 5.310458 0.6884365 3.713712
0.977009957 5.342468 0.6869322 3.743012
1.072267222 5.377386 0.6855324 3.774360
1.176811952 5.415443 0.6844058 3.805175
1.291549665 5.458257 0.6834713 3.837310
1.417474163 5.507601 0.6825933 3.873493
1.555676144 5.566556 0.6814855 3.918551
1.707352647 5.636855 0.6799977 3.973028
1.873817423 5.720373 0.6780110 4.036070
2.056512308 5.819382 0.6753206 4.109432
2.257019720 5.936463 0.6716188 4.194332
2.477076356 6.074523 0.6664300 4.297626
2.718588243 6.235495 0.6592842 4.422737
2.983647240 6.415520 0.6509253 4.559803
3.274549163 6.598890 0.6464106 4.692162
3.593813664 6.799916 0.6446798 4.830492
3.944206059 7.035107 0.6420236 4.992616
4.328761281 7.308930 0.6374380 5.186985
4.750810162 7.626249 0.6285255 5.418035
5.214008288 7.990396 0.6091422 5.693873
5.722367659 8.387284 0.5752620 6.009139
6.280291442 8.789860 0.5588197 6.345285
6.892612104 9.150170 0.5006470 6.649038
7.564633276 9.152032 NaN 6.650629
8.302175681 9.152032 NaN 6.650629
9.111627561 9.152032 NaN 6.650629
10.000000000 9.152032 NaN 6.650629
Tuning parameter 'alpha' was held constant at a value of 1
RMSE was used to select the optimal model using the
smallest value.
The final values used for the model were alpha = 1 and
lambda = 0.03125716.
# Plot the tuning results (RMSE vs lambda) using ggplot2 directly
lasso_plot_data <- lasso_caret$results
print(ggplot(lasso_plot_data, aes(x = lambda, y = RMSE)) +
geom_line() +
geom_point() +
# Add error bars if RMSESD is available in the results
{if("RMSESD" %in% names(lasso_plot_data)) geom_errorbar(aes(ymin = RMSE - RMSESD, ymax = RMSE + RMSESD), width = 0.01)} +
ggplot2::labs(title = "Lasso Regression Tuning (caret)",
x = "Lambda",
y = "RMSE") +
theme_minimal())
# Get the best lambda found by caret
cat("Optimal lambda for Lasso (caret):", lasso_caret$bestTune$lambda, "\n")
Optimal lambda for Lasso (caret): 0.03125716
# Get the coefficients from the best model
coef(lasso_caret$finalModel, s = lasso_caret$bestTune$lambda)
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 22.5328063
crim -0.8364201
zn 0.9551219
indus .
chas1 0.6809208
nox -1.8755555
rm 2.7219580
age .
dis -2.9166593
rad 2.1553510
tax -1.6200257
ptratio -2.0093438
b 0.8212497
lstat -3.7311343
caret
with glmnet
Here’s a summary of why caret is often preferred for these tasks:
Feature | glmnet (Direct) |
caret (with glmnet method) |
---|---|---|
Interface | Requires x (matrix) and y (vector) as
inputs. |
Uses formula interface (y ~ . ) or x and
y directly. |
Pre-processing | Manual (e.g., scale() ). |
Automated (preProcess argument, e.g.,
c("center", "scale") ). |
Cross-Validation | cv.glmnet() function. |
Integrated into train() via
trControl . |
Hyperparameter Tuning | Manual interpretation of plot(cv.glmnet) and selection
of lambda.min /lambda.1se . |
Automated search for optimal \(\lambda\) (and alpha for
Elastic Net) based on tuneGrid and
trControl . |
Model Comparison | Requires separate calls and manual comparison. | Unified train() output makes it easy to compare
different models and tuning results. |
Output | glmnet object, requires specific functions
(coef , predict ). |
train object, provides consistent methods
(predict , plot , print ). |
Workflow | More granular control, but more manual steps. | Streamlined, automated, and reproducible workflow. |
Flexibility | Highly flexible for glmnet specific tasks. |
Highly flexible for any model, providing a consistent API. |
In essence, caret
acts as a powerful wrapper around
glmnet
(and many other packages), abstracting away much of
the boilerplate code involved in model training, tuning, and evaluation.
This allows you to focus more on the modeling problem and less on the
implementation details of individual algorithms.
When evaluating a model, it’s crucial to estimate its performance on unseen data. This is often referred to as out-of-sample error.
While you could manually split your data into training and
test sets and then make predictions on the test set,
caret
’s train()
function, when used with
resampling methods like cross-validation, automatically provides robust
estimates of out-of-sample performance.
The print()
function on a caret
train
object (e.g., print(ridge_caret)
) will
display the cross-validated performance metrics (like RMSE, R-squared)
for each lambda
value tried, as well as for the optimal
lambda
. These metrics are calculated from the held-out
folds during cross-validation and are excellent indicators of how the
model generalizes to new data.
Let’s re-examine the output of print(ridge_caret)
and
print(lasso_caret)
:
# Re-printing the caret model results to highlight cross-validated metrics
cat("\n--- Ridge Regression Caret Model Summary (Cross-Validated Results) ---\n")
--- Ridge Regression Caret Model Summary (Cross-Validated Results) ---
print(ridge_caret)
glmnet
506 samples
13 predictor
Pre-processing: centered (13), scaled (13)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 456, 456, 454, 455, 455, 456, ...
Resampling results across tuning parameters:
lambda RMSE Rsquared MAE
0.001000000 4.806746 0.7301647 3.317381
0.001097499 4.806746 0.7301647 3.317381
0.001204504 4.806746 0.7301647 3.317381
0.001321941 4.806746 0.7301647 3.317381
0.001450829 4.806746 0.7301647 3.317381
0.001592283 4.806746 0.7301647 3.317381
0.001747528 4.806746 0.7301647 3.317381
0.001917910 4.806746 0.7301647 3.317381
0.002104904 4.806746 0.7301647 3.317381
0.002310130 4.806746 0.7301647 3.317381
0.002535364 4.806746 0.7301647 3.317381
0.002782559 4.806746 0.7301647 3.317381
0.003053856 4.806746 0.7301647 3.317381
0.003351603 4.806746 0.7301647 3.317381
0.003678380 4.806746 0.7301647 3.317381
0.004037017 4.806746 0.7301647 3.317381
0.004430621 4.806746 0.7301647 3.317381
0.004862602 4.806746 0.7301647 3.317381
0.005336699 4.806746 0.7301647 3.317381
0.005857021 4.806746 0.7301647 3.317381
0.006428073 4.806746 0.7301647 3.317381
0.007054802 4.806746 0.7301647 3.317381
0.007742637 4.806746 0.7301647 3.317381
0.008497534 4.806746 0.7301647 3.317381
0.009326033 4.806746 0.7301647 3.317381
0.010235310 4.806746 0.7301647 3.317381
0.011233240 4.806746 0.7301647 3.317381
0.012328467 4.806746 0.7301647 3.317381
0.013530478 4.806746 0.7301647 3.317381
0.014849683 4.806746 0.7301647 3.317381
0.016297508 4.806746 0.7301647 3.317381
0.017886495 4.806746 0.7301647 3.317381
0.019630407 4.806746 0.7301647 3.317381
0.021544347 4.806746 0.7301647 3.317381
0.023644894 4.806746 0.7301647 3.317381
0.025950242 4.806746 0.7301647 3.317381
0.028480359 4.806746 0.7301647 3.317381
0.031257158 4.806746 0.7301647 3.317381
0.034304693 4.806746 0.7301647 3.317381
0.037649358 4.806746 0.7301647 3.317381
0.041320124 4.806746 0.7301647 3.317381
0.045348785 4.806746 0.7301647 3.317381
0.049770236 4.806746 0.7301647 3.317381
0.054622772 4.806746 0.7301647 3.317381
0.059948425 4.806746 0.7301647 3.317381
0.065793322 4.806746 0.7301647 3.317381
0.072208090 4.806746 0.7301647 3.317381
0.079248290 4.806746 0.7301647 3.317381
0.086974900 4.806746 0.7301647 3.317381
0.095454846 4.806746 0.7301647 3.317381
0.104761575 4.806746 0.7301647 3.317381
0.114975700 4.806746 0.7301647 3.317381
0.126185688 4.806746 0.7301647 3.317381
0.138488637 4.806746 0.7301647 3.317381
0.151991108 4.806746 0.7301647 3.317381
0.166810054 4.806746 0.7301647 3.317381
0.183073828 4.806746 0.7301647 3.317381
0.200923300 4.806746 0.7301647 3.317381
0.220513074 4.806746 0.7301647 3.317381
0.242012826 4.806746 0.7301647 3.317381
0.265608778 4.806746 0.7301647 3.317381
0.291505306 4.806746 0.7301647 3.317381
0.319926714 4.806746 0.7301647 3.317381
0.351119173 4.806746 0.7301647 3.317381
0.385352859 4.806746 0.7301647 3.317381
0.422924287 4.806746 0.7301647 3.317381
0.464158883 4.806746 0.7301647 3.317381
0.509413801 4.806746 0.7301647 3.317381
0.559081018 4.806746 0.7301647 3.317381
0.613590727 4.806746 0.7301647 3.317381
0.673415066 4.807033 0.7301517 3.317428
0.739072203 4.811518 0.7298486 3.317314
0.811130831 4.816688 0.7294945 3.317410
0.890215085 4.822560 0.7290988 3.318470
0.977009957 4.829213 0.7286581 3.320186
1.072267222 4.836750 0.7281649 3.323002
1.176811952 4.845217 0.7276188 3.326951
1.291549665 4.854696 0.7270178 3.331652
1.417474163 4.865317 0.7263531 3.337287
1.555676144 4.877231 0.7256163 3.343453
1.707352647 4.890479 0.7248094 3.350169
1.873817423 4.905252 0.7239185 3.358365
2.056512308 4.921678 0.7229364 3.368157
2.257019720 4.939879 0.7218607 3.379089
2.477076356 4.960040 0.7206732 3.391748
2.718588243 4.982301 0.7193676 3.406106
2.983647240 5.006867 0.7179301 3.422477
3.274549163 5.033853 0.7163524 3.440514
3.593813664 5.063455 0.7146199 3.460589
3.944206059 5.095833 0.7127234 3.482775
4.328761281 5.131167 0.7106439 3.507956
4.750810162 5.169606 0.7083577 3.535950
5.214008288 5.211237 0.7058604 3.566416
5.722367659 5.256181 0.7031371 3.599599
6.280291442 5.304507 0.7001766 3.636717
6.892612104 5.356264 0.6969686 3.675653
7.564633276 5.411474 0.6935052 3.716819
8.302175681 5.470122 0.6897812 3.761187
9.111627561 5.532156 0.6857950 3.808851
10.000000000 5.597482 0.6815522 3.859416
Tuning parameter 'alpha' was held constant at a value of 0
RMSE was used to select the optimal model using the
smallest value.
The final values used for the model were alpha = 0 and
lambda = 0.6135907.
cat("\n--- Lasso Regression Caret Model Summary (Cross-Validated Results) ---\n")
--- Lasso Regression Caret Model Summary (Cross-Validated Results) ---
print(lasso_caret)
glmnet
506 samples
13 predictor
Pre-processing: centered (13), scaled (13)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 456, 455, 456, 455, 455, 455, ...
Resampling results across tuning parameters:
lambda RMSE Rsquared MAE
0.001000000 4.827425 0.7335316 3.393066
0.001097499 4.827425 0.7335316 3.393066
0.001204504 4.827425 0.7335316 3.393066
0.001321941 4.827425 0.7335316 3.393066
0.001450829 4.827425 0.7335316 3.393066
0.001592283 4.827425 0.7335316 3.393066
0.001747528 4.827425 0.7335316 3.393066
0.001917910 4.827425 0.7335316 3.393066
0.002104904 4.827425 0.7335316 3.393066
0.002310130 4.827425 0.7335316 3.393066
0.002535364 4.827425 0.7335316 3.393066
0.002782559 4.827425 0.7335316 3.393066
0.003053856 4.827425 0.7335316 3.393066
0.003351603 4.827425 0.7335316 3.393066
0.003678380 4.827425 0.7335316 3.393066
0.004037017 4.827425 0.7335316 3.393066
0.004430621 4.827425 0.7335316 3.393066
0.004862602 4.827425 0.7335316 3.393066
0.005336699 4.827425 0.7335316 3.393066
0.005857021 4.827406 0.7335340 3.393038
0.006428073 4.827260 0.7335475 3.392831
0.007054802 4.827098 0.7335620 3.392556
0.007742637 4.826881 0.7335818 3.392098
0.008497534 4.826607 0.7336092 3.391507
0.009326033 4.826304 0.7336394 3.390856
0.010235310 4.826006 0.7336691 3.390167
0.011233240 4.825662 0.7337037 3.389439
0.012328467 4.825308 0.7337402 3.388642
0.013530478 4.824912 0.7337794 3.387781
0.014849683 4.824455 0.7338255 3.386805
0.016297508 4.824026 0.7338671 3.385767
0.017886495 4.823597 0.7339097 3.384665
0.019630407 4.823147 0.7339534 3.383476
0.021544347 4.822767 0.7339890 3.382235
0.023644894 4.822450 0.7340148 3.380885
0.025950242 4.822233 0.7340258 3.379502
0.028480359 4.822119 0.7340269 3.378029
0.031257158 4.822114 0.7340157 3.376497
0.034304693 4.822354 0.7339813 3.374991
0.037649358 4.822854 0.7339164 3.373470
0.041320124 4.823551 0.7338280 3.371835
0.045348785 4.824612 0.7337017 3.370324
0.049770236 4.826082 0.7335323 3.368953
0.054622772 4.828061 0.7333088 3.367784
0.059948425 4.830641 0.7330209 3.366864
0.065793322 4.833864 0.7326627 3.366030
0.072208090 4.837730 0.7322322 3.365426
0.079248290 4.842752 0.7316706 3.365236
0.086974900 4.849078 0.7309704 3.365654
0.095454846 4.856825 0.7301116 3.366740
0.104761575 4.866181 0.7290690 3.368768
0.114975700 4.877414 0.7278118 3.372438
0.126185688 4.890990 0.7262832 3.377892
0.138488637 4.906864 0.7244850 3.385528
0.151991108 4.923308 0.7226186 3.394208
0.166810054 4.938754 0.7209047 3.403435
0.183073828 4.951737 0.7195315 3.412342
0.200923300 4.962284 0.7184644 3.420201
0.220513074 4.973371 0.7173569 3.428915
0.242012826 4.986318 0.7160580 3.438947
0.265608778 5.001966 0.7144936 3.451218
0.291505306 5.020676 0.7125915 3.466796
0.319926714 5.040717 0.7105349 3.484102
0.351119173 5.062932 0.7082395 3.502757
0.385352859 5.088999 0.7055459 3.523538
0.422924287 5.119854 0.7023088 3.549734
0.464158883 5.150742 0.6990987 3.577273
0.509413801 5.177733 0.6963939 3.600444
0.559081018 5.199051 0.6945347 3.618083
0.613590727 5.217934 0.6932181 3.633864
0.673415066 5.236341 0.6921782 3.649264
0.739072203 5.257714 0.6910335 3.667368
0.811130831 5.282367 0.6897834 3.688690
0.890215085 5.310458 0.6884365 3.713712
0.977009957 5.342468 0.6869322 3.743012
1.072267222 5.377386 0.6855324 3.774360
1.176811952 5.415443 0.6844058 3.805175
1.291549665 5.458257 0.6834713 3.837310
1.417474163 5.507601 0.6825933 3.873493
1.555676144 5.566556 0.6814855 3.918551
1.707352647 5.636855 0.6799977 3.973028
1.873817423 5.720373 0.6780110 4.036070
2.056512308 5.819382 0.6753206 4.109432
2.257019720 5.936463 0.6716188 4.194332
2.477076356 6.074523 0.6664300 4.297626
2.718588243 6.235495 0.6592842 4.422737
2.983647240 6.415520 0.6509253 4.559803
3.274549163 6.598890 0.6464106 4.692162
3.593813664 6.799916 0.6446798 4.830492
3.944206059 7.035107 0.6420236 4.992616
4.328761281 7.308930 0.6374380 5.186985
4.750810162 7.626249 0.6285255 5.418035
5.214008288 7.990396 0.6091422 5.693873
5.722367659 8.387284 0.5752620 6.009139
6.280291442 8.789860 0.5588197 6.345285
6.892612104 9.150170 0.5006470 6.649038
7.564633276 9.152032 NaN 6.650629
8.302175681 9.152032 NaN 6.650629
9.111627561 9.152032 NaN 6.650629
10.000000000 9.152032 NaN 6.650629
Tuning parameter 'alpha' was held constant at a value of 1
RMSE was used to select the optimal model using the
smallest value.
The final values used for the model were alpha = 1 and
lambda = 0.03125716.
In the output above, The “RMSE” and “Rsquared” columns represent the
average performance across the cross-validation folds for different
lambda
values. The row corresponding to the
bestTune
(optimal lambda
) provides the best
estimate of your model’s out-of-sample performance.
While caret handles the out-of-sample error estimation through
cross-validation, you can still use the predict()
function
to get predictions on new data. For demonstration purposes, we’ll
predict on the original data_df
, but in a real-world
scenario, this would be a completely separate test set.
# Make predictions on the original data (for demonstration)
# In a real scenario, you would predict on new, unseen test data.
# Predictions from Ridge model
ridge_predictions <- predict(ridge_caret, newdata = data_df)
cat("\nFirst 5 Ridge predictions:\n")
First 5 Ridge predictions:
print(head(ridge_predictions))
1 2 3 4 5 6
30.26572 25.02806 30.53628 28.96496 28.41959 25.56159
# Predictions from Lasso model
lasso_predictions <- predict(lasso_caret, newdata = data_df)
cat("\nFirst 5 Lasso predictions:\n")
First 5 Lasso predictions:
print(head(lasso_predictions))
1 2 3 4 5 6
30.21874 25.05759 30.65066 28.77710 28.12243 25.38625
# Evaluate performance (e.g., RMSE) on the training data
# (Again, in practice, do this on a separate test set)
ridge_rmse <- sqrt(mean((data_df$medv - ridge_predictions)^2))
lasso_rmse <- sqrt(mean((data_df$medv - lasso_predictions)^2))
cat("\nRMSE for Ridge (on training data):", ridge_rmse, "\n")
RMSE for Ridge (on training data): 4.734318
cat("RMSE for Lasso (on training data):", lasso_rmse, "\n")
RMSE for Lasso (on training data): 4.684567
# Compare coefficients (Lasso's sparsity vs Ridge's shrinkage)
cat("\n--- Coefficients from Lasso (best lambda) ---\n")
--- Coefficients from Lasso (best lambda) ---
print(coef(lasso_caret$finalModel, s = lasso_caret$bestTune$lambda))
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 22.5328063
crim -0.8364201
zn 0.9551219
indus .
chas1 0.6809208
nox -1.8755555
rm 2.7219580
age .
dis -2.9166593
rad 2.1553510
tax -1.6200257
ptratio -2.0093438
b 0.8212497
lstat -3.7311343
cat("\n--- Coefficients from Ridge (best lambda) ---\n")
--- Coefficients from Ridge (best lambda) ---
print(coef(ridge_caret$finalModel, s = ridge_caret$bestTune$lambda))
14 x 1 sparse Matrix of class "dgCMatrix"
s1
(Intercept) 22.5328063
crim -0.7532606
zn 0.7622018
indus -0.2607184
chas1 0.7365273
nox -1.3804925
rm 2.8184140
age -0.1050366
dis -2.3560256
rad 1.3385674
tax -0.9692660
ptratio -1.8509951
b 0.8283859
lstat -3.3736074
Notice how Lasso has driven some coefficients exactly to zero, performing feature selection, while Ridge has shrunk them but kept them non-zero.