Case study X1: Elastic-net regularisation models

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

1.Encourage using Machine Learning techniques in medical research in Vietnam

2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Introduction

Most of traditional model selection methods like forward/backward stepwise or subsets filtration become ineffective when dealing with very large dataset that contains many independent variables (even larger than the number of observations). The advanced regression methods such as LASSO and Elastic net allow us to resolve this problem. The glmnet package, developed by erome Friedman, Trevor Hastie and Rob Tibshirani in 2013 provides a powerful and flexible hybrid solution for fitting different types of generalized linear model using either LASSO, Ridge or Elastic net regularisation methods. This package also supports a very fast K-fold cross-validation process.

This case study aims to demonstrate the function of glmnet package in 3 different ways: On its own or in mlr and caret frameworks.

Material and method

In this study, we use the “Heart disease” dataset that includes clinical data of more than 700 patients from 4 Cardiolgy centers (Cleveland,Budapest,Long Beach and Zurich). The goal was to making diagnosis of heart disease based on 14 explanatory variables (features) uncluding age,sex,chest pain type,serum cholesterol,fasting blood sugar test, ST slope induced by exercise relative to rest and maximum heart rate achieved during CPET.

library(tidyverse)

va=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.va.data", sep =",",na.strings="?",strip.white=TRUE, fill = TRUE)%>%as_tibble()
hu=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.hungarian.data", sep =",",na.strings="?",strip.white=TRUE, fill = TRUE)%>%as_tibble()
sw=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.switzerland.data", sep =",",na.strings="?",strip.white=TRUE, fill = TRUE)%>%as_tibble()
cl=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data", sep =",",na.strings="?",strip.white=TRUE, fill = TRUE)%>%as_tibble()

df=rbind(va,hu,sw,cl)
names(df)=c("Age","Sex","ChestPain","RestBP","Chol","FBS","RestECG","MaxHR","CPETAgina","Oldpeak","Slope","CA","Thal","Class")

data=df[,-c(11,12,13)]%>%filter(.,Chol!=0)

data=na.omit(data)

data$Sex%<>%as.factor()%>%recode_factor(.,`0` = "Female", `1` = "Male")

data$ChestPain%<>%as.factor()%>%recode_factor(.,`1` = "Typical", `2` = "Atypical",`3` = "Non_aginal", `4` = "asymptomatic" )

data$FBS%<>%as.factor()%>%recode_factor(.,`0` = "No", `1` = "Yes")

data$RestECG%<>%as.factor()%>%recode_factor(.,`0` = "Normal", `1` = "Abnormal_ST",`2` = "LVHypertrophy")

data$CPETAgina%<>%as.factor()%>%recode_factor(.,`0` = "No", `1` = "Yes")

data$Class%<>%as.factor()%>%recode_factor(.,`0` = "Negative", `1` = "Positive",`2` = "Positive", `3` = "Positive",`4` = "Positive")

#Creating 2nd dataset with dummy variables

library(dummy)

data2=data%>%.[,c(1:10)]%>%dummy()%>%mutate(.,Age=data$Age,Chol=data$Chol,MaxHR=data$MaxHR,Oldpeak=data$Oldpeak,Class=data$Class)%>%as_tibble()

data2%>%select(.,c(1:13))%>%map(~as.integer(.))->data2[,c(1:13)]

It should be noted that the glmnet package does not support the factor variables in their native form, that’s why we must transform all the factors into dummy variables. Both mlr and caret packages can handle automatically the factor variables. Thus there are 2 different datasets in our experiment.

After removing the missing values, both original datasets (n=661) will be randomly splitted into training (n=561) and testing (n=100) subsets.

library(caret)
## Loading required package: lattice
## Warning: package 'lattice' was built under R version 3.3.2
## 
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
## 
##     lift
set.seed(123)
idx<- createDataPartition(data2$Class,p=99/661,list=FALSE)

trainorg<- data[-idx,]
testorg<- data[idx,]

