Machine Learning Classification Problem example. Uses different machine learning algorithms to classify flowers into 3 different category based on 4 features.
Load Required Packages
library(caret)
## Warning: package 'caret' was built under R version 3.2.3
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.2.3
# libraries for partition trees
library(rpart)
library(rpart.plot)
library(rattle)
## Warning: package 'rattle' was built under R version 3.2.3
## Rattle: A free graphical interface for data mining with R.
## Version 4.1.0 Copyright (c) 2006-2015 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
loading data and basic exploration
data(iris)
head(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
dim(iris)
## [1] 150 5
table(iris$Species)
##
## setosa versicolor virginica
## 50 50 50
split data
set.seed(1)
training_indices <- createDataPartition(y=iris$Species,p=0.8,list=F)
training_set <- iris[training_indices,]
validation_set <- iris[-training_indices,]
dim(training_set); dim(validation_set);
## [1] 120 5
## [1] 30 5
exploratory data analysis
# exploratory plots
#see how data groups into different clusters
set.seed(1)
#km <- kmeans(training_set[,1:4], 3)
km <- kmeans(model.matrix(~.+0, data=training_set),3)
plot(training_set[,1], training_set[,2], col=km$cluster)
points(km$centers[,c(1,2)], col=1:3, pch=19, cex=2)
table(km$cluster, training_set$Species)
##
## setosa versicolor virginica
## 1 40 0 0
## 2 0 40 1
## 3 0 0 39
featurePlot(x=training_set[,1:4], y=training_set[,5], plot="pairs", auto.key=list(columns=3))
featurePlot(x=training_set[,1:4], y=training_set[,5], plot="density", scales=list(x=list(relation="free"), y=list(relation="free")), auto.key=list(columns=3))
featurePlot(x=training_set[,1:4], y=training_set[,5], plot="box", scales=list(x=list(relation="free"), y=list(relation="free")), auto.key=list(columns=3))
# remove zero covaritates (features with no variability)
nearZeroVar(training_set[,-5],saveMetrics = T)
## freqRatio percentUnique zeroVar nzv
## Sepal.Length 1.285714 27.50000 FALSE FALSE
## Sepal.Width 1.833333 17.50000 FALSE FALSE
## Petal.Length 1.300000 34.16667 FALSE FALSE
## Petal.Width 2.363636 17.50000 FALSE FALSE
near_zero_covariates <- nearZeroVar(training_set[,-5])
head(near_zero_covariates)
## integer(0)
length(near_zero_covariates)
## [1] 0
feature_correlation <- cor(training_set[,-5])
# search through a correlation matrix and returns a vector of integers corresponding to columns to remove to reduce pair-wise correlations.
high_correlation <- findCorrelation(feature_correlation,0.9)
head(high_correlation)
## [1] 3
#length(high_correlation)
# PCA
pc <- prcomp(training_set[,-5],center=T,scale=T)
plot(pc,type="l")
model building
# k-fold cross validation
train_control <- trainControl(method="cv", number=5, savePredictions = T)
# fix the parameters of the algorithm
grid <- expand.grid(.fL=c(0), .usekernel=c(FALSE))
# model building
# svm - linear
set.seed(1)
svm_lm_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "svmLinear",preProcess = c("center", "scale","pca"))
## Loading required package: kernlab
## Warning: package 'kernlab' was built under R version 3.2.3
## Warning in .recacheSubclasses(def@className, def, doSubclasses, env):
## undefined subclass "externalRefMethod" of class "kfunction"; definition not
## updated
##
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
##
## alpha
svm_lm_model
## Support Vector Machines with Linear Kernel
##
## 120 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## Pre-processing: centered (4), scaled (4), principal component
## signal extraction (4)
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 96, 96, 96, 96, 96
## Resampling results
##
## Accuracy Kappa Accuracy SD Kappa SD
## 0.925 0.8875 0.06846532 0.102698
##
## Tuning parameter 'C' was held constant at a value of 1
##
varImp(svm_lm_model)
## ROC curve variable importance
##
## variables are sorted by maximum importance across the classes
## setosa versicolor virginica
## Petal.Width 100.00 100.00 100.00
## Petal.Length 100.00 100.00 100.00
## Sepal.Length 91.48 71.54 91.48
## Sepal.Width 52.73 52.73 0.00
plot(varImp(svm_lm_model))
# svm - rbf kernel
set.seed(1)
svm_rbf_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "svmRadial",preProcess = c("center", "scale","pca"))
# classification Trees
set.seed(1)
#rpart_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rpart")
rpart_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rpart",preProcess = c("center", "scale","pca"))
# plot classification trees
fancyRpartPlot(rpart_model$finalModel)
# random forest
set.seed(1)
rf_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "rf",preProcess = c("center", "scale","pca"),prox=T)
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
## Warning in randomForest.default(x, y, mtry = param$mtry, ...): invalid
## mtry: reset to within valid range
# boosting with tres
set.seed(1)
gbm_model <- train(Species ~ .,data = training_set, trControl=train_control, method = "gbm", preProcess = c("center", "scale","pca"), verbose=F)
## Loading required package: gbm
## Loading required package: survival
##
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
##
## cluster
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.1
## Loading required package: plyr
# naive bayes
set.seed(1)
nb_model <- train(Species~., data=training_set, trControl=train_control, method="nb", tuneGrid=grid,preProcess = c("center", "scale","pca"))
## Loading required package: klaR
## Loading required package: MASS
## Warning: package 'MASS' was built under R version 3.2.2
# linear discriminant analysis
set.seed(1)
lda_model <- train(Species~., data=training_set, trControl=train_control, method="lda", preProcess = c("center", "scale","pca"))
# collect resamples
train_results <- resamples(list(SVM_LM=svm_lm_model,SVM_RBF=svm_rbf_model,RPART=rpart_model,GBM=gbm_model,RF=rf_model,NB=nb_model,LDA=lda_model))
# summarize the distributions
summary(train_results)
##
## Call:
## summary.resamples(object = train_results)
##
## Models: SVM_LM, SVM_RBF, RPART, GBM, RF, NB, LDA
## Number of resamples: 5
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## SVM_LM 0.8333 0.8750 0.9583 0.9250 0.9583 1.0000 0
## SVM_RBF 0.8750 0.9167 0.9583 0.9333 0.9583 0.9583 0
## RPART 0.8333 0.8750 0.9167 0.9167 0.9583 1.0000 0
## GBM 0.8750 0.8750 0.9583 0.9333 0.9583 1.0000 0
## RF 0.7500 0.8750 0.8750 0.8917 0.9583 1.0000 0
## NB 0.8333 0.8750 0.9167 0.9083 0.9583 0.9583 0
## LDA 0.8750 0.8750 0.9583 0.9333 0.9583 1.0000 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## SVM_LM 0.7500 0.8125 0.9375 0.8875 0.9375 1.0000 0
## SVM_RBF 0.8125 0.8750 0.9375 0.9000 0.9375 0.9375 0
## RPART 0.7500 0.8125 0.8750 0.8750 0.9375 1.0000 0
## GBM 0.8125 0.8125 0.9375 0.9000 0.9375 1.0000 0
## RF 0.6250 0.8125 0.8125 0.8375 0.9375 1.0000 0
## NB 0.7500 0.8125 0.8750 0.8625 0.9375 0.9375 0
## LDA 0.8125 0.8125 0.9375 0.9000 0.9375 1.0000 0
# boxplots of results
bwplot(train_results)
# the above results suggest that GBM, SVM, LDA model performs best on the training data.
EVALUATE MODEL ACCURACY ON TEST SET
#Ideally, you select model that performs best on training data and evaluate on test set. I am doing for all models just for illustration
validation_pred_svm_lm <- predict(svm_lm_model, newdata=validation_set)
confusionMatrix(data=validation_pred_svm_lm, validation_set$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 7 1
## virginica 0 3 9
##
## Overall Statistics
##
## Accuracy : 0.8667
## 95% CI : (0.6928, 0.9624)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 2.296e-09
##
## Kappa : 0.8
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.7000 0.9000
## Specificity 1.0000 0.9500 0.8500
## Pos Pred Value 1.0000 0.8750 0.7500
## Neg Pred Value 1.0000 0.8636 0.9444
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2333 0.3000
## Detection Prevalence 0.3333 0.2667 0.4000
## Balanced Accuracy 1.0000 0.8250 0.8750
validation_pred_svm_rbf <- predict(svm_rbf_model, newdata=validation_set)
confusionMatrix(data=validation_pred_svm_rbf, validation_set$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 6 1
## virginica 0 4 9
##
## Overall Statistics
##
## Accuracy : 0.8333
## 95% CI : (0.6528, 0.9436)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 2.444e-08
##
## Kappa : 0.75
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.6000 0.9000
## Specificity 1.0000 0.9500 0.8000
## Pos Pred Value 1.0000 0.8571 0.6923
## Neg Pred Value 1.0000 0.8261 0.9412
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2000 0.3000
## Detection Prevalence 0.3333 0.2333 0.4333
## Balanced Accuracy 1.0000 0.7750 0.8500
validation_pred_rpart <- predict(rpart_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_rpart, validation_set$Species)
validation_pred_gbm <- predict(gbm_model, newdata=validation_set)
confusionMatrix(data=validation_pred_gbm, validation_set$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 7 0
## virginica 0 3 10
##
## Overall Statistics
##
## Accuracy : 0.9
## 95% CI : (0.7347, 0.9789)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 1.665e-10
##
## Kappa : 0.85
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.7000 1.0000
## Specificity 1.0000 1.0000 0.8500
## Pos Pred Value 1.0000 1.0000 0.7692
## Neg Pred Value 1.0000 0.8696 1.0000
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2333 0.3333
## Detection Prevalence 0.3333 0.2333 0.4333
## Balanced Accuracy 1.0000 0.8500 0.9250
validation_pred_rf <- predict(rf_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_rf, validation_set$Species)
validation_pred_nb <- predict(nb_model, newdata=validation_set)
#confusionMatrix(data=validation_pred_nb, validation_set$Species)
validation_pred_lda <- predict(lda_model, newdata=validation_set)
confusionMatrix(data=validation_pred_lda, validation_set$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 8 1
## virginica 0 2 9
##
## Overall Statistics
##
## Accuracy : 0.9
## 95% CI : (0.7347, 0.9789)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 1.665e-10
##
## Kappa : 0.85
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.8000 0.9000
## Specificity 1.0000 0.9500 0.9000
## Pos Pred Value 1.0000 0.8889 0.8182
## Neg Pred Value 1.0000 0.9048 0.9474
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2667 0.3000
## Detection Prevalence 0.3333 0.3000 0.3667
## Balanced Accuracy 1.0000 0.8750 0.9000