Lazy Learning: Classification Using Nearest Neighbors (k-NN)

1. Introduction

Lazy learning is a type of instance-based learning, where:

  • The model does not generalize the data at training time.
  • Computation is delayed until a prediction is needed.
  • The model memorizes the training data and searches for similar observations when classifying new instances.

A common lazy learning algorithm is the k-Nearest Neighbors (k-NN) classifier.


2. Understanding k-Nearest Neighbors (k-NN)

  • k-NN is a non-parametric, instance-based learning algorithm.
  • It classifies new data points based on the majority class of their k nearest neighbors in the training data.

2.1 Mathematical Formulation

Given a training dataset:

\[ D = \{(x_1, y_1), (x_2, y_2), ..., (x_n, y_n) \} \]

where:

  • \(x_i\) is a feature vector.
  • \(y_i\) is the class label.

For a new data point \(x^*\), the predicted class \(\hat{y}^*\) is determined as:

\[ \hat{y}^* = \text{mode} \{ y_i \mid x_i \in N_k(x^*) \} \]

where:

  • \(N_k(x^*)\) represents the k nearest neighbors of \(x^*\).
  • The distance metric (e.g., Euclidean, Manhattan) determines the nearest neighbors.

3. Choosing a Distance Metric

The choice of distance metric significantly affects the model’s performance.

3.1 Common Distance Metrics

1. Euclidean Distance (Default)

\[ d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2} \]

2. Manhattan Distance

\[ d(x, y) = \sum_{i=1}^{n} |x_i - y_i| \]

3. Minkowski Distance (Generalized form)

\[ d(x, y) = \left( \sum_{i=1}^{n} |x_i - y_i|^p \right)^{1/p} \]

where \(p\) determines the norm:

  • If \(p = 1\), it reduces to Manhattan distance.
  • If \(p = 2\), it reduces to Euclidean distance.

4. Cosine Similarity

\[ d(x, y) = \frac{x \cdot y}{\|x\| \|y\|} \]

This measures the angle between two vectors rather than their magnitude.


4. Dataset 1: Classification of Flower Species (Iris Dataset)

The Iris dataset contains 150 observations of flowers with four features:

It classifies the flowers into three species: Setosa, Versicolor, and Virginica.

library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.4.1
library(caret)
## Warning: package 'caret' was built under R version 4.4.1
## Loading required package: lattice
library(class)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(gridExtra)
## Warning: package 'gridExtra' was built under R version 4.4.1
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
# Install if necessary
packages <- c("ggplot2", "caret", "class", "dplyr", "gridExtra", "mlbench")
for (p in packages) {
  if (!require(p, character.only = TRUE)) install.packages(p, dependencies = TRUE)
  library(p, character.only = TRUE)
}
## Loading required package: mlbench
## Warning: package 'mlbench' was built under R version 4.4.1

Dataset 1: Classifying Flowers (Iris Dataset):

#The Iris dataset contains 150 observations of flowers with four features (Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) and three species.

# Load dataset
data("iris")
set.seed(123)

# Shuffle data to avoid order bias
iris <- iris[sample(nrow(iris)), ]

# Normalize numeric features
normalize <- function(x) { (x - min(x)) / (max(x) - min(x)) }
iris_norm <- as.data.frame(lapply(iris[1:4], normalize))

# Add target variable back
iris_norm$Species <- iris$Species

# Train-test split (80-20)
set.seed(123)
train_index <- sample(seq_len(nrow(iris_norm)), size = 0.8 * nrow(iris_norm))
train_data <- iris_norm[train_index, ]
test_data <- iris_norm[-train_index, ]
data("iris")

set.seed(123)
# Shuffle the data to remove any ordering
iris <- iris[sample(nrow(iris)), ]

ggplot(iris, aes(x = Petal.Length, y = Petal.Width, color = Species)) +
  geom_point(size = 3, alpha = 0.7) +
  labs(title = "Iris Dataset: Petal Length vs. Width",
       x = "Petal Length", y = "Petal Width") +
  theme_minimal()

4. Demonstration of k-NN in 2D Space

To visualize the decision boundaries, let’s focus on two features: Petal.Length and Petal.Width. This allows us to create a 2D grid of points and color them by predicted class.

4.1 Subset and Normalize

We’ll extract only the two petal features to keep it 2D for demonstration. Let’s also normalize these two features so distances aren’t skewed:

iris_2d <- iris %>%
  select(Petal.Length, Petal.Width, Species)

normalize <- function(x) {
  (x - min(x)) / (max(x) - min(x))
}

iris_2d$Petal.Length <- normalize(iris_2d$Petal.Length)
iris_2d$Petal.Width  <- normalize(iris_2d$Petal.Width)

4.2 Train/Test Split

set.seed(123)
train_index <- sample(1:nrow(iris_2d), size = 0.8 * nrow(iris_2d))
train_data <- iris_2d[train_index, ]
test_data  <- iris_2d[-train_index, ]

# Separate features and labels
train_X <- train_data[, 1:2]
test_X  <- test_data[, 1:2]
train_Y <- train_data$Species
test_Y  <- test_data$Species

Fit k-NN (k=5)

library(class)

k_value <- 5
knn_pred <- knn(train = train_X, test = test_X, cl = train_Y, k = k_value)

# Evaluate performance
conf_matrix <- confusionMatrix(knn_pred, test_Y)
conf_matrix
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa          8          0         0
##   versicolor      0          6         1
##   virginica       0          0        15
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9667          
##                  95% CI : (0.8278, 0.9992)
##     No Information Rate : 0.5333          
##     P-Value [Acc > NIR] : 1.759e-07       
##                                           
##                   Kappa : 0.9458          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            1.0000           0.9375
## Specificity                 1.0000            0.9583           1.0000
## Pos Pred Value              1.0000            0.8571           1.0000
## Neg Pred Value              1.0000            1.0000           0.9333
## Prevalence                  0.2667            0.2000           0.5333
## Detection Rate              0.2667            0.2000           0.5000
## Detection Prevalence        0.2667            0.2333           0.5000
## Balanced Accuracy           1.0000            0.9792           0.9688
  • You’ll see an accuracy measure and class-specific metrics. Typically, we get high accuracy on Iris with k-NN.

5. Visualizing k-NN Decision Boundaries

We can create a fine 2D grid across Petal.Length and Petal.Width and classify each point using our k-NN model.

# Create a grid of points
x_min <- 0
x_max <- 1
y_min <- 0
y_max <- 1

# Sequence of values
grid_points <- 200
x_seq <- seq(x_min, x_max, length.out = grid_points)
y_seq <- seq(y_min, y_max, length.out = grid_points)

# Make a grid of feature vectors
grid_df <- expand.grid(Petal.Length = x_seq, Petal.Width = y_seq)

# Predict on the grid
# We'll use the same knn function, but we must re-train on the 2D space
grid_pred <- knn(
  train = train_X,
  test = grid_df,
  cl = train_Y,
  k = k_value
)

# Combine predictions with the grid
grid_df$PredClass <- grid_pred

# Plot the boundary
ggplot() +
  geom_tile(
    data = grid_df,
    aes(x = Petal.Length, y = Petal.Width, fill = PredClass),
    alpha = 0.3
  ) +
  geom_point(
    data = train_data,
    aes(x = Petal.Length, y = Petal.Width, color = Species),
    size = 3
  ) +
  labs(
    title = paste("k-NN Decision Boundary with k =", k_value),
    x = "Petal Length (normalized)",
    y = "Petal Width (normalized)"
  ) +
  scale_fill_discrete(name = "Predicted Class") +
  theme_minimal()

* The colored regions in the background indicate the predicted class for points in that area. * The points represent the training samples.

6. Hyperparameter Tuning (Selecting k)

Often, we tune k using cross-validation:

set.seed(123)

# We can do 5-fold CV
train_control <- trainControl(method = "cv", number = 5)

knn_tuned <- train(
  Species ~ Petal.Length + Petal.Width,
  data = train_data,
  method = "knn",
  trControl = train_control,
  tuneLength = 10  # Search k from 1 to 10
)

knn_tuned
## k-Nearest Neighbors 
## 
## 120 samples
##   2 predictor
##   3 classes: 'setosa', 'versicolor', 'virginica' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 97, 96, 95, 97, 95 
## Resampling results across tuning parameters:
## 
##   k   Accuracy   Kappa    
##    5  0.9669710  0.9503396
##    7  0.9589710  0.9382036
##    9  0.9589710  0.9382036
##   11  0.9589710  0.9382036
##   13  0.9589710  0.9379725
##   15  0.9669710  0.9499912
##   17  0.9669710  0.9499912
##   19  0.9669710  0.9501079
##   21  0.9419420  0.9120269
##   23  0.9586377  0.9372259
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 5.
  • The output shows the accuracy for different k’s and chooses the best.

7. Additional Plots for Understanding k-NN

7.1 Nearest Neighbors Demo Plot

We can visually highlight the neighbors for a single test point. Let’s pick a test point and see how k-NN finds its neighbors:

# Take one test example
test_example <- test_X[1, , drop = FALSE]

