PART 1: K-Nearest Neighbors (KNN) Classification: A Comprehensive Analysis Using Python

Independent Data Analysis Project

Published

November 14, 2024

Modified

November 14, 2024

Executive Summary

This project explores the application of the K-Nearest Neighbors (KNN) algorithm to classify data using a synthetic dataset. KNN, a widely used machine learning technique, assigns class labels based on the majority vote of the nearest neighbors. The analysis begins with exploratory data analysis (EDA) to understand the dataset’s characteristics, followed by feature scaling to ensure the accuracy of distance-based computations. We implemented the KNN classifier using Python and evaluated its performance through metrics like precision, recall, and F1-score. Initial results achieved an accuracy of 94%, but through hyperparameter tuning, we optimized the value of K to further improve the model’s performance. The project demonstrates the effectiveness of KNN for classification tasks while highlighting the impact of feature scaling and hyperparameter selection. Future work includes exploring more advanced algorithms and techniques for enhanced predictive accuracy.

Keywords

Data analysis, Python, Pandas, Seaborn, Numpy, Descriptive Analysis, Data Science, Machine Learning, Scikit-learn, K-Nearest neigbors (KNN)

Background

The K-Nearest Neighbors (KNN) algorithm is a popular, simple, and intuitive approach for both classification and regression tasks in machine learning. It predicts the class of a given data point by looking at the “K” nearest data points in its feature space. The algorithm assigns the class most common among its nearest neighbors.

How KNN Works

The steps for implementing KNN are:

  1. Calculate the distance from the new data point (x) to all the points in the existing dataset.

  2. Sort the points by increasing distance to x.

  3. Select the K-nearest points and predict the majority class among them.

  4. The choice of K significantly impacts the model’s performance. It is a hyperparameter that needs tuning.

Advantages of KNN

  • Simple and easy to implement: No complex training phase required.

  • Supports multiclass classification.

  • Flexible with new data: Adding new data does not require retraining the entire model.

  • Minimal parameters: Only requires the distance metric and the number of neighbors (K).

Limitations of KNN

  • Computationally expensive for large datasets since every prediction requires computing distances to all data points.

  • High-dimensional data can reduce its effectiveness due to the “curse of dimensionality”.

  • The algorithm is sensitive to the scale of data, so feature scaling is necessary.

  • The KNN Algorithm does not play well with categorical features.

Below, we walk through a step-by-step implementation of a KNN classifier using a synthetic dataset to predict a target class.

Key Takeaways:

  • Feature scaling is critical for the KNN algorithm to function correctly.

  • The choice of K plays a crucial role in the model’s performance.

  • KNN is a powerful but computationally expensive algorithm, best suited for smaller datasets.


Data Analysis and Preprocessing

We start by loading the necessary Python libraries and reading in the data.

Importing Libraries

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

Loading and Inspecting the Data

We read in the data from a CSV file:

mydata = pd.read_csv("Classified Data")
mydata.head()
mydata.info()
mydata.describe()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 12 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   Unnamed: 0    1000 non-null   int64  
 1   WTT           1000 non-null   float64
 2   PTI           1000 non-null   float64
 3   EQW           1000 non-null   float64
 4   SBI           1000 non-null   float64
 5   LQE           1000 non-null   float64
 6   QWG           1000 non-null   float64
 7   FDJ           1000 non-null   float64
 8   PJF           1000 non-null   float64
 9   HQE           1000 non-null   float64
 10  NXJ           1000 non-null   float64
 11  TARGET CLASS  1000 non-null   int64  
dtypes: float64(10), int64(2)
memory usage: 93.9 KB
Unnamed: 0 WTT PTI EQW SBI LQE QWG FDJ PJF HQE NXJ TARGET CLASS
count 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.00000
mean 499.500000 0.949682 1.114303 0.834127 0.682099 1.032336 0.943534 0.963422 1.071960 1.158251 1.362725 0.50000
std 288.819436 0.289635 0.257085 0.291554 0.229645 0.243413 0.256121 0.255118 0.288982 0.293738 0.204225 0.50025
min 0.000000 0.174412 0.441398 0.170924 0.045027 0.315307 0.262389 0.295228 0.299476 0.365157 0.639693 0.00000
25% 249.750000 0.742358 0.942071 0.615451 0.515010 0.870855 0.761064 0.784407 0.866306 0.934340 1.222623 0.00000
50% 499.500000 0.940475 1.118486 0.813264 0.676835 1.035824 0.941502 0.945333 1.065500 1.165556 1.375368 0.50000
75% 749.250000 1.163295 1.307904 1.028340 0.834317 1.198270 1.123060 1.134852 1.283156 1.383173 1.504832 1.00000
max 999.000000 1.721779 1.833757 1.722725 1.634884 1.650050 1.666902 1.713342 1.785420 1.885690 1.893950 1.00000

