library(mlbench)
library(gbm)
data("PimaIndiansDiabetes2")
head(PimaIndiansDiabetes2)
## pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1 6 148 72 35 NA 33.6 0.627 50 pos
## 2 1 85 66 29 NA 26.6 0.351 31 neg
## 3 8 183 64 NA NA 23.3 0.672 32 pos
## 4 1 89 66 23 94 28.1 0.167 21 neg
## 5 0 137 40 35 168 43.1 2.288 33 pos
## 6 5 116 74 NA NA 25.6 0.201 30 neg
mydata = PimaIndiansDiabetes2
mydata$diabetes = as.numeric(mydata$diabetes)
head(mydata)
## pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1 6 148 72 35 NA 33.6 0.627 50 2
## 2 1 85 66 29 NA 26.6 0.351 31 1
## 3 8 183 64 NA NA 23.3 0.672 32 2
## 4 1 89 66 23 94 28.1 0.167 21 1
## 5 0 137 40 35 168 43.1 2.288 33 2
## 6 5 116 74 NA NA 25.6 0.201 30 1
mydata = transform(mydata, diabetes=diabetes-1)
head(mydata)
## pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1 6 148 72 35 NA 33.6 0.627 50 1
## 2 1 85 66 29 NA 26.6 0.351 31 0
## 3 8 183 64 NA NA 23.3 0.672 32 1
## 4 1 89 66 23 94 28.1 0.167 21 0
## 5 0 137 40 35 168 43.1 2.288 33 1
## 6 5 116 74 NA NA 25.6 0.201 30 0
For the classification, we use the bernoulli distribution. As the author suggested, normally, we should choose small shrinkage,such between 0.01 and 0.001; the number of trees, n.trees, is between 3000 and 10000.
gbm.model = gbm(diabetes~., data=mydata, shrinkage=0.01, distribution = 'bernoulli', cv.folds=5, n.trees=3000, verbose=F)
Check the best iteration number.
Summary of the model results, with the importance plot of predictors.
summary(gbm.model)
## var rel.inf
## glucose glucose 36.267595
## mass mass 18.718292
## age age 13.041025
## pedigree pedigree 10.776126
## insulin insulin 7.731030
## pressure pressure 4.845085
## pregnant pregnant 4.584827
## triceps triceps 4.036020
Plots the marginal effect of the selected variables by "integrating" out the other variables.
plot.gbm(gbm.model, 1, best.iter)
plot.gbm(gbm.model, 2, best.iter)
plot.gbm(gbm.model, 3, best.iter)
library(caret)
mydata=PimaIndiansDiabetes2
set.seed(123)
fitControl = trainControl(method="cv", number=5, returnResamp = "all")
model2 = train(diabetes~., data=mydata, method="gbm",distribution="bernoulli", trControl=fitControl, verbose=F, tuneGrid=data.frame(.n.trees=best.iter, .shrinkage=0.01, .interaction.depth=1, .n.minobsinnode=1))
model2
## Stochastic Gradient Boosting
##
## 768 samples
## 8 predictors
## 2 classes: 'neg', 'pos'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 314, 313, 314, 313, 314
## Resampling results:
##
## Accuracy Kappa
## 0.7729958 0.4658758
##
## Tuning parameter 'n.trees' was held constant at a value of 934
## 1
## Tuning parameter 'shrinkage' was held constant at a value of
## 0.01
## Tuning parameter 'n.minobsinnode' was held constant at a value of 1
##
confusionMatrix(model2)
## Cross-Validated (5 fold) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction neg pos
## neg 58.2 14.0
## pos 8.7 19.1
##
## Accuracy (average) : 0.773
mPred = predict(model2, mydata, na.action = na.pass)
postResample(mPred, mydata$diabetes)
## Accuracy Kappa
## 0.7981771 0.5223267
confusionMatrix(mPred, mydata$diabetes)
## Confusion Matrix and Statistics
##
## Reference
## Prediction neg pos
## neg 463 118
## pos 37 150
##
## Accuracy : 0.7982
## 95% CI : (0.768, 0.826)
## No Information Rate : 0.651
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5223
## Mcnemar's Test P-Value : 1.312e-10
##
## Sensitivity : 0.9260
## Specificity : 0.5597
## Pos Pred Value : 0.7969
## Neg Pred Value : 0.8021
## Prevalence : 0.6510
## Detection Rate : 0.6029
## Detection Prevalence : 0.7565
## Balanced Accuracy : 0.7429
##
## 'Positive' Class : neg
##
confusionMatrix(model2)
## Cross-Validated (5 fold) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction neg pos
## neg 58.2 14.0
## pos 8.7 19.1
##
## Accuracy (average) : 0.773
getTrainPerf(model2)
## TrainAccuracy TrainKappa method
## 1 0.7729958 0.4658758 gbm
mResults = predict(model2, mydata, na.action = na.pass, type = "prob")
mResults$obs = mydata$diabetes
head(mResults)
## neg pos obs
## 1 0.3852106 0.61478941 pos
## 2 0.8698684 0.13013164 neg
## 3 0.1897053 0.81029471 pos
## 4 0.9605993 0.03940066 neg
## 5 0.1850129 0.81498706 pos
## 6 0.9065770 0.09342299 neg
mnLogLoss(mResults, lev = levels(mResults$obs))
## logLoss
## 0.4398875
mResults$pred = predict(model2, mydata, na.action = na.pass)
multiClassSummary(mResults, lev = levels(mResults$obs))
## logLoss ROC Accuracy Kappa
## 0.4398875 0.8620299 0.7981771 0.5223267
## Sensitivity Specificity Pos_Pred_Value Neg_Pred_Value
## 0.9260000 0.5597015 0.7969019 0.8021390
## Detection_Rate Balanced_Accuracy
## 0.6028646 0.7428507
evalResults <- data.frame(Class = mydata$diabetes)
evalResults$GBM <- predict(model2, mydata, na.action = na.pass, type = "prob")[,"neg"]
head(evalResults)
## Class GBM
## 1 pos 0.3852106
## 2 neg 0.8698684
## 3 pos 0.1897053
## 4 neg 0.9605993
## 5 pos 0.1850129
## 6 neg 0.9065770
trellis.par.set(caretTheme())
liftData <- lift(Class ~ GBM, data = evalResults)
plot(liftData, values = 60, auto.key = list(columns = 1,
lines = TRUE,
points = FALSE))
sessionInfo()
## R version 3.3.0 beta (2016-03-30 r70404)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 14.04.4 LTS
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] parallel splines stats graphics grDevices utils datasets
## [8] methods base
##
## other attached packages:
## [1] plyr_1.8.3 caret_6.0-68 ggplot2_2.1.0 gbm_2.1.1
## [5] lattice_0.20-24 survival_2.37-7 mlbench_2.1-1
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.4 compiler_3.3.0 formatR_1.3
## [4] nloptr_1.0.4 class_7.3-9 iterators_1.0.8
## [7] tools_3.3.0 digest_0.6.9 lme4_1.1-11
## [10] evaluate_0.8.3 gtable_0.2.0 nlme_3.1-126
## [13] mgcv_1.7-28 Matrix_1.2-4 foreach_1.4.3
## [16] yaml_2.1.13 SparseM_1.7 e1071_1.6-7
## [19] stringr_1.0.0 knitr_1.12.3 pROC_1.8
## [22] MatrixModels_0.4-1 stats4_3.3.0 grid_3.3.0
## [25] nnet_7.3-7 rmarkdown_0.9.5 knitrBootstrap_1.0.0
## [28] minqa_1.2.4 reshape2_1.4.1 car_2.1-2
## [31] magrittr_1.5 scales_0.4.0 codetools_0.2-8
## [34] htmltools_0.3.5 MASS_7.3-29 pbkrtest_0.4-6
## [37] mime_0.4 colorspace_1.2-6 quantreg_5.21
## [40] stringi_1.0-1 munsell_0.4.3 markdown_0.7.7
## [1] "Last Updated on Thu Apr 14 13:52:27 2016"