Predicting Iris Flower Species

This project investigates the species of flowers on the Iris dataset. During this exercise we will perform the following:
  1. Load and Check Data
  2. Split the data
  3. Summarize data
  4. Visualize Data
  5. Build Models
  6. Make Prediction

1. Load and Check Data

# Load and check the training data
irisdata <- read.csv('iris.csv', header=FALSE)
colnames(irisdata) <- c("Sepal.Length","Sepal.Width","Petal.Length","Petal.Width","Species")
str(irisdata)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "Iris-setosa",..: 1 1 1 1 1 1 1 1 1 1 ...
As we can see, the dataset has 150 observations and 5 variables. Looking at the structure of our dataset gives us some sense of the data like:
  1. Dimensions of the dataset.
  2. Types of the attributes.
  3. Levels of the class attribute.

2. Split the Data

We will split the loaded dataset into two, 80% of which we will use to train our models and 20% will be used to test our model.

library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
# Split the data
sample <- createDataPartition(irisdata$Species, p=0.80, list=FALSE)

# Create training data
iris_train <- irisdata[sample,]
str(iris_train)
## 'data.frame':    120 obs. of  5 variables:
##  $ Sepal.Length: num  4.9 4.7 4.6 5 5.4 4.6 5 4.4 5.4 4.8 ...
##  $ Sepal.Width : num  3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.7 3.4 ...
##  $ Petal.Length: num  1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 1.6 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.2 0.2 ...
##  $ Species     : Factor w/ 3 levels "Iris-setosa",..: 1 1 1 1 1 1 1 1 1 1 ...
# Create test data
iris_test <- irisdata[-sample,]
str(iris_test)
## 'data.frame':    30 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.8 5.1 5.1 5.2 5.5 4.4 5 5.1 ...
##  $ Sepal.Width : num  3.5 3.1 3 3.5 3.8 4.1 4.2 3 3.5 3.8 ...
##  $ Petal.Length: num  1.4 1.5 1.4 1.4 1.5 1.5 1.4 1.3 1.3 1.6 ...
##  $ Petal.Width : num  0.2 0.1 0.1 0.3 0.3 0.1 0.2 0.2 0.3 0.2 ...
##  $ Species     : Factor w/ 3 levels "Iris-setosa",..: 1 1 1 1 1 1 1 1 1 1 ...

The training dataset has 120 observations and testing data has 30.

3. Summarize Data

3.1 Instances by Class

Let’s look at the number of instances (rows) that belong to each class. We can view this as an absolute count and as a percentage.

# Summarize the class distribution
percentage <- prop.table(table(iris_train$Species))*100
cbind(Freq = table(iris_train$Species), Percentage = percentage)
##                 Freq Percentage
## Iris-setosa       40   33.33333
## Iris-versicolor   40   33.33333
## Iris-virginica    40   33.33333

Observation- We can see that each class has the same number of instances (40 or 33% of the dataset).

3.2 Summarize the dataset

summary(iris_train)
##   Sepal.Length    Sepal.Width     Petal.Length    Petal.Width   
##  Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
##  1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
##  Median :5.800   Median :3.000   Median :4.400   Median :1.300  
##  Mean   :5.852   Mean   :3.034   Mean   :3.773   Mean   :1.198  
##  3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
##  Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
##             Species  
##  Iris-setosa    :40  
##  Iris-versicolor:40  
##  Iris-virginica :40  
##                      
##                      
## 

Observation- We can see that all of the numerical values have the same scale (centimeters) and similar ranges [0,8] centimeters.

4. Visualize Data

We are going to look at two types of plots:
  1. Univariate plots to better understand each attribute.
  2. Multivariate plots to better understand the relationships between attributes.

4.1 Univariate Plots

Let’s split the input attributes x and the output attribute (or class) y.

x <- iris_train[,1:4]
y<- iris_train[,5]
4.1.1 - Box Plot

The input variables are numeric, we can create box and whisker plots of each.

par(mfrow=c(1,4))
for (i in 1:4) {
  boxplot(x[i], main=names(iris_train)[i])
}

4.1.2 - Bar Plot

Let’s create a bar plot of the class variable

library(ggplot2)
qplot(y, xlab='Species')

4.2 Multivariate Plots

4.2.1 - Scatter Plots

Let’s look at scatter plots of all pairs of attributes and color the points by class.

library(caret)
featurePlot(x=x, y=y, plot='ellipse', auto.key=list(columns=3))

4.2.2 - Box Plots

Let’s look at the box plots of each attribute, broken down into separate plots for each class

featurePlot(x=x, y=y, plot='box', auto.key=list(columns=3))

4.2.3 - Density Plots

Let’s look at the distribution of each attribute, broken down into separate plots for each class using density plots

featurePlot(x=x, y=y, 
            plot='density', 
            scales = list(x = list(relation='free'),
                          y = list(relation='free')),
            auto.key=list(columns=3))

Observation- Like boxplots, we can see the difference in distribution of each attribute by class value. We can also see the bell curve of each attribute.

5. Build Models

