library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.5 ✓ purrr 0.3.4
## ✓ tibble 3.1.6 ✓ dplyr 1.0.8
## ✓ tidyr 1.2.0 ✓ stringr 1.4.0
## ✓ readr 2.1.2 ✓ forcats 0.5.1
## Warning: package 'tidyr' was built under R version 4.1.2
## Warning: package 'readr' was built under R version 4.1.2
## Warning: package 'dplyr' was built under R version 4.1.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
set.seed(12345)
x1<-runif(60,-1,1)
x2<-runif(60,-1,1)
y<-sample(c(0,1),size=60,replace=TRUE,prob=c(0.3,0.7))
Data<-data.frame(x1,x2,y)
SampleId<-sample(x=1:60,size=18)
DataTest<-Data[SampleId,]
DataTrain<-Data[-SampleId,]
DataTest
## x1 x2 y
## 37 0.73758980 0.44023301 1
## 52 0.65460574 -0.23311073 1
## 41 0.56438656 -0.41106919 0
## 46 -0.35755065 0.80500759 1
## 8 0.01844867 -0.70108411 1
## 20 0.90331751 -0.69980358 0
## 60 -0.52683909 -0.35841586 0
## 11 -0.93092913 0.37670671 1
## 21 -0.09254385 0.74089576 1
## 9 0.45541051 0.20071394 1
## 36 -0.27674886 0.02769131 0
## 10 0.97947388 0.89286150 0
## 7 -0.34980923 0.89741437 0
## 13 0.47136990 -0.25412751 0
## 34 0.36366725 0.20185656 1
## 18 -0.19502972 0.30992102 1
## 1 0.44180779 0.58313560 1
## 58 -0.83932791 0.17589595 0
DataTrain
## x1 x2 y
## 2 0.751546386 -0.482631366 1
## 3 0.521964657 0.971967664 1
## 4 0.772249132 0.513747487 1
## 5 -0.087038080 0.959556494 0
## 6 -0.667256430 -0.562104322 0
## 12 -0.695253020 0.011067446 1
## 14 -0.997726827 -0.328389923 1
## 15 -0.217593329 -0.903497295 0
## 16 -0.075010692 0.237895078 0
## 17 -0.223712037 0.922894584 1
## 19 -0.642072830 0.020583983 1
## 22 -0.346495183 0.028883356 1
## 23 0.930830647 -0.982704173 1
## 24 0.414963754 -0.961610469 1
## 25 0.289085273 -0.710976197 1
## 26 -0.220343030 -0.389936490 0
## 27 0.397087279 0.651313725 0
## 28 0.088115729 0.004689285 1
## 29 -0.547065643 0.607145252 0
## 30 -0.030884490 -0.878720036 1
## 31 0.586014340 0.855910273 0
## 32 -0.988024741 0.616357102 1
## 33 -0.624575108 -0.842373320 0
## 35 -0.259791753 0.429559761 1
## 38 0.808309332 0.499892540 1
## 39 0.234849131 -0.808718952 0
## 40 -0.731936731 -0.204348270 1
## 42 -0.141602360 0.234507331 1
## 43 0.854547950 0.948548255 0
## 44 0.546486449 0.236424076 0
## 45 -0.480637507 0.042738387 1
## 47 -0.879609685 0.274908865 0
## 48 -0.913087092 0.728602266 1
## 49 -0.889892363 -0.497764518 0
## 50 0.251085594 -0.569861839 1
## 51 0.928940577 0.218952090 0
## 53 -0.369943527 0.510542089 0
## 54 -0.573949098 -0.240527514 1
## 55 0.464992238 0.589944127 1
## 56 -0.001517959 0.811382272 0
## 57 0.459543943 0.968052347 0
## 59 -0.128939030 -0.981071897 0
dim<-rbind(dim(DataTrain),dim(DataTest)) %>% data.frame()
names(dim)<-c("nrow","ncol")
dim
## nrow ncol
## 1 42 3
## 2 18 3
par(mfrow=c(2,2),mar=c(4,6,4,4))
Data[,1:2]%>% plot(pch=Data[,3]+1,cex=0.8,
xlab="x1",ylab="x2",
main="Total samples")
Data[,1:2] %>% plot(pch=DataTrain[,3]+1,cex=0.8,
xlab="x1",ylab="x2",
main="Train and Test Samples")
DataTest[,1:2] %>% points(pch=DataTest[,3]+16,col=2,cex=0.8)
library(class)
options(digits=3)
#基于全体观测样本建模
errRatio<-vector()
for(i in 1:30){
KnnFit<-knn(train=Data[,1:2],test=Data[,1:2],
cl=Data[,3],k=i)
CT<-table(Data[,3],KnnFit)
errRatio<-c(errRatio,(1-sum(diag(CT))/sum(CT))*100)
}
plot(errRatio,type="l",col="blue",
xlab="K:number of neighbors",
ylab="errRatio(%)",
main="errRatio and number of neighbors K",
ylim=c(0,80))
#旁置法KNN
errRatio1<-vector()
for(i in 1:30){
KnnFit<-knn(train=DataTrain[,1:2],test=DataTest[,1:2],cl=DataTrain[,3],k=i)
CT<-table(DataTest[,3],KnnFit) #计算混淆矩阵
#计算分类错误率
errRatio1<-c(errRatio1,(1-sum(diag(CT))/sum(CT))*100)
}
lines(1:30,errRatio1,lty=2,col="red")
# 留一法交叉验证KNN
set.seed(12345)
errRatio2<-vector()
for(i in 1:30){
KnnFit<-knn.cv(train=Data[,1:2],cl=Data[,3],k=i)
CT<-table(Data[,3],KnnFit)
errRatio2<-c(errRatio2,(1-sum(diag(CT))/sum(CT))*100)
}
lines(1:30,errRatio2,col="black")
legend("topright",
c("Total","train","leave-one-out method"),
lty=1:2,
col=c("blue","red","black"),cex=0.6)
which.min(errRatio1)
## [1] 8
which.min(errRatio2[1:10])
## [1] 7
##随机生成数据
set.seed(12345)
x1<-runif(60,-1,1)
x2<-runif(60,-1,1)
y<-runif(60,10,20)
Data<-data.frame(x1,x2,y)
SampleId<-sample(x=1:60,size=18)
DataTest<-Data[SampleId,]
DataTrain<-Data[-SampleId,]
#旁置法
mseVector<-vector()
for(i in 1:30){
KnnFit<-knn(train=DataTrain[,1:2],
test=DataTest[,1:2],
cl=DataTrain[,3],k=i,prob=FALSE)
KnnFit<-as.double(as.vector(KnnFit))
mse<-sum((DataTest[,3]-KnnFit)^2)/length(DataTest[,3])
mseVector<-c(mseVector,mse)
}
plot(mseVector,type="l",
xlab="K:number of neighbors",
ylab="MSE",main="MSE and number of neighbors K ",
ylim=c(0,80))
mseVector
## [1] 13.52 11.76 11.95 18.14 15.71 12.16 15.70 9.39 17.94 19.76 13.06 16.12
## [13] 13.69 13.10 24.79 13.37 13.74 15.62 17.61 13.30 19.95 7.71 6.42 10.42
## [25] 12.81 15.13 21.59 23.51 14.08 17.60
which.min(mseVector[1:10])
## [1] 8
#全体样本
mseVector1<-vector()
for(i in 1:30){
KnnFit<-knn(train=Data[,1:2],test=Data[,1:2],cl=Data[,3],k=i)
KnnFit<-as.double(as.vector(KnnFit))
mse<-sum((Data[,3]-KnnFit)^2)/length(Data[,3])
mseVector1<-c(mseVector1,mse)
}
lines(1:30,mseVector1,lty=2,col="red")
mseVector1
## [1] 7.72e-28 7.69e+00 8.76e+00 1.50e+01 9.86e+00 1.10e+01 1.08e+01 1.21e+01
## [9] 1.40e+01 1.18e+01 1.62e+01 1.61e+01 1.46e+01 1.37e+01 1.46e+01 1.11e+01
## [17] 1.06e+01 1.40e+01 1.41e+01 1.16e+01 1.68e+01 1.30e+01 1.21e+01 1.32e+01
## [25] 1.32e+01 1.34e+01 1.66e+01 1.31e+01 1.39e+01 1.68e+01
which.min(mseVector1[1:10])
## [1] 1
#留一法
mseVector2<-vector()
for(i in 1:30){
KnnFit<-knn.cv(train=Data[,1:2],cl=Data[,3],k=i)
KnnFit<-as.double(as.vector(KnnFit))
mse<-sum((Data[,3]-KnnFit)^2)/length(Data[,3])
mseVector2<-c(mseVector2,mse)
}
lines(1:30,mseVector2,lty=3,col="blue")
legend("topright",
c("train","Total","leave-one-out method"),
lty=1:3,
col=c("black","red","blue"),cex=0.6)
mseVector2
## [1] 14.0 14.7 14.0 15.8 14.5 18.3 16.0 20.8 17.7 17.0 16.5 11.8 15.3 12.7 12.5
## [16] 19.0 16.4 16.3 17.8 11.3 15.6 16.6 15.0 13.9 13.5 10.5 13.0 11.0 14.1 15.1
which.min(mseVector2[1:10])
## [1] 1