Problem

You want to understand how the k-nearest neighbors (kNN) algorithm works & build your own kNN model.


Stack

I will be using the the R programming language to build the kNN model. The packages I will be using are tidyverse, mlr, & mclust. tidyverse is a collection of packages designed for data science, the mlr package to build the kNN model & the mclust package for our data.


What is the kNN Algorithm?

The kNN algorithm is arguably the simplest classification algorithm. However, don’t look down on it because it is simplistic. It can provide excellent classification performance & is simple to interpret.

How does the kNN Algorithm Learn?

Suppose a reptile conservation project hires you to build a kNN classifier to quickly classify three different species of reptiles: the grass snake, adder, & the slow worm. To build the kNN classifier, you need data. A biologist takes you into the woodlands to find these three reptiles. Upon finding one of them, you record its length & aggression. The biologist manually classifies each of the reptiles so far, but our kNN classifier will classify these reptiles for us in the future.

We plot each of our observations by its body length & aggression, grouped by their species that the biologist helped us classify. We then go back into the woodland & collect three more samples. The below figure demonstrates how the plot might look.

The kNN algorithm calculates the distance between each new, unclassified case & all of the classified cases. Then for each unclassified case, the algorithm ranks its ‘neighbors’ from nearest (most similar in terms of body length & aggression) to furthest (least similar in terms of body length & aggression).

The kNN algorithm finds the k classified cases most similar to each new case. k is a value that will need to be specified. Each of the k classified cases then votes on the class of the new case, based on the classified cases’ own class. This means that the the class of the new case will the classification of most of its neighbors. Here’s a visualization to aid your understanding.

If k is 1, the algorithm finds the classified case nearest to the new case. Each of the new cases are closest to reptiles classified as a grass snakes, so they would be classified as grass snakes as well. If k is 3, the algorithm finds the 3 classified cases nearest to the new case. When a new case has neighbors belonging to more than one class, each neighbor votes for itself & the majority vote wins. If k is 5, the algorithm finds the 5 classified cases nearest to the new case; & the new case takes the majority vote of the 5 neighbors. Notice in all of theses scenarios, the value of k impacts how the new cases are classified.

Tied Votes or Equidistant Cases

Now you may be wondering, what happens when the vote is tied? In our previous example, our k are odd values, but happens in a situation where our k is an even value & there is a tied vote? A common approach is to randomly assign cases with tied votes to one of the classes. In theory, the proportion of cases that have ties will be minuscule & have limited impact on the classification accuracy of our model. However, there are ways to avoid ties altogether such as:

  1. Exclusively selecting odd values of k so there cannot be ties
  2. Decreasing the value of k until a majority vote is won (This will not help if a new case is equidistant to its two nearest neighbors. In such a case, we would randomly assign the new case to a class.)
  3. Using a different classification algorithm

Building a kNN Model

Now that we have an understanding of the kNN algorithm, let’s put it to practice. Say you are a biostatistician tasked with a project to improve diagnoses of patients with diabetes. You collect data from suspected diabetes patients & record their diagnoses. This is the dataset you get as a result.

dbTib <- as_tibble(diabetes)
head(dbTib)
## # A tibble: 6 × 4
##   class  glucose insulin  sspg
##   <fct>    <dbl>   <dbl> <dbl>
## 1 Normal      80     356   124
## 2 Normal      97     289   117
## 3 Normal     105     319   143
## 4 Normal      90     356   199
## 5 Normal      90     323   240
## 6 Normal      86     381   157
summary(dbTib)
##       class       glucose       insulin            sspg      
##  Chemical:36   Min.   : 70   Min.   :  45.0   Min.   : 10.0  
##  Normal  :76   1st Qu.: 90   1st Qu.: 352.0   1st Qu.:118.0  
##  Overt   :33   Median : 97   Median : 403.0   Median :156.0  
##                Mean   :122   Mean   : 540.8   Mean   :186.1  
##                3rd Qu.:112   3rd Qu.: 558.0   3rd Qu.:221.0  
##                Max.   :353   Max.   :1568.0   Max.   :748.0

Defining a Task

