I would suggest you read section 5.1 of Introduction to Statistical Learning to get a full treatment of this topic
In classification methods, we are typically interested in using some observed characteristics of a case to predict a binary categorical outcome. This can be extended to a multi-category outcome, but the largest number of applications involve a 1/0 outcome.
In these examples, we will use the Demographic and Health Survey Model Data. These are based on the DHS survey, but are publicly available and are used to practice using the DHS data sets, but don’t represent a real country.
In this example, we will use the outcome of contraceptive choice (modern vs other/none) as our outcome.
library(haven)
dat<-url("https://github.com/coreysparks/data/blob/master/ZZIR62FL.DTA?raw=true")
model.dat<-read_dta(dat)
Here we recode some of our variables and limit our data to those women who are not currently pregnant and who are sexually active.
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
model.dat2<-model.dat%>%
mutate(region = v024,
modcontra= as.factor(ifelse(v364 ==1,"Modcontra", "NoModContra")),
age = cut(v012, breaks = 5),
livchildren=v218,
educ = v106,
currpreg=v213,
knowmodern=ifelse(v301==3, 1, 0),
age2=v012^2,
rural = ifelse(v025==2, 1,0),
wantmore = ifelse(v605%in%c(1,2), 1, 0))%>%
filter(currpreg==0, v536>0)%>% #notpreg, sex active
dplyr::select(caseid, region, modcontra,age, age2,livchildren, educ, knowmodern, rural, wantmore)
knitr::kable(head(model.dat2))
| caseid | region | modcontra | age | age2 | livchildren | educ | knowmodern | rural | wantmore |
|---|---|---|---|---|---|---|---|---|---|
| 1 1 2 | 2 | NoModContra | (28.6,35.4] | 900 | 4 | 0 | 1 | 1 | 1 |
| 1 4 2 | 2 | NoModContra | (35.4,42.2] | 1764 | 2 | 0 | 1 | 1 | 0 |
| 1 4 3 | 2 | NoModContra | (21.8,28.6] | 625 | 3 | 1 | 1 | 1 | 0 |
| 1 5 1 | 2 | NoModContra | (21.8,28.6] | 625 | 2 | 2 | 1 | 1 | 1 |
| 1 6 2 | 2 | NoModContra | (35.4,42.2] | 1369 | 2 | 0 | 1 | 1 | 1 |
| 1 6 3 | 2 | NoModContra | (15,21.8] | 289 | 0 | 2 | 0 | 1 | 1 |
The term cross-validation refers to fitting a model on a subset of data and then testing it on another subset of the data. Typically this process is repeated several times.
The simplest way of doing this is to leave out a single observation, refit the model without it in the data, then predict its value using the rest of the data. This is called hold out cross-validation.
K-fold cross-validation is a process where you leave out a “group” of observations, it is as follows:
A further method is called leave one out, or LOO cross-validation. This combines hold out and k-fold cross-validation.
By doing this, we can see how model accuracy is affected by particular individuals, and overall allows for model accuracy to be measured repeatedly so we can assess things such as model tuning parameters.
If you remember from last time, the regression partition (rpart) analysis depended upon us choosing a good value for the complexity parameter, CP. In a cross-validation analysis, we can use the various resamplings of the data to examine the model’s accuracy sensitivity to alternative values of this parameter.
This evaluation can either be done systematically, along a grid, or using a random search.
We talked last time about using model accuracy as a measure of overall fit. This was calculated using the observed and predicted values of our outcome. For classification model, another commonly used metric of model predictive power is the Receiver Operating Characteristics (ROC) curve. This is a probability curve, and is often accompanied by the area under the curve (AUC) measure, which summarizes the separability of the classes. Together they tell you how capable the model is of determining difference between the classes in the data. The higher the values of these, the better, and they are both bound on (0,1).
A nice description of these are found here.
As we saw in the first working group example, the regression tree is another common technique used in classification problems. Regression or classification trees attempt to find optimal splits in the data so that the best classification of observations can be found.
If we have a mixture of factor variables and continuous predictors in our analysis, it is best to set up the design matrix for our models before we run them. Many methods within caret won’t use factor variables correctly unless we set up the dummy variable representations first.
datmat<-model.matrix(~factor(region)-1+factor(age)+livchildren+rural+wantmore+factor(educ)-1, data=model.dat2)
datmat<-data.frame(datmat)
datmat$modcontra<- model.dat2$modcontra
head(datmat)
## factor.region.1 factor.region.2 factor.region.3 factor.region.4
## 1 0 1 0 0
## 2 0 1 0 0
## 3 0 1 0 0
## 4 0 1 0 0
## 5 0 1 0 0
## 6 0 1 0 0
## factor.age..21.8.28.6. factor.age..28.6.35.4. factor.age..35.4.42.2.
## 1 0 1 0
## 2 0 0 1
## 3 1 0 0
## 4 1 0 0
## 5 0 0 1
## 6 0 0 0
## factor.age..42.2.49. livchildren rural wantmore factor.educ.1
## 1 0 4 1 1 0
## 2 0 2 1 0 0
## 3 0 3 1 0 1
## 4 0 2 1 1 0
## 5 0 2 1 1 0
## 6 0 0 1 1 0
## factor.educ.2 factor.educ.3 modcontra
## 1 0 0 NoModContra
## 2 0 0 NoModContra
## 3 0 0 NoModContra
## 4 1 0 NoModContra
## 5 0 0 NoModContra
## 6 1 0 NoModContra
summary(datmat)
## factor.region.1 factor.region.2 factor.region.3 factor.region.4
## Min. :0.0000 Min. :0.0000 Min. :0.0000 Min. :0.000
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.000
## Median :0.0000 Median :0.0000 Median :0.0000 Median :0.000
## Mean :0.3855 Mean :0.2586 Mean :0.1669 Mean :0.189
## 3rd Qu.:1.0000 3rd Qu.:1.0000 3rd Qu.:0.0000 3rd Qu.:0.000
## Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.000
## factor.age..21.8.28.6. factor.age..28.6.35.4. factor.age..35.4.42.2.
## Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000 Median :0.0000
## Mean :0.2448 Mean :0.2168 Mean :0.1605
## 3rd Qu.:0.0000 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000 Max. :1.0000
## factor.age..42.2.49. livchildren rural wantmore
## Min. :0.0000 Min. : 0.000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.: 1.000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median : 2.000 Median :1.0000 Median :1.0000
## Mean :0.1255 Mean : 2.555 Mean :0.5934 Mean :0.5209
## 3rd Qu.:0.0000 3rd Qu.: 4.000 3rd Qu.:1.0000 3rd Qu.:1.0000
## Max. :1.0000 Max. :11.000 Max. :1.0000 Max. :1.0000
## factor.educ.1 factor.educ.2 factor.educ.3 modcontra
## Min. :0.0000 Min. :0.0000 Min. :0.00000 Modcontra :1761
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.00000 NoModContra:5044
## Median :0.0000 Median :0.0000 Median :0.00000
## Mean :0.1187 Mean :0.2575 Mean :0.03542
## 3rd Qu.:0.0000 3rd Qu.:1.0000 3rd Qu.:0.00000
## Max. :1.0000 Max. :1.0000 Max. :1.00000
We use an 80% training fraction
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
set.seed(1115)
train<- createDataPartition(y = datmat$modcontra , p = .80, list=F)
dtrain<-datmat[train,c(-1,-5, -9)]
dtest<-datmat[-train,c(-1,-5, -9)]
table(dtrain$modcontra)
##
## Modcontra NoModContra
## 1409 4036
prop.table(table(dtrain$modcontra))
##
## Modcontra NoModContra
## 0.2587695 0.7412305
summary(dtrain)
## factor.region.2 factor.region.3 factor.region.4 factor.age..28.6.35.4.
## Min. :0.0000 Min. :0.000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.000 Median :0.0000 Median :0.0000
## Mean :0.2573 Mean :0.168 Mean :0.1914 Mean :0.2101
## 3rd Qu.:1.0000 3rd Qu.:0.000 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.000 Max. :1.0000 Max. :1.0000
## factor.age..35.4.42.2. factor.age..42.2.49. rural
## Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000 Median :1.0000
## Mean :0.1638 Mean :0.1225 Mean :0.5927
## 3rd Qu.:0.0000 3rd Qu.:0.0000 3rd Qu.:1.0000
## Max. :1.0000 Max. :1.0000 Max. :1.0000
## wantmore factor.educ.1 factor.educ.2 factor.educ.3
## Min. :0.000 Min. :0.0000 Min. :0.0000 Min. :0.00000
## 1st Qu.:0.000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.00000
## Median :1.000 Median :0.0000 Median :0.0000 Median :0.00000
## Mean :0.521 Mean :0.1188 Mean :0.2597 Mean :0.03581
## 3rd Qu.:1.000 3rd Qu.:0.0000 3rd Qu.:1.0000 3rd Qu.:0.00000
## Max. :1.000 Max. :1.0000 Max. :1.0000 Max. :1.00000
## modcontra
## Modcontra :1409
## NoModContra:4036
##
##
##
##
summary(dtest)
## factor.region.2 factor.region.3 factor.region.4 factor.age..28.6.35.4.
## Min. :0.000 Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.000 Median :0.0000 Median :0.0000 Median :0.0000
## Mean :0.264 Mean :0.1625 Mean :0.1794 Mean :0.2434
## 3rd Qu.:1.000 3rd Qu.:0.0000 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.000 Max. :1.0000 Max. :1.0000 Max. :1.0000
## factor.age..35.4.42.2. factor.age..42.2.49. rural
## Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000 Median :1.0000
## Mean :0.1471 Mean :0.1375 Mean :0.5963
## 3rd Qu.:0.0000 3rd Qu.:0.0000 3rd Qu.:1.0000
## Max. :1.0000 Max. :1.0000 Max. :1.0000
## wantmore factor.educ.1 factor.educ.2 factor.educ.3
## Min. :0.0000 Min. :0.0000 Min. :0.0000 Min. :0.00000
## 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.00000
## Median :1.0000 Median :0.0000 Median :0.0000 Median :0.00000
## Mean :0.5206 Mean :0.1184 Mean :0.2485 Mean :0.03382
## 3rd Qu.:1.0000 3rd Qu.:0.0000 3rd Qu.:0.0000 3rd Qu.:0.00000
## Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.00000
## modcontra
## Modcontra : 352
## NoModContra:1008
##
##
##
##
To set up the training controls for a caret model, we typically have to specify the type of re-sampling method, the number of resamplings, the number of repeats (if you’re doing repeated sampling). Here we will do a 10 fold cross-validation, 10 is often recommended as a choice for k based on experimental sensitivity analysis.
The other things we specify are:
fitctrl <- trainControl(method="repeatedcv",
number=10,
repeats=3,
classProbs = TRUE,
search="random",
sampling = "down",
summaryFunction=twoClassSummary,
savePredictions = "all")
Here we use caret to fit the rpart model
library(rpart)
rp1<-caret::train(modcontra~.,
data=dtrain,
metric="ROC",
method ="rpart",
tuneLength=20, #try 20 random values of the tuning parameters
trControl=fitctrl)
rp1
## CART
##
## 5445 samples
## 11 predictor
## 2 classes: 'Modcontra', 'NoModContra'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times)
## Summary of sample sizes: 4901, 4900, 4900, 4901, 4900, 4901, ...
## Addtional sampling using down-sampling
##
## Resampling results across tuning parameters:
##
## cp ROC Sens Spec
## 0.0000000000 0.7118201 0.6656991 0.6573267
## 0.0001182872 0.7118201 0.6656991 0.6573267
## 0.0001419446 0.7118201 0.6656991 0.6573267
## 0.0002365744 0.7115241 0.6654627 0.6573277
## 0.0007097232 0.7078988 0.6671310 0.6574947
## 0.0012775018 0.7042976 0.6732945 0.6570073
## 0.0033120416 0.6940002 0.6654779 0.6620476
## 0.0035486160 0.6919881 0.6586221 0.6662556
## 0.0106458481 0.6818322 0.6709591 0.6462660
## 0.0124201561 0.6809661 0.6844343 0.6344422
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.0001419446.
gl1<-caret::train(modcontra~.,
data=dtrain,
metric="ROC",
method ="glm",
#family=binomial,
#tuneLength=20, #try 20 random values of the tuning parameters
trControl=fitctrl)
gl1
## Generalized Linear Model
##
## 5445 samples
## 11 predictor
## 2 classes: 'Modcontra', 'NoModContra'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times)
## Summary of sample sizes: 4901, 4900, 4900, 4900, 4900, 4901, ...
## Addtional sampling using down-sampling
##
## Resampling results:
##
## ROC Sens Spec
## 0.7241083 0.6394596 0.705647
##Accuracy on training set
predrp1<-predict(rp1, newdata=dtrain)
confusionMatrix(data = predrp1,dtrain$modcontra, positive = "Modcontra" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction Modcontra NoModContra
## Modcontra 1002 1389
## NoModContra 407 2647
##
## Accuracy : 0.6702
## 95% CI : (0.6575, 0.6826)
## No Information Rate : 0.7412
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.2991
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.7111
## Specificity : 0.6558
## Pos Pred Value : 0.4191
## Neg Pred Value : 0.8667
## Prevalence : 0.2588
## Detection Rate : 0.1840
## Detection Prevalence : 0.4391
## Balanced Accuracy : 0.6835
##
## 'Positive' Class : Modcontra
##
predgl1<-predict(gl1, newdata=dtrain)
confusionMatrix(data = predgl1,dtrain$modcontra, positive = "Modcontra" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction Modcontra NoModContra
## Modcontra 902 1179
## NoModContra 507 2857
##
## Accuracy : 0.6904
## 95% CI : (0.6779, 0.7026)
## No Information Rate : 0.7412
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3013
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6402
## Specificity : 0.7079
## Pos Pred Value : 0.4334
## Neg Pred Value : 0.8493
## Prevalence : 0.2588
## Detection Rate : 0.1657
## Detection Prevalence : 0.3822
## Balanced Accuracy : 0.6740
##
## 'Positive' Class : Modcontra
##
We see that by down sampling the more common level of the outcome, we end up with much more balanced accuracy in terms of specificity and sensitivity.
We can visualize the resulting “best fitting” model like we did before:
library(rpart.plot)
# rpart.plot(rp1$finalModel,
# box.palette="GnBu",
# shadow.col="gray",
# nn=TRUE, main="Classification tree for using modern contraception")
prp(rp1$finalModel,type=4, extra = 4,
main="Classification tree for using modern contraception")
You see that the best fitting model is much more complicated than the previous one. Each node box displays the classification, the probability of each class at that node (i.e. the probability of the class conditioned on the node) and the percentage of observations used at that node. From here.
The ROC curve can be shown for the model:
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
# Select a parameter setting
mycp<-rp1$pred$cp==rp1$bestTune$cp
selectedIndices <- rp1$pred$cp==mycp
# Plot:
plot.roc(rp1$pred$obs[selectedIndices], rp1$pred$Modcontra[selectedIndices], grid=T)
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
plot.roc(gl1$pred$obs, gl1$pred$Modcontra, add=T, col="blue")
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
#Value of ROC and AUC
roc(rp1$pred$obs[selectedIndices], rp1$pred$Modcontra[selectedIndices])
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
##
## Call:
## roc.default(response = rp1$pred$obs[selectedIndices], predictor = rp1$pred$Modcontra[selectedIndices])
##
## Data: rp1$pred$Modcontra[selectedIndices] in 4227 controls (rp1$pred$obs[selectedIndices] Modcontra) > 12108 cases (rp1$pred$obs[selectedIndices] NoModContra).
## Area under the curve: 0.7137
roc(gl1$pred$obs, gl1$pred$Modcontra)
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
##
## Call:
## roc.default(response = gl1$pred$obs, predictor = gl1$pred$Modcontra)
##
## Data: gl1$pred$Modcontra in 4227 controls (gl1$pred$obs Modcontra) > 12108 cases (gl1$pred$obs NoModContra).
## Area under the curve: 0.7232
auc(rp1$pred$obs[selectedIndices], rp1$pred$Modcontra[selectedIndices])
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
## Area under the curve: 0.7137
auc(gl1$pred$obs, gl1$pred$Modcontra)
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
## Area under the curve: 0.7232
roc.test(roc(gl1$pred$obs, gl1$pred$Modcontra),roc(rp1$pred$obs[selectedIndices], rp1$pred$Modcontra[selectedIndices]) )
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
## Setting levels: control = Modcontra, case = NoModContra
## Setting direction: controls > cases
##
## DeLong's test for two ROC curves
##
## data: roc(gl1$pred$obs, gl1$pred$Modcontra) and roc(rp1$pred$obs[selectedIndices], rp1$pred$Modcontra[selectedIndices])
## D = -1.4705, df = 32651, p-value = 0.1414
## alternative hypothesis: true difference in AUC is not equal to 0
## sample estimates:
## AUC of roc1 AUC of roc2
## 0.7231722 0.7136719
predrp1<-predict(rp1, newdata=dtest)
confusionMatrix(data = predrp1,dtest$modcontra, positive = "Modcontra" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction Modcontra NoModContra
## Modcontra 238 351
## NoModContra 114 657
##
## Accuracy : 0.6581
## 95% CI : (0.6322, 0.6833)
## No Information Rate : 0.7412
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.269
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6761
## Specificity : 0.6518
## Pos Pred Value : 0.4041
## Neg Pred Value : 0.8521
## Prevalence : 0.2588
## Detection Rate : 0.1750
## Detection Prevalence : 0.4331
## Balanced Accuracy : 0.6640
##
## 'Positive' Class : Modcontra
##
predgl1<-predict(gl1, newdata=dtest)
confusionMatrix(data = predgl1,dtest$modcontra, positive = "Modcontra" )
## Confusion Matrix and Statistics
##
## Reference
## Prediction Modcontra NoModContra
## Modcontra 214 288
## NoModContra 138 720
##
## Accuracy : 0.6868
## 95% CI : (0.6614, 0.7114)
## No Information Rate : 0.7412
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.283
##
## Mcnemar's Test P-Value : 5.234e-13
##
## Sensitivity : 0.6080
## Specificity : 0.7143
## Pos Pred Value : 0.4263
## Neg Pred Value : 0.8392
## Prevalence : 0.2588
## Detection Rate : 0.1574
## Detection Prevalence : 0.3691
## Balanced Accuracy : 0.6611
##
## 'Positive' Class : Modcontra
##
By doing cross-validation, we can do a few things: