Lazy learning is a type of instance-based learning, where:
A common lazy learning algorithm is the k-Nearest Neighbors (k-NN) classifier.
Given a training dataset:
\[ D = \{(x_1, y_1), (x_2, y_2), ..., (x_n, y_n) \} \]
where:
For a new data point \(x^*\), the predicted class \(\hat{y}^*\) is determined as:
\[ \hat{y}^* = \text{mode} \{ y_i \mid x_i \in N_k(x^*) \} \]
where:
The choice of distance metric significantly affects the model’s performance.
\[ d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2} \]
\[ d(x, y) = \sum_{i=1}^{n} |x_i - y_i| \]
\[ d(x, y) = \left( \sum_{i=1}^{n} |x_i - y_i|^p \right)^{1/p} \]
where \(p\) determines the norm:
\[ d(x, y) = \frac{x \cdot y}{\|x\| \|y\|} \]
This measures the angle between two vectors rather than their magnitude.
The Iris dataset contains 150 observations of flowers with four features:
It classifies the flowers into three species: Setosa, Versicolor, and Virginica.
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.4.1
library(caret)
## Warning: package 'caret' was built under R version 4.4.1
## Loading required package: lattice
library(class)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(gridExtra)
## Warning: package 'gridExtra' was built under R version 4.4.1
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
# Install if necessary
packages <- c("ggplot2", "caret", "class", "dplyr", "gridExtra", "mlbench")
for (p in packages) {
if (!require(p, character.only = TRUE)) install.packages(p, dependencies = TRUE)
library(p, character.only = TRUE)
}
## Loading required package: mlbench
## Warning: package 'mlbench' was built under R version 4.4.1
#The Iris dataset contains 150 observations of flowers with four features (Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) and three species.
# Load dataset
data("iris")
set.seed(123)
# Shuffle data to avoid order bias
iris <- iris[sample(nrow(iris)), ]
# Normalize numeric features
normalize <- function(x) { (x - min(x)) / (max(x) - min(x)) }
iris_norm <- as.data.frame(lapply(iris[1:4], normalize))
# Add target variable back
iris_norm$Species <- iris$Species
# Train-test split (80-20)
set.seed(123)
train_index <- sample(seq_len(nrow(iris_norm)), size = 0.8 * nrow(iris_norm))
train_data <- iris_norm[train_index, ]
test_data <- iris_norm[-train_index, ]
data("iris")
set.seed(123)
# Shuffle the data to remove any ordering
iris <- iris[sample(nrow(iris)), ]
ggplot(iris, aes(x = Petal.Length, y = Petal.Width, color = Species)) +
geom_point(size = 3, alpha = 0.7) +
labs(title = "Iris Dataset: Petal Length vs. Width",
x = "Petal Length", y = "Petal Width") +
theme_minimal()
To visualize the decision boundaries, let’s focus on two features:
Petal.Length
and Petal.Width
. This allows us
to create a 2D grid of points and color them by predicted class.
We’ll extract only the two petal features to keep it 2D for demonstration. Let’s also normalize these two features so distances aren’t skewed:
iris_2d <- iris %>%
select(Petal.Length, Petal.Width, Species)
normalize <- function(x) {
(x - min(x)) / (max(x) - min(x))
}
iris_2d$Petal.Length <- normalize(iris_2d$Petal.Length)
iris_2d$Petal.Width <- normalize(iris_2d$Petal.Width)
set.seed(123)
train_index <- sample(1:nrow(iris_2d), size = 0.8 * nrow(iris_2d))
train_data <- iris_2d[train_index, ]
test_data <- iris_2d[-train_index, ]
# Separate features and labels
train_X <- train_data[, 1:2]
test_X <- test_data[, 1:2]
train_Y <- train_data$Species
test_Y <- test_data$Species
library(class)
k_value <- 5
knn_pred <- knn(train = train_X, test = test_X, cl = train_Y, k = k_value)
# Evaluate performance
conf_matrix <- confusionMatrix(knn_pred, test_Y)
conf_matrix
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 8 0 0
## versicolor 0 6 1
## virginica 0 0 15
##
## Overall Statistics
##
## Accuracy : 0.9667
## 95% CI : (0.8278, 0.9992)
## No Information Rate : 0.5333
## P-Value [Acc > NIR] : 1.759e-07
##
## Kappa : 0.9458
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 0.9375
## Specificity 1.0000 0.9583 1.0000
## Pos Pred Value 1.0000 0.8571 1.0000
## Neg Pred Value 1.0000 1.0000 0.9333
## Prevalence 0.2667 0.2000 0.5333
## Detection Rate 0.2667 0.2000 0.5000
## Detection Prevalence 0.2667 0.2333 0.5000
## Balanced Accuracy 1.0000 0.9792 0.9688
We can create a fine 2D grid across Petal.Length
and
Petal.Width
and classify each point using our k-NN
model.
# Create a grid of points
x_min <- 0
x_max <- 1
y_min <- 0
y_max <- 1
# Sequence of values
grid_points <- 200
x_seq <- seq(x_min, x_max, length.out = grid_points)
y_seq <- seq(y_min, y_max, length.out = grid_points)
# Make a grid of feature vectors
grid_df <- expand.grid(Petal.Length = x_seq, Petal.Width = y_seq)
# Predict on the grid
# We'll use the same knn function, but we must re-train on the 2D space
grid_pred <- knn(
train = train_X,
test = grid_df,
cl = train_Y,
k = k_value
)
# Combine predictions with the grid
grid_df$PredClass <- grid_pred
# Plot the boundary
ggplot() +
geom_tile(
data = grid_df,
aes(x = Petal.Length, y = Petal.Width, fill = PredClass),
alpha = 0.3
) +
geom_point(
data = train_data,
aes(x = Petal.Length, y = Petal.Width, color = Species),
size = 3
) +
labs(
title = paste("k-NN Decision Boundary with k =", k_value),
x = "Petal Length (normalized)",
y = "Petal Width (normalized)"
) +
scale_fill_discrete(name = "Predicted Class") +
theme_minimal()
* The colored regions in the background indicate the predicted class for
points in that area. * The points represent the training samples.
Often, we tune k using cross-validation:
set.seed(123)
# We can do 5-fold CV
train_control <- trainControl(method = "cv", number = 5)
knn_tuned <- train(
Species ~ Petal.Length + Petal.Width,
data = train_data,
method = "knn",
trControl = train_control,
tuneLength = 10 # Search k from 1 to 10
)
knn_tuned
## k-Nearest Neighbors
##
## 120 samples
## 2 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 97, 96, 95, 97, 95
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 5 0.9669710 0.9503396
## 7 0.9589710 0.9382036
## 9 0.9589710 0.9382036
## 11 0.9589710 0.9382036
## 13 0.9589710 0.9379725
## 15 0.9669710 0.9499912
## 17 0.9669710 0.9499912
## 19 0.9669710 0.9501079
## 21 0.9419420 0.9120269
## 23 0.9586377 0.9372259
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 5.
We can visually highlight the neighbors for a single test point. Let’s pick a test point and see how k-NN finds its neighbors:
# Take one test example
test_example <- test_X[1, , drop = FALSE]
# Distances to each training example
distances <- sqrt(
(train_X$Petal.Length - test_example$Petal.Length)^2 +
(train_X$Petal.Width - test_example$Petal.Width)^2
)
# Sort and pick the top k
neighbors_idx <- order(distances)[1:k_value]
neighbors_data <- train_X[neighbors_idx, ]
# Plot
p <- ggplot() +
geom_point(
data = train_X,
aes(x = Petal.Length, y = Petal.Width),
color = "gray50",
alpha = 0.5,
size = 2
) +
geom_point(
data = neighbors_data,
aes(x = Petal.Length, y = Petal.Width),
color = "red",
size = 3
) +
geom_point(
data = as.data.frame(test_example),
aes(x = Petal.Length, y = Petal.Width),
color = "blue",
shape = 8,
size = 4
) +
labs(
title = paste("k=", k_value, "Nearest Neighbors (red) \nTest Point (blue star)")
) +
xlim(0, 1) + ylim(0, 1) +
theme_minimal()
p
In this plot:
This document covers:
k-NN is widely used for image classification, particularly digit recognition.
We will use the MNIST handwritten digits dataset.
The dataset contains grayscale images of digits (0-9) stored as pixel intensity values.
packages <- c("ggplot2", "dplyr", "class", "caret", "gridExtra", "mlbench", "R.utils")
for (p in packages) {
if (!requireNamespace(p, quietly = TRUE)) {
install.packages(p, dependencies = TRUE)
}
library(p, character.only = TRUE)
}
## Warning: package 'R.utils' was built under R version 4.4.1
## Loading required package: R.oo
## Loading required package: R.methodsS3
## R.methodsS3 v1.8.2 (2022-06-13 22:00:14 UTC) successfully loaded. See ?R.methodsS3 for help.
## R.oo v1.26.0 (2024-01-24 05:12:50 UTC) successfully loaded. See ?R.oo for help.
##
## Attaching package: 'R.oo'
## The following object is masked from 'package:R.methodsS3':
##
## throw
## The following objects are masked from 'package:methods':
##
## getClasses, getMethods
## The following objects are masked from 'package:base':
##
## attach, detach, load, save
## R.utils v2.12.3 (2023-11-18 01:00:02 UTC) successfully loaded. See ?R.utils for help.
##
## Attaching package: 'R.utils'
## The following object is masked from 'package:utils':
##
## timestamp
## The following objects are masked from 'package:base':
##
## cat, commandArgs, getOption, isOpen, nullfile, parse, use, warnings
If the ZipDigits
dataset is unavailable, we will use the
Sonar
dataset or the readmnist
package.
# Load dataset (replace ZipDigits with Sonar if missing)
data("Sonar")
# Convert label column to a factor
Sonar$Class <- as.factor(Sonar$Class)
# Shuffle data
set.seed(123)
Sonar <- Sonar[sample(nrow(Sonar)), ]
str(Sonar)
## 'data.frame': 208 obs. of 61 variables:
## $ V1 : num 0.0107 0.0303 0.0197 0.009 0.0392 0.013 0.0119 0.0228 0.0211 0.0323 ...
## $ V2 : num 0.0453 0.0353 0.0394 0.0062 0.0108 0.012 0.0582 0.0106 0.0319 0.0101 ...
## $ V3 : num 0.0289 0.049 0.0384 0.0253 0.0267 0.0436 0.0623 0.013 0.0415 0.0298 ...
## $ V4 : num 0.0713 0.0608 0.0076 0.0489 0.0257 0.0624 0.06 0.0842 0.0286 0.0564 ...
## $ V5 : num 0.1075 0.0167 0.0251 0.1197 0.041 ...
## $ V6 : num 0.1019 0.1354 0.0629 0.1589 0.0491 ...
## $ V7 : num 0.1606 0.1465 0.0747 0.1392 0.1053 ...
## $ V8 : num 0.2119 0.1123 0.0578 0.0987 0.169 ...
## $ V9 : num 0.3061 0.1945 0.1357 0.0955 0.2105 ...
## $ V10 : num 0.294 0.235 0.17 0.19 0.247 ...
## $ V11 : num 0.31 0.29 0.173 0.19 0.268 ...
## $ V12 : num 0.343 0.281 0.247 0.255 0.305 ...
## $ V13 : num 0.246 0.158 0.314 0.407 0.286 ...
## $ V14 : num 0.1887 0.0273 0.3297 0.2988 0.2294 ...
## $ V15 : num 0.1184 0.0673 0.2759 0.2901 0.1165 ...
## $ V16 : num 0.208 0.144 0.206 0.533 0.213 ...
## $ V17 : num 0.274 0.207 0.116 0.402 0.206 ...
## $ V18 : num 0.327 0.265 0.188 0.157 0.222 ...
## $ V19 : num 0.234 0.283 0.339 0.302 0.324 ...
## $ V20 : num 0.126 0.429 0.393 0.391 0.433 ...
## $ V21 : num 0.0576 0.5685 0.4282 0.3542 0.5071 ...
## $ V22 : num 0.124 0.699 0.542 0.444 0.594 ...
## $ V23 : num 0.324 0.725 0.645 0.641 0.708 ...
## $ V24 : num 0.436 0.762 0.722 0.46 0.764 ...
## $ V25 : num 0.573 0.924 0.785 0.601 0.888 ...
## $ V26 : num 0.782 1 0.798 0.869 0.971 ...
## $ V27 : num 0.925 0.998 0.885 0.835 0.988 ...
## $ V28 : num 0.935 0.83 0.958 0.767 0.981 ...
## $ V29 : num 0.935 0.703 0.899 0.508 0.946 ...
## $ V30 : num 1 0.714 0.683 0.462 0.854 ...
## $ V31 : num 0.931 0.689 0.611 0.538 0.646 ...
## $ V32 : num 0.848 0.496 0.548 0.537 0.34 ...
## $ V33 : num 0.76 0.258 0.506 0.384 0.383 ...
## $ V34 : num 0.704 0.0969 0.4476 0.3601 0.3204 ...
## $ V35 : num 0.7539 0.0776 0.2401 0.7402 0.1331 ...
## $ V36 : num 0.799 0.0364 0.1405 0.7761 0.044 ...
## $ V37 : num 0.767 0.157 0.177 0.386 0.123 ...
## $ V38 : num 0.5955 0.1823 0.1742 0.0667 0.203 ...
## $ V39 : num 0.473 0.135 0.333 0.368 0.165 ...
## $ V40 : num 0.484 0.0849 0.4021 0.6114 0.1043 ...
## $ V41 : num 0.434 0.0492 0.3009 0.351 0.1066 ...
## $ V42 : num 0.395 0.137 0.207 0.231 0.211 ...
## $ V43 : num 0.484 0.155 0.121 0.22 0.242 ...
## $ V44 : num 0.5379 0.1548 0.0255 0.3051 0.1631 ...
## $ V45 : num 0.4485 0.1319 0.0298 0.1937 0.0769 ...
## $ V46 : num 0.2674 0.0985 0.0691 0.157 0.0723 ...
## $ V47 : num 0.1541 0.1258 0.0781 0.0479 0.0912 ...
## $ V48 : num 0.1359 0.0954 0.0777 0.0538 0.0812 ...
## $ V49 : num 0.0941 0.0489 0.0369 0.0146 0.0496 0.0292 0.0304 0.0649 0.0171 0.0647 ...
## $ V50 : num 0.0261 0.0241 0.0057 0.0068 0.0101 0.0116 0.0074 0.0313 0.0383 0.0179 ...
## $ V51 : num 0.0079 0.0042 0.0091 0.0187 0.0089 0.0024 0.0069 0.0185 0.0053 0.0051 ...
## $ V52 : num 0.0164 0.0086 0.0134 0.0059 0.0083 0.0084 0.0025 0.0098 0.009 0.0061 ...
## $ V53 : num 0.012 0.0046 0.0097 0.0095 0.008 0.01 0.0103 0.0178 0.0042 0.0093 ...
## $ V54 : num 0.0113 0.0126 0.0042 0.0194 0.0026 0.0018 0.0074 0.0077 0.0153 0.0135 ...
## $ V55 : num 0.0021 0.0036 0.0058 0.008 0.0079 0.0035 0.0123 0.0074 0.0106 0.0063 ...
## $ V56 : num 0.0097 0.0035 0.0072 0.0152 0.0042 0.0058 0.0069 0.0095 0.002 0.0063 ...
## $ V57 : num 0.0072 0.0034 0.0041 0.0158 0.0071 0.0011 0.0076 0.0055 0.0105 0.0034 ...
## $ V58 : num 0.006 0.0079 0.0045 0.0053 0.0044 0.0009 0.0073 0.0045 0.0049 0.0032 ...
## $ V59 : num 0.0017 0.0036 0.0047 0.0189 0.0022 0.0033 0.003 0.0063 0.007 0.0062 ...
## $ V60 : num 0.0036 0.0048 0.0054 0.0102 0.0014 0.0026 0.0138 0.0039 0.008 0.0067 ...
## $ Class: Factor w/ 2 levels "M","R": 1 1 1 2 1 1 2 1 2 1 ...
summary(Sonar)
## V1 V2 V3 V4
## Min. :0.00150 Min. :0.00060 Min. :0.00150 Min. :0.00580
## 1st Qu.:0.01335 1st Qu.:0.01645 1st Qu.:0.01895 1st Qu.:0.02438
## Median :0.02280 Median :0.03080 Median :0.03430 Median :0.04405
## Mean :0.02916 Mean :0.03844 Mean :0.04383 Mean :0.05389
## 3rd Qu.:0.03555 3rd Qu.:0.04795 3rd Qu.:0.05795 3rd Qu.:0.06450
## Max. :0.13710 Max. :0.23390 Max. :0.30590 Max. :0.42640
## V5 V6 V7 V8
## Min. :0.00670 Min. :0.01020 Min. :0.0033 Min. :0.00550
## 1st Qu.:0.03805 1st Qu.:0.06703 1st Qu.:0.0809 1st Qu.:0.08042
## Median :0.06250 Median :0.09215 Median :0.1070 Median :0.11210
## Mean :0.07520 Mean :0.10457 Mean :0.1217 Mean :0.13480
## 3rd Qu.:0.10028 3rd Qu.:0.13412 3rd Qu.:0.1540 3rd Qu.:0.16960
## Max. :0.40100 Max. :0.38230 Max. :0.3729 Max. :0.45900
## V9 V10 V11 V12
## Min. :0.00750 Min. :0.0113 Min. :0.0289 Min. :0.0236
## 1st Qu.:0.09703 1st Qu.:0.1113 1st Qu.:0.1293 1st Qu.:0.1335
## Median :0.15225 Median :0.1824 Median :0.2248 Median :0.2490
## Mean :0.17800 Mean :0.2083 Mean :0.2360 Mean :0.2502
## 3rd Qu.:0.23342 3rd Qu.:0.2687 3rd Qu.:0.3016 3rd Qu.:0.3312
## Max. :0.68280 Max. :0.7106 Max. :0.7342 Max. :0.7060
## V13 V14 V15 V16
## Min. :0.0184 Min. :0.0273 Min. :0.0031 Min. :0.0162
## 1st Qu.:0.1661 1st Qu.:0.1752 1st Qu.:0.1646 1st Qu.:0.1963
## Median :0.2640 Median :0.2811 Median :0.2817 Median :0.3047
## Mean :0.2733 Mean :0.2966 Mean :0.3202 Mean :0.3785
## 3rd Qu.:0.3513 3rd Qu.:0.3862 3rd Qu.:0.4529 3rd Qu.:0.5357
## Max. :0.7131 Max. :0.9970 Max. :1.0000 Max. :0.9988
## V17 V18 V19 V20
## Min. :0.0349 Min. :0.0375 Min. :0.0494 Min. :0.0656
## 1st Qu.:0.2059 1st Qu.:0.2421 1st Qu.:0.2991 1st Qu.:0.3506
## Median :0.3084 Median :0.3683 Median :0.4350 Median :0.5425
## Mean :0.4160 Mean :0.4523 Mean :0.5048 Mean :0.5630
## 3rd Qu.:0.6594 3rd Qu.:0.6791 3rd Qu.:0.7314 3rd Qu.:0.8093
## Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000
## V21 V22 V23 V24
## Min. :0.0512 Min. :0.0219 Min. :0.0563 Min. :0.0239
## 1st Qu.:0.3997 1st Qu.:0.4069 1st Qu.:0.4502 1st Qu.:0.5407
## Median :0.6177 Median :0.6649 Median :0.6997 Median :0.6985
## Mean :0.6091 Mean :0.6243 Mean :0.6470 Mean :0.6727
## 3rd Qu.:0.8170 3rd Qu.:0.8320 3rd Qu.:0.8486 3rd Qu.:0.8722
## Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000
## V25 V26 V27 V28
## Min. :0.0240 Min. :0.0921 Min. :0.0481 Min. :0.0284
## 1st Qu.:0.5258 1st Qu.:0.5442 1st Qu.:0.5319 1st Qu.:0.5348
## Median :0.7211 Median :0.7545 Median :0.7456 Median :0.7319
## Mean :0.6754 Mean :0.6999 Mean :0.7022 Mean :0.6940
## 3rd Qu.:0.8737 3rd Qu.:0.8938 3rd Qu.:0.9171 3rd Qu.:0.9003
## Max. :1.0000 Max. :1.0000 Max. :1.0000 Max. :1.0000
## V29 V30 V31 V32
## Min. :0.0144 Min. :0.0613 Min. :0.0482 Min. :0.0404
## 1st Qu.:0.4637 1st Qu.:0.4114 1st Qu.:0.3456 1st Qu.:0.2814
## Median :0.6808 Median :0.6071 Median :0.4904 Median :0.4296
## Mean :0.6421 Mean :0.5809 Mean :0.5045 Mean :0.4390
## 3rd Qu.:0.8521 3rd Qu.:0.7352 3rd Qu.:0.6420 3rd Qu.:0.5803
## Max. :1.0000 Max. :1.0000 Max. :0.9657 Max. :0.9306
## V33 V34 V35 V36
## Min. :0.0477 Min. :0.0212 Min. :0.0223 Min. :0.0080
## 1st Qu.:0.2579 1st Qu.:0.2176 1st Qu.:0.1794 1st Qu.:0.1543
## Median :0.3912 Median :0.3510 Median :0.3127 Median :0.3211
## Mean :0.4172 Mean :0.4032 Mean :0.3926 Mean :0.3848
## 3rd Qu.:0.5561 3rd Qu.:0.5961 3rd Qu.:0.5934 3rd Qu.:0.5565
## Max. :1.0000 Max. :0.9647 Max. :1.0000 Max. :1.0000
## V37 V38 V39 V40
## Min. :0.0351 Min. :0.0383 Min. :0.0371 Min. :0.0117
## 1st Qu.:0.1601 1st Qu.:0.1743 1st Qu.:0.1740 1st Qu.:0.1865
## Median :0.3063 Median :0.3127 Median :0.2835 Median :0.2781
## Mean :0.3638 Mean :0.3397 Mean :0.3258 Mean :0.3112
## 3rd Qu.:0.5189 3rd Qu.:0.4405 3rd Qu.:0.4349 3rd Qu.:0.4244
## Max. :0.9497 Max. :1.0000 Max. :0.9857 Max. :0.9297
## V41 V42 V43 V44
## Min. :0.0360 Min. :0.0056 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.1631 1st Qu.:0.1589 1st Qu.:0.1552 1st Qu.:0.1269
## Median :0.2595 Median :0.2451 Median :0.2225 Median :0.1777
## Mean :0.2893 Mean :0.2783 Mean :0.2465 Mean :0.2141
## 3rd Qu.:0.3875 3rd Qu.:0.3842 3rd Qu.:0.3245 3rd Qu.:0.2717
## Max. :0.8995 Max. :0.8246 Max. :0.7733 Max. :0.7762
## V45 V46 V47 V48
## Min. :0.00000 Min. :0.00000 Min. :0.00000 Min. :0.00000
## 1st Qu.:0.09448 1st Qu.:0.06855 1st Qu.:0.06425 1st Qu.:0.04512
## Median :0.14800 Median :0.12135 Median :0.10165 Median :0.07810
## Mean :0.19723 Mean :0.16063 Mean :0.12245 Mean :0.09142
## 3rd Qu.:0.23155 3rd Qu.:0.20037 3rd Qu.:0.15443 3rd Qu.:0.12010
## Max. :0.70340 Max. :0.72920 Max. :0.55220 Max. :0.33390
## V49 V50 V51 V52
## Min. :0.00000 Min. :0.00000 Min. :0.000000 Min. :0.000800
## 1st Qu.:0.02635 1st Qu.:0.01155 1st Qu.:0.008425 1st Qu.:0.007275
## Median :0.04470 Median :0.01790 Median :0.013900 Median :0.011400
## Mean :0.05193 Mean :0.02042 Mean :0.016069 Mean :0.013420
## 3rd Qu.:0.06853 3rd Qu.:0.02527 3rd Qu.:0.020825 3rd Qu.:0.016725
## Max. :0.19810 Max. :0.08250 Max. :0.100400 Max. :0.070900
## V53 V54 V55 V56
## Min. :0.000500 Min. :0.001000 Min. :0.00060 Min. :0.000400
## 1st Qu.:0.005075 1st Qu.:0.005375 1st Qu.:0.00415 1st Qu.:0.004400
## Median :0.009550 Median :0.009300 Median :0.00750 Median :0.006850
## Mean :0.010709 Mean :0.010941 Mean :0.00929 Mean :0.008222
## 3rd Qu.:0.014900 3rd Qu.:0.014500 3rd Qu.:0.01210 3rd Qu.:0.010575
## Max. :0.039000 Max. :0.035200 Max. :0.04470 Max. :0.039400
## V57 V58 V59 V60
## Min. :0.00030 Min. :0.000300 Min. :0.000100 Min. :0.000600
## 1st Qu.:0.00370 1st Qu.:0.003600 1st Qu.:0.003675 1st Qu.:0.003100
## Median :0.00595 Median :0.005800 Median :0.006400 Median :0.005300
## Mean :0.00782 Mean :0.007949 Mean :0.007941 Mean :0.006507
## 3rd Qu.:0.01043 3rd Qu.:0.010350 3rd Qu.:0.010325 3rd Qu.:0.008525
## Max. :0.03550 Max. :0.044000 Max. :0.036400 Max. :0.043900
## Class
## M:111
## R: 97
##
##
##
##
ggplot(Sonar, aes(x = Class, fill = Class)) +
geom_bar() +
labs(title = "Class Distribution in the Dataset",
x = "Class",
y = "Count") +
theme_minimal()
## 5. Train-Test Split & Normalization
We split the dataset into 80% training and 20% testing sets.
set.seed(123)
train_index <- sample(seq_len(nrow(Sonar)), size = 0.8 * nrow(Sonar))
train_data <- Sonar[train_index, ]
test_data <- Sonar[-train_index, ]
# Extract features and labels
train_X <- train_data[, -ncol(Sonar)]
test_X <- test_data[, -ncol(Sonar)]
train_Y <- train_data$Class
test_Y <- test_data$Class
# Train k-NN classifier with k=3
knn_pred <- knn(train = train_X, test = test_X, cl = train_Y, k = 3)
# Confusion Matrix
conf_matrix <- confusionMatrix(knn_pred, test_Y)
conf_matrix
## Confusion Matrix and Statistics
##
## Reference
## Prediction M R
## M 21 7
## R 2 12
##
## Accuracy : 0.7857
## 95% CI : (0.6319, 0.897)
## No Information Rate : 0.5476
## P-Value [Acc > NIR] : 0.001189
##
## Kappa : 0.5574
##
## Mcnemar's Test P-Value : 0.182422
##
## Sensitivity : 0.9130
## Specificity : 0.6316
## Pos Pred Value : 0.7500
## Neg Pred Value : 0.8571
## Prevalence : 0.5476
## Detection Rate : 0.5000
## Detection Prevalence : 0.6667
## Balanced Accuracy : 0.7723
##
## 'Positive' Class : M
##
# Pick a test point (ensure it has column names)
test_example <- test_X[1, , drop = FALSE]
colnames(test_example) <- colnames(train_X)
# Compute Euclidean distance to all training points
distances <- sqrt(rowSums((train_X - matrix(as.numeric(test_example),
nrow = nrow(train_X),
ncol = ncol(train_X),
byrow = TRUE))^2))
# Select k nearest neighbors
k <- 3
neighbors_idx <- order(distances)[1:k]
neighbors_data <- train_X[neighbors_idx, ]
# PCA for visualization
pca_train <- prcomp(train_X, center = TRUE, scale. = TRUE)
pca_neighbors <- predict(pca_train, neighbors_data)
# Convert the PCA output for the test example to a data frame
pca_test <- as.data.frame(predict(pca_train, test_example))
colnames(pca_test) <- c("PC1", "PC2") # Explicitly rename columns
# Create data frames for plotting
df_plot <- data.frame(pca_train$x[, 1:2], Label = train_Y)
colnames(df_plot) <- c("PC1", "PC2", "Label")
df_neighbors <- data.frame(pca_neighbors[, 1:2], Label = train_Y[neighbors_idx])
colnames(df_neighbors) <- c("PC1", "PC2", "Label")
df_test <- pca_test # Already formatted properly
df_test$Label <- "Test Point"
# Plot PCA-reduced neighbors
ggplot(df_plot, aes(x = PC1, y = PC2, color = Label)) +
geom_point(alpha = 0.5) +
geom_point(data = df_neighbors, aes(x = PC1, y = PC2), color = "red", size = 3) +
geom_point(data = df_test, aes(x = PC1, y = PC2), color = "blue", shape = 8, size = 4) +
labs(title = "k-NN Neighbors (Red) for Test Point (Blue Star)") +
theme_minimal()
# Pick a test point (ensure it has column names)
test_example <- test_X[1, , drop = FALSE]
colnames(test_example) <- colnames(train_X)
# Compute Euclidean distance to all training points
distances <- sqrt(rowSums((train_X - matrix(as.numeric(test_example),
nrow = nrow(train_X),
ncol = ncol(train_X),
byrow = TRUE))^2))
# Select k nearest neighbors
k <- 5
neighbors_idx <- order(distances)[1:k]
neighbors_data <- train_X[neighbors_idx, ]
# PCA for visualization
pca_train <- prcomp(train_X, center = TRUE, scale. = TRUE)
pca_neighbors <- predict(pca_train, neighbors_data)
# Convert the PCA output for the test example to a data frame
pca_test <- as.data.frame(predict(pca_train, test_example))
colnames(pca_test) <- c("PC1", "PC2") # Explicitly rename columns
# Create data frames for plotting
df_plot <- data.frame(pca_train$x[, 1:2], Label = train_Y)
colnames(df_plot) <- c("PC1", "PC2", "Label")
df_neighbors <- data.frame(pca_neighbors[, 1:2], Label = train_Y[neighbors_idx])
colnames(df_neighbors) <- c("PC1", "PC2", "Label")
df_test <- pca_test # Already formatted properly
df_test$Label <- "Test Point"
# Plot PCA-reduced neighbors
ggplot(df_plot, aes(x = PC1, y = PC2, color = Label)) +
geom_point(alpha = 0.5) +
geom_point(data = df_neighbors, aes(x = PC1, y = PC2), color = "red", size = 3) +
geom_point(data = df_test, aes(x = PC1, y = PC2), color = "blue", shape = 8, size = 4) +
labs(title = "k-NN Neighbors (Red) for Test Point (Blue Star)") +
theme_minimal()
## 8. Hyperparameter Tuning (Finding Best k)
set.seed(123)
train_control <- trainControl(method = "cv", number = 5)
knn_tuned <- train(Class ~ ., data = train_data,
method = "knn",
trControl = train_control,
tuneLength = 10)
knn_tuned
## k-Nearest Neighbors
##
## 166 samples
## 60 predictor
## 2 classes: 'M', 'R'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 132, 134, 133, 133, 132
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 5 0.7462121 0.4833202
## 7 0.6853721 0.3559735
## 9 0.6434826 0.2705788
## 11 0.6497215 0.2845346
## 13 0.6679144 0.3202643
## 15 0.6620321 0.3066610
## 17 0.6375891 0.2580823
## 19 0.6194073 0.2199849
## 21 0.6249220 0.2300818
## 23 0.6370432 0.2555753
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 5.