knitr::opts_chunk$set(echo = T,
#warning = F,
#message = F,
fig.align = 'center')
pacman::p_load(tidyverse, class)
This assignment only has one objective: Write a function that will tune the value of \(k\) for kNN classification for any data set given to it.
You’ll need to do the following:
Name the function knn_search() with the following
arguments:
X = the data set of predictorsclass = the class/category for rows of
Xk_grid = the different values of \(k\) to search overPerform a grid search across k for k-nearest neighbors classification when the data are both standardize and normalize.
Returns two objects:
k_error: A data frame with the following columns -
k: the values of \(k\)
supplied by k_grid - rescale: the rescale
method (‘norm’ and ‘stan’). - error: The misclassification
rate for the choice of \(k\) and
rescale methodgg_serch: The resulting graph of the grid search with:
- x = k - y = error -
color = rescale - a dashed line at the value of
k that minimizes the error rate with color that matches the
rescale methodThere are two subsequent code chunks you can use check your function. The two chunks below that have their output hidden in Brightspace and will be used to grade your function.
knn_search <- function(X, class, k_grid){
## Step 1) Rescale
# Standardizing the data
X_stan <-
X |>
mutate(
across(
.cols = everything(),
.fns = scale
)
)
# Normalizing the data
X_norm <-
X |>
mutate(
across(
.cols = everything(),
.fns = ~(. - min(.)) / (max(.) - min(.))
)
)
## Step 2) Saving the results in data frames
# Normalized data
error_norm <-
data.frame(
k = k_grid,
error = -1
)
# Standardized data
error_stan <- error_norm
## Step 3) Grid Search
for (i in 1:length(k_grid)){
# Normalized data
knn_loop_norm <-
knn.cv(
train = X_norm,
cl = class,
k = k_grid[i]
)
error_norm[i, 'error'] = mean(class != knn_loop_norm)
# Standardized data
knn_loop_stan <-
knn.cv(
train = X_stan,
cl = class,
k = k_grid[i]
)
error_stan[i, 'error'] = mean(class != knn_loop_stan)
}
# Data frame for both standardized and normalized data
k_error <-
bind_rows(
.id = 'rescale',
'norm' = error_norm,
'stan' = error_stan
)
# Graph of the grid search
gg_search <-
ggplot(
data = k_error,
mapping = aes(
x = k,
y = error,
color = rescale
)
) +
geom_line() +
geom_vline(
data = k_error |> slice_min(error, n = 1),
mapping = aes(xintercept = k,
color = rescale),
linetype = 'dashed',
show.legend = F
) +
theme_bw()
return(
list(
'k_error' = k_error,
'gg_search' = gg_search
)
)
}
RNGversion('4.1.0');set.seed(2870)
check1_knn <-
knn_search(
X = iris[, -5],
class = iris$Species,
k_grid = 5:50
)
# Returned data frame
check1_knn$k_error |>
slice_sample(n = nrow(check1_knn$k_error)) |>
tibble()
## # A tibble: 92 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 stan 40 0.107
## 2 stan 6 0.04
## 3 norm 20 0.04
## 4 norm 19 0.04
## 5 stan 32 0.0667
## 6 stan 13 0.0333
## 7 norm 16 0.0333
## 8 stan 41 0.113
## 9 stan 10 0.0467
## 10 stan 39 0.1
## # ℹ 82 more rows
# Graph
check1_knn$gg_search + ggtitle('Iris data')
RNGversion('4.1.0');set.seed(2870)
check2_knn <-
knn_search(
X = drop_na(palmerpenguins::penguins[, 5:6]),
class = palmerpenguins::penguins |> dplyr::select(1, 5:6)|> drop_na() |> pull(species),
k_grid = 1:100
)
# Returned data frame
check2_knn$k_error |>
slice_sample(n = nrow(check2_knn$k_error)) |>
tibble()
## # A tibble: 200 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 stan 32 0.178
## 2 stan 57 0.202
## 3 norm 87 0.205
## 4 norm 43 0.184
## 5 stan 20 0.184
## 6 stan 86 0.211
## 7 norm 31 0.181
## 8 norm 3 0.234
## 9 stan 91 0.211
## 10 norm 57 0.202
## # ℹ 190 more rows
# Graph
check2_knn$gg_search + ggtitle('Penguin data')
RNGversion('4.1.0');set.seed(2870)
mpg2 <-
mpg |>
filter(class %in% c('compact', 'midsize', 'suv', 'pickup')) |>
dplyr::select(class, displ, cty, hwy)
check3_knn <-
knn_search(
X = mpg2[, 2:4],
class = mpg2$class,
k_grid = 1:100
)
# Returned data frame
check3_knn$k_error |>
slice_sample(n = nrow(check3_knn$k_error)) |>
tibble()
## # A tibble: 200 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 norm 23 0.421
## 2 norm 46 0.426
## 3 norm 4 0.410
## 4 stan 1 0.399
## 5 stan 35 0.410
## 6 stan 98 0.459
## 7 norm 58 0.421
## 8 stan 85 0.454
## 9 norm 78 0.404
## 10 stan 38 0.399
## # ℹ 190 more rows
# Graph
check3_knn$gg_search + ggtitle('Car data')
RNGversion('4.1.0');set.seed(2870)
cancer <-
read.csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', header = F) |>
dplyr::select(-V1) |>
rename(diagnosis = V2)
check4_knn <-
knn_search(
X = cancer[, -1],
class = cancer$diagnosis,
k_grid = 1:300
)
# Returned data frame
check4_knn$k_error |>
slice_sample(n = nrow(check4_knn$k_error)) |>
tibble()
## # A tibble: 600 × 3
## rescale k error
## <chr> <int> <dbl>
## 1 norm 144 0.0633
## 2 stan 119 0.0685
## 3 norm 179 0.0756
## 4 stan 260 0.128
## 5 norm 147 0.0650
## 6 norm 209 0.0896
## 7 stan 169 0.0826
## 8 stan 9 0.0316
## 9 stan 276 0.135
## 10 norm 208 0.0879
## # ℹ 590 more rows
# Graph
check4_knn$gg_search + ggtitle('Cancer data')