Use the built-in R dataset, “diamonds”, using “price” as our outcome variable and all other features in the dataset as “predictors” or features.
install_if_not <- function( list.of.packages ) {
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
if(length(new.packages)) { install.packages(new.packages) } else { print(paste0("the package '", list.of.packages , "' is already installed")) }
}
library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.3 v purrr 0.3.4
## v tibble 3.0.6 v dplyr 1.0.4
## v tidyr 1.0.2 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
# Take a look at the data
install_if_not('glmnet')
## [1] "the package 'glmnet' is already installed"
library(glmnet)
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 3.0-2
glimpse(diamonds)
## Rows: 53,940
## Columns: 10
## $ carat <dbl> 0.23, 0.21, 0.23, 0.29, 0.31, 0.24, 0.24, 0.26, 0.22, 0.23,...
## $ cut <ord> Ideal, Premium, Good, Premium, Good, Very Good, Very Good, ...
## $ color <ord> E, E, E, I, J, J, I, H, E, H, J, J, F, J, E, E, I, J, J, J,...
## $ clarity <ord> SI2, SI1, VS1, VS2, SI2, VVS2, VVS1, SI1, VS2, VS1, SI1, VS...
## $ depth <dbl> 61.5, 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, 59.4,...
## $ table <dbl> 55, 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54, 62,...
## $ price <int> 326, 326, 327, 334, 335, 336, 336, 337, 337, 338, 339, 340,...
## $ x <dbl> 3.95, 3.89, 4.05, 4.20, 4.34, 3.94, 3.95, 4.07, 3.87, 4.00,...
## $ y <dbl> 3.98, 3.84, 4.07, 4.23, 4.35, 3.96, 3.98, 4.11, 3.78, 4.05,...
## $ z <dbl> 2.43, 2.31, 2.31, 2.63, 2.75, 2.48, 2.47, 2.53, 2.49, 2.39,...
\(~\)
\(~\)
\(~\)
\(~\)
# Split the data into training and test set
set.seed(8675309)
training.samples <- diamonds$price %>%
createDataPartition(p = 0.7, list = FALSE)
train.data <- diamonds[training.samples, ]
## Warning: The `i` argument of ``[`()` can't be a matrix as of tibble 3.0.0.
## Convert to a vector.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
test.data <- diamonds[-training.samples, ]
\(~\)
\(~\)
\(~\)
\(~\)
lambda
tuning gridSet up a range of lambda values for cross-validation (that is, so you the optimal lambda is choosen).
lambda <- 10^seq(-3, 3, length = 100)
\(~\)
\(~\)
\(~\)
\(~\)
# Build the model using 10-fold cross-validation
ridge <- train(
price ~.,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = 0, lambda = lambda), #alpha=0 means Ridge
preProcess = c('center','scale')
)
plot(ridge)
# Model coefficients
coef_ridge <- coef(ridge$finalModel, ridge$bestTune$lambda)
coef_ridge <- as_tibble(as.matrix(coef_ridge), rownames='feature') %>%
rename(estimate = '1')
coef_ridge
## # A tibble: 24 x 2
## feature estimate
## <chr> <dbl>
## 1 (Intercept) 3932.
## 2 carat 2147.
## 3 cut.L 190.
## 4 cut.Q -112.
## 5 cut.C 26.4
## 6 cut^4 -2.62
## 7 color.L -449.
## 8 color.Q -146.
## 9 color.C -49.7
## 10 color^4 3.15
## # ... with 14 more rows
# Predict price using the estimated ridge model and test data
predictions <- ridge %>% predict(test.data)
test.scored.ridge <- cbind(test.data, predictions)
library('yardstick')
## For binary classification, the first factor level is assumed to be the event.
## Set the global option `yardstick.event_first` to `FALSE` to change this.
##
## Attaching package: 'yardstick'
## The following objects are masked from 'package:caret':
##
## precision, recall
## The following object is masked from 'package:readr':
##
## spec
rmse(test.scored.ridge, price, predictions)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 1373.
\(~\)
\(~\)
\(~\)
\(~\)
# Build the model using 10-fold cross-validation
lasso <- train(
price ~.,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = 1, lambda = lambda), #alpha=1 means LASSO
preProcess = c('center','scale')
)
# Investigate model coefficients
# Model coefficients
# You will see that the variable "y" was zeroed out of the model
coef_lasso <- coef(lasso$finalModel, lasso$bestTune$lambda)
coef_lasso <- as_tibble(as.matrix(coef_lasso), rownames='feature') %>%
rename(estimate = '1')
predictions <- lasso %>% predict(test.data)
test.scored.lasso <- cbind(test.data, predictions)
rmse(test.scored.lasso, price, predictions)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 1139.
Here, you can see that compared to the ridge model, lasso produced a model with a lower RMSE and higher R-squared. Thus, we would choose this model as our final model.
glm_model <- train(price ~.,
data = train.data,
method = "glm",
trControl = trainControl("cv", number = 10),
preProcess = c('center','scale')
)
predictions <- glm_model %>% predict(test.data)
test.scored.glm <- cbind(test.data, predictions)
test.scored_Stacked <- bind_rows( test.scored.ridge %>% mutate(model = 'ridge'),
test.scored.lasso %>% mutate(model = 'lasso'),
test.scored.glm %>% mutate(model = 'glm')
)
test.scored_Stacked %>%
group_by(model) %>%
rmse(truth=price, predictions) %>%
arrange(.estimate)
## # A tibble: 3 x 4
## model .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 lasso rmse standard 1139.
## 2 glm rmse standard 1141.
## 3 ridge rmse standard 1373.
coef_glm <- as_tibble(glm_model$finalModel$coefficients, rownames='feature') %>%
rename(estimate = value)
coef_compare <- bind_rows(coef_glm %>% mutate(model = 'glm'),
coef_lasso %>% mutate(model = 'lasso'),
coef_ridge %>% mutate(model = 'ridge')
)
any_column_NA <- function(x){
any(is.na(x))
}
replace_NA_0 <- function(x){
if_else(is.na(x),0,x)
}
coef_compare %>%
pivot_wider(names_from = model,
values_from = estimate) %>%
mutate_if(any_column_NA, replace_NA_0) %>%
mutate(lasso_pct_diff_glm = (glm-lasso)/glm) %>%
mutate(ridge_pct_diff_glm = (glm-ridge)/glm) %>%
arrange(-lasso_pct_diff_glm) %>%
knitr::kable()
feature | glm | lasso | ridge | lasso_pct_diff_glm | ridge_pct_diff_glm |
---|---|---|---|---|---|
cut^4 | 0.0000000 | -5.749140 | -2.615307 | Inf | Inf |
color^5 | 0.0000000 | -32.099308 | -21.247125 | Inf | Inf |
color^6 | 0.0000000 | -19.508480 | -21.264867 | Inf | Inf |
clarity^4 | 0.0000000 | -129.224723 | -35.276471 | Inf | Inf |
cut^4 |
-12.5826932 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
color^4 |
8.3376384 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
color^5 |
-34.9618857 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
color^6 |
-21.3028796 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
clarity^4 |
-137.0286078 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
clarity^5 |
92.1936626 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
clarity^6 |
-0.1474419 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
clarity^7 |
39.2175527 | 0.000000 | 0.000000 | 1.0000000 | 1.0000000 |
y | 165.2820422 | 0.000000 | 569.528185 | 1.0000000 | -2.4457959 |
x | -1068.0369153 | -816.584123 | 647.774978 | 0.2354346 | 1.6065099 |
z | -225.1271735 | -196.577485 | 596.050826 | 0.1268158 | 3.6476183 |
cut.C | 60.1095711 | 56.200869 | 26.409631 | 0.0650263 | 0.5606418 |
depth | -59.7719605 | -56.554788 | -10.247484 | 0.0538241 | 0.8285570 |
color.C | -65.2773861 | -61.903237 | -49.688592 | 0.0516894 | 0.2388085 |
cut.Q | -142.1065994 | -135.024557 | -112.192345 | 0.0498361 | 0.2105057 |
table | -61.2639120 | -58.348170 | -33.411300 | 0.0475931 | 0.4546333 |
clarity.C | 302.2403654 | 290.104646 | 88.484395 | 0.0401525 | 0.7072383 |
clarity.Q | -469.7390115 | -458.474010 | -272.481911 | 0.0239814 | 0.4199291 |
color.Q | -233.1298779 | -227.768585 | -146.265139 | 0.0229970 | 0.3726023 |
carat | 5309.0675020 | 5191.334787 | 2146.939898 | 0.0221758 | 0.5956089 |
cut.L | 206.9329173 | 203.424317 | 190.182371 | 0.0169553 | 0.0809467 |
color.L | -634.9561138 | -628.962264 | -448.988298 | 0.0094398 | 0.2928829 |
clarity.L | 1044.5822449 | 1036.337828 | 798.588123 | 0.0078925 | 0.2354952 |
(Intercept) | 3931.5103949 | 3931.510395 | 3931.510395 | 0.0000000 | 0.0000000 |
color^4 | 0.0000000 | 7.305958 | 3.152159 | -Inf | -Inf |
clarity^5 | 0.0000000 | 85.434511 | 20.131680 | -Inf | -Inf |
clarity^7 | 0.0000000 | 36.719177 | 50.483468 | -Inf | -Inf |
clarity^6 | 0.0000000 | 0.000000 | 9.785508 | NaN | -Inf |
coef_compare %>%
group_by(model) %>%
ggplot(aes(x=feature,
y=abs(estimate),
fill=model)) +
geom_bar(stat = 'identity',position = 'dodge') +
coord_flip()
\(~\)
\(~\)
As a final note, we can test both lasso and ridge within the same cross-validated model by expanding our tuning grid further:
alpha = c(0,1)
lasso_v_ridge <- train( price ~ . ,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = alpha, lambda = lambda),
preProcess = c('center','scale')
)
lasso_v_ridge
## glmnet
##
## 37759 samples
## 9 predictor
##
## Pre-processing: centered (23), scaled (23)
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 33983, 33982, 33984, 33982, 33983, 33984, ...
## Resampling results across tuning parameters:
##
## alpha lambda RMSE Rsquared MAE
## 0 1.000000e-03 1328.233 0.8905538 940.5504
## 0 1.149757e-03 1328.233 0.8905538 940.5504
## 0 1.321941e-03 1328.233 0.8905538 940.5504
## 0 1.519911e-03 1328.233 0.8905538 940.5504
## 0 1.747528e-03 1328.233 0.8905538 940.5504
## 0 2.009233e-03 1328.233 0.8905538 940.5504
## 0 2.310130e-03 1328.233 0.8905538 940.5504
## 0 2.656088e-03 1328.233 0.8905538 940.5504
## 0 3.053856e-03 1328.233 0.8905538 940.5504
## 0 3.511192e-03 1328.233 0.8905538 940.5504
## 0 4.037017e-03 1328.233 0.8905538 940.5504
## 0 4.641589e-03 1328.233 0.8905538 940.5504
## 0 5.336699e-03 1328.233 0.8905538 940.5504
## 0 6.135907e-03 1328.233 0.8905538 940.5504
## 0 7.054802e-03 1328.233 0.8905538 940.5504
## 0 8.111308e-03 1328.233 0.8905538 940.5504
## 0 9.326033e-03 1328.233 0.8905538 940.5504
## 0 1.072267e-02 1328.233 0.8905538 940.5504
## 0 1.232847e-02 1328.233 0.8905538 940.5504
## 0 1.417474e-02 1328.233 0.8905538 940.5504
## 0 1.629751e-02 1328.233 0.8905538 940.5504
## 0 1.873817e-02 1328.233 0.8905538 940.5504
## 0 2.154435e-02 1328.233 0.8905538 940.5504
## 0 2.477076e-02 1328.233 0.8905538 940.5504
## 0 2.848036e-02 1328.233 0.8905538 940.5504
## 0 3.274549e-02 1328.233 0.8905538 940.5504
## 0 3.764936e-02 1328.233 0.8905538 940.5504
## 0 4.328761e-02 1328.233 0.8905538 940.5504
## 0 4.977024e-02 1328.233 0.8905538 940.5504
## 0 5.722368e-02 1328.233 0.8905538 940.5504
## 0 6.579332e-02 1328.233 0.8905538 940.5504
## 0 7.564633e-02 1328.233 0.8905538 940.5504
## 0 8.697490e-02 1328.233 0.8905538 940.5504
## 0 1.000000e-01 1328.233 0.8905538 940.5504
## 0 1.149757e-01 1328.233 0.8905538 940.5504
## 0 1.321941e-01 1328.233 0.8905538 940.5504
## 0 1.519911e-01 1328.233 0.8905538 940.5504
## 0 1.747528e-01 1328.233 0.8905538 940.5504
## 0 2.009233e-01 1328.233 0.8905538 940.5504
## 0 2.310130e-01 1328.233 0.8905538 940.5504
## 0 2.656088e-01 1328.233 0.8905538 940.5504
## 0 3.053856e-01 1328.233 0.8905538 940.5504
## 0 3.511192e-01 1328.233 0.8905538 940.5504
## 0 4.037017e-01 1328.233 0.8905538 940.5504
## 0 4.641589e-01 1328.233 0.8905538 940.5504
## 0 5.336699e-01 1328.233 0.8905538 940.5504
## 0 6.135907e-01 1328.233 0.8905538 940.5504
## 0 7.054802e-01 1328.233 0.8905538 940.5504
## 0 8.111308e-01 1328.233 0.8905538 940.5504
## 0 9.326033e-01 1328.233 0.8905538 940.5504
## 0 1.072267e+00 1328.233 0.8905538 940.5504
## 0 1.232847e+00 1328.233 0.8905538 940.5504
## 0 1.417474e+00 1328.233 0.8905538 940.5504
## 0 1.629751e+00 1328.233 0.8905538 940.5504
## 0 1.873817e+00 1328.233 0.8905538 940.5504
## 0 2.154435e+00 1328.233 0.8905538 940.5504
## 0 2.477076e+00 1328.233 0.8905538 940.5504
## 0 2.848036e+00 1328.233 0.8905538 940.5504
## 0 3.274549e+00 1328.233 0.8905538 940.5504
## 0 3.764936e+00 1328.233 0.8905538 940.5504
## 0 4.328761e+00 1328.233 0.8905538 940.5504
## 0 4.977024e+00 1328.233 0.8905538 940.5504
## 0 5.722368e+00 1328.233 0.8905538 940.5504
## 0 6.579332e+00 1328.233 0.8905538 940.5504
## 0 7.564633e+00 1328.233 0.8905538 940.5504
## 0 8.697490e+00 1328.233 0.8905538 940.5504
## 0 1.000000e+01 1328.233 0.8905538 940.5504
## 0 1.149757e+01 1328.233 0.8905538 940.5504
## 0 1.321941e+01 1328.233 0.8905538 940.5504
## 0 1.519911e+01 1328.233 0.8905538 940.5504
## 0 1.747528e+01 1328.233 0.8905538 940.5504
## 0 2.009233e+01 1328.233 0.8905538 940.5504
## 0 2.310130e+01 1328.233 0.8905538 940.5504
## 0 2.656088e+01 1328.233 0.8905538 940.5504
## 0 3.053856e+01 1328.233 0.8905538 940.5504
## 0 3.511192e+01 1328.233 0.8905538 940.5504
## 0 4.037017e+01 1328.233 0.8905538 940.5504
## 0 4.641589e+01 1328.233 0.8905538 940.5504
## 0 5.336699e+01 1328.233 0.8905538 940.5504
## 0 6.135907e+01 1328.233 0.8905538 940.5504
## 0 7.054802e+01 1328.233 0.8905538 940.5504
## 0 8.111308e+01 1328.233 0.8905538 940.5504
## 0 9.326033e+01 1328.233 0.8905538 940.5504
## 0 1.072267e+02 1328.233 0.8905538 940.5504
## 0 1.232847e+02 1328.233 0.8905538 940.5504
## 0 1.417474e+02 1328.233 0.8905538 940.5504
## 0 1.629751e+02 1328.233 0.8905538 940.5504
## 0 1.873817e+02 1328.233 0.8905538 940.5504
## 0 2.154435e+02 1328.233 0.8905538 940.5504
## 0 2.477076e+02 1328.233 0.8905538 940.5504
## 0 2.848036e+02 1328.233 0.8905538 940.5504
## 0 3.274549e+02 1328.233 0.8905538 940.5504
## 0 3.764936e+02 1331.202 0.8901090 942.4954
## 0 4.328761e+02 1348.075 0.8875674 953.3319
## 0 4.977024e+02 1365.185 0.8850186 963.5800
## 0 5.722368e+02 1382.631 0.8824611 973.2589
## 0 6.579332e+02 1400.457 0.8799049 982.3759
## 0 7.564633e+02 1418.812 0.8773460 991.0442
## 0 8.697490e+02 1437.914 0.8747747 999.2806
## 0 1.000000e+03 1457.959 0.8721833 1007.1196
## 1 1.000000e-03 1128.530 0.9197241 740.8260
## 1 1.149757e-03 1128.530 0.9197241 740.8260
## 1 1.321941e-03 1128.530 0.9197241 740.8260
## 1 1.519911e-03 1128.530 0.9197241 740.8260
## 1 1.747528e-03 1128.530 0.9197241 740.8260
## 1 2.009233e-03 1128.530 0.9197241 740.8260
## 1 2.310130e-03 1128.530 0.9197241 740.8260
## 1 2.656088e-03 1128.530 0.9197241 740.8260
## 1 3.053856e-03 1128.530 0.9197241 740.8260
## 1 3.511192e-03 1128.530 0.9197241 740.8260
## 1 4.037017e-03 1128.530 0.9197241 740.8260
## 1 4.641589e-03 1128.530 0.9197241 740.8260
## 1 5.336699e-03 1128.530 0.9197241 740.8260
## 1 6.135907e-03 1128.530 0.9197241 740.8260
## 1 7.054802e-03 1128.530 0.9197241 740.8260
## 1 8.111308e-03 1128.530 0.9197241 740.8260
## 1 9.326033e-03 1128.530 0.9197241 740.8260
## 1 1.072267e-02 1128.530 0.9197241 740.8260
## 1 1.232847e-02 1128.530 0.9197241 740.8260
## 1 1.417474e-02 1128.530 0.9197241 740.8260
## 1 1.629751e-02 1128.530 0.9197241 740.8260
## 1 1.873817e-02 1128.530 0.9197241 740.8260
## 1 2.154435e-02 1128.530 0.9197241 740.8260
## 1 2.477076e-02 1128.530 0.9197241 740.8260
## 1 2.848036e-02 1128.530 0.9197241 740.8260
## 1 3.274549e-02 1128.530 0.9197241 740.8260
## 1 3.764936e-02 1128.530 0.9197241 740.8260
## 1 4.328761e-02 1128.530 0.9197241 740.8260
## 1 4.977024e-02 1128.530 0.9197241 740.8260
## 1 5.722368e-02 1128.530 0.9197241 740.8260
## 1 6.579332e-02 1128.530 0.9197241 740.8260
## 1 7.564633e-02 1128.530 0.9197241 740.8260
## 1 8.697490e-02 1128.530 0.9197241 740.8260
## 1 1.000000e-01 1128.530 0.9197241 740.8260
## 1 1.149757e-01 1128.530 0.9197241 740.8260
## 1 1.321941e-01 1128.530 0.9197241 740.8260
## 1 1.519911e-01 1128.530 0.9197241 740.8260
## 1 1.747528e-01 1128.530 0.9197241 740.8260
## 1 2.009233e-01 1128.530 0.9197241 740.8260
## 1 2.310130e-01 1128.530 0.9197241 740.8260
## 1 2.656088e-01 1128.530 0.9197241 740.8260
## 1 3.053856e-01 1128.530 0.9197241 740.8260
## 1 3.511192e-01 1128.530 0.9197241 740.8260
## 1 4.037017e-01 1128.530 0.9197241 740.8260
## 1 4.641589e-01 1128.530 0.9197241 740.8260
## 1 5.336699e-01 1128.530 0.9197241 740.8260
## 1 6.135907e-01 1128.530 0.9197241 740.8260
## 1 7.054802e-01 1128.530 0.9197241 740.8260
## 1 8.111308e-01 1128.530 0.9197241 740.8260
## 1 9.326033e-01 1128.530 0.9197241 740.8260
## 1 1.072267e+00 1128.530 0.9197241 740.8260
## 1 1.232847e+00 1128.530 0.9197241 740.8260
## 1 1.417474e+00 1128.530 0.9197241 740.8260
## 1 1.629751e+00 1128.530 0.9197241 740.8260
## 1 1.873817e+00 1128.554 0.9197212 740.8719
## 1 2.154435e+00 1128.651 0.9197077 741.1256
## 1 2.477076e+00 1128.754 0.9196929 741.4981
## 1 2.848036e+00 1128.893 0.9196732 741.9429
## 1 3.274549e+00 1129.064 0.9196490 742.4901
## 1 3.764936e+00 1129.278 0.9196187 743.1335
## 1 4.328761e+00 1129.594 0.9195747 743.8847
## 1 4.977024e+00 1129.986 0.9195199 744.7839
## 1 5.722368e+00 1130.502 0.9194480 745.8555
## 1 6.579332e+00 1131.182 0.9193532 747.1605
## 1 7.564633e+00 1132.061 0.9192305 748.7196
## 1 8.697490e+00 1133.229 0.9190674 750.6287
## 1 1.000000e+01 1134.764 0.9188529 752.9746
## 1 1.149757e+01 1136.798 0.9185679 755.9350
## 1 1.321941e+01 1139.499 0.9181888 759.6902
## 1 1.519911e+01 1143.124 0.9176791 764.5271
## 1 1.747528e+01 1147.632 0.9170412 770.3485
## 1 2.009233e+01 1153.172 0.9162504 777.0174
## 1 2.310130e+01 1159.992 0.9152696 784.4810
## 1 2.656088e+01 1166.065 0.9143955 790.1091
## 1 3.053856e+01 1169.270 0.9139619 789.8811
## 1 3.511192e+01 1173.084 0.9134470 789.6917
## 1 4.037017e+01 1176.863 0.9129486 789.4904
## 1 4.641589e+01 1180.521 0.9124875 788.6588
## 1 5.336699e+01 1184.966 0.9119282 787.8838
## 1 6.135907e+01 1190.816 0.9111811 787.6081
## 1 7.054802e+01 1198.061 0.9102534 787.8937
## 1 8.111308e+01 1205.069 0.9094053 788.0329
## 1 9.326033e+01 1213.517 0.9083886 788.7301
## 1 1.072267e+02 1224.594 0.9070231 790.5540
## 1 1.232847e+02 1239.086 0.9051857 794.0593
## 1 1.417474e+02 1257.024 0.9028673 799.1376
## 1 1.629751e+02 1276.842 0.9003378 804.5367
## 1 1.873817e+02 1301.890 0.8970313 812.8215
## 1 2.154435e+02 1331.716 0.8930051 823.6523
## 1 2.477076e+02 1370.108 0.8875213 839.9880
## 1 2.848036e+02 1415.244 0.8808129 859.6805
## 1 3.274549e+02 1465.793 0.8730583 882.9116
## 1 3.764936e+02 1503.311 0.8681694 901.0994
## 1 4.328761e+02 1549.897 0.8617489 929.5854
## 1 4.977024e+02 1608.610 0.8529723 976.3468
## 1 5.722368e+02 1647.380 0.8495885 1008.8091
## 1 6.579332e+02 1679.058 0.8495885 1036.6043
## 1 7.564633e+02 1720.039 0.8495885 1076.7016
## 1 8.697490e+02 1772.761 0.8495885 1131.7484
## 1 1.000000e+03 1840.141 0.8495885 1200.8500
##
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were alpha = 1 and lambda = 1.629751.
plot(lasso_v_ridge)
\(~\)
\(~\)
\(~\)
\(~\)
install_if_not <- function( list.of.packages ) {
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
if(length(new.packages)) { install.packages(new.packages) } else { print(paste0("the package '", list.of.packages , "' is already installed")) }
}
library(tidyverse)
library(caret)
# Take a look at the data
install_if_not('glmnet')
library(glmnet)
glimpse(diamonds)
# Split the data into training and test set
set.seed(8675309)
training.samples <- diamonds$price %>%
createDataPartition(p = 0.7, list = FALSE)
train.data <- diamonds[training.samples, ]
test.data <- diamonds[-training.samples, ]
lambda <- 10^seq(-3, 3, length = 100)
# Build the model using 10-fold cross-validation
ridge <- train(
price ~.,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = 0, lambda = lambda), #alpha=0 means Ridge
preProcess = c('center','scale')
)
plot(ridge)
# Model coefficients
coef_ridge <- coef(ridge$finalModel, ridge$bestTune$lambda)
coef_ridge <- as_tibble(as.matrix(coef_ridge), rownames='feature') %>%
rename(estimate = '1')
coef_ridge
# Predict price using the estimated ridge model and test data
predictions <- ridge %>% predict(test.data)
test.scored.ridge <- cbind(test.data, predictions)
library('yardstick')
rmse(test.scored.ridge, price, predictions)
# Build the model using 10-fold cross-validation
lasso <- train(
price ~.,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = 1, lambda = lambda), #alpha=1 means LASSO
preProcess = c('center','scale')
)
# Investigate model coefficients
# Model coefficients
# You will see that the variable "y" was zeroed out of the model
coef_lasso <- coef(lasso$finalModel, lasso$bestTune$lambda)
coef_lasso <- as_tibble(as.matrix(coef_lasso), rownames='feature') %>%
rename(estimate = '1')
predictions <- lasso %>% predict(test.data)
test.scored.lasso <- cbind(test.data, predictions)
rmse(test.scored.lasso, price, predictions)
glm_model <- train(price ~.,
data = train.data,
method = "glm",
trControl = trainControl("cv", number = 10),
preProcess = c('center','scale')
)
predictions <- glm_model %>% predict(test.data)
test.scored.glm <- cbind(test.data, predictions)
test.scored_Stacked <- bind_rows( test.scored.ridge %>% mutate(model = 'ridge'),
test.scored.lasso %>% mutate(model = 'lasso'),
test.scored.glm %>% mutate(model = 'glm')
)
test.scored_Stacked %>%
group_by(model) %>%
rmse(truth=price, predictions) %>%
arrange(.estimate)
coef_glm <- as_tibble(glm_model$finalModel$coefficients, rownames='feature') %>%
rename(estimate = value)
coef_compare <- bind_rows(coef_glm %>% mutate(model = 'glm'),
coef_lasso %>% mutate(model = 'lasso'),
coef_ridge %>% mutate(model = 'ridge')
)
any_column_NA <- function(x){
any(is.na(x))
}
replace_NA_0 <- function(x){
if_else(is.na(x),0,x)
}
coef_compare %>%
pivot_wider(names_from = model,
values_from = estimate) %>%
mutate_if(any_column_NA, replace_NA_0) %>%
mutate(lasso_pct_diff_glm = (glm-lasso)/glm) %>%
mutate(ridge_pct_diff_glm = (glm-ridge)/glm) %>%
arrange(-lasso_pct_diff_glm) %>%
knitr::kable()
coef_compare %>%
group_by(model) %>%
ggplot(aes(x=feature,
y=abs(estimate),
fill=model)) +
geom_bar(stat = 'identity',position = 'dodge') +
coord_flip()
alpha = c(0,1)
lasso_v_ridge <- train( price ~ . ,
data = train.data,
method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneGrid = expand.grid(alpha = alpha, lambda = lambda),
preProcess = c('center','scale')
)
lasso_v_ridge
plot(lasso_v_ridge)