library(caret)
## Warning: package 'caret' was built under R version 3.5.2
## Loading required package: lattice
## Loading required package: ggplot2
library(dplyr)
## Warning: package 'dplyr' was built under R version 3.5.2
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
data1 <- iris
names(data1) <- c("s.len", "s.wid", "p.len", "p.wid", "species")
head(data1,4)
## s.len s.wid p.len p.wid species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
set.seed(12345)
v.index <- createDataPartition(data1$species, p = 0.80, list = FALSE)
validation <- data1[-v.index,] # validation data
tdata <- data1[v.index,] # training data
dim(tdata)
## [1] 120 5
dim(validation)
## [1] 30 5
head(tdata)
## s.len s.wid p.len p.wid species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
## 8 5.0 3.4 1.5 0.2 setosa
head(validation)
## s.len s.wid p.len p.wid species
## 3 4.7 3.2 1.3 0.2 setosa
## 7 4.6 3.4 1.4 0.3 setosa
## 9 4.4 2.9 1.4 0.2 setosa
## 18 5.1 3.5 1.4 0.3 setosa
## 19 5.7 3.8 1.7 0.3 setosa
## 25 4.8 3.4 1.9 0.2 setosa
sapply(tdata, class)
## s.len s.wid p.len p.wid species
## "numeric" "numeric" "numeric" "numeric" "factor"
sapply(validation, class)
## s.len s.wid p.len p.wid species
## "numeric" "numeric" "numeric" "numeric" "factor"
levels(tdata$species)
## [1] "setosa" "versicolor" "virginica"
levels(validation$species)
## [1] "setosa" "versicolor" "virginica"
freq <- data.frame(table(tdata$species))
freq %>% mutate(percent = Freq/nrow(tdata))
## Var1 Freq percent
## 1 setosa 40 0.3333333
## 2 versicolor 40 0.3333333
## 3 virginica 40 0.3333333
par(mfrow = c(1,4))
sapply(tdata[,1:4], boxplot)
## s.len s.wid p.len p.wid
## stats Numeric,5 Numeric,5 Numeric,5 Numeric,5
## n 120 120 120 120
## conf Numeric,2 Numeric,2 Numeric,2 Numeric,2
## out Numeric,0 Numeric,2 Numeric,0 Numeric,0
## group Numeric,0 Numeric,2 Numeric,0 Numeric,0
## names "1" "1" "1" "1"
plot(tdata[,1:4])
x <- tdata[,1:4]
y <- tdata[,5]
featurePlot(x=x, y=y, plot = "ellipse")
featurePlot(x = x, y = y, plot = "box")
Running the algorithms using 10-fold crossvalidation
control <- trainControl(method = "cv", number = 10)
metric <- "Accuracy"
Build models with 5 differ4ent algorithms, (1) Linear Discriminant Analsysis, LDA, (2) Classification and regression Trees, CART, (3) k-Nearest Neighbors, kNN, (4) Support Vector Machines, SVM, with a linear kernel, and (5) Random Forest, RF.
set.seed(3)
f.lda <- train(species~., tdata, method="lda", metric=metric, trControl=control)
set.seed(3)
f.cart <-train(species~., tdata,method="rpart", metric=metric, trControl=control)
set.seed(3)
f.knn <- train(species~., tdata,method="knn", metric=metric, trControl=control)
set.seed(3)
f.svm <-train(species~., tdata,method="svmRadial",metric=metric, trControl=control)
set.seed(3)
f.rf <- train(species~., tdata,method="rf", metric=metric, trControl=control)
result <- resamples(list(lda=f.lda, cart=f.cart, knn=f.knn, svm=f.svm, rf=f.rf))
summary(result)
##
## Call:
## summary.resamples(object = result)
##
## Models: lda, cart, knn, svm, rf
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.8333333 0.9375000 1.0000000 0.9666667 1 1 0
## cart 0.8333333 0.9166667 0.9583333 0.9416667 1 1 0
## knn 0.8333333 0.9166667 1.0000000 0.9500000 1 1 0
## svm 0.7500000 0.8541667 0.9583333 0.9250000 1 1 0
## rf 0.8333333 0.9166667 0.9583333 0.9416667 1 1 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.750 0.90625 1.0000 0.9500 1 1 0
## cart 0.750 0.87500 0.9375 0.9125 1 1 0
## knn 0.750 0.87500 1.0000 0.9250 1 1 0
## svm 0.625 0.78125 0.9375 0.8875 1 1 0
## rf 0.750 0.87500 0.9375 0.9125 1 1 0
The LDA method generated the best accuracy. The summary of the best model
print(f.lda)
## Linear Discriminant Analysis
##
## 120 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 108, 108, 108, 108, 108, 108, ...
## Resampling results:
##
## Accuracy Kappa
## 0.9666667 0.95
Making predictions with the LDA algorithm on the validation or test data
prediction <- predict(f.lda, validation)
confusionMatrix(prediction, validation$species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 10 0
## virginica 0 0 10
##
## Overall Statistics
##
## Accuracy : 1
## 95% CI : (0.8843, 1)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 4.857e-15
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 1.0000
## Specificity 1.0000 1.0000 1.0000
## Pos Pred Value 1.0000 1.0000 1.0000
## Neg Pred Value 1.0000 1.0000 1.0000
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3333 0.3333
## Detection Prevalence 0.3333 0.3333 0.3333
## Balanced Accuracy 1.0000 1.0000 1.0000
The prediction of the test data yields an accuracy of 100%.
prediction2 <- predict(f.svm, validation)
confusionMatrix(prediction2, validation$species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 10 0
## virginica 0 0 10
##
## Overall Statistics
##
## Accuracy : 1
## 95% CI : (0.8843, 1)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 4.857e-15
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 1.0000
## Specificity 1.0000 1.0000 1.0000
## Pos Pred Value 1.0000 1.0000 1.0000
## Neg Pred Value 1.0000 1.0000 1.0000
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3333 0.3333
## Detection Prevalence 0.3333 0.3333 0.3333
## Balanced Accuracy 1.0000 1.0000 1.0000
even the SVM algirithm also generates a perfect accuracy rate on the test or validation data