GBM classification example

Load the libraries

library(mlbench)
library(gbm)
## Loading required package: survival
## Loading required package: splines
## Loading required package: lattice
## Loading required package: parallel
## Loaded gbm 2.1.1

Load the data and check

data("PimaIndiansDiabetes2")
head(PimaIndiansDiabetes2)
##   pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1        6     148       72      35      NA 33.6    0.627  50      pos
## 2        1      85       66      29      NA 26.6    0.351  31      neg
## 3        8     183       64      NA      NA 23.3    0.672  32      pos
## 4        1      89       66      23      94 28.1    0.167  21      neg
## 5        0     137       40      35     168 43.1    2.288  33      pos
## 6        5     116       74      NA      NA 25.6    0.201  30      neg
mydata = PimaIndiansDiabetes2

mydata$diabetes = as.numeric(mydata$diabetes)
head(mydata)
##   pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1        6     148       72      35      NA 33.6    0.627  50        2
## 2        1      85       66      29      NA 26.6    0.351  31        1
## 3        8     183       64      NA      NA 23.3    0.672  32        2
## 4        1      89       66      23      94 28.1    0.167  21        1
## 5        0     137       40      35     168 43.1    2.288  33        2
## 6        5     116       74      NA      NA 25.6    0.201  30        1

Check the target

mydata = transform(mydata, diabetes=diabetes-1)
head(mydata)
##   pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1        6     148       72      35      NA 33.6    0.627  50        1
## 2        1      85       66      29      NA 26.6    0.351  31        0
## 3        8     183       64      NA      NA 23.3    0.672  32        1
## 4        1      89       66      23      94 28.1    0.167  21        0
## 5        0     137       40      35     168 43.1    2.288  33        1
## 6        5     116       74      NA      NA 25.6    0.201  30        0

Get the model using the gbm

For the classification, we use the bernoulli distribution. As the author suggested, normally, we should choose small shrinkage,such between 0.01 and 0.001; the number of trees, n.trees, is between 3000 and 10000.

gbm.model = gbm(diabetes~., data=mydata, shrinkage=0.01, distribution = 'bernoulli', cv.folds=5, n.trees=3000, verbose=F)

Check the best iteration number.

best.iter = gbm.perf(gbm.model, method="cv")
best.iter
## [1] 934

Summary of the model results, with the importance plot of predictors.

summary(gbm.model)
##               var   rel.inf
## glucose   glucose 36.267595
## mass         mass 18.718292
## age           age 13.041025
## pedigree pedigree 10.776126
## insulin   insulin  7.731030
## pressure pressure  4.845085
## pregnant pregnant  4.584827
## triceps   triceps  4.036020

Plots the marginal effect of the selected variables by "integrating" out the other variables.

plot.gbm(gbm.model, 1, best.iter)
plot.gbm(gbm.model, 2, best.iter) 
plot.gbm(gbm.model, 3, best.iter)
Using the caret package the get the model preformance in the best iteration.
library(caret)
## Loading required package: ggplot2
## 
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
## 
##     cluster
mydata=PimaIndiansDiabetes2

set.seed(123)
fitControl = trainControl(method="cv", number=5, returnResamp = "all")