Let’s evaluate some algorithms and estimate their accuracy on unseen data. We will perform the following in this step:
  1. Build 5 different models to predict species
  • Linear Discriminant Analysis (LDA)
  • Classification and Regression Trees (CART)
  • k-Nearest Neighbors (KNN)
  • Support Vector Machines (SVM) with a radial kernel
  • Random Forest (RF)
  1. Select the best model

5.1 Evaluate algorithms and pick the best model

We will use 10-fold cross validation to estimate accuracy. This will split our dataset into 10 parts, train in 9 and test on 1 and repeat for all combinations of train-test splits.

We will use a mixture of simple linear (LDA), nonlinear (CART, KNN) and complex nonlinear methods (SVM, RF).

library(caret)
control <- trainControl(method='cv', number=10)
metric <- 'Accuracy'

# Linear Discriminant Analysis (LDA)
set.seed(101)
fit.lda <- train(Species~., data=iris_train, method='lda', 
                  trControl=control, metric=metric)
## Loading required package: MASS
# Classification and Regression Trees (CART)
set.seed(101)
fit.cart <- train(Species~., data=iris_train, method='rpart', 
                  trControl=control, metric=metric)
## Loading required package: rpart
# k-Nearest Neighbors (KNN)
set.seed(101)
fit.knn <- train(Species~., data=iris_train, method='knn', 
                  trControl=control, metric=metric)

# Support Vector Machines (SVM) with a radial kernel
set.seed(101)
fit.svm <- train(Species~., data=iris_train, method='svmRadial', 
                  trControl=control, metric=metric)
## Loading required package: kernlab
## 
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
## 
##     alpha
# Random Forest (RF)
set.seed(101)
fit.rf <- train(Species~., data=iris_train, method='ranger', 
                  trControl=control, metric=metric)
## Loading required package: e1071
## Loading required package: ranger
# Compare the results of these algorithms
iris.results <- resamples(list(lda=fit.lda, cart=fit.cart, knn=fit.knn, svm=fit.svm, rf=fit.rf))

# Table Comparison
summary(iris.results)
## 
## Call:
## summary.resamples(object = iris.results)
## 
## Models: lda, cart, knn, svm, rf 
## Number of resamples: 10 
## 
## Accuracy 
##        Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
## lda  0.9167  0.9375 1.0000 0.9750  1.0000    1    0
## cart 0.9167  0.9167 0.9167 0.9250  0.9167    1    0
## knn  0.9167  0.9167 1.0000 0.9667  1.0000    1    0
## svm  0.9167  0.9167 0.9167 0.9500  1.0000    1    0
## rf   0.9167  0.9167 0.9167 0.9500  1.0000    1    0
## 
## Kappa 
##       Min. 1st Qu. Median   Mean 3rd Qu. Max. NA's
## lda  0.875  0.9062  1.000 0.9625   1.000    1    0
## cart 0.875  0.8750  0.875 0.8875   0.875    1    0
## knn  0.875  0.8750  1.000 0.9500   1.000    1    0
## svm  0.875  0.8750  0.875 0.9250   1.000    1    0
## rf   0.875  0.8750  0.875 0.9250   1.000    1    0
# Let's plot the results of these algorithms:
bwplot(iris.results)

dotplot(iris.results)

Observation- Looking at the results and plots, we can say that mean accuracy of LDA model is better than other models. Lets use LDA model to make final predictions.

6. Make Prediction

We will use the test data and LDA model to make final predictions.

6.1 Predict on test data

iris_prediction <- predict(fit.lda, iris_test)
confusionMatrix(iris_prediction, iris_test$Species)
## Confusion Matrix and Statistics
## 
##                  Reference
## Prediction        Iris-setosa Iris-versicolor Iris-virginica
##   Iris-setosa              10               0              0
##   Iris-versicolor           0              10              0
##   Iris-virginica            0               0             10
## 
## Overall Statistics
##                                      
##                Accuracy : 1          
##                  95% CI : (0.8843, 1)
##     No Information Rate : 0.3333     
##     P-Value [Acc > NIR] : 4.857e-15  
##                                      
##                   Kappa : 1          
##  Mcnemar's Test P-Value : NA         
## 
## Statistics by Class:
## 
##                      Class: Iris-setosa Class: Iris-versicolor
## Sensitivity                      1.0000                 1.0000
## Specificity                      1.0000                 1.0000
## Pos Pred Value                   1.0000                 1.0000
## Neg Pred Value                   1.0000                 1.0000
## Prevalence                       0.3333                 0.3333
## Detection Rate                   0.3333                 0.3333
## Detection Prevalence             0.3333                 0.3333
## Balanced Accuracy                1.0000                 1.0000
##                      Class: Iris-virginica
## Sensitivity                         1.0000
## Specificity                         1.0000
## Pos Pred Value                      1.0000
## Neg Pred Value                      1.0000
## Prevalence                          0.3333
## Detection Rate                      0.3333
## Detection Prevalence                0.3333
## Balanced Accuracy                   1.0000

We can see that the accuracy is 100%. It was a small validation dataset, but this result is within our expected margin of 97% +/-4% suggesting we may have an accurate and a reliably accurate model.