You want to understand how the k-nearest neighbors (kNN) algorithm works & build your own kNN model.
I will be using the the R programming language to build the kNN model. The packages I will be using are tidyverse, mlr, & mclust. tidyverse is a collection of packages designed for data science, the mlr package to build the kNN model & the mclust package for our data.
The kNN algorithm is arguably the simplest classification algorithm. However, don’t look down on it because it is simplistic. It can provide excellent classification performance & is simple to interpret.
Suppose a reptile conservation project hires you to build a kNN classifier to quickly classify three different species of reptiles: the grass snake, adder, & the slow worm. To build the kNN classifier, you need data. A biologist takes you into the woodlands to find these three reptiles. Upon finding one of them, you record its length & aggression. The biologist manually classifies each of the reptiles so far, but our kNN classifier will classify these reptiles for us in the future.
We plot each of our observations by its body length & aggression, grouped by their species that the biologist helped us classify. We then go back into the woodland & collect three more samples. The below figure demonstrates how the plot might look.
The kNN algorithm calculates the distance between each new, unclassified case & all of the classified cases. Then for each unclassified case, the algorithm ranks its ‘neighbors’ from nearest (most similar in terms of body length & aggression) to furthest (least similar in terms of body length & aggression).
The kNN algorithm finds the k classified cases most similar to each new case. k is a value that will need to be specified. Each of the k classified cases then votes on the class of the new case, based on the classified cases’ own class. This means that the the class of the new case will the classification of most of its neighbors. Here’s a visualization to aid your understanding.
If k is 1, the algorithm finds the classified case nearest to the new case. Each of the new cases are closest to reptiles classified as a grass snakes, so they would be classified as grass snakes as well. If k is 3, the algorithm finds the 3 classified cases nearest to the new case. When a new case has neighbors belonging to more than one class, each neighbor votes for itself & the majority vote wins. If k is 5, the algorithm finds the 5 classified cases nearest to the new case; & the new case takes the majority vote of the 5 neighbors. Notice in all of theses scenarios, the value of k impacts how the new cases are classified.
Now you may be wondering, what happens when the vote is tied? In our previous example, our k are odd values, but happens in a situation where our k is an even value & there is a tied vote? A common approach is to randomly assign cases with tied votes to one of the classes. In theory, the proportion of cases that have ties will be minuscule & have limited impact on the classification accuracy of our model. However, there are ways to avoid ties altogether such as:
Now that we have an understanding of the kNN algorithm, let’s put it to practice. Say you are a biostatistician tasked with a project to improve diagnoses of patients with diabetes. You collect data from suspected diabetes patients & record their diagnoses. This is the dataset you get as a result.
dbTib <- as_tibble(diabetes)
head(dbTib)
## # A tibble: 6 × 4
## class glucose insulin sspg
## <fct> <dbl> <dbl> <dbl>
## 1 Normal 80 356 124
## 2 Normal 97 289 117
## 3 Normal 105 319 143
## 4 Normal 90 356 199
## 5 Normal 90 323 240
## 6 Normal 86 381 157
summary(dbTib)
## class glucose insulin sspg
## Chemical:36 Min. : 70 Min. : 45.0 Min. : 10.0
## Normal :76 1st Qu.: 90 1st Qu.: 352.0 1st Qu.:118.0
## Overt :33 Median : 97 Median : 403.0 Median :156.0
## Mean :122 Mean : 540.8 Mean :186.1
## 3rd Qu.:112 3rd Qu.: 558.0 3rd Qu.:221.0
## Max. :353 Max. :1568.0 Max. :748.0
Since we want to build a classification model, we use makeClassifTask() to define the classification task. We supply the function with our data diabetesTib & the name of the variable that we want to build our model to predict, class.
dbTask <- makeClassifTask(data = dbTib, target = 'class')
dbTask
## Supervised task: dbTib
## Type: classif
## Target: class
## Observations: 145
## Features:
## numerics factors ordered functionals
## 3 0 0 0
## Missings: FALSE
## Has weights: FALSE
## Has blocking: FALSE
## Has coordinates: FALSE
## Classes: 3
## Chemical Normal Overt
## 36 76 33
## Positive class: NA
Calling the task, we can get some information about the number of observations & the number of different variable types. It also indicates whether we have missing data, the number of observations in each class, etc.
Before we continue, allow me to explain what a type of cross-validation that is imperative for ML model building. First off, what is cross validation? The whole of a ML model is for its ability to predict on new data. As such, it is horribly inefficient to gather data, build a model from it, & gather data again for our model to predict on. Instead, we split our data in two. We use one portion to train the model & the remaining portion to test the model. This process is called cross-validation (CV).
In k-fold CV, we split the data into equal-sized chunks called folds, reserving one of the folds for testing & the remaining folds for training. We pass the training set through the model, making records of its performance metrics, repeating with a different fold of the data. continuing until all folds have been used for testing, & getting an average of the performance metric as the estimate of model performance.
Say our chosen value of k for k-fold is 10. This means we split the data into 10 equal-sized chunks to perform CV. However, we can repeat this procedure multiple times to get repeated k-fold CSV. To perform repeated k-fold CV in mlr, we make a resampling description, or a set of instructions for how the data will be split into training & testing sets. We supply four arguments to makeResamplingDesc(): method is the type of cross-validation we want to perform, folds, reps is the number of repetitions, & stratify is whether or not to maintain the proportion of each class of patient in each set.
repKFold <- makeResampleDesc(method = 'RepCV', folds = 10, reps = 5,
stratify = TRUE)
repKFold
## Resample description: repeated cross-validation with 50 iterations: 10 folds and 5 reps.
## Predict: test
## Stratification: TRUE
A learner tells mlr which algorithm you want to use. Since we want to us the kNN algorithm, we supply 'classif.knn' as our learner. We also supply par.vals, which allows us to specify the value of k we want our kNN algorithm to use. We can set k to 5 for now.
knnLearner <- makeLearner('classif.knn', par.vals = list('k' = 5))
knnLearner
## Learner classif.knn from package class
## Type: classif
## Name: k-Nearest Neighbor; Short name: knn
## Class: classif.knn
## Properties: twoclass,multiclass,numerics
## Predict-Type: response
## Hyperparameters: k=5
What if instead of specifying our k value, we tell the computer to find the value of k that will give us the best model prediction accuracy? This process is called hyperparameter tuning & we can perform iwth with the makeDiscreteParam() function in mlr. The makeDiscreteParam() function allows us to specify the hyperparameter k we will tune & the range of values we want to search for the best value of k. We nestle this function inside the makeParamSet() function, which defines the hyperparameter space we defined from the parameter set, & command mlr to search that parameter space for the best value of k. We also cross-validate the tuning process to mitigate against an overfit model.
knnParamSpace <- makeParamSet(makeDiscreteParam('k', values = 1:10))
knnParamSpace
## Type len Def Constr Req Tunable Trafo
## k discrete - - 1,2,3,4,5,6,7,8,9,10 - TRUE -
gridSearch <- makeTuneControlGrid()
gridSearch
## Tune control: TuneControlGrid
## Same resampling instance: TRUE
## Imputation value: <worst>
## Start: <NULL>
##
## Tune threshold: FALSE
## Further arguments: resolution=10
bestK <- tuneParams('classif.knn', task = dbTask,
resampling = repKFold, par.set = knnParamSpace,
control = gridSearch)
## [Tune] Started tuning learner classif.knn for parameter set:
## Type len Def Constr Req Tunable Trafo
## k discrete - - 1,2,3,4,5,6,7,8,9,10 - TRUE -
## With control class: TuneControlGrid
## Imputation value: 1
## [Tune-x] 1: k=1
## [Tune-y] 1: mmce.test.mean=0.1043654; time: 0.0 min
## [Tune-x] 2: k=2
## [Tune-y] 2: mmce.test.mean=0.0944194; time: 0.0 min
## [Tune-x] 3: k=3
## [Tune-y] 3: mmce.test.mean=0.0908773; time: 0.0 min
## [Tune-x] 4: k=4
## [Tune-y] 4: mmce.test.mean=0.0936392; time: 0.0 min
## [Tune-x] 5: k=5
## [Tune-y] 5: mmce.test.mean=0.0823297; time: 0.0 min
## [Tune-x] 6: k=6
## [Tune-y] 6: mmce.test.mean=0.0823984; time: 0.0 min
## [Tune-x] 7: k=7
## [Tune-y] 7: mmce.test.mean=0.0741841; time: 0.0 min
## [Tune-x] 8: k=8
## [Tune-y] 8: mmce.test.mean=0.0812582; time: 0.0 min
## [Tune-x] 9: k=9
## [Tune-y] 9: mmce.test.mean=0.0823864; time: 0.0 min
## [Tune-x] 10: k=10
## [Tune-y] 10: mmce.test.mean=0.0838150; time: 0.0 min
## [Tune] Result: k=7 : mmce.test.mean=0.0741841
bestK
## Tune result:
## Op. pars: k=7
## mmce.test.mean=0.0741841
Our best-performing value for k is 7, because it has the lowest mean misclassification error (mmce), but we can visualize the tuning process, define a new learner with our best k value, & train a model with that new learner.
plotHyperParsEffect(generateHyperParsEffectData(bestK),
x = 'k', y = 'mmce.test.mean', plot.type = 'line') +
theme_bw() +
labs(title = 'Mean Misclassification Error vs. k')
parallelStartSocket(cpus = detectCores() - 1)
## Starting parallelization in mode=socket with cpus=7.
newLearner <- setHyperPars(makeLearner('classif.knn'), par.vals = bestK$x)
kNNModel <- train(newLearner, dbTask)
# Estimate of model predictive accuracy
wrapper <- makeTuneWrapper('classif.knn', resampling = repKFold,
par.set = knnParamSpace, control = gridSearch)
wrapper
## Learner classif.knn.tuned from package class
## Type: classif
## Name: ; Short name:
## Class: TuneWrapper
## Properties: numerics,twoclass,multiclass
## Predict-Type: response
## Hyperparameters:
modelEstimates <- resample(wrapper, dbTask, resampling = repKFold)
## Exporting objects to slaves for mode socket: .mlr.slave.options
## Resampling: repeated cross-validation
## Measures: mmce
## Mapping in parallel: mode = socket; level = mlr.resample; cpus = 7; elements = 50.
##
## Aggregated Result: mmce.test.mean=0.0842280
##
parallelStop()
## Stopped parallelization. All cleaned up.
modelEstimates
## Resample Result
## Task: dbTib
## Learner: classif.knn.tuned
## Aggr perf: mmce.test.mean=0.0842280
## Runtime: 51.775
Our knn model is estimated to correctly classify over 90% of new cases.
Now that we have our model, we collect diagnostics data on 3 new suspected diabetic patients. We can pass the data from these 3 patients & get their predicted diabetic status.
newPatients <- tibble(glucose = c(83, 107, 299),
insulin = c(360, 287, 1051),
sspg = c(199, 185, 134))
newPatients
## # A tibble: 3 × 3
## glucose insulin sspg
## <dbl> <dbl> <dbl>
## 1 83 360 199
## 2 107 287 185
## 3 299 1051 134
getPredictionResponse(predict(kNNModel, newdata = newPatients))
## [1] Normal Normal Overt
## Levels: Chemical Normal Overt
Rhys, Hefin I. Machine Learning with R, the Tidyverse, and MLR. Manning Publications, 2020.