# 2. K-nearest neighbor
# Using the same training and test set as above, tune a K-nearest nieghbor alorithm. To do this make a plot of
# test set accuracy vs k and choose k.

library(AppliedPredictiveModeling)
data(hepatic)

## 
##  a. Split the data into a training and test set. use random
## 从281个中抽取211个作为训练集

nfolds <- round(281*0.75)
nfolds 
## [1] 211
set.seed(12334)
ind <- sample(seq(1,281,by = 1),nfolds)
## 训练集
biotrain <- bio[ind,] 
injurytrain <- injury[ind] 
## 测试集
biotest <- bio[-ind,] 
injurytest <- injury[-ind]

library(class)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.3.2
library(ggplot2)
acc <- c()
kk <- 3:30
for(ii in kk){
  aa <- knn(biotrain, biotest, injurytrain, k = ii, l = 0,prob = FALSE)
  bb <- confusionMatrix(aa,injurytest)
  acc[ii] <- bb$overall[1]
}

acc <- acc[kk]
acc
##  [1] 0.4142857 0.4857143 0.4571429 0.4714286 0.4285714 0.3142857 0.4000000
##  [8] 0.3714286 0.4142857 0.3714286 0.4428571 0.4142857 0.4142857 0.3714286
## [15] 0.4285714 0.4428571 0.5000000 0.4571429 0.4428571 0.5000000 0.4857143
## [22] 0.5285714 0.5142857 0.5285714 0.5142857 0.5285714 0.5285714 0.5428571
plot(kk,acc,type = "l",ylab = "Accuracy",main = "K-nearest neighbor")

## 我再重新绘制一个图
data2 <- data.frame(kk = kk,Accuracy = acc)
ggplot(data2,aes(x = kk,y = Accuracy)) +
  theme_bw() +
  geom_line() +
  geom_point() +
  labs(title = "K-nearest neighbor") +
  theme(plot.title = element_text(hjust = 0.5))

aa <- knn(biotrain, biotest, injurytrain, k = 10, prob = FALSE)
bb <- confusionMatrix(aa,injurytest)
acc[ii] <- bb$overall[1]
bb <- mean(aa == injurytest)
bb
## [1] 0.3857143