DATA 607: Data Science in Context - K Nearest Neighbors Algorithm

Eric Lehmphul

10/27/2021

KNN

Regession Example Classification Example

Why use KNN

Limitations of KNN

Euclidean Distance

\[ Distance = \sqrt{(x_1 - u_1)^2 + (x_2 - u_2)^2 + ...+(x_p - u_p)^2} \]

Choosing k

Demo: Classifying Fruit

Load Libraries

library(class) # for knn model
library(caret) # for confusion matrix
## Loading required package: ggplot2
## Loading required package: lattice
library(ggplot2) # for graphs
library(gghighlight) # also for graphs

Dataset

fruit.data <- read.csv("https://raw.githubusercontent.com/SaneSky109/DATA607/main/DataScienceInContext/Data/citrus.csv")

fruit.data$name <- factor(fruit.data$name)

head(fruit.data)
##     name diameter weight red green blue
## 1 orange     2.96  86.76 172    85    2
## 2 orange     3.91  88.05 166    78    3
## 3 orange     4.42  95.17 156    81    2
## 4 orange     4.47  95.60 163    81    4
## 5 orange     4.48  95.76 161    72    9
## 6 orange     4.59  95.86 142   100    2

Normalize Data

normalize <- function(x) {
  return((x - min(x)) / (max(x) - min(x)))
}
norm.fruit <- as.data.frame(lapply(fruit.data[,-1], normalize)) 

norm.fruit$name <- fruit.data$name

head(norm.fruit)
##     diameter      weight       red     green       blue   name
## 1 0.00000000 0.000000000 0.7402597 0.6352941 0.00000000 orange
## 2 0.07042254 0.007381974 0.6623377 0.5529412 0.01851852 orange
## 3 0.10822832 0.048125894 0.5324675 0.5882353 0.00000000 orange
## 4 0.11193477 0.050586552 0.6233766 0.5882353 0.03703704 orange
## 5 0.11267606 0.051502146 0.5974026 0.4823529 0.12962963 orange
## 6 0.12083024 0.052074392 0.3506494 0.8117647 0.00000000 orange

Create Training and Testing Data Frames

set.seed(105)

data.sample <- sample(1:nrow(norm.fruit), size = nrow(norm.fruit)*0.7, replace = FALSE) # random sample of 70% of data

train.fruit <- norm.fruit[data.sample,]
test.fruit <- norm.fruit[-data.sample,]

Create Model

model <- knn(train = train.fruit[,1:5], test = test.fruit[,1:5], cl = train.fruit[,6], k = sqrt(nrow(train.fruit)))

Results

actual <- test.fruit[,6]

results <- table(model, actual)
results
##             actual
## model        grapefruit orange
##   grapefruit       1412    111
##   orange            127   1350
confusionMatrix(results)
## Confusion Matrix and Statistics
## 
##             actual
## model        grapefruit orange
##   grapefruit       1412    111
##   orange            127   1350
##                                           
##                Accuracy : 0.9207          
##                  95% CI : (0.9104, 0.9301)
##     No Information Rate : 0.513           
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8413          
##                                           
##  Mcnemar's Test P-Value : 0.3309          
##                                           
##             Sensitivity : 0.9175          
##             Specificity : 0.9240          
##          Pos Pred Value : 0.9271          
##          Neg Pred Value : 0.9140          
##              Prevalence : 0.5130          
##          Detection Rate : 0.4707          
##    Detection Prevalence : 0.5077          
##       Balanced Accuracy : 0.9208          
##                                           
##        'Positive' Class : grapefruit      
## 

Find Model with optimal k in terms of Accuracy

k.optm <- 1

for (i in 1:sqrt(nrow(train.fruit))) {
  model <- knn(train = train.fruit[,1:5], test = test.fruit[,1:5], cl = train.fruit[,6], k = i)
  k.optm[i] <- 100*sum(test.fruit[,6]==model)/nrow(test.fruit)
  k <- i
  cat(k,'=', k.optm[i],'\n')
}
## 1 = 88.86667 
## 2 = 89.46667 
## 3 = 91.33333 
## 4 = 90.7 
## 5 = 91.3 
## 6 = 91.1 
## 7 = 91.5 
## 8 = 91.56667 
## 9 = 91.7 
## 10 = 91.83333 
## 11 = 91.73333 
## 12 = 91.7 
## 13 = 91.83333 
## 14 = 91.8 
## 15 = 91.83333 
## 16 = 91.96667 
## 17 = 91.86667 
## 18 = 91.76667 
## 19 = 91.8 
## 20 = 91.96667 
## 21 = 91.96667 
## 22 = 92.16667 
## 23 = 92.2 
## 24 = 92.1 
## 25 = 92.06667 
## 26 = 92.16667 
## 27 = 92.23333 
## 28 = 92.06667 
## 29 = 92.06667 
## 30 = 92.03333 
## 31 = 92.13333 
## 32 = 92.13333 
## 33 = 92.1 
## 34 = 92.1 
## 35 = 92.2 
## 36 = 91.93333 
## 37 = 92.1 
## 38 = 92.1 
## 39 = 92.03333 
## 40 = 91.96667 
## 41 = 91.86667 
## 42 = 91.9 
## 43 = 92.1 
## 44 = 92.16667 
## 45 = 92.13333 
## 46 = 92.16667 
## 47 = 92.2 
## 48 = 92.16667 
## 49 = 92.16667 
## 50 = 92.16667 
## 51 = 92.1 
## 52 = 92.06667 
## 53 = 92.16667 
## 54 = 92.2 
## 55 = 92.03333 
## 56 = 92.1 
## 57 = 92.03333 
## 58 = 91.86667 
## 59 = 91.86667 
## 60 = 91.9 
## 61 = 91.8 
## 62 = 91.93333 
## 63 = 91.93333 
## 64 = 92 
## 65 = 92.06667 
## 66 = 92.1 
## 67 = 91.93333 
## 68 = 91.96667 
## 69 = 92.06667 
## 70 = 92 
## 71 = 92.06667 
## 72 = 92.03333 
## 73 = 92.03333 
## 74 = 91.96667 
## 75 = 92 
## 76 = 91.96667 
## 77 = 91.9 
## 78 = 91.96667 
## 79 = 91.93333 
## 80 = 92.06667 
## 81 = 92 
## 82 = 92.03333 
## 83 = 92.06667
accuracy<-as.data.frame(cbind(k = c(1:83),k.optm))
ggplot(accuracy, aes(x=k,y=k.optm)) +
  geom_point() +
  gghighlight(k.optm == max(k.optm)) +
  geom_label(aes(label = k),
              hjust = 1.2, vjust = 1.2, fill = "dark blue", colour = "white", alpha= 0.5)

Q/A

Thank you for listening. Are there any questions?