On the training subsets, 5 different models will be fitted as follow:

  1. A LASSO regularisation based model, by setting the alpha=1 in glmnet function
  2. A Ridge regularisation based model, by setting the alpha=0 in glmnet function
  3. An Elastic net regularisation with 10 fold cross validation using cv.glmnet function
  4. A manual grid-based tuned Elastic net model using MLR framework and 5. A automated Tuning based Elastic model using CARET framework

The performance of these 5 models will be evaluated upon the same Testing subset.

Results

1) LASSO Model

trainset<- data2[-idx,]
testset<- data2[idx,]

trainx=as.matrix(trainset[,c(1:17)])
testx=as.matrix(testset[,c(1:17)])

library(glmnet)

lasso = glmnet(x=trainx,y=trainset$Class, family = "binomial",alpha=1)

plot(lasso,xvar="lambda")

coef(lasso,s=min(lasso$lambda))
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                    1
## (Intercept)             1.740551e+00
## Sex_Female             -1.439619e+00
## Sex_Male                1.238079e-13
## ChestPain_Typical      -2.261458e-01
## ChestPain_Atypical     -3.492781e-01
## ChestPain_Non_aginal    .           
## ChestPain_asymptomatic  1.278536e+00
## FBS_No                 -5.850382e-01
## FBS_Yes                 9.668426e-03
## RestECG_Normal         -2.183646e-01
## RestECG_Abnormal_ST     9.609337e-02
## RestECG_LVHypertrophy   .           
## CPETAgina_No           -1.158147e+00
## CPETAgina_Yes           3.864020e-13
## Age                     2.655659e-02
## Chol                    4.261702e-03
## MaxHR                  -8.308493e-03
## Oldpeak                 7.268944e-01
predlasso=predict(lasso, newx = testx, s = min(lasso$lambda), type = "class")

cm1=confusionMatrix(predlasso,testset$Class,positive="Positive")

cm1
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Negative Positive
##   Negative       42        6
##   Positive       10       42
##                                           
##                Accuracy : 0.84            
##                  95% CI : (0.7532, 0.9057)
##     No Information Rate : 0.52            
##     P-Value [Acc > NIR] : 1.863e-11       
##                                           
##                   Kappa : 0.6805          
##  Mcnemar's Test P-Value : 0.4533          
##                                           
##             Sensitivity : 0.8750          
##             Specificity : 0.8077          
##          Pos Pred Value : 0.8077          
##          Neg Pred Value : 0.8750          
##              Prevalence : 0.4800          
##          Detection Rate : 0.4200          
##    Detection Prevalence : 0.5200          
##       Balanced Accuracy : 0.8413          
##                                           
##        'Positive' Class : Positive        
## 

2. RIDGE Model

ridge = glmnet(x=trainx,y=trainset$Class, family = "binomial",alpha=0)

plot(ridge,xvar="lambda")

coef(ridge,s=min(ridge$lambda))
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                   1
## (Intercept)            -1.032373920
## Sex_Female             -0.635334238
## Sex_Male                0.632930848
## ChestPain_Typical      -0.470871993
## ChestPain_Atypical     -0.684132823
## ChestPain_Non_aginal   -0.345786305
## ChestPain_asymptomatic  0.815172235
## FBS_No                 -0.279138466
## FBS_Yes                 0.278334874
## RestECG_Normal         -0.145189119
## RestECG_Abnormal_ST     0.153032209
## RestECG_LVHypertrophy   0.086293399
## CPETAgina_No           -0.576595767
## CPETAgina_Yes           0.576579724
## Age                     0.023530949
## Chol                    0.003467407
## MaxHR                  -0.008191013
## Oldpeak                 0.580417760
predridge=predict(ridge, newx = testx, s =min(ridge$lambda), type = "class")

cm2=confusionMatrix(predridge,testset$Class,positive="Positive")

cm2
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Negative Positive
##   Negative       42        8
##   Positive       10       40
##                                           
##                Accuracy : 0.82            
##                  95% CI : (0.7305, 0.8897)
##     No Information Rate : 0.52            
##     P-Value [Acc > NIR] : 3.758e-10       
##                                           
##                   Kappa : 0.64            
##  Mcnemar's Test P-Value : 0.8137          
##                                           
##             Sensitivity : 0.8333          
##             Specificity : 0.8077          
##          Pos Pred Value : 0.8000          
##          Neg Pred Value : 0.8400          
##              Prevalence : 0.4800          
##          Detection Rate : 0.4000          
##    Detection Prevalence : 0.5000          
##       Balanced Accuracy : 0.8205          
##                                           
##        'Positive' Class : Positive        
## 

