This case study use the Wage dataset in the ISLR package. There are 3,000 observations and 11 variables in the data. The target variable is wage. The data dictionary is shown below
| Variable | Description | Characteristics |
|---|---|---|
| year | Calendar year that wage information was recorded | Integer form 2003 to 2009 |
| age | Age of worker | Integer form 18 to 80 |
| maritl | Marital status of worker | Factor with 5 levels: 1. Never Married 2. Married 3. Widowed 4. Divorced 5. Separated |
| race | Race of worker | Factor with 4 levels: 1. White 2. Black 3. Asian 4. Other |
| education | Education level of worker | Factor with 5 levels: 1. < HS Grad 2. HS Grad 3. Some College 4. College Grad 5. Advanced Degree |
| region | Region of the country | Factor with 9 levels but only 2. Middle Atlantic contains observations. |
| jobclass | Job class of worker | Factor with 2 levels: 1. Industrial 2. Information |
| health | Health level of worker | Factor with 2 levels: 1. <=Good 2. >=Very Good |
| health_ins | Whether worker has health insurance | Factor with levels: 1. Yes 2. No |
| logwage | Log of worker’s raw wage | Numeric from 3 to 5.763 |
| wage | Worker’s raw wage (in $1,000s) | Numeric from 20.09 to 318.34 |
# CHUNK 1
#### Load the data
library(ISLR)
data("Wage")
#### Summarize the data
summary(Wage)
## year age maritl race
## Min. :2003 Min. :18.00 1. Never Married: 648 1. White:2480
## 1st Qu.:2004 1st Qu.:33.75 2. Married :2074 2. Black: 293
## Median :2006 Median :42.00 3. Widowed : 19 3. Asian: 190
## Mean :2006 Mean :42.41 4. Divorced : 204 4. Other: 37
## 3rd Qu.:2008 3rd Qu.:51.00 5. Separated : 55
## Max. :2009 Max. :80.00
##
## education region jobclass
## 1. < HS Grad :268 2. Middle Atlantic :3000 1. Industrial :1544
## 2. HS Grad :971 1. New England : 0 2. Information:1456
## 3. Some College :650 3. East North Central: 0
## 4. College Grad :685 4. West North Central: 0
## 5. Advanced Degree:426 5. South Atlantic : 0
## 6. East South Central: 0
## (Other) : 0
## health health_ins logwage wage
## 1. <=Good : 858 1. Yes:2083 Min. :3.000 Min. : 20.09
## 2. >=Very Good:2142 2. No : 917 1st Qu.:4.447 1st Qu.: 85.38
## Median :4.653 Median :104.92
## Mean :4.654 Mean :111.70
## 3rd Qu.:4.857 3rd Qu.:128.68
## Max. :5.763 Max. :318.34
##
# CHUNK 2
#### Remove region
Wage$region <- NULL
In this case study, let’s practice a classification decision tree.
# CHUNK 3
#### Construct classification trees to predict the probability of a worker earning more than $100,000.
Wage$wage_flag <- ifelse(Wage$wage >= 100, 1, 0)
#### View summary
summary(Wage$wage_flag)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0000 0.0000 1.0000 0.5523 1.0000 1.0000
We will end this task by removing the two wage variables, which will not be used later.
# CHUNK 4
#### Remove the two wage variables: wage and logwage
Wage$wage <- NULL
Wage$logwage <- NULL
In this preparatory task, we will identify key variables associated with high earners and process some of the variables to prepare for later analysis.
# CHUNK 5
library(ggplot2)
#### Save predictors' name in the vector.
vars <- colnames(Wage)[-9] # exclude wage_flag
#### Draw filled bar chart using for loop.
for (i in vars) {
plot <- ggplot(Wage, aes(x = Wage[, i], fill = factor(wage_flag))) +
geom_bar(position = "fill") +
labs(x = i, y = "Proportion of High Earners") +
theme(axis.text.x = element_text(angle = 90, hjust = 1))
print(plot)}
# CHUNK 6
#### Take out the third, fourth, and fifth levels of maritl and change them to "3. Other"
levels(Wage$maritl)[3:5] <- "3. Other"
#### View the new frequency table of maritl
table(Wage$maritl)
##
## 1. Never Married 2. Married 3. Other
## 648 2074 278
maritl:There are only 19 widowed workers and 55 separated workers. The proportions of high earners for these workers are close to the proportion for divorced workers. Because “3. Widowed”, “4. Divorced”, and “5. Separated” have somewhat similar meaning (they are all related to what happens after getting married) and similar relationships to wage_flag, we can combine them into a more populous level known as “3. Other” to improve the robustness of the models to be constructed.
# CHUNK 7
#### Split the data into training(0.7) and test sets(0.3)
library(lattice)
library(caret)
set.seed(2021)
partition <- createDataPartition(y = as.factor(Wage$wage_flag),p = .7,list = FALSE)
data.train <- Wage[partition, ]
data.test <- Wage[-partition, ]
The control parameters are deliberately set so that a sufficiently complex tree can be grown.
# CHUNK 8
library(rpart)
## Warning: package 'rpart' was built under R version 4.2.2
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 4.2.2
set.seed(60)
########################################################################################################################################
# method = "class" ensures that target is treated as a categorical variable ##########################################################
### maxdepth:the maximum depth of the tree #############################################################################################
########################################################################################################################################
tree1 = rpart(wage_flag ~. , data = data.train, method = "class", control = rpart.control(minbucket=5, cp=0.0005, maxdepth=7), parms = list(split = "gini"))
#### Print output for the tree
########################################################################################################################################
### RSS is now replaced by the number of misclassifications in each node, as shown in the loss column ##################################
########################################################################################################################################
tree1
## n= 2101
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 2101 941 1 (0.44788196 0.55211804)
## 2) education=1. < HS Grad,2. HS Grad 853 295 0 (0.65416178 0.34583822)
## 4) health_ins=2. No 334 62 0 (0.81437126 0.18562874) *
## 5) health_ins=1. Yes 519 233 0 (0.55105973 0.44894027)
## 10) maritl=1. Never Married 89 16 0 (0.82022472 0.17977528)
## 20) health=2. >=Very Good 59 5 0 (0.91525424 0.08474576) *
## 21) health=1. <=Good 30 11 0 (0.63333333 0.36666667)
## 42) age>=51.5 9 1 0 (0.88888889 0.11111111) *
## 43) age< 51.5 21 10 0 (0.52380952 0.47619048)
## 86) year< 2005.5 11 4 0 (0.63636364 0.36363636) *
## 87) year>=2005.5 10 4 1 (0.40000000 0.60000000) *
## 11) maritl=2. Married,3. Other 430 213 1 (0.49534884 0.50465116)
## 22) age>=65.5 17 2 0 (0.88235294 0.11764706) *
## 23) age< 65.5 413 198 1 (0.47941889 0.52058111)
## 46) education=1. < HS Grad 67 26 0 (0.61194030 0.38805970)
## 92) age< 34.5 10 0 0 (1.00000000 0.00000000) *
## 93) age>=34.5 57 26 0 (0.54385965 0.45614035)
## 186) year< 2004.5 18 5 0 (0.72222222 0.27777778) *
## 187) year>=2004.5 39 18 1 (0.46153846 0.53846154) *
## 47) education=2. HS Grad 346 157 1 (0.45375723 0.54624277)
## 94) maritl=3. Other 45 19 0 (0.57777778 0.42222222)
## 188) age< 53.5 36 13 0 (0.63888889 0.36111111) *
## 189) age>=53.5 9 3 1 (0.33333333 0.66666667) *
## 95) maritl=2. Married 301 131 1 (0.43521595 0.56478405) *
## 3) education=3. Some College,4. College Grad,5. Advanced Degree 1248 383 1 (0.30689103 0.69310897)
## 6) health_ins=2. No 302 137 0 (0.54635762 0.45364238)
## 12) maritl=1. Never Married 76 10 0 (0.86842105 0.13157895) *
## 13) maritl=2. Married,3. Other 226 99 1 (0.43805310 0.56194690)
## 26) education=3. Some College,4. College Grad 179 85 1 (0.47486034 0.52513966)
## 52) age< 28.5 8 1 0 (0.87500000 0.12500000) *
## 53) age>=28.5 171 78 1 (0.45614035 0.54385965)
## 106) year< 2005.5 85 39 0 (0.54117647 0.45882353)
## 212) jobclass=2. Information 47 18 0 (0.61702128 0.38297872) *
## 213) jobclass=1. Industrial 38 17 1 (0.44736842 0.55263158) *
## 107) year>=2005.5 86 32 1 (0.37209302 0.62790698)
## 214) race=2. Black,3. Asian 20 8 0 (0.60000000 0.40000000) *
## 215) race=1. White 66 20 1 (0.30303030 0.69696970) *
## 27) education=5. Advanced Degree 47 14 1 (0.29787234 0.70212766)
## 54) age>=33.5 41 14 1 (0.34146341 0.65853659)
## 108) age< 46.5 14 6 0 (0.57142857 0.42857143) *
## 109) age>=46.5 27 6 1 (0.22222222 0.77777778)
## 218) race=2. Black,3. Asian 7 2 0 (0.71428571 0.28571429) *
## 219) race=1. White 20 1 1 (0.05000000 0.95000000) *
## 55) age< 33.5 6 0 1 (0.00000000 1.00000000) *
## 7) health_ins=1. Yes 946 218 1 (0.23044397 0.76955603)
## 14) age< 31.5 149 70 1 (0.46979866 0.53020134)
## 28) education=3. Some College 65 23 0 (0.64615385 0.35384615)
## 56) age< 25.5 17 2 0 (0.88235294 0.11764706) *
## 57) age>=25.5 48 21 0 (0.56250000 0.43750000)
## 114) age>=30.5 8 1 0 (0.87500000 0.12500000) *
## 115) age< 30.5 40 20 0 (0.50000000 0.50000000)
## 230) jobclass=1. Industrial 23 9 0 (0.60869565 0.39130435) *
## 231) jobclass=2. Information 17 6 1 (0.35294118 0.64705882) *
## 29) education=4. College Grad,5. Advanced Degree 84 28 1 (0.33333333 0.66666667)
## 58) year< 2004.5 23 10 0 (0.56521739 0.43478261)
## 116) age< 29.5 16 5 0 (0.68750000 0.31250000) *
## 117) age>=29.5 7 2 1 (0.28571429 0.71428571) *
## 59) year>=2004.5 61 15 1 (0.24590164 0.75409836) *
## 15) age>=31.5 797 148 1 (0.18569636 0.81430364)
## 30) education=3. Some College 260 77 1 (0.29615385 0.70384615)
## 60) year< 2003.5 41 20 1 (0.48780488 0.51219512)
## 120) jobclass=2. Information 18 5 0 (0.72222222 0.27777778) *
## 121) jobclass=1. Industrial 23 7 1 (0.30434783 0.69565217)
## 242) age< 38 6 2 0 (0.66666667 0.33333333) *
## 243) age>=38 17 3 1 (0.17647059 0.82352941) *
## 61) year>=2003.5 219 57 1 (0.26027397 0.73972603) *
## 31) education=4. College Grad,5. Advanced Degree 537 71 1 (0.13221601 0.86778399)
## 62) maritl=1. Never Married,3. Other 114 29 1 (0.25438596 0.74561404)
## 124) year< 2004.5 28 12 1 (0.42857143 0.57142857)
## 248) age>=52 6 2 0 (0.66666667 0.33333333) *
## 249) age< 52 22 8 1 (0.36363636 0.63636364) *
## 125) year>=2004.5 86 17 1 (0.19767442 0.80232558) *
## 63) maritl=2. Married 423 42 1 (0.09929078 0.90070922) *
#### Plot the tree
########################################################################################################################################
rpart.plot(tree1,tweak = 2)
### Each node is showing that the predicted class (top value), the proportion of observations in that node lying in the second class of the target variable,i.e., high earners (middle value), the proportion of training observations belonging to that node (bottom value) ###
########################################################################################################################################
# CHUNK 10
#### Get the cptable
tree1$cptable
## CP nsplit rel error xerror xstd
## 1 0.2794899044 0 1.0000000 1.0000000 0.02422262
## 2 0.0297555792 1 0.7205101 0.7810840 0.02323092
## 3 0.0100956429 3 0.6609989 0.6833156 0.02244820
## 4 0.0085015940 5 0.6408077 0.6599362 0.02222665
## 5 0.0074388948 9 0.6068013 0.6567481 0.02219536
## 6 0.0046050301 10 0.5993624 0.6620616 0.02224737
## 7 0.0042507970 13 0.5855473 0.6695005 0.02231896
## 8 0.0031880978 15 0.5770457 0.6673751 0.02229865
## 9 0.0028338647 18 0.5674814 0.6726886 0.02234921
## 10 0.0021253985 21 0.5589798 0.6663124 0.02228845
## 11 0.0017711654 22 0.5568544 0.6461211 0.02208915
## 12 0.0015940489 28 0.5462274 0.6471838 0.02209991
## 13 0.0007084662 30 0.5430393 0.6599362 0.02222665
## 14 0.0005000000 36 0.5387885 0.6737513 0.02235924
#### Get cp value with the minimum of xerror.
cp.min <- tree1$cptable[which.min(tree1$cptable[,"xerror"]),"CP"]
#### Prune the tree with this cp
tree2 <- prune(tree1,cp = cp.min)
rpart.plot(tree2, tweak = 2)
Based on the cptable, a common way to simplify a decision tree is to prune it using the cp value that corresponds to the lowest cross-validation error (xerror). In the case of Tree 1, this minimum cp value is 0.0017711654, which is the value in the 11th row of the cp table.
Thre pruned tree has 22 splits and 23 terminal nodes. The pruned tree looks to be simpler than the original tree, but still complex. ### Tree 3: Pruning Tree 1 using the one-standard-error rule.
We can observe from the cptable that many of the smaller trees (i.e., those with 5 to 21 splits) have a cross-validation error which is comparable to that of Tree 1 while having a much smaller size.
To identify a tree that is much more interpretable but is comparably predictive, one common practice is to employ the one-standard-error rule
One-standard-error rule: to select the smallest tree whose cross-validation error is within one standard error from the minimum cross-validation error.
In the current setting, this cutoff level equals 0.6461211 + 0.02208915 = 0.66821025
Among all trees with a cross validation error within one standard error from the minimum cross-validation error, the simplest tree has 5 splits (or 6 terminal nodes) tree, as we can see from the cptable, with a complexity parameter of 0.0085015940.
# CHUNK 11
#### Prune the tree using one-standard-error rule
tree3 <- prune(tree1, cp = tree1$cptable[4,"CP"])
#### Plot the tree
rpart.plot(tree3)
Tree 3 has 5 splits and 6 terminal nodes. Tree 3 is much simpler and more interpretable than tree 1 and tree 2. * The first split is based on education, which can be seen as the most important predictor. The next splits are health insurance health insurance, marital status, age, and education again. * The first split differentiates relatively low education level on the left and high education level to the right.The predicted class for earners with low education level is low earners. The predicted class for high education level is high earners. * Node 7 is futher partitioned using an age of 32 as the cutoff level. Workers under this cutoff and having some college level of education are classified into low earners. Workers under this age cutoff with higher levels of education than some college (College grad or Advanced Degree) are classified into high earners.
# CHUNK 13
#### Get predicted class(level) for three trees: tree1, tree2, tree3
pred1.class <- predict(tree1, newdata = data.test, type = "class")
pred2.class <- predict(tree2, newdata = data.test, type = "class")
pred3.class <- predict(tree3, newdata = data.test, type = "class")
########################################################################################################################################
### type="prob": To produce a matrix of predicted class probabilities, one column for each class of the target variable. This is the default for a classification tree fitted by rpart (). ##################################################################################
### type="class": To produce a vector of predicted class labels based on a cutoff of 0.5 ###############################################
########################################################################################################################################
#### Get confision matrirx
confusionMatrix(pred1.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 253 90
## 1 149 407
##
## Accuracy : 0.7341
## 95% CI : (0.704, 0.7628)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.4546
##
## Mcnemar's Test P-Value : 0.0001756
##
## Sensitivity : 0.8189
## Specificity : 0.6294
## Pos Pred Value : 0.7320
## Neg Pred Value : 0.7376
## Prevalence : 0.5528
## Detection Rate : 0.4527
## Detection Prevalence : 0.6185
## Balanced Accuracy : 0.7241
##
## 'Positive' Class : 1
##
confusionMatrix(pred2.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 267 90
## 1 135 407
##
## Accuracy : 0.7497
## 95% CI : (0.7201, 0.7777)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.4883
##
## Mcnemar's Test P-Value : 0.003353
##
## Sensitivity : 0.8189
## Specificity : 0.6642
## Pos Pred Value : 0.7509
## Neg Pred Value : 0.7479
## Prevalence : 0.5528
## Detection Rate : 0.4527
## Detection Prevalence : 0.6029
## Balanced Accuracy : 0.7415
##
## 'Positive' Class : 1
##
confusionMatrix(pred3.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 298 143
## 1 104 354
##
## Accuracy : 0.7253
## 95% CI : (0.6948, 0.7542)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.4494
##
## Mcnemar's Test P-Value : 0.01561
##
## Sensitivity : 0.7123
## Specificity : 0.7413
## Pos Pred Value : 0.7729
## Neg Pred Value : 0.6757
## Prevalence : 0.5528
## Detection Rate : 0.3938
## Detection Prevalence : 0.5095
## Balanced Accuracy : 0.7268
##
## 'Positive' Class : 1
##
The test accuracy of the three classification trees is as follows: tree3 < tree1 < tree2 Tree 2 has the highest accuracy among the 3 constructed trees, which agrees with the result from the cptable.
We can compare them using a cutoff-free metric such as the test AUC.
# CHUNK 14
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
#### Extract the predicted probabilities for the level, 1, of wage_flag
pred1.prob <- predict(tree1, newdata = data.test, type = "prob")[,2]
pred2.prob <- predict(tree2, newdata = data.test, type = "prob")[,2]
pred3.prob <- predict(tree3, newdata = data.test, type = "prob")[,2]
#### Get ROC and AUC
roc(data.test$wage_flag,pred1.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred1.prob)
##
## Data: pred1.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.8096
roc(data.test$wage_flag,pred2.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred2.prob)
##
## Data: pred2.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.8135
roc(data.test$wage_flag,pred3.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred3.prob)
##
## Data: pred3.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.759
The test AUC values for tree 1, 2, and 3 are 0.8096, 0.8135, and 0.759, respectively. The ordering of these test AUC values follow the order of the test accuracy of these trees (Tree 3 < Tree 1 < Tree 2)
Tree 1: Tree 1 is the most complex tree. The test accuracy and test AUC suggest that the tree may be unecessarily complex and have overfitted the data. Tree 3: Tree 3 is the simplest tree. The performance metrics suggest that the tree may be too simple and underfitting the data as a result of it’s simplicity and high interpretability. Tree 2: Tree 2 is the pruned tree. It appears to have the optimal level of tree complexity and has the best performance in terms of test AUC and test accuracy out of the three trees. Tree 2 is simpler than Tree 1 while also having a better predictive performance. If we have to recommend one tree to use, then it is reasonable to suggest using Tree 2.
# CHUNK 15
#### Train random forests using trainControl()
#### Set the controls
#######################################################################################################################################
### method: "cv” or "repeatedcv" corresponding to cross-validation and repeated cross-validation, respectively. #######################
### number: The number argument specifies the number of folds used in k-fold crossvalidation ##########################################
### repeats: The repeats argument, applicable only if method = "repeatedcv", controls how many times cross-validation is performed. ###
### sampling: if sampling = "down", undersampling is applied. Because wage_flag is a rather balanced binary variable, the use of undersampling is only for illustration purposes and not absolutely necessary. (You can also try oversampling by specifying sampling = "up".). These sampling techniques are used to make y balanced. ################################################################################################################################
#######################################################################################################################################
ctrl <- trainControl(method = "repeatedcv", number = 5, repeats = 3, sampling = "down")
Among all the parameters of a random forest, the number of features considered in each split, represented by the mtry parameter, is arguably the most important (in fact, the caret package only supports the tuning of mtry for a random forest model). Its default value is \(\sqrt{p}\) for a classification tree and \(p/3\) for a regression tree, where \(p\) is the number of predictors.
# CHUNK 16
#### Set up the tuning grid for the number of features considered in each split : Randomization at each split.
rf.grid = expand.grid(mtry = 1:5)
rf.grid
## mtry
## 1 1
## 2 2
## 3 3
## 4 4
## 5 5
# CHUNK 17
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.2.2
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
#### Set up the x and y variables
target <- factor(data.train$wage_flag)
predictors <- data.train[,-9]
#### Train the first random forest with 5 trees
set.seed(20) # because cross-validation will be done
rf1 <- train(y = target, x = predictors, method = "rf", ntree = 5, importance = TRUE, trControl = ctrl, tuneGrid = rf.grid)
#### View the output
rf1
## Random Forest
##
## 2101 samples
## 8 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 3 times)
## Summary of sample sizes: 1681, 1681, 1680, 1681, 1681, 1680, ...
## Addtional sampling using down-sampling
##
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 1 0.6798398 0.3493491
## 2 0.6871251 0.3717869
## 3 0.6783976 0.3520302
## 4 0.6657083 0.3291085
## 5 0.6655571 0.3298912
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
ggplot(rf1)
# CHUNK 18
set.seed(50) # this seed need not be the same as the previous seed
#### Train the second random forest with 20 trees
rf2 <- train(y = target, x = predictors, method = "rf", ntree = 20, importance = TRUE, trControl = ctrl, tuneGrid = rf.grid)
#### Train the third random forest with 100 trees
rf3 <- train(y = target, x = predictors, method = "rf", ntree = 100, importance = TRUE, trControl = ctrl, tuneGrid = rf.grid)
In contrast to an rpart object, for which the predict () function outputs predicted probabilities by default, the default type of predictions returned by predict () applied to an object of the train class is the predicted class (which can be specified explicitly by the type = “raw” option if you like)
# CHUNK 19
#### Get predicted class(level) for three random forests: rf1, rf2, rf3
pred.rf1.class <- predict(rf1, newdata = data.test)
pred.rf2.class <- predict(rf2, newdata = data.test)
pred.rf3.class <- predict(rf3, newdata = data.test)
#### Calculate confusion matrix
confusionMatrix(pred.rf1.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 288 148
## 1 114 349
##
## Accuracy : 0.7086
## 95% CI : (0.6777, 0.7381)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.4153
##
## Mcnemar's Test P-Value : 0.04148
##
## Sensitivity : 0.7022
## Specificity : 0.7164
## Pos Pred Value : 0.7538
## Neg Pred Value : 0.6606
## Prevalence : 0.5528
## Detection Rate : 0.3882
## Detection Prevalence : 0.5150
## Balanced Accuracy : 0.7093
##
## 'Positive' Class : 1
##
confusionMatrix(pred.rf2.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 322 134
## 1 80 363
##
## Accuracy : 0.762
## 95% CI : (0.7327, 0.7895)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5246
##
## Mcnemar's Test P-Value : 0.0002912
##
## Sensitivity : 0.7304
## Specificity : 0.8010
## Pos Pred Value : 0.8194
## Neg Pred Value : 0.7061
## Prevalence : 0.5528
## Detection Rate : 0.4038
## Detection Prevalence : 0.4928
## Balanced Accuracy : 0.7657
##
## 'Positive' Class : 1
##
confusionMatrix(pred.rf3.class, as.factor(data.test$wage_flag), positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 312 99
## 1 90 398
##
## Accuracy : 0.7898
## 95% CI : (0.7616, 0.816)
## No Information Rate : 0.5528
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.5757
##
## Mcnemar's Test P-Value : 0.5606
##
## Sensitivity : 0.8008
## Specificity : 0.7761
## Pos Pred Value : 0.8156
## Neg Pred Value : 0.7591
## Prevalence : 0.5528
## Detection Rate : 0.4427
## Detection Prevalence : 0.5428
## Balanced Accuracy : 0.7885
##
## 'Positive' Class : 1
##
The test accuracies of 3 random forests are ordered as follows: This order suggests that the larger the ntree parameter is, the more accurate the prediction produced by the random forest matches the test set.
# CHUNK 20
#### Add the type = "prob" option to return predicted probabilities
pred.rf1.prob <- predict(rf1, newdata = data.test, type = "prob")[,2]
pred.rf2.prob <- predict(rf2, newdata = data.test, type = "prob")[,2]
pred.rf3.prob <- predict(rf3, newdata = data.test, type = "prob")[,2]
#### Get ROC and AUC
roc(data.test$wage_flag, pred.rf1.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred.rf1.prob)
##
## Data: pred.rf1.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.7883
roc(data.test$wage_flag, pred.rf2.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred.rf2.prob)
##
## Data: pred.rf2.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.8391
roc(data.test$wage_flag, pred.rf3.prob)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
##
## Call:
## roc.default(response = data.test$wage_flag, predictor = pred.rf3.prob)
##
## Data: pred.rf3.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
## Area under the curve: 0.8469
The test AUC values of the 3 random forests are ordered as follows: This order follows the order of the test accuracies of the trees. Building more trees tends to improved the prediction performance of a random forest. This makes sense because with more trees, the variance reduction contributed by averaging becomes more significant and model predictions become more precise. (####REWRITE THIS####) * Of the three random forests, it is reasonable to recommend using rf3, which has the most predictive power according to test AUC and test accuracy.
Because ensemble trees consist of potentially hundreds of decision trees, it is difficult, if not impossible, to interpret the relationship between the predictors and the target variable using a series of easy-to-understand classification rules like what a single decision tree shows.
One useful tool is a variable importance plot,which ranks the predictors according to their importance scores.
The importance score for a particular predictor is computed by totaling the drop in node impurity (RSS for regression trees and Gini index for classification trees) due to that predictor, averaged over all the trees in the ensemble tree. In other words, it is the average amount of node impurity reduction
In general, variables that are used to form most of the top splits in the individual trees lead to larger improvements in node purity and therefore are more important as captured by the variable importance score.
# CHUNK 21
#### Calculate variable importance score
imp <- varImp(rf3)
#imp
# Draw variable importance plot
plot(imp,main = "Variable Importance of Classification Random Forests")
According to the Variable Importance plot, Education is the most import predictor of determining whether a worker is a high earner. Variables that follow Education in terms of importance are health_ins, age, and maritl. The variable Year has an importance score of 0. This indicates that Year does not reduce the node impurity by any amount in any of the 100 trees.
[Comment 1: Interpreting the output of a classification tree]
In the first split of the tree, education is used. This indicates that it is the most distinguishing predictor for determining whether a worker is a high earner or not. After the first split is made, the number of misclassifications decreases from 941 to 678 (295 + 385) in nodes 2 and 3 combined. This split is an improvement.