Machine Learning Classification Problem example. Uses different machine learning algorithms to classify flowers into 3 different category based on 4 features.

Load Required Packages

library(caret)
## Warning: package 'caret' was built under R version 3.2.3
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.2.3
# libraries for partition trees
library(rpart)
library(rpart.plot)
library(rattle)
## Warning: package 'rattle' was built under R version 3.2.3
## Rattle: A free graphical interface for data mining with R.
## Version 4.1.0 Copyright (c) 2006-2015 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.

loading data and basic exploration

data(iris)
head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa
dim(iris)
## [1] 150   5
table(iris$Species)
## 
##     setosa versicolor  virginica 
##         50         50         50

split data

set.seed(1)
training_indices <- createDataPartition(y=iris$Species,p=0.8,list=F)
training_set <- iris[training_indices,]
validation_set <- iris[-training_indices,]
dim(training_set); dim(validation_set);
## [1] 120   5
## [1] 30  5

exploratory data analysis

# exploratory plots
#see how data groups into different clusters
set.seed(1)
#km <- kmeans(training_set[,1:4], 3)
km <- kmeans(model.matrix(~.+0, data=training_set),3)
plot(training_set[,1], training_set[,2], col=km$cluster)
points(km$centers[,c(1,2)], col=1:3, pch=19, cex=2)

table(km$cluster, training_set$Species)
##    
##     setosa versicolor virginica
##   1     40          0         0
##   2      0         40         1
##   3      0          0        39
featurePlot(x=training_set[,1:4], y=training_set[,5], plot="pairs", auto.key=list(columns=3))

featurePlot(x=training_set[,1:4], y=training_set[,5], plot="density", scales=list(x=list(relation="free"), y=list(relation="free")), auto.key=list(columns=3))

featurePlot(x=training_set[,1:4], y=training_set[,5], plot="box", scales=list(x=list(relation="free"), y=list(relation="free")), auto.key=list(columns=3))

# remove zero covaritates (features with no variability)
nearZeroVar(training_set[,-5],saveMetrics = T)
##              freqRatio percentUnique zeroVar   nzv
## Sepal.Length  1.285714      27.50000   FALSE FALSE
## Sepal.Width   1.833333      17.50000   FALSE FALSE
## Petal.Length  1.300000      34.16667   FALSE FALSE
## Petal.Width   2.363636      17.50000   FALSE FALSE
near_zero_covariates <- nearZeroVar(training_set[,-5])
head(near_zero_covariates)
## integer(0)
length(near_zero_covariates)
## [1] 0
feature_correlation <- cor(training_set[,-5])
# search through a correlation matrix and returns a vector of integers corresponding to columns to remove to reduce pair-wise correlations.
high_correlation <- findCorrelation(feature_correlation,0.9)
head(high_correlation)
## [1] 3
#length(high_correlation)

# PCA
pc <- prcomp(training_set[,-5],center=T,scale=T)
plot(pc,type="l")

model building

# k-fold cross validation
train_control <- trainControl(method="cv", number=5, savePredictions = T)
# fix the parameters of the algorithm
grid <- expand.grid(.fL=c(0), .usekernel=c(FALSE))

# model building
# svm - linear
set.seed(1)
svm_lm_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "svmLinear",preProcess = c("center", "scale","pca"))
## Loading required package: kernlab
## Warning: package 'kernlab' was built under R version 3.2.3
## Warning in .recacheSubclasses(def@className, def, doSubclasses, env):
## undefined subclass "externalRefMethod" of class "kfunction"; definition not
## updated
## 
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
## 
##     alpha
svm_lm_model
## Support Vector Machines with Linear Kernel 
## 
## 120 samples
##   4 predictor
##   3 classes: 'setosa', 'versicolor', 'virginica' 
## 
## Pre-processing: centered (4), scaled (4), principal component
##  signal extraction (4) 
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 96, 96, 96, 96, 96 
## Resampling results
## 
##   Accuracy  Kappa   Accuracy SD  Kappa SD
##   0.925     0.8875  0.06846532   0.102698
## 
## Tuning parameter 'C' was held constant at a value of 1
## 
varImp(svm_lm_model)
## ROC curve variable importance
## 
##   variables are sorted by maximum importance across the classes
##              setosa versicolor virginica
## Petal.Width  100.00     100.00    100.00
## Petal.Length 100.00     100.00    100.00
## Sepal.Length  91.48      71.54     91.48
## Sepal.Width   52.73      52.73      0.00
plot(varImp(svm_lm_model))

# svm - rbf kernel
set.seed(1)
svm_rbf_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "svmRadial",preProcess = c("center", "scale","pca"))

# classification Trees
set.seed(1)
#rpart_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rpart")
rpart_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rpart",preProcess = c("center", "scale","pca"))
# plot classification trees
fancyRpartPlot(rpart_model$finalModel)

# random forest
set.seed(1)
rf_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rf",preProcess = c("center", "scale","pca"),prox=T)
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range

## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
# boosting with tres
set.seed(1)
gbm_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "gbm", preProcess = c("center", "scale","pca"), verbose=F)
## Loading required package: gbm
## Loading required package: survival
## 
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
## 
##     cluster
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.1
## Loading required package: plyr
# naive bayes
set.seed(1)
nb_model <- train(Species~., data=training_set, trControl=train_control, method="nb", tuneGrid=grid,preProcess = c("center", "scale","pca"))
## Loading required package: klaR
## Loading required package: MASS
## Warning: package 'MASS' was built under R version 3.2.2
# linear discriminant analysis
set.seed(1)
lda_model <- train(Species~., data=training_set, trControl=train_control, method="lda", preProcess = c("center", "scale","pca"))

