Elastic Net regression is a badass, probably because it is a dynamic ensemble of regularisation techniques. Often during a regression setup, merely LASSO (L1 penalty) or ridge (L2 penalty) does not seem to make it. Specially when the number of features in the data set, p >> n, the number of observations. Or maybe the variables are in the bad company of multicolinearity. In such a distress, Elastic Net can redeem you. The formula of Elastic Net is
$ \hat{\beta} = \underset{\beta}{\operatorname{argmin}} [ \frac{1}{2N}\sum_{i=1}^N (y-X \beta )^2 + \lambda\ P_\alpha(\beta)] $
$P_\alpha(\beta) = (1 - \alpha) \frac{1}{2} \| \Gamma \beta\|^2_{l_1} + \alpha \| \Gamma \beta\|_{l_2}$
where $\lambda$ is called the complexity parameter and $\alpha$ is the hyperparameter, if $\alpha$ = 0, it's completely ridge and $\alpha$ = 1 is completely LASSO. $\Gamma$ is the vector of penalty factors.
In order to demonstrate the implementation of Elastic Net, we choose a classification problem, although not too high dimensional, is fairly enough for our purpose.
We shall be utilizing the glmnet package and be seeking help from parallel and doParallel packages in R programming language and environment
suppressPackageStartupMessages({
library(useful)
library(glmnet)
library(parallel)
library(doParallel)
library(foreach)
library(reshape2)
library(stringr)
})
This data set is downloaded from the UCI Machine Learning repository. Prediction task is to determine whether a person makes over 50K a year
adult <- read.table("http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", sep = ",", header = FALSE, stringsAsFactors = FALSE)
First, let's rename the columns as per convenience. The column names, however, are available from the UCI website.
names(adult) <- c('age','workclass','fnlwgt','education','education_num','marital_status','occupation','relationship','race','sex','capital_gain','capital_loss','hours_per_week','native_country','income')
Next, getting a first impression of the data, as we almost always do.
head(adult,10)
| age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
| 50 | Self-emp-not-inc | 83311 | Bachelors | 13 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 13 | United-States | <=50K |
| 38 | Private | 215646 | HS-grad | 9 | Divorced | Handlers-cleaners | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
| 53 | Private | 234721 | 11th | 7 | Married-civ-spouse | Handlers-cleaners | Husband | Black | Male | 0 | 0 | 40 | United-States | <=50K |
| 28 | Private | 338409 | Bachelors | 13 | Married-civ-spouse | Prof-specialty | Wife | Black | Female | 0 | 0 | 40 | Cuba | <=50K |
| 37 | Private | 284582 | Masters | 14 | Married-civ-spouse | Exec-managerial | Wife | White | Female | 0 | 0 | 40 | United-States | <=50K |
| 49 | Private | 160187 | 9th | 5 | Married-spouse-absent | Other-service | Not-in-family | Black | Female | 0 | 0 | 16 | Jamaica | <=50K |
| 52 | Self-emp-not-inc | 209642 | HS-grad | 9 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 45 | United-States | >50K |
| 31 | Private | 45781 | Masters | 14 | Never-married | Prof-specialty | Not-in-family | White | Female | 14084 | 0 | 50 | United-States | >50K |
| 42 | Private | 159449 | Bachelors | 13 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 5178 | 0 | 40 | United-States | >50K |
str(adult)
'data.frame': 32561 obs. of 15 variables: $ age : int 39 50 38 53 28 37 49 52 31 42 ... $ workclass : chr " State-gov" " Self-emp-not-inc" " Private" " Private" ... $ fnlwgt : int 77516 83311 215646 234721 338409 284582 160187 209642 45781 159449 ... $ education : chr " Bachelors" " Bachelors" " HS-grad" " 11th" ... $ education_num : int 13 13 9 7 13 14 5 9 14 13 ... $ marital_status: chr " Never-married" " Married-civ-spouse" " Divorced" " Married-civ-spouse" ... $ occupation : chr " Adm-clerical" " Exec-managerial" " Handlers-cleaners" " Handlers-cleaners" ... $ relationship : chr " Not-in-family" " Husband" " Not-in-family" " Husband" ... $ race : chr " White" " White" " White" " Black" ... $ sex : chr " Male" " Male" " Male" " Male" ... $ capital_gain : int 2174 0 0 0 0 0 0 0 14084 5178 ... $ capital_loss : int 0 0 0 0 0 0 0 0 0 0 ... $ hours_per_week: int 40 13 40 40 40 40 16 45 50 40 ... $ native_country: chr " United-States" " United-States" " United-States" " United-States" ... $ income : chr " <=50K" " <=50K" " <=50K" " <=50K" ...
summary(adult)
age workclass fnlwgt education
Min. :17.00 Length:32561 Min. : 12285 Length:32561
1st Qu.:28.00 Class :character 1st Qu.: 117827 Class :character
Median :37.00 Mode :character Median : 178356 Mode :character
Mean :38.58 Mean : 189778
3rd Qu.:48.00 3rd Qu.: 237051
Max. :90.00 Max. :1484705
education_num marital_status occupation relationship
Min. : 1.00 Length:32561 Length:32561 Length:32561
1st Qu.: 9.00 Class :character Class :character Class :character
Median :10.00 Mode :character Mode :character Mode :character
Mean :10.08
3rd Qu.:12.00
Max. :16.00
race sex capital_gain capital_loss
Length:32561 Length:32561 Min. : 0 Min. : 0.0
Class :character Class :character 1st Qu.: 0 1st Qu.: 0.0
Mode :character Mode :character Median : 0 Median : 0.0
Mean : 1078 Mean : 87.3
3rd Qu.: 0 3rd Qu.: 0.0
Max. :99999 Max. :4356.0
hours_per_week native_country income
Min. : 1.00 Length:32561 Length:32561
1st Qu.:40.00 Class :character Class :character
Median :40.00 Mode :character Mode :character
Mean :40.44
3rd Qu.:45.00
Max. :99.00
dim(adult)
counter_na <- colSums(is.na(adult))
counter_na
Bravo! No missing values. Careful, we need one more layer of screening.
apply(adult, 2, unique)
Some of the features contain missing observations in form of ' ?'. We assume, in the present context, the ' ?''s are MCAR and exclude them from analysis. Alright, need to subset observations with complete cases.
adult <- adult[adult$workclass != " ?" & adult$occupation != " ?" & adult$native_country != " ?",]
dim(adult)
We could have spent much more time dealing with feature engineering and data manipulation works (Goshh! Some variables do look thirsty for such treatments!). But let's save it for some other day, please. Next, make a binary variable 'income' for building a logistic regression.
adult$income <- with(adult, adult$income == " >50K")
glmnet in general needs a predictor matrix. The build.x function from useful package creates nice predictor matrices.
adultX <- build.x(income ~ age + workclass + fnlwgt + education + education_num + marital_status + occupation +
relationship + race + sex + capital_gain + capital_loss + hours_per_week + native_country - 1, data=adult, contrasts = FALSE)
class(adultX);dim(adultX)
adultY <- build.y(income ~ age + workclass + fnlwgt + education + education_num + marital_status + occupation +
relationship + race + sex + capital_gain + capital_loss + hours_per_week + native_country - 1, data=adult)
set.seed(21364597)
cv.glmnet tells us which value of λ minimizes the cross-validation error. Additionally, it also returns the largest value of λ with a cross-validation error that is within one standard error of the minimum. Occam's Razor (Principle of Parsimony) suggests that the simpler model, even though it is slightly less accurate, should be preferred.
adultCV1 <- cv.glmnet(x=adultX, y=adultY, family = "binomial", nfold=5)
adultCV1$lambda.min
adultCV1$lambda.1se
The cross-validation errors for differing values of λ are shown in the plot below. The top row of numbers indicates how many variables (factor levels are counted as individual variables) are in the model for a given value of log(λ). The dots represent the cross-validation error at that point and the vertical lines are the confidence interval for the error. The leftmost vertical line indicates the value of λ where the error is minimized and the rightmost vertical line is the next largest value of λ error that is within one standard error of the minimum.
plot(adultCV1)
coef(adultCV1, s="lambda.1se")
105 x 1 sparse Matrix of class "dgCMatrix"
1
(Intercept) -7.452203e+00
age 2.284516e-02
workclass Federal-gov 3.725640e-01
workclass Local-gov .
workclass Private .
workclass Self-emp-inc 1.568324e-01
workclass Self-emp-not-inc -2.937237e-01
workclass State-gov -3.972143e-02
workclass Without-pay .
fnlwgt 2.670333e-07
education 10th .
education 11th .
education 12th .
education 1st-4th .
education 5th-6th .
education 7th-8th .
education 9th .
education Assoc-acdm -1.319282e-01
education Assoc-voc .
education Bachelors .
education Doctorate .
education HS-grad .
education Masters .
education Preschool .
education Prof-school 1.331023e-01
education Some-college .
education_num 2.742087e-01
marital_status Divorced .
marital_status Married-AF-spouse 1.587036e+00
marital_status Married-civ-spouse 1.756452e+00
marital_status Married-spouse-absent .
marital_status Never-married -3.205346e-01
marital_status Separated .
marital_status Widowed .
occupation Adm-clerical .
occupation Armed-Forces .
occupation Craft-repair .
occupation Exec-managerial 6.677254e-01
occupation Farming-fishing -7.778358e-01
occupation Handlers-cleaners -3.915025e-01
occupation Machine-op-inspct -1.398751e-01
occupation Other-service -5.998474e-01
occupation Priv-house-serv .
occupation Prof-specialty 3.628801e-01
occupation Protective-serv 2.307353e-01
occupation Sales 1.482594e-01
occupation Tech-support 3.909261e-01
occupation Transport-moving .
relationship Husband .
relationship Not-in-family .
relationship Other-relative -1.398014e-01
relationship Own-child -5.743604e-01
relationship Unmarried .
relationship Wife 8.543963e-01
race Amer-Indian-Eskimo .
race Asian-Pac-Islander .
race Black .
race Other .
race White 7.449669e-02
sex Female -5.357086e-01
sex Male .
capital_gain 2.255052e-04
capital_loss 5.458884e-04
hours_per_week 2.561027e-02
native_country Cambodia 7.888523e-02
native_country Canada .
native_country China .
native_country Columbia -2.528716e-01
native_country Cuba .
native_country Dominican-Republic .
native_country Ecuador .
native_country El-Salvador .
native_country England .
native_country France .
native_country Germany .
native_country Greece .
native_country Guatemala .
native_country Haiti .
native_country Holand-Netherlands .
native_country Honduras .
native_country Hong .
native_country Hungary .
native_country India .
native_country Iran .
native_country Ireland .
native_country Italy 4.908942e-02
native_country Jamaica .
native_country Japan .
native_country Laos .
native_country Mexico -2.194433e-02
native_country Nicaragua .
native_country Outlying-US(Guam-USVI-etc) .
native_country Peru .
native_country Philippines .
native_country Poland .
native_country Portugal .
native_country Puerto-Rico .
native_country Scotland .
native_country South -1.592238e-01
native_country Taiwan .
native_country Thailand .
native_country Trinadad&Tobago .
native_country United-States 4.851934e-02
native_country Vietnam .
native_country Yugoslavia .
Some levels of factor are selected and others are not because LASSO eliminates variables that are highly correlated with each other.
In the following plot, each line represents a coefficient’s value at different values of λ. The leftmost vertical line indicates the value of λ where the error is minimized and the rightmost vertical line is the next largest value of λ error that is within one standard error of the minimum
plot(adultCV1$glmnet.fit, xvar = "lambda")
abline(v=log(c(adultCV1$lambda.min, adultCV1$lambda.1se)), lty=2)
Setting α to 0 causes the regularisation to be completely ridge. In this case, every variable is kept in the model but is just shrunk closer to 0
set.seed(84756)
adultCV2 <- cv.glmnet(x=adultX, y=adultY, family="binomial", nfold=5, alpha=0)
adultCV2$lambda.min
adultCV2$lambda.1se
coef(adultCV2, s="lambda.1se")
105 x 1 sparse Matrix of class "dgCMatrix"
1
(Intercept) -5.345928e+00
age 2.061885e-02
workclass Federal-gov 4.865905e-01
workclass Local-gov -6.749782e-02
workclass Private 5.027305e-02
workclass Self-emp-inc 3.128812e-01
workclass Self-emp-not-inc -2.972907e-01
workclass State-gov -1.815915e-01
workclass Without-pay -1.645734e+00
fnlwgt 5.461147e-07
education 10th -4.892027e-01
education 11th -4.652898e-01
education 12th -3.171417e-01
education 1st-4th -4.088599e-01
education 5th-6th -4.549892e-01
education 7th-8th -6.641192e-01
education 9th -5.815318e-01
education Assoc-acdm -6.006992e-02
education Assoc-voc 4.105526e-02
education Bachelors 3.674892e-01
education Doctorate 8.969487e-01
education HS-grad -2.034310e-01
education Masters 5.948533e-01
education Preschool -9.038387e-01
education Prof-school 9.204415e-01
education Some-college -1.772004e-02
education_num 1.144983e-01
marital_status Divorced -2.358436e-01
marital_status Married-AF-spouse 1.186445e+00
marital_status Married-civ-spouse 6.996802e-01
marital_status Married-spouse-absent -2.161284e-01
marital_status Never-married -6.015863e-01
marital_status Separated -3.017178e-01
marital_status Widowed -1.257116e-01
occupation Adm-clerical -9.214468e-02
occupation Armed-Forces -8.493606e-01
occupation Craft-repair -3.798050e-02
occupation Exec-managerial 6.268609e-01
occupation Farming-fishing -7.758269e-01
occupation Handlers-cleaners -5.893824e-01
occupation Machine-op-inspct -3.119999e-01
occupation Other-service -6.318488e-01
occupation Priv-house-serv -8.053591e-01
occupation Prof-specialty 3.844989e-01
occupation Protective-serv 3.756021e-01
occupation Sales 1.712084e-01
occupation Tech-support 4.302558e-01
occupation Transport-moving -1.588432e-01
relationship Husband 4.339236e-01
relationship Not-in-family -1.875291e-01
relationship Other-relative -4.439067e-01
relationship Own-child -6.754041e-01
relationship Unmarried -3.157881e-01
relationship Wife 1.309420e+00
race Amer-Indian-Eskimo -3.261232e-01
race Asian-Pac-Islander 1.224241e-01
race Black -8.432171e-02
race Other -3.454349e-01
race White 7.736776e-02
sex Female -2.996986e-01
sex Male 2.984909e-01
capital_gain 8.811657e-05
capital_loss 4.791516e-04
hours_per_week 2.256299e-02
native_country Cambodia 1.029611e+00
native_country Canada 2.497022e-01
native_country China -3.950500e-01
native_country Columbia -1.204090e+00
native_country Cuba 2.495072e-01
native_country Dominican-Republic -9.010458e-01
native_country Ecuador -1.898093e-01
native_country El-Salvador -3.117686e-01
native_country England 3.159462e-01
native_country France 5.677196e-01
native_country Germany 3.841827e-01
native_country Greece -5.508070e-01
native_country Guatemala -2.082662e-01
native_country Haiti -1.372037e-01
native_country Holand-Netherlands -7.821264e-01
native_country Honduras -5.105604e-01
native_country Hong -8.549506e-02
native_country Hungary -2.600314e-02
native_country India -2.072899e-01
native_country Iran 5.295037e-02
native_country Ireland 3.814043e-01
native_country Italy 6.240565e-01
native_country Jamaica -5.401974e-02
native_country Japan 2.764393e-01
native_country Laos -4.710587e-01
native_country Mexico -4.256294e-01
native_country Nicaragua -6.103550e-01
native_country Outlying-US(Guam-USVI-etc) -1.370468e+00
native_country Peru -6.441535e-01
native_country Philippines 3.141296e-01
native_country Poland -4.236572e-02
native_country Portugal -8.456079e-02
native_country Puerto-Rico -2.338966e-01
native_country Scotland -2.968606e-01
native_country South -7.715265e-01
native_country Taiwan -8.763373e-02
native_country Thailand -4.577744e-01
native_country Trinadad&Tobago -3.451659e-01
native_country United-States 1.786337e-01
native_country Vietnam -6.881111e-01
native_country Yugoslavia 4.809380e-01
plot(adultCV2)
plot(adultCV2$glmnet.fit, xvar = "lambda")
abline(v=log(c(adultCV2$lambda.min, adultCV2$lambda.1se)), lty=2)
Next we find the optimal value of $\alpha$ with the help of parallel, doParallel and foreach package.
First a cluster is created and registered. Setting .errorhandling to "remove" means that if an error occurs, that iteration will be skipped. Setting .inorder to FALSE means that the order of combining the results does not matter and they can be combined whenever returned, which yields significant speed improvements. Because we are using the default combination function, list, which takes multiple arguments at once, we can speed up the process by setting .multicombine to TRUE. We specify in .packages that glmnet should be loaded on each of the workers, again leading to performance improvements. The operator %dopar% tells foreach to work in parallel.
set.seed(25364758)
theFolds <- sample(rep(x=1:5, length.out = nrow(adultX)))
alphas <- seq(from=.5, to=1, by=.05)
set.seed(5245175)
cl <- makeCluster(2)
registerDoParallel(cl)
before <- Sys.time()
adultDouble <- foreach(i=1:length(alphas), .errorhandling = "remove", .inorder = FALSE, .multicombine = TRUE,
.export = c("adultX", "adultY", "alphas", "theFolds"), .packages = "glmnet" ) %dopar%
{
print(alphas[i])
cv.glmnet(x=adultX, y=adultY, family="binomial", nfolds=5, foldid=theFolds, alpha=alphas[i])
}
after <- Sys.time()
stopCluster(cl)
Warning message in e$fun(obj, substitute(ex), parent.frame(), e$data): "already exporting variable(s): adultX, adultY, alphas, theFolds"
after - before
Time difference of 4.216807 mins
Next we use ggplot to find the optimal value of $\alpha$ and $\lambda$ by plotting the errors
extractGlmnetInfo <- function(object) {
lambdaMin <- object$lambda.min
lambda1se <- object$lambda.1se
whichMin <- which(object$lambda == lambdaMin)
which1se <- which(object$lambda == lambda1se)
data.frame(lambda.min=lambdaMin, error.min=object$cvm[whichMin],
lambda.1se=lambda1se, error.1se=object$cvm[which1se])
}
alphaInfo <- Reduce(rbind, lapply(adultDouble, extractGlmnetInfo))
alphaInfo$Alpha <- alphas
alphaMelt <- melt(alphaInfo, id.vars="Alpha", value.name="Value", variable.name="Measure")
alphaMelt$Type <- str_extract(string=alphaMelt$Measure, pattern="(min)|(1se)")
alphaMelt$Measure <- str_replace(string=alphaMelt$Measure, pattern = "\\.(min|1se)", replacement = "")
alphaCast <- dcast(alphaMelt, Alpha + Type ~ Measure, value.var = "Value")
ggplot(alphaCast, aes(x=Alpha, y=error))+
geom_line(aes(group=Type))+
facet_wrap(~Type, scales="free_y", ncol=1)+
geom_point(aes(size=lambda))
Clearly, 0.9 is the optimal value of $\alpha$ since the bottom pane shows the minimum error corresponds to $\alpha$ = 1.0 which is higher with respect to the top pane.
Now, we refit the model with optimal $\alpha$
set.seed(5222841)
adultCV3 <- cv.glmnet(x=adultX, y=adultY, family="binomial", nfold=5, alpha=alphaInfo$Alpha[which.min(alphaInfo$error.1se)])
plot(adultCV3)
plot(adultCV3$glmnet.fit, xvar="lambda")
abline(v=log(c(adultCV3$lambda.min, adultCV3$lambda.1se)), lty=2)
theCoef <- as.matrix(coef(adultCV3, s="lambda.1se"))
coefDF <- data.frame(Value=theCoef, Coefficient=rownames(theCoef))
coefDF <- coefDF[nonzeroCoef(coef(adultCV3, s="lambda.1se")), ]
ggplot(coefDF, aes(x=X1, y=reorder(Coefficient, X1))) +
geom_vline(xintercept = 0, color="grey", linetype=2)+
geom_point(color="blue")+
labs(x="Value", y="Coefficient", title="Coefficient Plot")
Above, in the plot, we find that..... wow! Married Civil Union Spouse is the strongest indicator of income over 50K and so on.........(drum plays in background). Thanks for checking out. Will be very happy to face expert comments, suggestions, questions and almost everything except gunshots. Cheers!