model2 = train(diabetes~., data=mydata, method="gbm",distribution="bernoulli", trControl=fitControl, verbose=F, tuneGrid=data.frame(.n.trees=best.iter, .shrinkage=0.01, .interaction.depth=1, .n.minobsinnode=1))
## Loading required package: plyr
model2
## Stochastic Gradient Boosting 
## 
## 768 samples
##   8 predictors
##   2 classes: 'neg', 'pos' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 314, 313, 314, 313, 314 
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.7729958  0.4658758
## 
## Tuning parameter 'n.trees' was held constant at a value of 934
##  1
## Tuning parameter 'shrinkage' was held constant at a value of
##  0.01
## Tuning parameter 'n.minobsinnode' was held constant at a value of 1
## 
Model performance checking
confusionMatrix(model2)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction  neg  pos
##        neg 58.2 14.0
##        pos  8.7 19.1
##                            
##  Accuracy (average) : 0.773
mPred = predict(model2, mydata, na.action = na.pass)
postResample(mPred, mydata$diabetes)
##  Accuracy     Kappa 
## 0.7981771 0.5223267
confusionMatrix(mPred, mydata$diabetes)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction neg pos
##        neg 463 118
##        pos  37 150
##                                         
##                Accuracy : 0.7982        
##                  95% CI : (0.768, 0.826)
##     No Information Rate : 0.651         
##     P-Value [Acc > NIR] : < 2.2e-16     
##                                         
##                   Kappa : 0.5223        
##  Mcnemar's Test P-Value : 1.312e-10     
##                                         
##             Sensitivity : 0.9260        
##             Specificity : 0.5597        
##          Pos Pred Value : 0.7969        
##          Neg Pred Value : 0.8021        
##              Prevalence : 0.6510        
##          Detection Rate : 0.6029        
##    Detection Prevalence : 0.7565        
##       Balanced Accuracy : 0.7429        
##                                         
##        'Positive' Class : neg           
## 
confusionMatrix(model2)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction  neg  pos
##        neg 58.2 14.0
##        pos  8.7 19.1
##                            
##  Accuracy (average) : 0.773
Check other relevant metrics:
getTrainPerf(model2)
##   TrainAccuracy TrainKappa method
## 1     0.7729958  0.4658758    gbm
mResults = predict(model2, mydata, na.action = na.pass, type = "prob")
mResults$obs = mydata$diabetes
head(mResults)
##         neg        pos obs
## 1 0.3852106 0.61478941 pos
## 2 0.8698684 0.13013164 neg
## 3 0.1897053 0.81029471 pos
## 4 0.9605993 0.03940066 neg
## 5 0.1850129 0.81498706 pos
## 6 0.9065770 0.09342299 neg
mnLogLoss(mResults, lev = levels(mResults$obs))
##   logLoss 
## 0.4398875
mResults$pred = predict(model2, mydata, na.action = na.pass)
multiClassSummary(mResults, lev = levels(mResults$obs))
##           logLoss               ROC          Accuracy             Kappa 
##         0.4398875         0.8620299         0.7981771         0.5223267 
##       Sensitivity       Specificity    Pos_Pred_Value    Neg_Pred_Value 
##         0.9260000         0.5597015         0.7969019         0.8021390 
##    Detection_Rate Balanced_Accuracy 
##         0.6028646         0.7428507
evalResults <- data.frame(Class = mydata$diabetes)
evalResults$GBM <- predict(model2, mydata, na.action = na.pass, type = "prob")[,"neg"]

head(evalResults)
##   Class       GBM
## 1   pos 0.3852106
## 2   neg 0.8698684
## 3   pos 0.1897053
## 4   neg 0.9605993
## 5   pos 0.1850129
## 6   neg 0.9065770
trellis.par.set(caretTheme())
liftData <- lift(Class ~ GBM, data = evalResults)
plot(liftData, values = 60, auto.key = list(columns = 1,
                                            lines = TRUE,
                                            points = FALSE))
sessionInfo()
## R version 3.3.0 beta (2016-03-30 r70404)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 14.04.4 LTS
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] parallel  splines   stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
## [1] plyr_1.8.3      caret_6.0-68    ggplot2_2.1.0   gbm_2.1.1      
## [5] lattice_0.20-24 survival_2.37-7 mlbench_2.1-1  
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.4          compiler_3.3.0       formatR_1.3         
##  [4] nloptr_1.0.4         class_7.3-9          iterators_1.0.8     
##  [7] tools_3.3.0          digest_0.6.9         lme4_1.1-11         
## [10] evaluate_0.8.3       gtable_0.2.0         nlme_3.1-126        
## [13] mgcv_1.7-28          Matrix_1.2-4         foreach_1.4.3       
## [16] yaml_2.1.13          SparseM_1.7          e1071_1.6-7         
## [19] stringr_1.0.0        knitr_1.12.3         pROC_1.8            
## [22] MatrixModels_0.4-1   stats4_3.3.0         grid_3.3.0          
## [25] nnet_7.3-7           rmarkdown_0.9.5      knitrBootstrap_1.0.0
## [28] minqa_1.2.4          reshape2_1.4.1       car_2.1-2           
## [31] magrittr_1.5         scales_0.4.0         codetools_0.2-8     
## [34] htmltools_0.3.5      MASS_7.3-29          pbkrtest_0.4-6      
## [37] mime_0.4             colorspace_1.2-6     quantreg_5.21       
## [40] stringi_1.0-1        munsell_0.4.3        markdown_0.7.7
## [1] "Last Updated on Thu Apr 14 13:52:27 2016"