Install required libraries.

library(BiocInstaller)
#biocLite(c('mlbench','adabag','randomForest','party','mboost'))

Classification Trees

As a simple dataset to try with machine learning, we are going to predict the species of iris based on four measurements.

data(iris)
View(iris)
pairs(iris[,1:4],col=iris$Species)

We can start with a simple learner, a classification tree. This learner requires:

  1. Start with whole dataset.
  2. Choose features one-at-a-time and look for a value of each variable that ends up with the most homogeneous two groups after splitting on that variable/value.
  3. For each resulting group, repeat step 2 until all remaining groups have only one class in them.
  4. Optionally, “prune” the tree to keep only splits that are “statistically significant”.

The party package includes a function, ctree to “learn” a tree from data.

library(party)
x = ctree(Species ~ .,data=iris)
plot(x)

And how well does our tree do with predicting the original classes from the data?

library(caret)
prediction = predict(x,iris)
table(prediction)
## prediction
##     setosa versicolor  virginica 
##         50         54         46
confusionMatrix(iris$Species,prediction)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         49         1
##   virginica       0          5        45
## 
## Overall Statistics
##                                          
##                Accuracy : 0.96           
##                  95% CI : (0.915, 0.9852)
##     No Information Rate : 0.36           
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.94           
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9074           0.9783
## Specificity                 1.0000            0.9896           0.9519
## Pos Pred Value              1.0000            0.9800           0.9000
## Neg Pred Value              1.0000            0.9500           0.9900
## Prevalence                  0.3333            0.3600           0.3067
## Detection Rate              0.3333            0.3267           0.3000
## Detection Prevalence        0.3333            0.3333           0.3333
## Balanced Accuracy           1.0000            0.9485           0.9651

What is the problem with what we just did to determine our prediction accurace?

To deal with this problem, we can split the dataset into a “training” set and then check our prediction on the other piece of the data, the “test” set.

# choose every "odd" row for training
set.seed(42)
trainIdx = sample(c(TRUE,FALSE),size=nrow(iris),prob=c(0.2,0.8),replace=TRUE)
irisTrain = iris[trainIdx,]
# choose every "even" row for testing
irisTest  = iris[!trainIdx,]

Now, we can “train” our tree on the “training” set.

trainTree = ctree(Species ~ ., data = irisTrain)
plot(trainTree)

And how does our trainTree do at predicting the original classes in the “training” data?

library(caret)
trainPred = predict(trainTree,irisTrain)
confusionMatrix(irisTrain$Species,trainPred)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         18          0         0
##   versicolor      0          0         5
##   virginica       0          0        12
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8571          
##                  95% CI : (0.6974, 0.9519)
##     No Information Rate : 0.5143          
##     P-Value [Acc > NIR] : 2.275e-05       
##                                           
##                   Kappa : 0.7489          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000                NA           0.7059
## Specificity                 1.0000            0.8571           1.0000
## Pos Pred Value              1.0000                NA           1.0000
## Neg Pred Value              1.0000                NA           0.7826
## Prevalence                  0.5143            0.0000           0.4857
## Detection Rate              0.5143            0.0000           0.3429
## Detection Prevalence        0.5143            0.1429           0.3429
## Balanced Accuracy           1.0000                NA           0.8529

How is our prediction performance now on the “test” data?

testPred = predict(trainTree,irisTest)
confusionMatrix(irisTest$Species,testPred)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         30          0         2
##   versicolor      0          0        45
##   virginica       0          0        38
## 
## Overall Statistics
##                                           
##                Accuracy : 0.5913          
##                  95% CI : (0.4957, 0.6821)
##     No Information Rate : 0.7391          
##     P-Value [Acc > NIR] : 0.9998          
##                                           
##                   Kappa : 0.4018          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000                NA           0.4471
## Specificity                 0.9765            0.6087           1.0000
## Pos Pred Value              0.9375                NA           1.0000
## Neg Pred Value              1.0000                NA           0.3896
## Prevalence                  0.2609            0.0000           0.7391
## Detection Rate              0.2609            0.0000           0.3304
## Detection Prevalence        0.2783            0.3913           0.3304
## Balanced Accuracy           0.9882                NA           0.7235

Now, let’s make this harder. We will now look at a dataset that is designed to “foil” tree classifiers.

library(mlbench)
spiral = mlbench.spirals(1000,sd=0.1)
spiral = data.frame(x=spiral$x[,1],y=spiral$x[,2],class=factor(spiral$classes))
library(ggplot2)
ggplot(spiral,aes(x,y,color=class)) + geom_point()

trainIdx = sample(c(TRUE,FALSE),nrow(spiral),replace=TRUE,prob=c(0.8,0.3))
spiralTrain = spiral[trainIdx,]
trainTree   = ctree(class ~ .,spiralTrain)
plot(trainTree)

