I would suggest you read section 5.1 of Introduction to Statistical Learning to get a full treatment of this topic
In classification methods, we are typically interested in using some observed characteristics of a case to predict a binary categorical outcome. This can be extended to a multi-category outcome, but the largest number of applications involve a 1/0 outcome.
In these examples, we will use the Demographic and Health Survey Model Data. These are based on the DHS survey, but are publicly available and are used to practice using the DHS data sets, but don’t represent a real country.
In this example, we will use the outcome of contraceptive choice (modern vs other/none) as our outcome.
library(haven)
dat<-url("https://github.com/coreysparks/data/blob/master/ZZIR62FL.DTA?raw=true")
model.dat<-read_dta(dat)
Here we recode some of our variables and limit our data to those women who are not currently pregnant and who are sexually active.
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
model.dat2<-model.dat%>%
mutate(region = v024,
modcontra= ifelse(v364 ==1,1, 0),
age = cut(v012, breaks = 5),
livchildren=v218,
educ = v106,
currpreg=v213,
wealth = as.factor(v190),
partnered = ifelse(v701<=1, 0, 1),
work = ifelse(v731%in%c(0,1), 0, 1),
knowmodern=ifelse(v301==3, 1, 0),
age2=v012^2,
rural = ifelse(v025==2, 1,0),
wantmore = ifelse(v605%in%c(1,2), 1, 0))%>%
filter(currpreg==0, v536>0, v701!=9)%>% #notpreg, sex active
dplyr::select(caseid, region, modcontra,age, age2,livchildren, educ, knowmodern, rural, wantmore, partnered,wealth, work)
knitr::kable(head(model.dat2))
| caseid | region | modcontra | age | age2 | livchildren | educ | knowmodern | rural | wantmore | partnered | wealth | work |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 1 2 | 2 | 0 | (28.6,35.4] | 900 | 4 | 0 | 1 | 1 | 1 | 0 | 1 | 1 |
| 1 4 2 | 2 | 0 | (35.4,42.2] | 1764 | 2 | 0 | 1 | 1 | 0 | 0 | 3 | 1 |
| 1 4 3 | 2 | 0 | (21.8,28.6] | 625 | 3 | 1 | 1 | 1 | 0 | 0 | 3 | 1 |
| 1 5 1 | 2 | 0 | (21.8,28.6] | 625 | 2 | 2 | 1 | 1 | 1 | 0 | 2 | 1 |
| 1 6 2 | 2 | 0 | (35.4,42.2] | 1369 | 2 | 0 | 1 | 1 | 1 | 0 | 3 | 1 |
| 1 7 2 | 2 | 0 | (15,21.8] | 441 | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 1 |
The term cross-validation refers to fitting a model on a subset of data and then testing it on another subset of the data. Typically this process is repeated several times.
The simplest way of doing this is to leave out a single observation, refit the model without it in the data, then predict its value using the rest of the data. This is called hold out cross-validation.
K-fold cross-validation is a process where you leave out a “group” of observations, it is as follows:
A further method is called leave one out, or LOO cross-validation. This combines hold out and k-fold cross-validation.
By doing this, we can see how model accuracy is affected by particular individuals, and overall allows for model accuracy to be measured repeatedly so we can assess things such as model tuning parameters.
If you remember from last time, the Lasso analysis depended upon us choosing a good value for the penalty term \(\lambda\). In a cross-validation analysis, we can use the various resamplings of the data to examine the model’s accuracy sensitivity to alternative values of this parameter.
This evaluation can either be done systematically, along a grid, or using a random search.
We talked last time about using model accuracy as a measure of overall fit. This was calculated using the observed and predicted values of our outcome. For classification model, another commonly used metric of model predictive power is the Receiver Operating Characteristics (ROC) curve. This is a probability curve, and is often accompanied by the area under the curve (AUC) measure, which summarizes the separability of the classes. Together they tell you how capable the model is of determining difference between the classes in the data. The higher the values of these, the better, and they are both bound on (0,1).
A nice description of these are found here.
Regression trees are a common technique used in classification problems. Regression or classification trees attempt to find optimal splits in the data so that the best classification of observations can be found. Chapter 8 of Introduction to Statistical Learning is a good place to start with this.
Regression trees generate a set of splitting rules, which classify the data into a set of classes, based on combinations of the predictors.
This example, from the text, shows a 3 region partition of data on baseball hitter data. The outcome here is salary in dollars. Region 1 is players who’ve played less than 4.5 years, they typically have lower salary. The other 2 regions consist of players who’ve played longer than 4.5 years, and who have either less than 117.5 or greater than 117.5 hits. Those with more hits have higher salary than those with lower hits.
The regions can be thought of as nodes (or leaves) on a tree.
Here is a regression tree for these data. The Nodes are the mean salary (in thousands) for players in that region. For example, if a player has less than 4.5 years experiences, and have less than 39.5 hits, their average salary is 676.5 thousand dollars.
library(tree)
data(Hitters, package = "ISLR")
head(Hitters)
## AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun
## -Andy Allanson 293 66 1 30 29 14 1 293 66 1
## -Alan Ashby 315 81 7 24 38 39 14 3449 835 69
## -Alvin Davis 479 130 18 66 72 76 3 1624 457 63
## -Andre Dawson 496 141 20 65 78 37 11 5628 1575 225
## -Andres Galarraga 321 87 10 39 42 30 2 396 101 12
## -Alfredo Griffin 594 169 4 74 51 35 11 4408 1133 19
## CRuns CRBI CWalks League Division PutOuts Assists Errors
## -Andy Allanson 30 29 14 A E 446 33 20
## -Alan Ashby 321 414 375 N W 632 43 10
## -Alvin Davis 224 266 263 A W 880 82 14
## -Andre Dawson 828 838 354 N E 200 11 3
## -Andres Galarraga 48 46 33 N E 805 40 4
## -Alfredo Griffin 501 336 194 A W 282 421 25
## Salary NewLeague
## -Andy Allanson NA A
## -Alan Ashby 475.0 N
## -Alvin Davis 480.0 A
## -Andre Dawson 500.0 N
## -Andres Galarraga 91.5 N
## -Alfredo Griffin 750.0 A
fit1<-tree(Salary ~ Years+Hits, data=Hitters)
fit1
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 263 53320000 535.9
## 2) Years < 4.5 90 6769000 225.8
## 4) Hits < 39.5 5 3131000 676.5 *
## 5) Hits > 39.5 85 2564000 199.3
## 10) Years < 3.5 58 309400 138.8 *
## 11) Years > 3.5 27 1586000 329.3 *
## 3) Years > 4.5 173 33390000 697.2
## 6) Hits < 117.5 90 5312000 464.9
## 12) Years < 6.5 26 644100 334.7 *
## 13) Years > 6.5 64 4048000 517.8 *
## 7) Hits > 117.5 83 17960000 949.2
## 14) Hits < 185 76 13290000 914.3
## 28) Years < 5.5 8 82790 622.5 *
## 29) Years > 5.5 68 12450000 948.7 *
## 15) Hits > 185 7 3571000 1328.0 *
plot(fit1); text(fit1, pretty=1)
The cut points are decided by minimizing the residual sums of squares for a particular region. So we identify regions of the predictor space, \(R_1, R_2, \dots, R_j\) so that
\[\sum_j \sum_{\in R_j} \left ( y_i - \hat{y_{R_j}} \right )^2\] where \(\hat{y_{R_j}}\) is the mean for a particular region j.
Often this process may over-fit the data, meaning it creates too complicated of a tree (too many terminal nodes). It’s possible to prune the tree to arrive at a simpler tree split that may be easier to interpret.
We can tune the tree depth parameter by cross-validation of the data, across different tree depths. In this case a depth of 3 is optimal.
cvt<-cv.tree(fit1)
plot(cvt$size, cvt$dev, type="b")
Then, we can prune the tree, to basically get the tree version of the figure from above
tree2<-prune.tree(fit1, best=3)
plot(tree2); text(tree2, pretty=1)
# plot(x=Hitters$Years, y=Hitters$Hits)
# abline(v=4.5, col=3, lwd=3)
# abline(h=117.5, col=4, lwd=3)
Prediction works by assigning the mean value from a region to an observation who matches the decision rule. For example, let’s make up a player who has 6 years experience and 200 hits
new<-data.frame(Hits=200, Years=6)
pred<-predict(fit1, newdata = new)
pred
## 1
## 1327.5
If our outcome is categorical, or binary, the tree will be a classification tree. Instead of the mean of a particular value being predicted, the classification tree predicts the value of the most common class at a particular terminal node. So in addition to the tree predicting the class at each node, it also gives the class proportions at each node. The classification error rate is the percent of of observations at a node that do not belong to the most common class.
\[Error = 1- max (\hat p_{mk})\] This is not a good method for growing trees, and instead either the Gini index or the entropy is measured at each node:
\[Gini = \sum_k \hat p_{mk}(1-\hat p_{mk})\] The Gini index is used as a measure of node purity, if a node only contains 1 class, it is considered pure
\[Entropy = D = - \sum_k \hat p_{mk} \text{log} \hat p_{mk}\]
The example above is a single “tree”, if we did this type of analysis a large number of times, then we would end up with a forest of such trees.
Bagging is short for bootstrap aggragation. This is a general purpose procedure for reducing the variance in a statistical test, but it is also commonly used in regression tree contexts. How this works in this setting is the data are bootstrapped into a large number of training sets, each of the same size. The regression tree is fit to each of these large number of trees and not pruned. By averaging these bootstrapped trees, the accuracy is actually higher than for a single tree alone.
Random forests not only bag the trees, but at each iteration a different set of predictors is chosen from the data, so not only do we arrive at a more accurate bagged tree, but we can also get an idea of how important any particular variable is, based on its averaged Gini impurity across all the trees considered.
## Warning: Number of logged events: 653
#set up training set identifier
train1<-sample(1:dim(prb2)[1], size = .75*dim(prb2)[1], replace=T)
fit<-tree(e0total~., data=prb2[train1,])
summary(fit)
##
## Regression tree:
## tree(formula = e0total ~ ., data = prb2[train1, ])
## Variables actually used in tree construction:
## [1] "imr" "cdr"
## Number of terminal nodes: 5
## Residual mean deviance: 8.09 = 1222 / 151
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -6.78900 -1.26800 -0.08696 0.00000 1.33900 9.75000
plot(fit); text(fit, pretty=1)
cv.fit<-cv.tree(fit)
plot(cv.fit$size, cv.fit$dev, type="b")
pt1<-prune.tree(fit, best=7)
plot(pt1); text(pt1, pretty=1)
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
set.seed(1115)
bag.1<-randomForest(e0total~., data=prb2[train1,], mtry=3, ntree=100,importance=T) #mtry = 3; choose 3 variables for each tree
bag.1
##
## Call:
## randomForest(formula = e0total ~ ., data = prb2[train1, ], mtry = 3, ntree = 100, importance = T)
## Type of random forest: regression
## Number of trees: 100
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 7.312335
## % Var explained: 94.93
plot(bag.1)
importance(bag.1)
## %IncMSE IncNodePurity
## continent 4.019511 1538.7463
## population. 3.162366 282.1635
## cbr 2.929610 928.9195
## cdr 6.189876 2468.9265
## rate.of.natural.increase 4.705163 346.9130
## net.migration.rate 3.337810 171.5961
## imr 7.144786 3032.0655
## womandlifetimeriskmaternaldeath 4.803696 1301.9272
## tfr 2.889436 710.0053
## percpoplt15 4.670797 2535.1616
## percpopgt65 3.762082 712.3060
## percurban 4.039048 639.4774
## percpopinurbangt750k 2.275862 391.5408
## percpop1549hivaids2007 5.609438 1348.0021
## percmarwomcontraall 5.176757 976.0832
## percmarwomcontramodern 3.653181 397.7990
## percppundernourished0204 3.594125 736.6241
## motorvehper1000pop0005 6.055125 476.5836
## percpopwaccessimprovedwatersource 3.884817 2008.8031
## gnippppercapitausdollars 5.499013 939.5480
## popdenspersqkm 2.489169 188.6840
varImpPlot(bag.1, n.var = 10, type=2)
prb2$lowe0<-as.factor(ifelse(prb2$e0total<median(prb2$e0total), "low", "high"))
fit<-tree(lowe0~., data=prb2[train1,-12])
fit
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 156 213.700 high ( 0.56410 0.43590 )
## 2) imr < 20.5 93 54.540 high ( 0.91398 0.08602 )
## 4) percpopinurbangt750k < 22.5 35 37.630 high ( 0.77143 0.22857 )
## 8) percmarwomcontramodern < 53.5 18 24.730 high ( 0.55556 0.44444 )
## 16) net.migration.rate < 2.5 13 17.320 low ( 0.38462 0.61538 )
## 32) motorvehper1000pop0005 < 166.5 7 8.376 high ( 0.71429 0.28571 ) *
## 33) motorvehper1000pop0005 > 166.5 6 0.000 low ( 0.00000 1.00000 ) *
## 17) net.migration.rate > 2.5 5 0.000 high ( 1.00000 0.00000 ) *
## 9) percmarwomcontramodern > 53.5 17 0.000 high ( 1.00000 0.00000 ) *
## 5) percpopinurbangt750k > 22.5 58 0.000 high ( 1.00000 0.00000 ) *
## 3) imr > 20.5 63 24.120 low ( 0.04762 0.95238 )
## 6) percppundernourished0204 < 4.5 6 8.318 high ( 0.50000 0.50000 ) *
## 7) percppundernourished0204 > 4.5 57 0.000 low ( 0.00000 1.00000 ) *
plot(fit); text(fit, pretty=1)
cv.fit<-cv.tree(fit)
cv.fit
## $size
## [1] 7 5 4 3 2 1
##
## $dev
## [1] 140.66418 133.59489 100.78413 97.67824 94.23793 215.04364
##
## $k
## [1] -Inf 8.177421 12.897492 15.804188 16.913619 135.027065
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv.fit$size, cv.fit$dev, type="b")
pt1<-prune.tree(fit, best=cv.fit$size[which.min(cv.fit$dev)])
plot(pt1); text(pt1, pretty=1)
#Tune to find best number of variables to try
t1<-tuneRF(y=prb2$lowe0, x=prb2[,c(-12,-23)], trace=T, stepFactor = 2, ntreeTry = 1000, plot=T)
## mtry = 4 OOB error = 11.96%
## Searching left ...
## mtry = 2 OOB error = 12.44%
## -0.04 0.05
## Searching right ...
## mtry = 8 OOB error = 13.4%
## -0.12 0.05
t1
## mtry OOBError
## 2.OOB 2 0.1244019
## 4.OOB 4 0.1196172
## 8.OOB 8 0.1339713
bag.2<-randomForest(lowe0~., data=prb2[train1,-12], mtry=2, ntree=500,importance=T)
bag.2
##
## Call:
## randomForest(formula = lowe0 ~ ., data = prb2[train1, -12], mtry = 2, ntree = 500, importance = T)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 5.13%
## Confusion matrix:
## high low class.error
## high 85 3 0.03409091
## low 5 63 0.07352941
plot(bag.2)
importance(bag.2,scale = T )
## high low MeanDecreaseAccuracy
## continent 0.8960079 3.732227 3.230584
## population. 3.7178560 5.365841 6.212556
## cbr 6.6176285 6.520524 8.805896
## cdr 11.5576213 8.769685 13.305334
## rate.of.natural.increase 3.7687952 4.432593 5.647005
## net.migration.rate 6.9591938 8.257085 9.586694
## imr 12.6304590 12.552386 15.343487
## womandlifetimeriskmaternaldeath 11.8247886 11.474400 13.933892
## tfr 5.9173245 6.143863 8.364331
## percpoplt15 7.3961344 7.506203 9.657258
## percpopgt65 5.3528400 6.208990 7.140397
## percurban 9.3710657 6.787947 10.354594
## percpopinurbangt750k 8.7076703 8.144914 10.309891
## percpop1549hivaids2007 5.5805239 4.908233 6.667127
## percmarwomcontraall 9.0006644 8.041696 11.581819
## percmarwomcontramodern 7.1305005 8.056778 9.689848
## percppundernourished0204 9.2346561 10.136050 12.352580
## motorvehper1000pop0005 6.1267178 7.061174 8.616828
## percpopwaccessimprovedwatersource 10.7759029 8.473270 12.130466
## gnippppercapitausdollars 11.1613018 11.187313 14.126590
## popdenspersqkm 5.8920997 4.343651 6.778459
## MeanDecreaseGini
## continent 0.9573333
## population. 1.3505035
## cbr 3.5322531
## cdr 3.2592309
## rate.of.natural.increase 1.9651325
## net.migration.rate 1.8581847
## imr 9.2822795
## womandlifetimeriskmaternaldeath 7.7300771
## tfr 3.1296814
## percpoplt15 3.5768800
## percpopgt65 2.0161076
## percurban 4.3892427
## percpopinurbangt750k 3.2094591
## percpop1549hivaids2007 1.9510930
## percmarwomcontraall 3.7824227
## percmarwomcontramodern 2.5813312
## percppundernourished0204 4.4104305
## motorvehper1000pop0005 2.8350576
## percpopwaccessimprovedwatersource 5.9103717
## gnippppercapitausdollars 7.4416494
## popdenspersqkm 1.0035350
varImpPlot(bag.2, n.var = 10, type=2)
pred<-predict(bag.2, newdata=prb2[-train1,])
table(pred, prb2[-train1, "lowe0"])
##
## pred high low
## high 45 9
## low 3 45
mean(pred==prb2[-train1, "lowe0"]) #accuracy
## [1] 0.8823529
We use an 80% training fraction
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
set.seed(1115)
train<- createDataPartition(y = model.dat2$modcontra , p = .80, list=F)
dtrain<-model.dat2[train,]
## Warning: The `i` argument of ``[`()` can't be a matrix as of tibble 3.0.0.
## Convert to a vector.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
dtest<-model.dat2[-train,]
If we have a mixture of factor variables and continuous predictors in our analysis, it is best to set up the design matrix for our models before we run them. Many methods within caret won’t use factor variables correctly unless we set up the dummy variable representations first.
y<-dtrain$modcontra
y<-as.factor(ifelse(y==1, "mod", "notmod"))
x<-model.matrix(~factor(region)+factor(age)+livchildren+factor(rural)+factor(wantmore)+factor(educ)+partnered+factor(work)+factor(wealth), data=dtrain)
x<-data.frame(x)[,-1]
table(y)
## y
## mod notmod
## 719 3410
prop.table(table(y))
## y
## mod notmod
## 0.1741342 0.8258658
xtest<-model.matrix(~factor(region)+factor(age)+livchildren+factor(rural)+factor(wantmore)+factor(educ)+partnered+factor(work)+factor(wealth), data=dtest)
xtest<-xtest[,-1]
xtest<-data.frame(xtest)
yt<-dtest$modcontra
yt<-as.factor(ifelse(yt==1, "mod", "notmod"))
prop.table(table(yt))
## yt
## mod notmod
## 0.2102713 0.7897287
To set up the training controls for a caret model, we typically have to specify the type of re-sampling method, the number of resamplings, the number of repeats (if you’re doing repeated sampling). Here we will do a 10 fold cross-validation, 10 is often recommended as a choice for k based on experimental sensitivity analysis.
The other things we specify are:
fitctrl <- trainControl(method="repeatedcv",
number=10,
repeats=5,
classProbs = TRUE,
search="random", #randomly search on different values of the tuning parameters
sampling = "down", #optional, but good for unbalanced outcomes like this one
summaryFunction=twoClassSummary,
savePredictions = "all")
Here we fit a basic regression classification tree using the rpart() function
fitctrl <- trainControl(method="cv",
number=10,
#repeats=5,
classProbs = TRUE,
search="random",
sampling = "down",
summaryFunction=twoClassSummary,
savePredictions = "all")
rp1<-caret::train(y=y, x=x,
metric="ROC",
method ="rpart",
tuneLength=20, #try 20 random values of the tuning parameters
trControl=fitctrl,
preProcess=c("center", "scale"))
rp1
## CART
##
## 4129 samples
## 19 predictor
## 2 classes: 'mod', 'notmod'
##
## Pre-processing: centered (19), scaled (19)
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3716, 3716, 3716, 3716, 3717, 3716, ...
## Addtional sampling using down-sampling prior to pre-processing
##
## Resampling results across tuning parameters:
##
## cp ROC Sens Spec
## 0.0000000000 0.6410664 0.6077856 0.6073314
## 0.0002781641 0.6410664 0.6077856 0.6073314
## 0.0003477051 0.6410664 0.6077856 0.6073314
## 0.0004636069 0.6395609 0.6077856 0.6067449
## 0.0013908206 0.6456062 0.6119327 0.6219941
## 0.0018544274 0.6410840 0.6147105 0.6225806
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.001390821.
library(rpart.plot)
## Loading required package: rpart
plot(rp1)
#plot(rp1$finalModel)
prp(rp1$finalModel,type=4, extra = 4,
main="Classification tree for using modern contraception")
varImp(rp1)
## rpart variable importance
##
## Overall
## livchildren 100.000
## factor.educ.2 51.409
## factor.rural.1 49.962
## partnered 45.479
## factor.age..28.6.35.4. 45.425
## factor.region.3 44.418
## factor.age..42.2.49. 42.981
## factor.region.4 41.612
## factor.wealth.5 38.174
## factor.wantmore.1 29.981
## factor.region.2 28.628
## factor.wealth.4 22.629
## factor.age..21.8.28.6. 18.083
## factor.work.1 17.900
## factor.wealth.3 12.914
## factor.age..35.4.42.2. 9.608
## factor.educ.1 5.464
## factor.wealth.2 2.217
## factor.educ.3 0.000
plot(varImp(rp1), top=10)
##Accuracy on training set
pred1<-predict(rp1, newdata=x)
confusionMatrix(data = pred1,reference = y, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 479 1113
## notmod 240 2297
##
## Accuracy : 0.6723
## 95% CI : (0.6578, 0.6866)
## No Information Rate : 0.8259
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.2297
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6662
## Specificity : 0.6736
## Pos Pred Value : 0.3009
## Neg Pred Value : 0.9054
## Prevalence : 0.1741
## Detection Rate : 0.1160
## Detection Prevalence : 0.3856
## Balanced Accuracy : 0.6699
##
## 'Positive' Class : mod
##
predt1<-predict(rp1, newdata=xtest)
confusionMatrix(data = predt1, yt, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 129 271
## notmod 88 544
##
## Accuracy : 0.6521
## 95% CI : (0.6222, 0.6812)
## No Information Rate : 0.7897
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.2001
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.5945
## Specificity : 0.6675
## Pos Pred Value : 0.3225
## Neg Pred Value : 0.8608
## Prevalence : 0.2103
## Detection Rate : 0.1250
## Detection Prevalence : 0.3876
## Balanced Accuracy : 0.6310
##
## 'Positive' Class : mod
##
fitctrl <- trainControl(method="cv",
number=10,
#repeats=5,
classProbs = TRUE,
search="random",
sampling = "down",
summaryFunction=twoClassSummary,
savePredictions = "all")
bt1<-caret::train(y=y, x=x,
metric="ROC",
method ="treebag",
tuneLength=20, #try 20 random values of the tuning parameters
trControl=fitctrl,
preProcess=c("center", "scale"))
print(bt1)
## Bagged CART
##
## 4129 samples
## 19 predictor
## 2 classes: 'mod', 'notmod'
##
## Pre-processing: centered (19), scaled (19)
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3717, 3716, 3716, 3716, 3716, 3716, ...
## Addtional sampling using down-sampling prior to pre-processing
##
## Resampling results:
##
## ROC Sens Spec
## 0.6242268 0.6008803 0.5788856
plot(varImp(bt1))
##Accuracy on training set
pred1<-predict(bt1, newdata=x)
confusionMatrix(data = pred1,reference = y, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 654 1202
## notmod 65 2208
##
## Accuracy : 0.6931
## 95% CI : (0.6788, 0.7072)
## No Information Rate : 0.8259
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3431
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.9096
## Specificity : 0.6475
## Pos Pred Value : 0.3524
## Neg Pred Value : 0.9714
## Prevalence : 0.1741
## Detection Rate : 0.1584
## Detection Prevalence : 0.4495
## Balanced Accuracy : 0.7786
##
## 'Positive' Class : mod
##
predt1<-predict(bt1, newdata=xtest)
confusionMatrix(data = predt1,yt, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 132 360
## notmod 85 455
##
## Accuracy : 0.5688
## 95% CI : (0.5379, 0.5993)
## No Information Rate : 0.7897
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1137
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6083
## Specificity : 0.5583
## Pos Pred Value : 0.2683
## Neg Pred Value : 0.8426
## Prevalence : 0.2103
## Detection Rate : 0.1279
## Detection Prevalence : 0.4767
## Balanced Accuracy : 0.5833
##
## 'Positive' Class : mod
##
library(rpart)
rf1<-caret::train(y=y, x=x,
data=dtrain,
metric="ROC",
method ="rf",
tuneLength=20, #try 20 random values of the tuning parameters
trControl=fitctrl,
preProcess=c("center", "scale"))
rf1
## Random Forest
##
## 4129 samples
## 19 predictor
## 2 classes: 'mod', 'notmod'
##
## Pre-processing: centered (19), scaled (19)
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 3716, 3716, 3716, 3716, 3716, 3717, ...
## Addtional sampling using down-sampling prior to pre-processing
##
## Resampling results across tuning parameters:
##
## mtry ROC Sens Spec
## 1 0.6743208 0.5313576 0.7149560
## 2 0.6803158 0.5912559 0.6744868
## 4 0.6665729 0.6260172 0.6258065
## 5 0.6596331 0.6259977 0.6067449
## 6 0.6433240 0.5981808 0.5879765
## 7 0.6491962 0.6218310 0.5976540
## 11 0.6321789 0.6300861 0.5656891
## 12 0.6262939 0.6357394 0.5674487
## 13 0.6270354 0.6092332 0.5651026
## 14 0.6376019 0.6301056 0.5777126
## 15 0.6182209 0.6301056 0.5501466
## 18 0.6354484 0.6370892 0.5633431
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
##Accuracy on training set
predrf1<-predict(rf1, newdata=x)
confusionMatrix(data = predrf1,y, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 474 1046
## notmod 245 2364
##
## Accuracy : 0.6873
## 95% CI : (0.6729, 0.7015)
## No Information Rate : 0.8259
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.2449
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6592
## Specificity : 0.6933
## Pos Pred Value : 0.3118
## Neg Pred Value : 0.9061
## Prevalence : 0.1741
## Detection Rate : 0.1148
## Detection Prevalence : 0.3681
## Balanced Accuracy : 0.6763
##
## 'Positive' Class : mod
##
predgl1<-predict(rf1, newdata=xtest)
confusionMatrix(data = predgl1,yt, positive = "mod" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction mod notmod
## mod 126 284
## notmod 91 531
##
## Accuracy : 0.6366
## 95% CI : (0.6064, 0.666)
## No Information Rate : 0.7897
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1751
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.5806
## Specificity : 0.6515
## Pos Pred Value : 0.3073
## Neg Pred Value : 0.8537
## Prevalence : 0.2103
## Detection Rate : 0.1221
## Detection Prevalence : 0.3973
## Balanced Accuracy : 0.6161
##
## 'Positive' Class : mod
##
We see that by down sampling the more common level of the outcome, we end up with much more balanced accuracy in terms of specificity and sensitivity.
You see that the best fitting model is much more complicated than the previous one. Each node box displays the classification, the probability of each class at that node (i.e. the probability of the class conditioned on the node) and the percentage of observations used at that node. From here.
The ROC curve can be shown for the model:
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
# Select a parameter setting
mycp<-rf1$pred$mtry==rf1$bestTune$mtry
selectedIndices <- mycp==T
# Plot:
plot.roc(rf1$pred$obs[selectedIndices], rf1$pred$mod[selectedIndices], grid=T)
## Setting levels: control = mod, case = notmod
## Setting direction: controls > cases
#Value of ROC and AUC
roc(rf1$pred$obs[selectedIndices], rf1$pred$mod[selectedIndices])
## Setting levels: control = mod, case = notmod
## Setting direction: controls > cases
##
## Call:
## roc.default(response = rf1$pred$obs[selectedIndices], predictor = rf1$pred$mod[selectedIndices])
##
## Data: rf1$pred$mod[selectedIndices] in 719 controls (rf1$pred$obs[selectedIndices] mod) > 3410 cases (rf1$pred$obs[selectedIndices] notmod).
## Area under the curve: 0.6794
auc(rf1$pred$obs[selectedIndices], rf1$pred$mod[selectedIndices])
## Setting levels: control = mod, case = notmod
## Setting direction: controls > cases
## Area under the curve: 0.6794