# Distances to each training example
distances <- sqrt(
  (train_X$Petal.Length - test_example$Petal.Length)^2 +
  (train_X$Petal.Width - test_example$Petal.Width)^2
)

# Sort and pick the top k
neighbors_idx <- order(distances)[1:k_value]
neighbors_data <- train_X[neighbors_idx, ]

# Plot
p <- ggplot() +
  geom_point(
    data = train_X,
    aes(x = Petal.Length, y = Petal.Width),
    color = "gray50",
    alpha = 0.5,
    size = 2
  ) +
  geom_point(
    data = neighbors_data,
    aes(x = Petal.Length, y = Petal.Width),
    color = "red",
    size = 3
  ) +
  geom_point(
    data = as.data.frame(test_example),
    aes(x = Petal.Length, y = Petal.Width),
    color = "blue",
    shape = 8,
    size = 4
  ) +
  labs(
    title = paste("k=", k_value, "Nearest Neighbors (red) \nTest Point (blue star)")
  ) +
  xlim(0, 1) + ylim(0, 1) +
  theme_minimal()

p

In this plot:

  • Gray points = all training data
  • Red points = the k=5 neighbors for the test example
  • Blue star = the test example

8. Conclusion

  • k-NN is intuitive and effective for small to medium-sized datasets.
  • Key hyperparameters: the number of neighbors k, the distance metric, and the features selected.
  • Decision boundary visualization helps understand how k-NN classifies in 2D.
  • Cross-validation is recommended for choosing the optimal k.

Final Takeaways:

  • k-NN is easy to implement and interpret.
  • It can struggle with high-dimensional or large datasets (due to computational cost).
  • Normalizing features is often important when using distance-based methods.

References

  • Cover, T., & Hart, P. (1967). Nearest neighbor pattern classification. IEEE Transactions on Information Theory.
  • Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning.

4. Dataset 2: Identifying Handwritten Digits (MNIST-like Subset)

k-NN on Handwritten Digits (MNIST / Sonar)

This document covers:

  • k-NN applied to image recognition (handwritten digits).
  • Decision boundary visualization.
  • Hyperparameter tuning for the optimal k.

1. Introduction

k-NN is widely used for image classification, particularly digit recognition.

We will use the MNIST handwritten digits dataset.

The dataset contains grayscale images of digits (0-9) stored as pixel intensity values.

2. Load Required Packages

packages <- c("ggplot2", "dplyr", "class", "caret", "gridExtra", "mlbench", "R.utils")

for (p in packages) {
  if (!requireNamespace(p, quietly = TRUE)) {
    install.packages(p, dependencies = TRUE)
  }
  library(p, character.only = TRUE)
}
## Warning: package 'R.utils' was built under R version 4.4.1
## Loading required package: R.oo
## Loading required package: R.methodsS3
## R.methodsS3 v1.8.2 (2022-06-13 22:00:14 UTC) successfully loaded. See ?R.methodsS3 for help.
## R.oo v1.26.0 (2024-01-24 05:12:50 UTC) successfully loaded. See ?R.oo for help.
## 
## Attaching package: 'R.oo'
## The following object is masked from 'package:R.methodsS3':
## 
##     throw
## The following objects are masked from 'package:methods':
## 
##     getClasses, getMethods
## The following objects are masked from 'package:base':
## 
##     attach, detach, load, save
## R.utils v2.12.3 (2023-11-18 01:00:02 UTC) successfully loaded. See ?R.utils for help.
## 
## Attaching package: 'R.utils'
## The following object is masked from 'package:utils':
## 
##     timestamp
## The following objects are masked from 'package:base':
## 
##     cat, commandArgs, getOption, isOpen, nullfile, parse, use, warnings

3. Dataset: Handwritten Digits (MNIST-like)

If the ZipDigits dataset is unavailable, we will use the Sonar dataset or the readmnist package.

# Load dataset (replace ZipDigits with Sonar if missing)
data("Sonar")

# Convert label column to a factor
Sonar$Class <- as.factor(Sonar$Class)

# Shuffle data
set.seed(123)
Sonar <- Sonar[sample(nrow(Sonar)), ]

4. Exploratory Data Analysis (EDA)

4.1 Data Overview

