knitr::opts_chunk$set(echo = F,
                      #warning = F,
                      #message = F,
                      fig.align = 'center')

pacman::p_load(tidyverse, class)

Homework 9 Description

This assignment only has one objective: Write a function that will tune the value of \(k\) for kNN classification for any data set given to it.

You’ll need to do the following:

  1. Name the function knn_search() with the following arguments:

    • X = the data set of predictors
    • class = the class/category for rows of X
    • k_grid = the different values of \(k\) to search over
  2. Perform a grid search across k for k-nearest neighbors classification properly using methods seen in class.

  3. Returns two objects:

  1. k_error: A data frame with the following columns - k: the values of \(k\) supplied by k_grid - rescale: the rescale method (‘norm’ and ‘stan’). - error: The misclassification rate for the choice of \(k\) and rescale method
  2. gg_serch: The resulting graph of the grid search with: - x = k - y = error - color = rescale - a dashed line at the value of k that minimizes the error rate with color that matches the rescale method

There are two subsequent code chunks you can use check your function. The two chunks below that have their output hidden in Brightspace and will be used to grade your function.

Check 1: iris data

RNGversion('4.1.0');set.seed(2870)

check1_knn <- 
  knn_search(
    X = iris[, -5], 
    class = iris$Species,
    k_grid = 5:50
  )

# Returned data frame
check1_knn$k_error |> 
  slice_sample(n = nrow(check1_knn$k_error)) |> 
  tibble()
## # A tibble: 92 × 3
##    rescale     k  error
##    <chr>   <int>  <dbl>
##  1 stan       40 0.107 
##  2 stan        6 0.04  
##  3 norm       20 0.04  
##  4 norm       19 0.04  
##  5 stan       32 0.0667
##  6 stan       13 0.0333
##  7 norm       16 0.0333
##  8 stan       41 0.113 
##  9 stan       10 0.0467
## 10 stan       39 0.1   
## # ℹ 82 more rows
# Graph
check1_knn$gg_search + ggtitle('Iris data')

Check 2: penguins data

RNGversion('4.1.0');set.seed(2870)

check2_knn <- 
  knn_search(
    X = drop_na(palmerpenguins::penguins[, 5:6]), 
    class = palmerpenguins::penguins |> dplyr::select(1, 5:6)|> drop_na() |> pull(species),
    k_grid = 1:100
  )

# Returned data frame
check2_knn$k_error |> 
  slice_sample(n = nrow(check2_knn$k_error)) |> 
  tibble()
## # A tibble: 200 × 3
##    rescale     k error
##    <chr>   <int> <dbl>
##  1 stan       32 0.178
##  2 stan       57 0.202
##  3 norm       87 0.205
##  4 norm       43 0.184
##  5 stan       20 0.184
##  6 stan       86 0.211
##  7 norm       31 0.181
##  8 norm        3 0.234
##  9 stan       91 0.211
## 10 norm       57 0.202
## # ℹ 190 more rows
# Graph
check2_knn$gg_search + ggtitle('Penguin data')

Check 3:

## # A tibble: 200 × 3
##    rescale     k error
##    <chr>   <int> <dbl>
##  1 norm       23 0.421
##  2 norm       46 0.426
##  3 norm        4 0.410
##  4 stan        1 0.399
##  5 stan       35 0.410
##  6 stan       98 0.459
##  7 norm       58 0.421
##  8 stan       85 0.454
##  9 norm       78 0.404
## 10 stan       38 0.399
## # ℹ 190 more rows

Check 4:

## # A tibble: 600 × 3
##    rescale     k  error
##    <chr>   <int>  <dbl>
##  1 norm      144 0.0633
##  2 stan      119 0.0685
##  3 norm      179 0.0756
##  4 stan      260 0.128 
##  5 norm      147 0.0650
##  6 norm      209 0.0896
##  7 stan      169 0.0826
##  8 stan        9 0.0316
##  9 stan      276 0.135 
## 10 norm      208 0.0879
## # ℹ 590 more rows