Lesson 5.1 - Introduction to Cross Validation

Robbie Beane

Load Packages and Data

library(class)
library(RColorBrewer)
pima <- read.table('data/diabetes.csv', header=TRUE, sep=',')
summary(pima)
##   Pregnancies        Glucose      BloodPressure    SkinThickness  
##  Min.   : 0.000   Min.   :  0.0   Min.   :  0.00   Min.   : 0.00  
##  1st Qu.: 1.000   1st Qu.: 99.0   1st Qu.: 62.00   1st Qu.: 0.00  
##  Median : 3.000   Median :117.0   Median : 72.00   Median :23.00  
##  Mean   : 3.845   Mean   :120.9   Mean   : 69.11   Mean   :20.54  
##  3rd Qu.: 6.000   3rd Qu.:140.2   3rd Qu.: 80.00   3rd Qu.:32.00  
##  Max.   :17.000   Max.   :199.0   Max.   :122.00   Max.   :99.00  
##     Insulin           BMI        DiabetesPedigreeFunction      Age       
##  Min.   :  0.0   Min.   : 0.00   Min.   :0.0780           Min.   :21.00  
##  1st Qu.:  0.0   1st Qu.:27.30   1st Qu.:0.2437           1st Qu.:24.00  
##  Median : 30.5   Median :32.00   Median :0.3725           Median :29.00  
##  Mean   : 79.8   Mean   :31.99   Mean   :0.4719           Mean   :33.24  
##  3rd Qu.:127.2   3rd Qu.:36.60   3rd Qu.:0.6262           3rd Qu.:41.00  
##  Max.   :846.0   Max.   :67.10   Max.   :2.4200           Max.   :81.00  
##     Outcome     
##  Min.   :0.000  
##  1st Qu.:0.000  
##  Median :0.000  
##  Mean   :0.349  
##  3rd Qu.:1.000  
##  Max.   :1.000

Variation in Validation Accuracy

colors <- brewer.pal(n = 8, name = "Set1")

set.seed(3)
val_acc_by_split_and_k <- c()

k_range <- 1:60

for(i in 1:10){
  
  # Create 70/30 train/validation split
  sel <- sample(1:nrow(pima), 0.7*nrow(pima))
  X_train <- pima[sel, ][,1:8]
  X_valid <- pima[-sel, ][,1:8]
  y_train <- pima[sel, ][,9]
  y_valid <- pima[-sel, ][,9]

  val_acc_by_k <- c()
  
  for (k in k_range){
    val_pred <- knn(X_train, X_valid, y_train, k=k)
    val_acc_by_k <- c(val_acc_by_k, mean(val_pred == y_valid))
  }
  
  val_acc_by_split_and_k <- rbind(val_acc_by_split_and_k, val_acc_by_k)
  
}

val_acc_by_split_and_k[, 1:8]
##                   [,1]      [,2]      [,3]      [,4]      [,5]      [,6]
## val_acc_by_k 0.7099567 0.6623377 0.6969697 0.7056277 0.7012987 0.7099567
## val_acc_by_k 0.6623377 0.6623377 0.7012987 0.7056277 0.7316017 0.7575758
## val_acc_by_k 0.6060606 0.6060606 0.6493506 0.6233766 0.6536797 0.6709957
## val_acc_by_k 0.6493506 0.6623377 0.7056277 0.6666667 0.6796537 0.6709957
## val_acc_by_k 0.7186147 0.6839827 0.6883117 0.7142857 0.7056277 0.7142857
## val_acc_by_k 0.6536797 0.6103896 0.6320346 0.6493506 0.6753247 0.6796537
## val_acc_by_k 0.6277056 0.6666667 0.6883117 0.6969697 0.7272727 0.7229437
## val_acc_by_k 0.6580087 0.6450216 0.6969697 0.7229437 0.7186147 0.7229437
## val_acc_by_k 0.6796537 0.6493506 0.7099567 0.6969697 0.7142857 0.7186147
## val_acc_by_k 0.7056277 0.7099567 0.6969697 0.6753247 0.6969697 0.7056277
##                   [,7]      [,8]
## val_acc_by_k 0.7316017 0.7056277
## val_acc_by_k 0.7445887 0.7445887
## val_acc_by_k 0.6666667 0.6580087
## val_acc_by_k 0.6753247 0.6753247
## val_acc_by_k 0.7316017 0.7186147
## val_acc_by_k 0.6839827 0.6839827
## val_acc_by_k 0.7099567 0.7056277
## val_acc_by_k 0.7445887 0.7402597
## val_acc_by_k 0.7359307 0.7359307
## val_acc_by_k 0.7272727 0.7402597