str(Sonar)
## 'data.frame':    208 obs. of  61 variables:
##  $ V1   : num  0.0107 0.0303 0.0197 0.009 0.0392 0.013 0.0119 0.0228 0.0211 0.0323 ...
##  $ V2   : num  0.0453 0.0353 0.0394 0.0062 0.0108 0.012 0.0582 0.0106 0.0319 0.0101 ...
##  $ V3   : num  0.0289 0.049 0.0384 0.0253 0.0267 0.0436 0.0623 0.013 0.0415 0.0298 ...
##  $ V4   : num  0.0713 0.0608 0.0076 0.0489 0.0257 0.0624 0.06 0.0842 0.0286 0.0564 ...
##  $ V5   : num  0.1075 0.0167 0.0251 0.1197 0.041 ...
##  $ V6   : num  0.1019 0.1354 0.0629 0.1589 0.0491 ...
##  $ V7   : num  0.1606 0.1465 0.0747 0.1392 0.1053 ...
##  $ V8   : num  0.2119 0.1123 0.0578 0.0987 0.169 ...
##  $ V9   : num  0.3061 0.1945 0.1357 0.0955 0.2105 ...
##  $ V10  : num  0.294 0.235 0.17 0.19 0.247 ...
##  $ V11  : num  0.31 0.29 0.173 0.19 0.268 ...
##  $ V12  : num  0.343 0.281 0.247 0.255 0.305 ...
##  $ V13  : num  0.246 0.158 0.314 0.407 0.286 ...
##  $ V14  : num  0.1887 0.0273 0.3297 0.2988 0.2294 ...
##  $ V15  : num  0.1184 0.0673 0.2759 0.2901 0.1165 ...
##  $ V16  : num  0.208 0.144 0.206 0.533 0.213 ...
##  $ V17  : num  0.274 0.207 0.116 0.402 0.206 ...
##  $ V18  : num  0.327 0.265 0.188 0.157 0.222 ...
##  $ V19  : num  0.234 0.283 0.339 0.302 0.324 ...
##  $ V20  : num  0.126 0.429 0.393 0.391 0.433 ...
##  $ V21  : num  0.0576 0.5685 0.4282 0.3542 0.5071 ...
##  $ V22  : num  0.124 0.699 0.542 0.444 0.594 ...
##  $ V23  : num  0.324 0.725 0.645 0.641 0.708 ...
##  $ V24  : num  0.436 0.762 0.722 0.46 0.764 ...
##  $ V25  : num  0.573 0.924 0.785 0.601 0.888 ...
##  $ V26  : num  0.782 1 0.798 0.869 0.971 ...
##  $ V27  : num  0.925 0.998 0.885 0.835 0.988 ...
##  $ V28  : num  0.935 0.83 0.958 0.767 0.981 ...
##  $ V29  : num  0.935 0.703 0.899 0.508 0.946 ...
##  $ V30  : num  1 0.714 0.683 0.462 0.854 ...
##  $ V31  : num  0.931 0.689 0.611 0.538 0.646 ...
##  $ V32  : num  0.848 0.496 0.548 0.537 0.34 ...
##  $ V33  : num  0.76 0.258 0.506 0.384 0.383 ...
##  $ V34  : num  0.704 0.0969 0.4476 0.3601 0.3204 ...
##  $ V35  : num  0.7539 0.0776 0.2401 0.7402 0.1331 ...
##  $ V36  : num  0.799 0.0364 0.1405 0.7761 0.044 ...
##  $ V37  : num  0.767 0.157 0.177 0.386 0.123 ...
##  $ V38  : num  0.5955 0.1823 0.1742 0.0667 0.203 ...
##  $ V39  : num  0.473 0.135 0.333 0.368 0.165 ...
##  $ V40  : num  0.484 0.0849 0.4021 0.6114 0.1043 ...
##  $ V41  : num  0.434 0.0492 0.3009 0.351 0.1066 ...
##  $ V42  : num  0.395 0.137 0.207 0.231 0.211 ...
##  $ V43  : num  0.484 0.155 0.121 0.22 0.242 ...
##  $ V44  : num  0.5379 0.1548 0.0255 0.3051 0.1631 ...
##  $ V45  : num  0.4485 0.1319 0.0298 0.1937 0.0769 ...
##  $ V46  : num  0.2674 0.0985 0.0691 0.157 0.0723 ...
##  $ V47  : num  0.1541 0.1258 0.0781 0.0479 0.0912 ...
##  $ V48  : num  0.1359 0.0954 0.0777 0.0538 0.0812 ...
##  $ V49  : num  0.0941 0.0489 0.0369 0.0146 0.0496 0.0292 0.0304 0.0649 0.0171 0.0647 ...
##  $ V50  : num  0.0261 0.0241 0.0057 0.0068 0.0101 0.0116 0.0074 0.0313 0.0383 0.0179 ...
##  $ V51  : num  0.0079 0.0042 0.0091 0.0187 0.0089 0.0024 0.0069 0.0185 0.0053 0.0051 ...
##  $ V52  : num  0.0164 0.0086 0.0134 0.0059 0.0083 0.0084 0.0025 0.0098 0.009 0.0061 ...
##  $ V53  : num  0.012 0.0046 0.0097 0.0095 0.008 0.01 0.0103 0.0178 0.0042 0.0093 ...
##  $ V54  : num  0.0113 0.0126 0.0042 0.0194 0.0026 0.0018 0.0074 0.0077 0.0153 0.0135 ...
##  $ V55  : num  0.0021 0.0036 0.0058 0.008 0.0079 0.0035 0.0123 0.0074 0.0106 0.0063 ...
##  $ V56  : num  0.0097 0.0035 0.0072 0.0152 0.0042 0.0058 0.0069 0.0095 0.002 0.0063 ...
##  $ V57  : num  0.0072 0.0034 0.0041 0.0158 0.0071 0.0011 0.0076 0.0055 0.0105 0.0034 ...
##  $ V58  : num  0.006 0.0079 0.0045 0.0053 0.0044 0.0009 0.0073 0.0045 0.0049 0.0032 ...
##  $ V59  : num  0.0017 0.0036 0.0047 0.0189 0.0022 0.0033 0.003 0.0063 0.007 0.0062 ...
##  $ V60  : num  0.0036 0.0048 0.0054 0.0102 0.0014 0.0026 0.0138 0.0039 0.008 0.0067 ...
##  $ Class: Factor w/ 2 levels "M","R": 1 1 1 2 1 1 2 1 2 1 ...
summary(Sonar)
##        V1                V2                V3                V4         
##  Min.   :0.00150   Min.   :0.00060   Min.   :0.00150   Min.   :0.00580  
##  1st Qu.:0.01335   1st Qu.:0.01645   1st Qu.:0.01895   1st Qu.:0.02438  
##  Median :0.02280   Median :0.03080   Median :0.03430   Median :0.04405  
##  Mean   :0.02916   Mean   :0.03844   Mean   :0.04383   Mean   :0.05389  
##  3rd Qu.:0.03555   3rd Qu.:0.04795   3rd Qu.:0.05795   3rd Qu.:0.06450  
##  Max.   :0.13710   Max.   :0.23390   Max.   :0.30590   Max.   :0.42640  
##        V5                V6                V7               V8         
##  Min.   :0.00670   Min.   :0.01020   Min.   :0.0033   Min.   :0.00550  
##  1st Qu.:0.03805   1st Qu.:0.06703   1st Qu.:0.0809   1st Qu.:0.08042  
##  Median :0.06250   Median :0.09215   Median :0.1070   Median :0.11210  
##  Mean   :0.07520   Mean   :0.10457   Mean   :0.1217   Mean   :0.13480  
##  3rd Qu.:0.10028   3rd Qu.:0.13412   3rd Qu.:0.1540   3rd Qu.:0.16960  
##  Max.   :0.40100   Max.   :0.38230   Max.   :0.3729   Max.   :0.45900  
##        V9               V10              V11              V12        
##  Min.   :0.00750   Min.   :0.0113   Min.   :0.0289   Min.   :0.0236  
##  1st Qu.:0.09703   1st Qu.:0.1113   1st Qu.:0.1293   1st Qu.:0.1335  
##  Median :0.15225   Median :0.1824   Median :0.2248   Median :0.2490  
##  Mean   :0.17800   Mean   :0.2083   Mean   :0.2360   Mean   :0.2502  
##  3rd Qu.:0.23342   3rd Qu.:0.2687   3rd Qu.:0.3016   3rd Qu.:0.3312  
##  Max.   :0.68280   Max.   :0.7106   Max.   :0.7342   Max.   :0.7060  
##       V13              V14              V15              V16        
##  Min.   :0.0184   Min.   :0.0273   Min.   :0.0031   Min.   :0.0162  
##  1st Qu.:0.1661   1st Qu.:0.1752   1st Qu.:0.1646   1st Qu.:0.1963  
##  Median :0.2640   Median :0.2811   Median :0.2817   Median :0.3047  
##  Mean   :0.2733   Mean   :0.2966   Mean   :0.3202   Mean   :0.3785  
##  3rd Qu.:0.3513   3rd Qu.:0.3862   3rd Qu.:0.4529   3rd Qu.:0.5357  
##  Max.   :0.7131   Max.   :0.9970   Max.   :1.0000   Max.   :0.9988  
##       V17              V18              V19              V20        
##  Min.   :0.0349   Min.   :0.0375   Min.   :0.0494   Min.   :0.0656  
##  1st Qu.:0.2059   1st Qu.:0.2421   1st Qu.:0.2991   1st Qu.:0.3506  
##  Median :0.3084   Median :0.3683   Median :0.4350   Median :0.5425  
##  Mean   :0.4160   Mean   :0.4523   Mean   :0.5048   Mean   :0.5630  
##  3rd Qu.:0.6594   3rd Qu.:0.6791   3rd Qu.:0.7314   3rd Qu.:0.8093  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##       V21              V22              V23              V24        
##  Min.   :0.0512   Min.   :0.0219   Min.   :0.0563   Min.   :0.0239  
##  1st Qu.:0.3997   1st Qu.:0.4069   1st Qu.:0.4502   1st Qu.:0.5407  
##  Median :0.6177   Median :0.6649   Median :0.6997   Median :0.6985  
##  Mean   :0.6091   Mean   :0.6243   Mean   :0.6470   Mean   :0.6727  
##  3rd Qu.:0.8170   3rd Qu.:0.8320   3rd Qu.:0.8486   3rd Qu.:0.8722  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##       V25              V26              V27              V28        
##  Min.   :0.0240   Min.   :0.0921   Min.   :0.0481   Min.   :0.0284  
##  1st Qu.:0.5258   1st Qu.:0.5442   1st Qu.:0.5319   1st Qu.:0.5348  
##  Median :0.7211   Median :0.7545   Median :0.7456   Median :0.7319  
##  Mean   :0.6754   Mean   :0.6999   Mean   :0.7022   Mean   :0.6940  
##  3rd Qu.:0.8737   3rd Qu.:0.8938   3rd Qu.:0.9171   3rd Qu.:0.9003  
##  Max.   :1.0000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
##       V29              V30              V31              V32        
##  Min.   :0.0144   Min.   :0.0613   Min.   :0.0482   Min.   :0.0404  
##  1st Qu.:0.4637   1st Qu.:0.4114   1st Qu.:0.3456   1st Qu.:0.2814  
##  Median :0.6808   Median :0.6071   Median :0.4904   Median :0.4296  
##  Mean   :0.6421   Mean   :0.5809   Mean   :0.5045   Mean   :0.4390  
##  3rd Qu.:0.8521   3rd Qu.:0.7352   3rd Qu.:0.6420   3rd Qu.:0.5803  
##  Max.   :1.0000   Max.   :1.0000   Max.   :0.9657   Max.   :0.9306  
##       V33              V34              V35              V36        
##  Min.   :0.0477   Min.   :0.0212   Min.   :0.0223   Min.   :0.0080  
##  1st Qu.:0.2579   1st Qu.:0.2176   1st Qu.:0.1794   1st Qu.:0.1543  
##  Median :0.3912   Median :0.3510   Median :0.3127   Median :0.3211  
##  Mean   :0.4172   Mean   :0.4032   Mean   :0.3926   Mean   :0.3848  
##  3rd Qu.:0.5561   3rd Qu.:0.5961   3rd Qu.:0.5934   3rd Qu.:0.5565  
##  Max.   :1.0000   Max.   :0.9647   Max.   :1.0000   Max.   :1.0000  
##       V37              V38              V39              V40        
##  Min.   :0.0351   Min.   :0.0383   Min.   :0.0371   Min.   :0.0117  
##  1st Qu.:0.1601   1st Qu.:0.1743   1st Qu.:0.1740   1st Qu.:0.1865  
##  Median :0.3063   Median :0.3127   Median :0.2835   Median :0.2781  
##  Mean   :0.3638   Mean   :0.3397   Mean   :0.3258   Mean   :0.3112  
##  3rd Qu.:0.5189   3rd Qu.:0.4405   3rd Qu.:0.4349   3rd Qu.:0.4244  
##  Max.   :0.9497   Max.   :1.0000   Max.   :0.9857   Max.   :0.9297  
##       V41              V42              V43              V44        
##  Min.   :0.0360   Min.   :0.0056   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.1631   1st Qu.:0.1589   1st Qu.:0.1552   1st Qu.:0.1269  
##  Median :0.2595   Median :0.2451   Median :0.2225   Median :0.1777  
##  Mean   :0.2893   Mean   :0.2783   Mean   :0.2465   Mean   :0.2141  
##  3rd Qu.:0.3875   3rd Qu.:0.3842   3rd Qu.:0.3245   3rd Qu.:0.2717  
##  Max.   :0.8995   Max.   :0.8246   Max.   :0.7733   Max.   :0.7762  
##       V45               V46               V47               V48         
##  Min.   :0.00000   Min.   :0.00000   Min.   :0.00000   Min.   :0.00000  
##  1st Qu.:0.09448   1st Qu.:0.06855   1st Qu.:0.06425   1st Qu.:0.04512  
##  Median :0.14800   Median :0.12135   Median :0.10165   Median :0.07810  
##  Mean   :0.19723   Mean   :0.16063   Mean   :0.12245   Mean   :0.09142  
##  3rd Qu.:0.23155   3rd Qu.:0.20037   3rd Qu.:0.15443   3rd Qu.:0.12010  
##  Max.   :0.70340   Max.   :0.72920   Max.   :0.55220   Max.   :0.33390  
##       V49               V50               V51                V52          
##  Min.   :0.00000   Min.   :0.00000   Min.   :0.000000   Min.   :0.000800  
##  1st Qu.:0.02635   1st Qu.:0.01155   1st Qu.:0.008425   1st Qu.:0.007275  
##  Median :0.04470   Median :0.01790   Median :0.013900   Median :0.011400  
##  Mean   :0.05193   Mean   :0.02042   Mean   :0.016069   Mean   :0.013420  
##  3rd Qu.:0.06853   3rd Qu.:0.02527   3rd Qu.:0.020825   3rd Qu.:0.016725  
##  Max.   :0.19810   Max.   :0.08250   Max.   :0.100400   Max.   :0.070900  
##       V53                V54                V55               V56          
##  Min.   :0.000500   Min.   :0.001000   Min.   :0.00060   Min.   :0.000400  
##  1st Qu.:0.005075   1st Qu.:0.005375   1st Qu.:0.00415   1st Qu.:0.004400  
##  Median :0.009550   Median :0.009300   Median :0.00750   Median :0.006850  
##  Mean   :0.010709   Mean   :0.010941   Mean   :0.00929   Mean   :0.008222  
##  3rd Qu.:0.014900   3rd Qu.:0.014500   3rd Qu.:0.01210   3rd Qu.:0.010575  
##  Max.   :0.039000   Max.   :0.035200   Max.   :0.04470   Max.   :0.039400  
##       V57               V58                V59                V60          
##  Min.   :0.00030   Min.   :0.000300   Min.   :0.000100   Min.   :0.000600  
##  1st Qu.:0.00370   1st Qu.:0.003600   1st Qu.:0.003675   1st Qu.:0.003100  
##  Median :0.00595   Median :0.005800   Median :0.006400   Median :0.005300  
##  Mean   :0.00782   Mean   :0.007949   Mean   :0.007941   Mean   :0.006507  
##  3rd Qu.:0.01043   3rd Qu.:0.010350   3rd Qu.:0.010325   3rd Qu.:0.008525  
##  Max.   :0.03550   Max.   :0.044000   Max.   :0.036400   Max.   :0.043900  
##  Class  
##  M:111  
##  R: 97  
##         
##         
##         
## 