3)Elastic net model with CV

#Model 3: Elastic net with 10x10 CV

enet = cv.glmnet(x=trainx,y=trainset$Class, family = "binomial", type.measure = "class")

plot(enet)

coef(enet,s="lambda.min")
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                    1
## (Intercept)             5.759316e-01
## Sex_Female             -7.672276e-01
## Sex_Male                9.329967e-14
## ChestPain_Typical       .           
## ChestPain_Atypical     -6.835365e-02
## ChestPain_Non_aginal    .           
## ChestPain_asymptomatic  1.104963e+00
## FBS_No                 -1.908759e-01
## FBS_Yes                 .           
## RestECG_Normal          .           
## RestECG_Abnormal_ST     .           
## RestECG_LVHypertrophy   .           
## CPETAgina_No           -9.622832e-01
## CPETAgina_Yes           3.437673e-13
## Age                     1.552322e-02
## Chol                    2.347580e-04
## MaxHR                  -4.851293e-03
## Oldpeak                 5.127933e-01
predenet=predict(enet, newx = testx, s = "lambda.min", type = "class")

cm3=confusionMatrix(predenet,testset$Class,positive="Positive")

cm3
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Negative Positive
##   Negative       42        8
##   Positive       10       40
##                                           
##                Accuracy : 0.82            
##                  95% CI : (0.7305, 0.8897)
##     No Information Rate : 0.52            
##     P-Value [Acc > NIR] : 3.758e-10       
##                                           
##                   Kappa : 0.64            
##  Mcnemar's Test P-Value : 0.8137          
##                                           
##             Sensitivity : 0.8333          
##             Specificity : 0.8077          
##          Pos Pred Value : 0.8000          
##          Neg Pred Value : 0.8400          
##              Prevalence : 0.4800          
##          Detection Rate : 0.4000          
##    Detection Prevalence : 0.5000          
##       Balanced Accuracy : 0.8205          
##                                           
##        'Positive' Class : Positive        
## 

4)Grid based tuning of Elastic net model in MLR package

library(mlr)

tasktrain=makeClassifTask(id="Hearttrain",data=trainorg,target="Class",positive = "Positive")
tasktest=makeClassifTask(id="Hearttest",data=testorg,target="Class",positive = "Positive")


learner = makeLearner("classif.glmnet", predict.type = "prob")

learner$par.set
##                           Type  len     Def                 Constr Req
## alpha                  numeric    -       1                 0 to 1   -
## s                      numeric    -       -               0 to Inf   -
## exact                  logical    -   FALSE                      -   -
## nlambda                integer    -     100               1 to Inf   -
## lambda.min.ratio       numeric    -       -                 0 to 1   -
## lambda           numericvector <NA>       -               0 to Inf   -
## standardize            logical    -    TRUE                      -   -
## intercept              logical    -    TRUE                      -   -
## thresh                 numeric    -   1e-07               0 to Inf   -
## dfmax                  integer    -       -               0 to Inf   -
## pmax                   integer    -       -               0 to Inf   -
## exclude          integervector <NA>       -               1 to Inf   -
## penalty.factor   numericvector <NA>       -                 0 to 1   -
## lower.limits     numericvector <NA>       -              -Inf to 0   -
## upper.limits     numericvector <NA>       -               0 to Inf   -
## maxit                  integer    -  100000               1 to Inf   -
## type.logistic         discrete    -       - Newton,modified.Newton   -
## type.multinomial      discrete    -       -      ungrouped,grouped   -
## fdev                   numeric    -   1e-05                 0 to 1   -
## devmax                 numeric    -   0.999                 0 to 1   -
## eps                    numeric    -   1e-06                 0 to 1   -
## big                    numeric    - 9.9e+35            -Inf to Inf   -
## mnlam                  integer    -       5               1 to Inf   -
## pmin                   numeric    -   1e-09                 0 to 1   -
## exmx                   numeric    -     250            -Inf to Inf   -
## prec                   numeric    -   1e-10            -Inf to Inf   -
## mxit                   integer    -     100               1 to Inf   -
## factory                logical    -   FALSE                      -   -
##                  Tunable Trafo
## alpha               TRUE     -
## s                   TRUE     -
## exact               TRUE     -
## nlambda             TRUE     -
## lambda.min.ratio    TRUE     -
## lambda              TRUE     -
## standardize         TRUE     -
## intercept           TRUE     -
## thresh              TRUE     -
## dfmax               TRUE     -
## pmax                TRUE     -
## exclude             TRUE     -
## penalty.factor      TRUE     -
## lower.limits        TRUE     -
## upper.limits        TRUE     -
## maxit               TRUE     -
## type.logistic       TRUE     -
## type.multinomial    TRUE     -
## fdev                TRUE     -
## devmax              TRUE     -
## eps                 TRUE     -
## big                 TRUE     -
## mnlam               TRUE     -
## pmin                TRUE     -
## exmx                TRUE     -
## prec                TRUE     -
## mxit                TRUE     -
## factory             TRUE     -
ps = makeParamSet(
  makeDiscreteParam("alpha", values = c(0,0.1,0.2,1)),
  makeNumericParam("lambda", lower =0, upper =0.05)
)