Variation in Validation Accuracy

min_val_acc_by_k <- apply(val_acc_by_split_and_k, 2, min)
max_val_acc_by_k <- apply(val_acc_by_split_and_k, 2, max)

plot(k_range, min_val_acc_by_k, ylim=c(0.5, 0.85), pch=".", col="salmon", 
     xlab="K", ylab="Accuracy", main="Minimum and Maximum Validation Accuracy")

lines(k_range, min_val_acc_by_k, lty=2, col='black')
lines(k_range, max_val_acc_by_k, lty=2, col='black')

Variation in Validation Accuracy

diff_val_acc_by_k <- max_val_acc_by_k - min_val_acc_by_k
cat(which.max(diff_val_acc_by_k), '\n', max(diff_val_acc_by_k), sep='')
## 34
## 0.1168831

Variation in Validation Accuracy

m1 <- val_acc_by_split_and_k[3,]
m2 <- val_acc_by_split_and_k[6,]
m3 <- val_acc_by_split_and_k[9,]

plot(k_range, min_val_acc_by_k, ylim=c(0.5, 0.85), pch=".", col="salmon", 
     xlab="K", ylab="Accuracy", main="Variation in Validation Accuracy")

lines(k_range, min_val_acc_by_k, lty=2, col='salmon')
lines(k_range, max_val_acc_by_k, lty=2, col='salmon')

lines(k_range, m1, col=colors[2], lwd=2)
lines(k_range, m2, col=colors[3], lwd=2)
lines(k_range, m3, col=colors[4], lwd=2)

segments(which.max(m1), 0, which.max(m1), max(m1), lty=2, col=colors[2])
segments(which.max(m2), 0, which.max(m2), max(m2), lty=2, col=colors[3])
segments(which.max(m3), 0, which.max(m3), max(m3), lty=2, col=colors[4])

Variation in Validation Accuracy

cat('Model', '\t', 'Best K', '\t', 'Max Val Acc', '\n', sep='')
## Model    Best K  Max Val Acc
cat('---------------------------\n')
## ---------------------------
for (i in 1:nrow(val_acc_by_split_and_k)){
  cat(i, '\t', which.max(val_acc_by_split_and_k[i, ]), '\t', max(val_acc_by_split_and_k[i, ]), '\n', sep='')
}
## 1    16  0.7402597
## 2    13  0.7922078
## 3    14  0.7186147
## 4    13  0.7142857
## 5    19  0.7489177
## 6    49  0.7402597
## 7    22  0.7445887
## 8    52  0.7532468
## 9    16  0.7922078
## 10   15  0.7662338

10-Fold Cross Validation

set.seed(1)
ix <- 1:nrow(pima)
shuffled_ix <- sample(1:nrow(pima))
folds <- split(ix, shuffled_ix%%10)

folds[[1]]
##  [1]   6  12  34  43  44  64  69  79  83  87  88 105 121 133 140 153 157
## [18] 161 176 181 220 242 245 259 261 274 281 287 308 310 328 341 363 369
## [35] 385 387 396 412 418 420 427 448 451 454 455 466 467 477 479 482 491
## [52] 494 512 515 529 550 564 590 596 605 615 633 634 648 654 659 660 676
## [69] 693 704 710 716 726 746 753 763

10-Fold Cross Validation

cat('Fold', '\t', 'Rows', '\n', sep='')
## Fold Rows
cat('------------\n')
## ------------
for (i in 1:10){
  cat(i, '\t', length(folds[[i]]), '\n', sep='')
}
## 1    76
## 2    77
## 3    77
## 4    77
## 5    77
## 6    77
## 7    77
## 8    77
## 9    77
## 10   76

10-Fold Cross Validation

avg_accuracy_by_K <- c()

for (k in k_range){
  
  total = 0
  for (i in 1:10){
    X_train_temp <- pima[-folds[[i]], 1:8]
    X_valid_temp <- pima[ folds[[i]], 1:8]
    y_train_temp <- pima[-folds[[i]], 9]
    y_valid_temp <- pima[ folds[[i]], 9]
    
    temp_valid_pred <- knn(X_train_temp, X_valid_temp, y_train_temp, k=k)
    temp_valid_acc <- mean(temp_valid_pred == y_valid_temp)
    total <- total + temp_valid_acc
  }
  avg_accuracy_by_K <- c(avg_accuracy_by_K, total / 10)  
  
}

plot(k_range, avg_accuracy_by_K, ylim=c(0.5, 0.85), pch=".", col="salmon", 
     xlab="K", ylab="Accuracy", main="10-Fold Cross-Validation Accuracy")