# collect resamples
train_results <- resamples(list(SVM_LM=svm_lm_model,SVM_RBF=svm_rbf_model,RPART=rpart_model,GBM=gbm_model,RF=rf_model,NB=nb_model,LDA=lda_model))
# summarize the distributions
summary(train_results)
## 
## Call:
## summary.resamples(object = train_results)
## 
## Models: SVM_LM, SVM_RBF, RPART, GBM, RF, NB, LDA 
## Number of resamples: 5 
## 
## Accuracy 
##           Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
## SVM_LM  0.8333  0.8750 0.9583 0.9250  0.9583 1.0000    0
## SVM_RBF 0.8750  0.9167 0.9583 0.9333  0.9583 0.9583    0
## RPART   0.8333  0.8750 0.9167 0.9167  0.9583 1.0000    0
## GBM     0.8750  0.8750 0.9583 0.9333  0.9583 1.0000    0
## RF      0.7500  0.8750 0.8750 0.8917  0.9583 1.0000    0
## NB      0.8333  0.8750 0.9167 0.9083  0.9583 0.9583    0
## LDA     0.8750  0.8750 0.9583 0.9333  0.9583 1.0000    0
## 
## Kappa 
##           Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
## SVM_LM  0.7500  0.8125 0.9375 0.8875  0.9375 1.0000    0
## SVM_RBF 0.8125  0.8750 0.9375 0.9000  0.9375 0.9375    0
## RPART   0.7500  0.8125 0.8750 0.8750  0.9375 1.0000    0
## GBM     0.8125  0.8125 0.9375 0.9000  0.9375 1.0000    0
## RF      0.6250  0.8125 0.8125 0.8375  0.9375 1.0000    0
## NB      0.7500  0.8125 0.8750 0.8625  0.9375 0.9375    0
## LDA     0.8125  0.8125 0.9375 0.9000  0.9375 1.0000    0
# boxplots of results
bwplot(train_results)

# the above results suggest that GBM, SVM, LDA model performs best on the training data.

EVALUATE MODEL ACCURACY ON TEST SET

#Ideally, you select model that performs best on training data and evaluate on test set. I am doing for all models just for illustration 
validation_pred_svm_lm <- predict(svm_lm_model, newdata=validation_set)
confusionMatrix(data=validation_pred_svm_lm, validation_set$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0          7         1
##   virginica       0          3         9
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8667          
##                  95% CI : (0.6928, 0.9624)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 2.296e-09       
##                                           
##                   Kappa : 0.8             
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.7000           0.9000
## Specificity                 1.0000            0.9500           0.8500
## Pos Pred Value              1.0000            0.8750           0.7500
## Neg Pred Value              1.0000            0.8636           0.9444
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.2333           0.3000
## Detection Prevalence        0.3333            0.2667           0.4000
## Balanced Accuracy           1.0000            0.8250           0.8750
validation_pred_svm_rbf <- predict(svm_rbf_model, newdata=validation_set)
confusionMatrix(data=validation_pred_svm_rbf, validation_set$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0          6         1
##   virginica       0          4         9
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8333          
##                  95% CI : (0.6528, 0.9436)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 2.444e-08       
##                                           
##                   Kappa : 0.75            
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.6000           0.9000
## Specificity                 1.0000            0.9500           0.8000
## Pos Pred Value              1.0000            0.8571           0.6923
## Neg Pred Value              1.0000            0.8261           0.9412
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.2000           0.3000
## Detection Prevalence        0.3333            0.2333           0.4333
## Balanced Accuracy           1.0000            0.7750           0.8500
validation_pred_rpart <- predict(rpart_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_rpart, validation_set$Species)

validation_pred_gbm <- predict(gbm_model, newdata=validation_set)
confusionMatrix(data=validation_pred_gbm, validation_set$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0          7         0
##   virginica       0          3        10
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9             
##                  95% CI : (0.7347, 0.9789)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 1.665e-10       
##                                           
##                   Kappa : 0.85            
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.7000           1.0000
## Specificity                 1.0000            1.0000           0.8500
## Pos Pred Value              1.0000            1.0000           0.7692
## Neg Pred Value              1.0000            0.8696           1.0000
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.2333           0.3333
## Detection Prevalence        0.3333            0.2333           0.4333
## Balanced Accuracy           1.0000            0.8500           0.9250
validation_pred_rf <- predict(rf_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_rf, validation_set$Species)

validation_pred_nb <- predict(nb_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_nb, validation_set$Species)

validation_pred_lda <- predict(lda_model, newdata=validation_set)
confusionMatrix(data=validation_pred_lda, validation_set$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0          8         1
##   virginica       0          2         9
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9             
##                  95% CI : (0.7347, 0.9789)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 1.665e-10       
##                                           
##                   Kappa : 0.85            
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.8000           0.9000
## Specificity                 1.0000            0.9500           0.9000
## Pos Pred Value              1.0000            0.8889           0.8182
## Neg Pred Value              1.0000            0.9048           0.9474
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.2667           0.3000
## Detection Prevalence        0.3333            0.3000           0.3667
## Balanced Accuracy           1.0000            0.8750           0.9000