rdesc = makeResampleDesc("RepCV",reps = 10,folds=10)
ctrl = makeTuneControlGrid(resolution = 20L)
set.seed(123)
res=tuneParams(learner, task=tasktrain,resampling=rdesc,par.set=ps,control=ctrl,measures = list(mmce))

res$x
## $alpha
## [1] 0.1
## 
## $lambda
## [1] 0.04736842
resdf=generateHyperParsEffectData(res)

resdata=resdf$data%>%as_tibble()

resdata%>%ggplot(aes(x=lambda,y=alpha))+geom_point(aes(size=mmce.test.mean,fill=mmce.test.mean),alpha=0.6,shape=21)+geom_vline(xintercept=res$x$lambda,color="red",size=0.7)+geom_hline(yintercept=res$x$alpha,color="red",size=0.7)+scale_fill_gradient(high="purple",low="#ff0033")+scale_y_continuous(breaks=c(0,0.1,0.2,1))+theme_bw()

learner2=setHyperPars(learner,par.vals = res$x)

glmnmlr=mlr::train(learner2,tasktrain)
predmlr=predict(glmnmlr,tasktest)%>%as_tibble()

mets=list(auc,bac,tpr,tnr,mmce,ber,fpr,fnr)
performance(predict(glmnmlr,tasktest), measures =mets)
##       auc       bac       tpr       tnr      mmce       ber       fpr 
## 0.9186699 0.8100962 0.8125000 0.8076923 0.1900000 0.1899038 0.1923077 
##       fnr 
## 0.1875000
coef(glmnmlr$learner.model)
## 19 x 1 sparse Matrix of class "dgCMatrix"
##                                  s0
## (Intercept)            -1.908175782
## Age                     0.019101064
## RestBP                  0.005334390
## Chol                    0.002510014
## MaxHR                  -0.007621614
## Oldpeak                 0.486407276
## Sex.Female             -0.547810060
## Sex.Male                0.546888991
## ChestPain.Typical      -0.261334077
## ChestPain.Atypical     -0.550772485
## ChestPain.Non_aginal   -0.209744300
## ChestPain.asymptomatic  0.835723281
## FBS.No                 -0.222860161
## FBS.Yes                 0.222491693
## RestECG.Normal         -0.162116838
## RestECG.Abnormal_ST     0.029799702
## RestECG.LVHypertrophy   0.013371833
## CPETAgina.No           -0.546782116
## CPETAgina.Yes           0.546076066
cm4=confusionMatrix(predmlr$response,testorg$Class,positive="Positive")
cm4
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Negative Positive
##   Negative       42        9
##   Positive       10       39
##                                           
##                Accuracy : 0.81            
##                  95% CI : (0.7193, 0.8816)
##     No Information Rate : 0.52            
##     P-Value [Acc > NIR] : 1.528e-09       
##                                           
##                   Kappa : 0.6197          
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8125          
##             Specificity : 0.8077          
##          Pos Pred Value : 0.7959          
##          Neg Pred Value : 0.8235          
##              Prevalence : 0.4800          
##          Detection Rate : 0.3900          
##    Detection Prevalence : 0.4900          
##       Balanced Accuracy : 0.8101          
##                                           
##        'Positive' Class : Positive        
## 