lines(k_range, avg_accuracy_by_K, col='black', lwd=2)
lines(k_range, min_val_acc_by_k, lty=2, col='salmon')
lines(k_range, max_val_acc_by_k, lty=2, col='salmon')

segments(which.max(avg_accuracy_by_K), 0, 
        which.max(avg_accuracy_by_K), max(avg_accuracy_by_K), lty=2)

Variation in 10-Fold Cross Validation Accuracy

set.seed(1)
avg_val_acc_by_split_and_k <- c()


# Create 10 different splits (into 10 folds)
for(i in 1:10){
  ix <- 1:nrow(pima)
  shuffled_ix <- sample(1:nrow(pima))
  folds <- split(ix, shuffled_ix%%10)
  
  # Loop over values of K
  avg_accuracy_by_K <- c()
  for (k in k_range){
    
    #a Loop over each fold
    total = 0
    for (f in 1:10){
      X_train_temp <- pima[-folds[[f]], 1:8]
      X_valid_temp <- pima[ folds[[f]], 1:8]
      y_train_temp <- pima[-folds[[f]], 9]
      y_valid_temp <- pima[ folds[[f]], 9]
      
      temp_valid_pred <- knn(X_train_temp, X_valid_temp, y_train_temp, k=k)
      temp_valid_acc <- mean(temp_valid_pred == y_valid_temp)
      total <- total + temp_valid_acc
    }
    avg_accuracy_by_K <- c(avg_accuracy_by_K, total / 10)  
    
  }
  
  avg_val_acc_by_split_and_k <- rbind(avg_val_acc_by_split_and_k, avg_accuracy_by_K)
}

Variation in 10-Fold Cross Validation Accuracy

min_avg_val_acc_by_k <- apply(avg_val_acc_by_split_and_k, 2, min)
max_avg_val_acc_by_k <- apply(avg_val_acc_by_split_and_k, 2, max)

plot(k_range, min_avg_val_acc_by_k, ylim=c(0.5, 0.85), pch=".", 
     col="salmon", xlab="K", ylab="Accuracy", 
     main="Min and Max 10-Fold Cross Validation Accuracy")

lines(k_range, min_val_acc_by_k, lty=2, col='salmon')
lines(k_range, max_val_acc_by_k, lty=2, col='salmon')

lines(k_range, min_avg_val_acc_by_k, lty=2, col='black')
lines(k_range, max_avg_val_acc_by_k, lty=2, col='black')

Variation in 10-Fold Cross Validation Accuracy

diff_avg_val_acc_by_k <- max_avg_val_acc_by_k - min_avg_val_acc_by_k
cat(which.max(diff_val_acc_by_k), '\n', max(diff_avg_val_acc_by_k), sep='')
## 34
## 0.04683869

Variation in 10-Fold Cross Validation Accuracy

m1 <- avg_val_acc_by_split_and_k[1,]
m2 <- avg_val_acc_by_split_and_k[2,]
m3 <- avg_val_acc_by_split_and_k[3,]

plot(k_range, min_avg_val_acc_by_k, ylim=c(0.5, 0.85), pch=".", 
     col="salmon", xlab="K", ylab="Accuracy", 
     main="Variation in 10-Fold Cross Validation Accuracy")

lines(k_range, min_avg_val_acc_by_k, lty=2, col='salmon')
lines(k_range, max_avg_val_acc_by_k, lty=2, col='salmon')

lines(k_range, m1, col=colors[2], lwd=2)
lines(k_range, m2, col=colors[3], lwd=2)
lines(k_range, m3, col=colors[4], lwd=2)

segments(which.max(m1), 0, which.max(m1), max(m1), lty=2, col=colors[2])
segments(which.max(m2), 0, which.max(m2), max(m2), lty=2, col=colors[3])
segments(which.max(m3), 0, which.max(m3), max(m3), lty=2, col=colors[4])

Variation in 10-Fold Cross Validation Accuracy

cat('Split', '\t', 'Best K', '\t', 'Max AVg Val Acc', '\n', sep='')
## Split    Best K  Max AVg Val Acc
cat('-------------------------------\n')
## -------------------------------
for (i in 1:nrow(val_acc_by_split_and_k)){
  cat(i, '\t', which.max(avg_val_acc_by_split_and_k[i, ]), '\t', max(avg_val_acc_by_split_and_k[i, ]), '\n', sep='')
}
## 1    22  0.7566131
## 2    17  0.7683014
## 3    19  0.7538107
## 4    17  0.7487867
## 5    17  0.7616712
## 6    17  0.7654819
## 7    17  0.7617738
## 8    16  0.7539815
## 9    19  0.7551948
## 10   16  0.7499487