Loading and spliting the data.

Lets use iris dataset and split the data into training and testing.

library(class)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(knitr)


set.seed(1234)
ind <- sample(2, nrow(iris), replace=TRUE, prob=c(0.6, 0.3))
trainData <- iris[ind==1,]
testData <- iris[ind==2,]

trainData1 = trainData[-5]
testData1 = testData[-5]


iris_train_lbls <- trainData$Species 
iris_test_lbls <- testData$Species 

Knn model with k=4(random).

The accuracy is coming as .095 when k=4. We want to find the optimal value of k which gives better accuracy. If the dataset is imbalanced, ensure to note the precision and recall metrics.

knn_model <- knn(train = trainData1, test = testData1, cl= iris_train_lbls,k = 4)

predict.knn_result <- confusionMatrix(knn_model, testData$Species)
predict.knn_result$table
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         12          0         0
##   versicolor      0         13         2
##   virginica       0          0        14
predict.knn_result$overall['Accuracy']
##  Accuracy 
## 0.9512195

Finding the Optimal value K for the KNN -Model classifer

Here lets try building and evaluvating knn models with different k values ranging from 1 to 100.

knnMetrics <- data.frame(matrix(ncol = 5, nrow = 0) ,stringsAsFactors = FALSE)

for (k in 1:100){
  iris_test_pred1 <- knn(train = trainData1, test = testData1, cl= iris_train_lbls,k = k,prob=TRUE)
  result <- confusionMatrix(iris_test_pred1, testData$Species)
  #print(table(iris_test_pred1, testData$Species))
  knnMetrics <- rbind(knnMetrics, c(k, result$overall['Accuracy']*100, result$byClass[1,'Pos Pred Value']*100, result$byClass[2,'Pos Pred Value']*100,result$byClass[3,'Pos Pred Value']*100 ), stringsAsFactors = FALSE)

  metrics <- c("k", "Accuracy", "setosa- precision","versicolor-precision" ,  "virginica-precision")

}
colnames(knnMetrics) <- metrics
knnMetrics[is.na(knnMetrics)] <- 0

Visualize the variation of accuracy and precision for different value of K

We could see precision for setosa is pretty stable for almost any value of K. Also metrics are getting degraded as k value increases.

ggplot() +
  geom_line(data = knnMetrics, aes(x = k, y = knnMetrics$Accuracy, color="Accuracy"), size = .5)+
  geom_line(data = knnMetrics, aes(x = k, y = knnMetrics$`setosa- precision`, color="Precision-Setosa"), size = .5)+
  geom_line(data = knnMetrics, aes(x = k, y = knnMetrics$`versicolor-precision`, color="Precision-versicolor"), size = .5)+
  geom_line(data = knnMetrics, aes(x = k, y = knnMetrics$`virginica-precision`, color="Precision-virginica"), size = .5)+ stat_smooth() + 
  xlab("K - Value") +
  ylab("Performance Metrics(1-100)") +
  ylim(0, 105) + scale_x_continuous(breaks=seq(0,106,5)) 

We could see the optimal k is 5 when the accuracy becomes 100. so are all precision metrics

kable(knnMetrics[1:30,])
k Accuracy setosa- precision versicolor-precision virginica-precision
1 97.56098 100 92.85714 100
2 97.56098 100 92.85714 100
3 97.56098 100 92.85714 100
4 95.12195 100 86.66667 100
5 100.00000 100 100.00000 100
6 100.00000 100 100.00000 100
7 100.00000 100 100.00000 100
8 100.00000 100 100.00000 100
9 100.00000 100 100.00000 100
10 100.00000 100 100.00000 100
11 100.00000 100 100.00000 100
12 97.56098 100 92.85714 100
13 97.56098 100 92.85714 100
14 100.00000 100 100.00000 100
15 100.00000 100 100.00000 100
16 100.00000 100 100.00000 100
17 100.00000 100 100.00000 100
18 95.12195 100 86.66667 100
19 100.00000 100 100.00000 100
20 97.56098 100 92.85714 100
21 95.12195 100 86.66667 100
22 95.12195 100 86.66667 100
23 95.12195 100 86.66667 100
24 92.68293 100 81.25000 100
25 92.68293 100 81.25000 100
26 95.12195 100 86.66667 100
27 95.12195 100 86.66667 100
28 92.68293 100 81.25000 100
29 95.12195 100 86.66667 100
30 95.12195 100 86.66667 100

Compare with LDA

Knn seems to be out performing the LDA model.

library(MASS)
lda_model <- lda(formula = Species ~ ., 
         data = trainData)

lda.predit = predict(lda_model,newdata = testData[-5] )
predict.lda_result <- confusionMatrix(lda.predit$class, testData$Species)
predict.lda_result$table
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         12          0         0
##   versicolor      0         13         1
##   virginica       0          0        15
predict.lda_result$overall['Accuracy']
##  Accuracy 
## 0.9756098