knitr::opts_chunk$set(echo = T,
                      #warning = F,
                      #message = F,
                      fig.align = 'center')

pacman::p_load(tidyverse, class)

Homework 9 Description

This assignment only has one objective: Write a function that will tune the value of \(k\) for kNN classification for any data set given to it.

You’ll need to do the following:

  1. Name the function knn_search() with the following arguments:

    • X = the data set of predictors
    • class = the class/category for rows of X
    • k_grid = the different values of \(k\) to search over
  2. Perform a grid search across k for k-nearest neighbors classification when the data are both standardize and normalize.

  3. Returns two objects:

  1. k_error: A data frame with the following columns - k: the values of \(k\) supplied by k_grid - rescale: the rescale method (‘norm’ and ‘stan’). - error: The misclassification rate for the choice of \(k\) and rescale method
  2. gg_serch: The resulting graph of the grid search with: - x = k - y = error - color = rescale - a dashed line at the value of k that minimizes the error rate with color that matches the rescale method

There are two subsequent code chunks you can use check your function. The two chunks below that have their output hidden in Brightspace and will be used to grade your function.

knn_search <- function(X, class, k_grid){
  ## Step 1) Rescale
  # Standardizing the data
  X_stan <- 
    X |> 
    mutate(
      across(
        .cols = everything(),
        .fns = scale
      )
    )
  
  # Normalizing the data
  X_norm <- 
    X |> 
    mutate(
      across(
        .cols = everything(),
        .fns = ~(. - min(.)) / (max(.) - min(.))
      )
    )
  
  
  ## Step 2) Saving the results in data frames
  # Normalized data
  error_norm <- 
    data.frame(
      k = k_grid,
      error = -1
    )
  
  # Standardized data
  error_stan <- error_norm
  
  ## Step 3) Grid Search
  for (i in 1:length(k_grid)){
    
    # Normalized data
    knn_loop_norm <- 
      knn.cv(
        train = X_norm,
        cl = class,
        k = k_grid[i]
      )
    
    error_norm[i, 'error'] = mean(class != knn_loop_norm)
    
    
    # Standardized data
    knn_loop_stan <- 
      knn.cv(
        train = X_stan,
        cl = class,
        k = k_grid[i]
      )
    
    error_stan[i, 'error'] = mean(class != knn_loop_stan)
  }
  
  # Data frame for both standardized and normalized data
  k_error <- 
    bind_rows(
      .id = 'rescale',
      'norm' = error_norm,
      'stan' = error_stan
    )
  
  # Graph of the grid search
  gg_search <- 
    ggplot(
      data = k_error,
      mapping = aes(
        x = k,
        y = error,
        color = rescale
      )
    ) + 
    geom_line() + 
    geom_vline(
      data = k_error |> slice_min(error, n = 1),
      mapping = aes(xintercept = k,
                    color = rescale),
      linetype = 'dashed',
      show.legend = F
    ) + 
    theme_bw()
  
  return(
    list(
      'k_error' = k_error,
      'gg_search' = gg_search
    )
  )
}

Check 1: iris data

RNGversion('4.1.0');set.seed(2870)

check1_knn <- 
  knn_search(
    X = iris[, -5], 
    class = iris$Species,
    k_grid = 5:50
  )

# Returned data frame
check1_knn$k_error |> 
  slice_sample(n = nrow(check1_knn$k_error)) |> 
  tibble()
## # A tibble: 92 × 3
##    rescale     k  error
##    <chr>   <int>  <dbl>
##  1 stan       40 0.107 
##  2 stan        6 0.04  
##  3 norm       20 0.04  
##  4 norm       19 0.04  
##  5 stan       32 0.0667
##  6 stan       13 0.0333
##  7 norm       16 0.0333
##  8 stan       41 0.113 
##  9 stan       10 0.0467
## 10 stan       39 0.1   
## # ℹ 82 more rows
# Graph
check1_knn$gg_search + ggtitle('Iris data')

Check 2: penguins data

RNGversion('4.1.0');set.seed(2870)

check2_knn <- 
  knn_search(
    X = drop_na(palmerpenguins::penguins[, 5:6]), 
    class = palmerpenguins::penguins |> dplyr::select(1, 5:6)|> drop_na() |> pull(species),
    k_grid = 1:100
  )

# Returned data frame
check2_knn$k_error |> 
  slice_sample(n = nrow(check2_knn$k_error)) |> 
  tibble()
## # A tibble: 200 × 3
##    rescale     k error
##    <chr>   <int> <dbl>
##  1 stan       32 0.178
##  2 stan       57 0.202
##  3 norm       87 0.205
##  4 norm       43 0.184
##  5 stan       20 0.184
##  6 stan       86 0.211
##  7 norm       31 0.181
##  8 norm        3 0.234
##  9 stan       91 0.211
## 10 norm       57 0.202
## # ℹ 190 more rows
# Graph
check2_knn$gg_search + ggtitle('Penguin data')

Check 3:

RNGversion('4.1.0');set.seed(2870)

mpg2 <- 
  mpg |> 
  filter(class %in% c('compact', 'midsize', 'suv', 'pickup')) |> 
  dplyr::select(class, displ, cty, hwy)

check3_knn <- 
  knn_search(
    X = mpg2[, 2:4],
    class = mpg2$class,
    k_grid = 1:100
  )

# Returned data frame
check3_knn$k_error |> 
  slice_sample(n = nrow(check3_knn$k_error)) |> 
  tibble()
## # A tibble: 200 × 3
##    rescale     k error
##    <chr>   <int> <dbl>
##  1 norm       23 0.421
##  2 norm       46 0.426
##  3 norm        4 0.410
##  4 stan        1 0.399
##  5 stan       35 0.410
##  6 stan       98 0.459
##  7 norm       58 0.421
##  8 stan       85 0.454
##  9 norm       78 0.404
## 10 stan       38 0.399
## # ℹ 190 more rows
# Graph
check3_knn$gg_search + ggtitle('Car data')

Check 4:

RNGversion('4.1.0');set.seed(2870)

cancer <- 
  read.csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', header = F) |> 
  dplyr::select(-V1) |> 
  rename(diagnosis = V2)

check4_knn <- 
  knn_search(
    X = cancer[, -1],
    class = cancer$diagnosis,
    k_grid = 1:300
  )

# Returned data frame
check4_knn$k_error |> 
  slice_sample(n = nrow(check4_knn$k_error)) |> 
  tibble()
## # A tibble: 600 × 3
##    rescale     k  error
##    <chr>   <int>  <dbl>
##  1 norm      144 0.0633
##  2 stan      119 0.0685
##  3 norm      179 0.0756
##  4 stan      260 0.128 
##  5 norm      147 0.0650
##  6 norm      209 0.0896
##  7 stan      169 0.0826
##  8 stan        9 0.0316
##  9 stan      276 0.135 
## 10 norm      208 0.0879
## # ℹ 590 more rows
# Graph
check4_knn$gg_search + ggtitle('Cancer data')