4.2 Class Distribution

ggplot(Sonar, aes(x = Class, fill = Class)) +
  geom_bar() +
  labs(title = "Class Distribution in the Dataset",
       x = "Class",
       y = "Count") +
  theme_minimal()

## 5. Train-Test Split & Normalization

We split the dataset into 80% training and 20% testing sets.

set.seed(123)
train_index <- sample(seq_len(nrow(Sonar)), size = 0.8 * nrow(Sonar))
train_data <- Sonar[train_index, ]
test_data  <- Sonar[-train_index, ]

# Extract features and labels
train_X <- train_data[, -ncol(Sonar)]
test_X  <- test_data[, -ncol(Sonar)]
train_Y <- train_data$Class
test_Y  <- test_data$Class

6. Train k-NN Classifier

# Train k-NN classifier with k=3
knn_pred <- knn(train = train_X, test = test_X, cl = train_Y, k = 3)

# Confusion Matrix
conf_matrix <- confusionMatrix(knn_pred, test_Y)
conf_matrix
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  M  R
##          M 21  7
##          R  2 12
##                                          
##                Accuracy : 0.7857         
##                  95% CI : (0.6319, 0.897)
##     No Information Rate : 0.5476         
##     P-Value [Acc > NIR] : 0.001189       
##                                          
##                   Kappa : 0.5574         
##                                          
##  Mcnemar's Test P-Value : 0.182422       
##                                          
##             Sensitivity : 0.9130         
##             Specificity : 0.6316         
##          Pos Pred Value : 0.7500         
##          Neg Pred Value : 0.8571         
##              Prevalence : 0.5476         
##          Detection Rate : 0.5000         
##    Detection Prevalence : 0.6667         
##       Balanced Accuracy : 0.7723         
##                                          
##        'Positive' Class : M              
## 

