knitr::opts_chunk$set(echo = F,
#warning = F,
#message = F,
fig.align = 'center')
pacman::p_load(tidyverse, class)
This assignment only has one objective: Write a function that will tune the value of \(k\) for kNN classification for any data set given to it.
You’ll need to do the following:
Name the function knn_search() with the following
arguments:
X = the data set of predictorsclass = the class/category for rows of
Xk_grid = the different values of \(k\) to search overPerform a grid search across k for k-nearest neighbors classification properly using methods seen in class.
Returns two objects:
k_error: A data frame with the following columns -
k: the values of \(k\)
supplied by k_grid - rescale: the rescale
method (‘norm’ and ‘stan’). - error: The misclassification
rate for the choice of \(k\) and
rescale methodgg_serch: The resulting graph of the grid search with:
- x = k - y = error -
color = rescale - a dashed line at the value of
k that minimizes the error rate with color that matches the
rescale methodThere are two subsequent code chunks you can use check your function. The two chunks below that have their output hidden in Brightspace and will be used to grade your function.
RNGversion('4.1.0');set.seed(2870)
check1_knn <-
knn_search(
X = iris[, -5],
class = iris$Species,
k_grid = 5:50
)
# Returned data frame
check1_knn$k_error |>
slice_sample(n = nrow(check1_knn$k_error)) |>
tibble()
## # A tibble: 92 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 stan 40 0.107
## 2 stan 6 0.04
## 3 norm 20 0.04
## 4 norm 19 0.04
## 5 stan 32 0.0667
## 6 stan 13 0.0333
## 7 norm 16 0.0333
## 8 stan 41 0.113
## 9 stan 10 0.0467
## 10 stan 39 0.1
## # ℹ 82 more rows
# Graph
check1_knn$gg_search + ggtitle('Iris data')
RNGversion('4.1.0');set.seed(2870)
check2_knn <-
knn_search(
X = drop_na(palmerpenguins::penguins[, 5:6]),
class = palmerpenguins::penguins |> dplyr::select(1, 5:6)|> drop_na() |> pull(species),
k_grid = 1:100
)
# Returned data frame
check2_knn$k_error |>
slice_sample(n = nrow(check2_knn$k_error)) |>
tibble()
## # A tibble: 200 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 stan 32 0.178
## 2 stan 57 0.202
## 3 norm 87 0.205
## 4 norm 43 0.184
## 5 stan 20 0.184
## 6 stan 86 0.211
## 7 norm 31 0.181
## 8 norm 3 0.234
## 9 stan 91 0.211
## 10 norm 57 0.202
## # ℹ 190 more rows
# Graph
check2_knn$gg_search + ggtitle('Penguin data')
## # A tibble: 200 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 norm 23 0.421
## 2 norm 46 0.426
## 3 norm 4 0.410
## 4 stan 1 0.399
## 5 stan 35 0.410
## 6 stan 98 0.459
## 7 norm 58 0.421
## 8 stan 85 0.454
## 9 norm 78 0.404
## 10 stan 38 0.399
## # ℹ 190 more rows
## # A tibble: 600 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 norm 144 0.0633
## 2 stan 119 0.0685
## 3 norm 179 0.0756
## 4 stan 260 0.128
## 5 norm 147 0.0650
## 6 norm 209 0.0896
## 7 stan 169 0.0826
## 8 stan 9 0.0316
## 9 stan 276 0.135
## 10 norm 208 0.0879
## # ℹ 590 more rows