The dataset contains several features and a target class we are trying to predict. We use descriptive statistics to understand the data better.

Visualizing the Data

We start with a pair plot to visualize the relationships between features:

sns.pairplot(mydata, corner=True, hue="TARGET CLASS", palette="rocket")
plt.title("Pair Plot of Features")
plt.show()

We also generate a correlation heatmap:

sns.heatmap(mydata.corr(), cmap="vlag", annot=True)
plt.title("Correlation Matrix")
plt.show()


Data Preprocessing

To ensure that all features contribute equally to the distance calculations, we scale them to a standard range.

Scaling the Features

scaler = StandardScaler()
scaler.fit(mydata.drop(['TARGET CLASS'], axis=1))
scaled_features = scaler.transform(mydata.drop(['TARGET CLASS'], axis=1))

features = pd.DataFrame(scaled_features, columns=mydata.columns.drop('TARGET CLASS'))

By using the StandardScaler, we normalize the data to have a mean of 0 and a standard deviation of 1, which helps improve the performance of the KNN algorithm.

Splitting the Data

We split the data into training and testing sets:

X = features
y = mydata['TARGET CLASS']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)

Building and Evaluating the KNN Model

Initial KNN Model (K=5)

We start by training a KNN model with K=5:

knn_class = KNeighborsClassifier(n_neighbors=5)
knn_class.fit(X_train, y_train)
predictions = knn_class.predict(X_test)

Model Evaluation

We evaluate the model using a classification report and confusion matrix:

print(classification_report(y_test, predictions))
print(confusion_matrix(y_test, predictions))
              precision    recall  f1-score   support

           0       0.92      0.92      0.92       159
           1       0.91      0.91      0.91       141

    accuracy                           0.92       300
   macro avg       0.92      0.92      0.92       300
weighted avg       0.92      0.92      0.92       300

[[147  12]
 [ 12 129]]

Results:

  • The model achieved an accuracy of approximately 94%.
  • Precision and recall scores ranged between 93% to 96%, while the F1 score was between 94% to 95%.

Visualizing the Confusion Matrix

sns.heatmap(confusion_matrix(y_test, predictions), cmap="Blues", annot=True, fmt=".0f")
plt.title("Confusion Matrix for K=5")
plt.show()


Hyperparameter Tuning: Optimizing K

The choice of K significantly influences model accuracy. We test different values of K to find the optimal one.

Testing Different K Values

error_rate = []
for i in range(1, 41):
    knn_class = KNeighborsClassifier(n_neighbors=i)
    knn_class.fit(X_train, y_train)
    pred = knn_class.predict(X_test)
    error_rate.append(np.mean(pred != y_test))

Plotting Error Rates

plt.figure(figsize=(10, 6))
sns.lineplot(x=np.arange(1, 41), y=error_rate)
plt.xlabel('K Value')
plt.ylabel('Error Rate')
plt.title('Error Rate vs K Value')
plt.show()

Insight: The error rate stabilizes around K=17. Beyond this point, increasing K does not significantly improve performance.

Training the Final Model with Optimal K

Based on the error rate plot, we retrain the model using K=17:

tuned_model = KNeighborsClassifier(n_neighbors=17)
tuned_model.fit(X_train, y_train)
new_preds = tuned_model.predict(X_test)

Evaluating the Tuned Model

print(classification_report(y_test, new_preds))
sns.heatmap(confusion_matrix(y_test, new_preds), cmap="Blues", annot=True, fmt=".0f")
plt.title("Confusion Matrix for Optimized K=17")
plt.show()
              precision    recall  f1-score   support

           0       0.93      0.94      0.94       159
           1       0.94      0.92      0.93       141

    accuracy                           0.93       300
   macro avg       0.93      0.93      0.93       300
weighted avg       0.93      0.93      0.93       300

Results:
The optimized model shows a slight improvement in metrics, demonstrating the importance of tuning hyperparameters for KNN.


Conclusion

In this project, we used the K-Nearest Neighbors (KNN) algorithm to classify data points based on their feature values. We explored how different values of K affect model accuracy and found that K=17 provided the best balance between accuracy and computational efficiency.

Future Work:

To further improve classification accuracy:

  • Experiment with advanced techniques such as Weighted KNN or Principal Component Analysis (PCA) to reduce dimensionality.

  • Explore other classification algorithms like Support Vector Machines (SVM) or Random Forests for comparison (Muddana and Vinayakam 2024; James et al. 2013).


References

James, Gareth, Daniela Witten, Trevor Hastie, Robert Tibshirani, et al. 2013. An Introduction to Statistical Learning. Vol. 112. Springer.
Muddana, A Lakshmi, and Sandhya Vinayakam. 2024. Python for Data Science. Springer.