7. Visualizing How k-NN Works

  • We pick one test sample and highlight its nearest neighbors.
# Pick a test point (ensure it has column names)
test_example <- test_X[1, , drop = FALSE]
colnames(test_example) <- colnames(train_X)

# Compute Euclidean distance to all training points
distances <- sqrt(rowSums((train_X - matrix(as.numeric(test_example), 
                                             nrow = nrow(train_X), 
                                             ncol = ncol(train_X), 
                                             byrow = TRUE))^2))

# Select k nearest neighbors
k <- 3
neighbors_idx <- order(distances)[1:k]
neighbors_data <- train_X[neighbors_idx, ]

# PCA for visualization
pca_train <- prcomp(train_X, center = TRUE, scale. = TRUE)
pca_neighbors <- predict(pca_train, neighbors_data)

# Convert the PCA output for the test example to a data frame
pca_test <- as.data.frame(predict(pca_train, test_example))
colnames(pca_test) <- c("PC1", "PC2")  # Explicitly rename columns

# Create data frames for plotting
df_plot <- data.frame(pca_train$x[, 1:2], Label = train_Y)
colnames(df_plot) <- c("PC1", "PC2", "Label")

df_neighbors <- data.frame(pca_neighbors[, 1:2], Label = train_Y[neighbors_idx])
colnames(df_neighbors) <- c("PC1", "PC2", "Label")