Since we want to build a classification model, we use makeClassifTask() to define the classification task. We supply the function with our data diabetesTib & the name of the variable that we want to build our model to predict, class.

dbTask <- makeClassifTask(data = dbTib, target = 'class')
dbTask
## Supervised task: dbTib
## Type: classif
## Target: class
## Observations: 145
## Features:
##    numerics     factors     ordered functionals 
##           3           0           0           0 
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE
## Has coordinates: FALSE
## Classes: 3
## Chemical   Normal    Overt 
##       36       76       33 
## Positive class: NA

Calling the task, we can get some information about the number of observations & the number of different variable types. It also indicates whether we have missing data, the number of observations in each class, etc.

K-Fold Cross-Validation

Before we continue, allow me to explain what a type of cross-validation that is imperative for ML model building. First off, what is cross validation? The whole of a ML model is for its ability to predict on new data. As such, it is horribly inefficient to gather data, build a model from it, & gather data again for our model to predict on. Instead, we split our data in two. We use one portion to train the model & the remaining portion to test the model. This process is called cross-validation (CV).

In k-fold CV, we split the data into equal-sized chunks called folds, reserving one of the folds for testing & the remaining folds for training. We pass the training set through the model, making records of its performance metrics, repeating with a different fold of the data. continuing until all folds have been used for testing, & getting an average of the performance metric as the estimate of model performance.

Say our chosen value of k for k-fold is 10. This means we split the data into 10 equal-sized chunks to perform CV. However, we can repeat this procedure multiple times to get repeated k-fold CSV. To perform repeated k-fold CV in mlr, we make a resampling description, or a set of instructions for how the data will be split into training & testing sets. We supply four arguments to makeResamplingDesc(): method is the type of cross-validation we want to perform, folds, reps is the number of repetitions, & stratify is whether or not to maintain the proportion of each class of patient in each set.

repKFold <- makeResampleDesc(method = 'RepCV', folds = 10, reps = 5,
                          stratify = TRUE)
repKFold
## Resample description: repeated cross-validation with 50 iterations: 10 folds and 5 reps.
## Predict: test
## Stratification: TRUE

Defining a Learner

A learner tells mlr which algorithm you want to use. Since we want to us the kNN algorithm, we supply 'classif.knn' as our learner. We also supply par.vals, which allows us to specify the value of k we want our kNN algorithm to use. We can set k to 5 for now.

knnLearner <- makeLearner('classif.knn', par.vals = list('k' = 5))
knnLearner
## Learner classif.knn from package class
## Type: classif
## Name: k-Nearest Neighbor; Short name: knn
## Class: classif.knn
## Properties: twoclass,multiclass,numerics
## Predict-Type: response
## Hyperparameters: k=5

Hyperparameter Tuning

What if instead of specifying our k value, we tell the computer to find the value of k that will give us the best model prediction accuracy? This process is called hyperparameter tuning & we can perform iwth with the makeDiscreteParam() function in mlr. The makeDiscreteParam() function allows us to specify the hyperparameter k we will tune & the range of values we want to search for the best value of k. We nestle this function inside the makeParamSet() function, which defines the hyperparameter space we defined from the parameter set, & command mlr to search that parameter space for the best value of k. We also cross-validate the tuning process to mitigate against an overfit model.

knnParamSpace <- makeParamSet(makeDiscreteParam('k', values = 1:10))
knnParamSpace
##       Type len Def               Constr Req Tunable Trafo
## k discrete   -   - 1,2,3,4,5,6,7,8,9,10   -    TRUE     -
gridSearch <- makeTuneControlGrid()
gridSearch
## Tune control: TuneControlGrid
## Same resampling instance: TRUE
## Imputation value: <worst>
## Start: <NULL>
## 
## Tune threshold: FALSE
## Further arguments: resolution=10
bestK <- tuneParams('classif.knn', task = dbTask, 
                    resampling = repKFold, par.set = knnParamSpace,
                    control = gridSearch)