prediction = predict(trainTree,spiralTrain)
confusionMatrix(spiralTrain$class,prediction)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   2
##          1 347   6
##          2 269  82
##                                           
##                Accuracy : 0.6094          
##                  95% CI : (0.5722, 0.6456)
##     No Information Rate : 0.875           
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.2171          
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 0.5633          
##             Specificity : 0.9318          
##          Pos Pred Value : 0.9830          
##          Neg Pred Value : 0.2336          
##              Prevalence : 0.8750          
##          Detection Rate : 0.4929          
##    Detection Prevalence : 0.5014          
##       Balanced Accuracy : 0.7476          
##                                           
##        'Positive' Class : 1               
## 
spiralTest = spiral[!trainIdx,]
prediction = predict(trainTree,spiralTest)
confusionMatrix(spiralTest$class,prediction)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   2
##          1 144   3
##          2 113  36
##                                           
##                Accuracy : 0.6081          
##                  95% CI : (0.5499, 0.6641)
##     No Information Rate : 0.8682          
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.2201          
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 0.5603          
##             Specificity : 0.9231          
##          Pos Pred Value : 0.9796          
##          Neg Pred Value : 0.2416          
##              Prevalence : 0.8682          
##          Detection Rate : 0.4865          
##    Detection Prevalence : 0.4966          
##       Balanced Accuracy : 0.7417          
##                                           
##        'Positive' Class : 1               
## 

Many trees have similar prediction capability, but each is really bad. This is a characteristic of a “weak learner”. Here, we see that in action by performing a bootstrap sampling (resample with replacement), train, plot, and check prediction accuracy.

plotBootSample = function(spiral) {
  trainIdx = sample(1:nrow(spiral),replace=TRUE)
  spiralTrain = spiral[trainIdx,]
  trainTree   = ctree(class ~ .,spiralTrain,ctree_control(minsplit=2,maxsplit=2))
  plot(trainTree)
  prediction = predict(trainTree,spiral[!trainIdx,])
  print(confusionMatrix(spiralTrain$class,prediction)$overall['Accuracy'])
}
# press 'ESC' to stop
while(TRUE) {
  par(ask=TRUE)
  plotBootSample(spiral)
}

Boosting

We can “combine” a bunch of “weak learners”, giving more “weight” to hard-to-classify observations as we build each new classifier. In this case, we will be using the same classification tree again.

library(adabag)
trainIdx      = sample(c(TRUE,FALSE),nrow(spiral),replace=TRUE,prob=c(0.5,0.5))
spiralTrain   = spiral[trainIdx,]
boostTree     = boosting(class ~ x + y,data = spiralTrain,control = rpart.control(maxdepth=2))
prediction    = predict(boostTree,spiralTrain)
confusionMatrix(spiralTrain$class,prediction$class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   2
##          1 254  21
##          2  21 231
##                                          
##                Accuracy : 0.9203         
##                  95% CI : (0.8938, 0.942)
##     No Information Rate : 0.5218         
##     P-Value [Acc > NIR] : <2e-16         
##                                          
##                   Kappa : 0.8403         
##  Mcnemar's Test P-Value : 1              
##                                          
##             Sensitivity : 0.9236         
##             Specificity : 0.9167         
##          Pos Pred Value : 0.9236         
##          Neg Pred Value : 0.9167         
##              Prevalence : 0.5218         
##          Detection Rate : 0.4820         
##    Detection Prevalence : 0.5218         
##       Balanced Accuracy : 0.9202         
##                                          
##        'Positive' Class : 1              
## 
library(rpart.plot)
par(mfrow=c(3,3),ask=FALSE)
for(i in 1:9) {
  rpart.plot(boostTree$trees[[i]])
}

And how does our boosted tree work on the test data?

spiralTest = spiral[!trainIdx,]
prediction = predict(boostTree,spiralTest)
confusionMatrix(spiralTest$class,prediction$class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   2
##          1 203  22
##          2  25 223
##                                           
##                Accuracy : 0.9006          
##                  95% CI : (0.8701, 0.9261)
##     No Information Rate : 0.518           
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8009          
##  Mcnemar's Test P-Value : 0.7705          
##                                           
##             Sensitivity : 0.8904          
##             Specificity : 0.9102          
##          Pos Pred Value : 0.9022          
##          Neg Pred Value : 0.8992          
##              Prevalence : 0.4820          
##          Detection Rate : 0.4292          
##    Detection Prevalence : 0.4757          
##       Balanced Accuracy : 0.9003          
##                                           
##        'Positive' Class : 1               
## 

Random Forests

library(randomForest)
res = randomForest(Species ~ .,data=iris)
res
## 
## Call:
##  randomForest(formula = Species ~ ., data = iris) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 4%
## Confusion matrix:
##            setosa versicolor virginica class.error
## setosa         50          0         0        0.00
## versicolor      0         47         3        0.06
## virginica       0          3        47        0.06
varImpPlot(res)