df_test <- pca_test  # Already formatted properly
df_test$Label <- "Test Point"

# Plot PCA-reduced neighbors
ggplot(df_plot, aes(x = PC1, y = PC2, color = Label)) +
  geom_point(alpha = 0.5) +
  geom_point(data = df_neighbors, aes(x = PC1, y = PC2), color = "red", size = 3) +
  geom_point(data = df_test, aes(x = PC1, y = PC2), color = "blue", shape = 8, size = 4) +
  labs(title = "k-NN Neighbors (Red) for Test Point (Blue Star)") +
  theme_minimal()

# Pick a test point (ensure it has column names)
test_example <- test_X[1, , drop = FALSE]
colnames(test_example) <- colnames(train_X)

# Compute Euclidean distance to all training points
distances <- sqrt(rowSums((train_X - matrix(as.numeric(test_example), 
                                             nrow = nrow(train_X), 
                                             ncol = ncol(train_X), 
                                             byrow = TRUE))^2))

# Select k nearest neighbors
k <- 5
neighbors_idx <- order(distances)[1:k]
neighbors_data <- train_X[neighbors_idx, ]

# PCA for visualization
pca_train <- prcomp(train_X, center = TRUE, scale. = TRUE)
pca_neighbors <- predict(pca_train, neighbors_data)