## [Tune] Started tuning learner classif.knn for parameter set:
##       Type len Def               Constr Req Tunable Trafo
## k discrete   -   - 1,2,3,4,5,6,7,8,9,10   -    TRUE     -
## With control class: TuneControlGrid
## Imputation value: 1
## [Tune-x] 1: k=1
## [Tune-y] 1: mmce.test.mean=0.1043654; time: 0.0 min
## [Tune-x] 2: k=2
## [Tune-y] 2: mmce.test.mean=0.0944194; time: 0.0 min
## [Tune-x] 3: k=3
## [Tune-y] 3: mmce.test.mean=0.0908773; time: 0.0 min
## [Tune-x] 4: k=4
## [Tune-y] 4: mmce.test.mean=0.0936392; time: 0.0 min
## [Tune-x] 5: k=5
## [Tune-y] 5: mmce.test.mean=0.0823297; time: 0.0 min
## [Tune-x] 6: k=6
## [Tune-y] 6: mmce.test.mean=0.0823984; time: 0.0 min
## [Tune-x] 7: k=7
## [Tune-y] 7: mmce.test.mean=0.0741841; time: 0.0 min
## [Tune-x] 8: k=8
## [Tune-y] 8: mmce.test.mean=0.0812582; time: 0.0 min
## [Tune-x] 9: k=9
## [Tune-y] 9: mmce.test.mean=0.0823864; time: 0.0 min
## [Tune-x] 10: k=10
## [Tune-y] 10: mmce.test.mean=0.0838150; time: 0.0 min
## [Tune] Result: k=7 : mmce.test.mean=0.0741841
bestK
## Tune result:
## Op. pars: k=7
## mmce.test.mean=0.0741841

Our best-performing value for k is 7, because it has the lowest mean misclassification error (mmce), but we can visualize the tuning process, define a new learner with our best k value, & train a model with that new learner.

plotHyperParsEffect(generateHyperParsEffectData(bestK),
                    x = 'k', y = 'mmce.test.mean', plot.type = 'line') +
  theme_bw() +
  labs(title = 'Mean Misclassification Error vs. k')

Building the Model & Estimating Model Predictive Accuracy

parallelStartSocket(cpus = detectCores() - 1)
## Starting parallelization in mode=socket with cpus=7.
newLearner <- setHyperPars(makeLearner('classif.knn'), par.vals = bestK$x)
kNNModel <- train(newLearner, dbTask)

# Estimate of model predictive accuracy
wrapper <- makeTuneWrapper('classif.knn', resampling = repKFold,
                           par.set = knnParamSpace, control = gridSearch)
wrapper
## Learner classif.knn.tuned from package class
## Type: classif
## Name: ; Short name: 
## Class: TuneWrapper
## Properties: numerics,twoclass,multiclass
## Predict-Type: response
## Hyperparameters:
modelEstimates <- resample(wrapper, dbTask, resampling = repKFold)
## Exporting objects to slaves for mode socket: .mlr.slave.options
## Resampling: repeated cross-validation
## Measures:             mmce
## Mapping in parallel: mode = socket; level = mlr.resample; cpus = 7; elements = 50.
## 
## Aggregated Result: mmce.test.mean=0.0842280
## 
parallelStop()
## Stopped parallelization. All cleaned up.
modelEstimates
## Resample Result
## Task: dbTib
## Learner: classif.knn.tuned
## Aggr perf: mmce.test.mean=0.0842280
## Runtime: 51.775

Our knn model is estimated to correctly classify over 90% of new cases.


Using our Model to Make Predictions

Now that we have our model, we collect diagnostics data on 3 new suspected diabetic patients. We can pass the data from these 3 patients & get their predicted diabetic status.

newPatients <- tibble(glucose = c(83, 107, 299),
                      insulin = c(360, 287, 1051),
                      sspg = c(199, 185, 134))
newPatients
## # A tibble: 3 × 3
##   glucose insulin  sspg
##     <dbl>   <dbl> <dbl>
## 1      83     360   199
## 2     107     287   185
## 3     299    1051   134
getPredictionResponse(predict(kNNModel, newdata = newPatients))
## [1] Normal Normal Overt 
## Levels: Chemical Normal Overt

Citations/References

Rhys, Hefin I. Machine Learning with R, the Tidyverse, and MLR. Manning Publications, 2020.