5) Automated tuning of Elastic net model in CARET

Control <- trainControl(method = "repeatedcv",number = 10,repeats = 10,summaryFunction = defaultSummary)

glmncaret=caret::train(data=trainorg,Class~.,method="glmnet",trControl = Control,tuneLength = 10)

glmncaret$finalModel$tuneValue
##    alpha    lambda
## 25   0.3 0.0410411
plot(glmncaret$finalModel,xvar="lambda")

coef(glmncaret$finalModel,s = min(glmncaret$finalModel$lambda))
## 14 x 1 sparse Matrix of class "dgCMatrix"
##                                  1
## (Intercept)           -5.240741319
## Age                    0.023979407
## SexMale                1.451750595
## ChestPainAtypical     -0.141629659
## ChestPainNon_aginal    0.248911346
## ChestPainasymptomatic  1.532453439
## RestBP                 0.007292650
## Chol                   0.004132244
## FBSYes                 0.572148517
## RestECGAbnormal_ST     0.273293000
## RestECGLVHypertrophy   0.224103698
## MaxHR                 -0.008586779
## CPETAginaYes           1.137733201
## Oldpeak                0.714533244
predcaret=predict(glmncaret,newdata=testorg)

cm5=confusionMatrix(predcaret,testorg$Class,positive="Positive")
cm5
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Negative Positive
##   Negative       42        9
##   Positive       10       39
##                                           
##                Accuracy : 0.81            
##                  95% CI : (0.7193, 0.8816)
##     No Information Rate : 0.52            
##     P-Value [Acc > NIR] : 1.528e-09       
##                                           
##                   Kappa : 0.6197          
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8125          
##             Specificity : 0.8077          
##          Pos Pred Value : 0.7959          
##          Neg Pred Value : 0.8235          
##              Prevalence : 0.4800          
##          Detection Rate : 0.3900          
##    Detection Prevalence : 0.4900          
##       Balanced Accuracy : 0.8101          
##                                           
##        'Positive' Class : Positive        
## 

Performance of 5 models

perf=rbind(cm1$byClass,cm2$byClass,cm3$byClass,cm4$byClass,cm5$byClass)%>%as_tibble()%>%mutate(.,Model=c("Lasso","Ridge","Elasticnet","MLR","CARET"))%>%.[,c(1,2,11,12)]

perf%>%gather(Sensitivity:`Balanced Accuracy`,key="Metric",value="Value")%>%ggplot(aes(x=Metric,y=Value,fill=Model))+geom_point(shape=21,color="black",size=5)+geom_hline(yintercept=0.84,color="blue",linetype="dashed",size=1)+coord_flip()+theme_bw()

library(viridis)
perf%>%gather(Sensitivity:`Balanced Accuracy`,key="Metric",value="Value")%>%ggplot()+geom_tile(aes(y=reorder(Model,Value),x=reorder(Metric,-Value),fill=Value),color="black")+geom_text(aes(y=Model,x=Metric,label=round(Value,4)),color="white",fontface = "bold")+scale_fill_viridis(option="A",begin=0.7,end=0.1)+scale_y_discrete("Models")+scale_x_discrete("Metrics")+ggtitle("Performance of 5 models")+theme_bw()

Conclusion

Despite that glmnet package is supported by both caret and mlr frameworks, it’s recommended to use this algorithm on its own. Fitting Elastic net model via caret is a bad idea since the tuning process will cost you a lot of time but cannot warrant a reliable outcome. The mlr package is even worse, as by default there will be no tuning process, so the algorithm would perform as well as the standalone glmnet package can do. A greedy tuning could be misleading and again, too slow. The best way to fit a Elastic model is by using the built-in cross-validation function, it works very fast and provides optimised results in most of cases.