library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(ggplot2)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
data("iris")
dataset <- iris
# create a list of 80% of the rows in the original dataset we can use for training
validationIndex <- createDataPartition(dataset$Species, p=0.80, list=FALSE)
# select 20% of the data for validation
validation <- dataset[-validationIndex,]
# use the remaining 80% of data to training and testing the models
dataset <- dataset[validationIndex,]
dim(dataset)
## [1] 120 5
sapply(dataset, class)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## "numeric" "numeric" "numeric" "numeric" "factor"
## Peek the data
head(dataset)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
# list the levels class
levels(dataset$Species)
## [1] "setosa" "versicolor" "virginica"
# summarize the class distribution
percentage <- prop.table(table(dataset$Species)) * 100
cbind(freq=table(dataset$Species), percentage=percentage)
## freq percentage
## setosa 40 33.33333
## versicolor 40 33.33333
## virginica 40 33.33333
# summarize attribute distributions
summary(dataset)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.200 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.500 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.400 Median :1.300
## Mean :5.848 Mean :3.067 Mean :3.768 Mean :1.211
## 3rd Qu.:6.400 3rd Qu.:3.325 3rd Qu.:5.100 3rd Qu.:1.825
## Max. :7.700 Max. :4.200 Max. :6.900 Max. :2.500
## Species
## setosa :40
## versicolor:40
## virginica :40
##
##
##
# Univaraite plots
# split input and output
x <- dataset[,1:4]
y <- dataset[,5]
# boxplot for each attribute on one image
par(mfrow=c(1,4))
for(i in 1:4) {
boxplot(x[,i], main=names(iris)[i])
}
# barplot for class breakdown
plot(y)
# multivariate plot
# scatterplot matrix
featurePlot(x=x, y=y, plot="ellipse")
We can see some clear relationships between the input attributes (trends) and between attributes and the class values
# box and whisker plots for each attribute
featurePlot(x=x, y=y, plot="box")
This is useful as it shows that there are clearly different distributions of the attributes for each class value. Next we can get an idea of the distribution of each attribute, again like the box and whisker plots, broken down by class value. Sometimes histograms are good for this, but in this case we will use some probability density plots to give nice smooth lines for each distribution.
# density plots for each attribute by class value
scales <- list(x=list(relation="free"), y=list(relation="free"))
featurePlot(x=x, y=y, plot="density", scales=scales)
Like the boxplots, we can see the difference in distribution of each attribute by class value. We can also see the Gaussian-like distribution (bell curve) of each attribute.
Now it is time to create some models of the data and estimate their accuracy on unseen data. Here is what we are going to cover in this step: 1. Set-up the test harness to use 10-fold cross validation. 2. Build 5 different models to predict species from flower measurements 3. Select the best model.
# Run algorithms using 10-fold cross validation
trainControl <- trainControl(method="cv", number=10)
metric <- "Accuracy"
We are using the metric of Accuracy to evaluate models. This is a ratio of the number of correctly predicted instances divided by the total number of instances in the dataset multiplied by 100 to give a percentage (e.g. 95% accurate). We will be using the metric variable when we run build and evaluate each model next.
We don’t know which algorithms would be good on this problem or what configurations to use. We do get an idea from the plots that some of the classes are partially linearly separable in some dimensions, so we are expecting generally good results. Let’s evaluate 5 different algorithms: - Linear Discriminant Analysis (LDA). - Classification and Regression Trees (CART). - k-Nearest Neighbors (KNN). - Support Vector Machines (SVM) with a radial kernel. - Random Forest (RF).
This is a good mixture of simple linear (LDA), non-linear (CART, KNN) and complex non-linear methods (SVM, RF). We reset the random number seed before reach run to ensure that the evaluation of each algorithm is performed using exactly the same data splits. It ensures the results are directly comparable. Let’s build our five models:
# LDA
set.seed(7)
fit.lda <- train(Species~., data=dataset, method="lda", metric=metric,
trControl=trainControl)
# CART
set.seed(7)
fit.cart <- train(Species~., data=dataset, method="rpart", metric=metric,
trControl=trainControl)
# KNN
set.seed(7)
fit.knn <- train(Species~., data=dataset, method="knn", metric=metric,
trControl=trainControl)
# SVM
set.seed(7)
fit.svm <- train(Species~., data=dataset, method="svmRadial", metric=metric,
trControl=trainControl)
# Random Forest
set.seed(7)
fit.rf <- train(Species~., data=dataset, method="rf", metric=metric, trControl=trainControl)
# summarize accuracy of models
results <- resamples(list(lda=fit.lda, cart=fit.cart, knn=fit.knn, svm=fit.svm, rf=fit.rf))
summary(results)
##
## Call:
## summary.resamples(object = results)
##
## Models: lda, cart, knn, svm, rf
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.8333333 1 1 0.9750000 1 1 0
## cart 0.9166667 1 1 0.9833333 1 1 0
## knn 0.9166667 1 1 0.9833333 1 1 0
## svm 0.7500000 1 1 0.9666667 1 1 0
## rf 0.9166667 1 1 0.9833333 1 1 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.750 1 1 0.9625 1 1 0
## cart 0.875 1 1 0.9750 1 1 0
## knn 0.875 1 1 0.9750 1 1 0
## svm 0.625 1 1 0.9500 1 1 0
## rf 0.875 1 1 0.9750 1 1 0
We can also create a plot of the model evaluation results and compare the spread and the mean accuracy of each model. There is a population of accuracy measures for each algorithm because each algorithm was evaluated 10 times (10 fold cross validation).
# compare accuracy of models
dotplot(results)
# summarize Best Model
print(fit.lda)
## Linear Discriminant Analysis
##
## 120 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 108, 108, 108, 108, 108, 108, ...
## Resampling results:
##
## Accuracy Kappa
## 0.975 0.9625
# estimate skill of LDA on the validation dataset
predictions <- predict(fit.lda, validation)
confusionMatrix(predictions, validation$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 9 1
## virginica 0 1 9
##
## Overall Statistics
##
## Accuracy : 0.9333
## 95% CI : (0.7793, 0.9918)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 8.747e-12
##
## Kappa : 0.9
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.9000 0.9000
## Specificity 1.0000 0.9500 0.9500
## Pos Pred Value 1.0000 0.9000 0.9000
## Neg Pred Value 1.0000 0.9500 0.9500
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3000 0.3000
## Detection Prevalence 0.3333 0.3333 0.3333
## Balanced Accuracy 1.0000 0.9250 0.9250
The accuracy is accepted and within the range