# Convert the PCA output for the test example to a data frame
pca_test <- as.data.frame(predict(pca_train, test_example))
colnames(pca_test) <- c("PC1", "PC2")  # Explicitly rename columns

# Create data frames for plotting
df_plot <- data.frame(pca_train$x[, 1:2], Label = train_Y)
colnames(df_plot) <- c("PC1", "PC2", "Label")

df_neighbors <- data.frame(pca_neighbors[, 1:2], Label = train_Y[neighbors_idx])
colnames(df_neighbors) <- c("PC1", "PC2", "Label")

df_test <- pca_test  # Already formatted properly
df_test$Label <- "Test Point"

# Plot PCA-reduced neighbors
ggplot(df_plot, aes(x = PC1, y = PC2, color = Label)) +
  geom_point(alpha = 0.5) +
  geom_point(data = df_neighbors, aes(x = PC1, y = PC2), color = "red", size = 3) +
  geom_point(data = df_test, aes(x = PC1, y = PC2), color = "blue", shape = 8, size = 4) +
  labs(title = "k-NN Neighbors (Red) for Test Point (Blue Star)") +
  theme_minimal()

## 8. Hyperparameter Tuning (Finding Best k)

set.seed(123)

train_control <- trainControl(method = "cv", number = 5)

knn_tuned <- train(Class ~ ., data = train_data,
                   method = "knn",
                   trControl = train_control,
                   tuneLength = 10)

knn_tuned
## k-Nearest Neighbors 
## 
## 166 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 132, 134, 133, 133, 132 
## Resampling results across tuning parameters:
## 
##   k   Accuracy   Kappa    
##    5  0.7462121  0.4833202
##    7  0.6853721  0.3559735
##    9  0.6434826  0.2705788
##   11  0.6497215  0.2845346
##   13  0.6679144  0.3202643
##   15  0.6620321  0.3066610
##   17  0.6375891  0.2580823
##   19  0.6194073  0.2199849
##   21  0.6249220  0.2300818
##   23  0.6370432  0.2555753
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 5.

10. References

  • Cover, T., & Hart, P. (1967). Nearest neighbor pattern classification.
  • Hastie, T., Tibshirani, R., & Friedman, J. (2009). Elements of Statistical Learning. (Or: The Elements of Statistical Learning)