아주 잘 알려진 지도 학습 알고리즘인 k-NN을 국민 데이터인 iris data로 쉽고 재밌게 배워봅시다.
Iris 데이터를 scale() 함수를 써서 표준정규분포로 표준화한다.
# Normalization of all columns except Species
dataNorm <- iris
dataNorm[,-5] <- scale(iris[,-5])
그런 다음, 70%의 train data와 30%의 test data로 나눈다.
# 70% train and 30% test
library(tidyverse)
## -- Attaching packages ---------------------------- tidyverse 1.2.1 --
## √ ggplot2 2.2.1 √ purrr 0.2.4
## √ tibble 1.4.2 √ dplyr 0.7.4
## √ tidyr 0.7.2 √ stringr 1.2.0
## √ readr 1.1.1 √ forcats 0.3.0
## -- Conflicts ------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
set.seed(1234)
# 1. sample 함수 : dataNorm의 data 갯수만큼 1을 0.7% 확률로, 2를 0.3% 확률로 무작위 복원추출
ind <- sample(2, nrow(dataNorm), replace=TRUE, prob=c(0.7,0.3))
trainData <- dataNorm[ind==1,]
testData <- dataNorm[ind==2,]
# # 2. dplyr::sample_frac 함수
# trainData <- sample_frac(dataNorm, 0.7)
# trainIndex <- as.numeric(rownames(trainData)) # rownames() return character
# testData <- dataNorm[-trainIndex, ]
knn()
함수는 다음과 같은 arguments를 가진다:
train
. marix or dataframe 형태의 training data settest
. marix or dataframe 형태의 test data setcl
. training data set의 분류 요이k
. 고려할 이웃의 수# Load the class package
library(class)
# Execution of k-NN with k=1
KnnTestPrediction_k1 <- knn(trainData[,-5], testData[,-5],
trainData$Species, k=1, prob=TRUE)
# Execution of k-NN with k=2
KnnTestPrediction_k2 <- knn(trainData[,-5], testData[,-5],
trainData$Species, k=2, prob=TRUE)
# Execution of k-NN with k=3
KnnTestPrediction_k3 <- knn(trainData[,-5], testData[,-5],
trainData$Species, k=3, prob=TRUE)
# Execution of k-NN with k=4
KnnTestPrediction_k4 <- knn(trainData[,-5], testData[,-5],
trainData$Species, k=4, prob=TRUE)
’k’를 다르게 사용함에 따라 각각의 분류에 대한 정확성을 평가할 수 있고 어떤 값의 ’k’가 가장 훌륭한 결과를 제공하는지 확인할 수 있다.
# Confusion matrix of KnnTestPrediction_k1
table(testData$Species, KnnTestPrediction_k1)
## KnnTestPrediction_k1
## setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 0
## virginica 0 2 14
위의 교차표는 어떻게 해석할 수 있을까요?!
test data에서 10개의 setosa 관측치가 정확하게 setosa로 예측된다.
test data에서 12개의 versicolor 관측치가 정확하게 versicolor로 예측된다.
test data에서 16개 중 14개의 virginica가 정확하게 virginica로 예측되고 다른 2개의 관측치는 versicolor로 오분류 되었다.
분류의 정확성을 아래와 같이 계산할 수 있다.
# Classification accuracy of KnnTestPrediction_k1
sum(KnnTestPrediction_k1==testData$Species) / length(testData$Species)*100
## [1] 94.73684
다른 분류 결과들은 아래와 같다.
# Confusion matrix of KnnTestPrediction_k2
table(testData$Species, KnnTestPrediction_k2)
## KnnTestPrediction_k2
## setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 0
## virginica 0 3 13
# Classification accuracy of KnnTestPrediction_k2
sum(KnnTestPrediction_k2 == testData$Species) / length(testData$Species) * 100
## [1] 92.10526
# Confusion matrix of KnnTestPrediction_k3
table(testData$Species, KnnTestPrediction_k3)
## KnnTestPrediction_k3
## setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 0
## virginica 0 2 14
# Classification accuracy of KnnTestPrediction_k3
sum(KnnTestPrediction_k3 == testData$Species) / length(testData$Species) * 100
## [1] 94.73684
# Confusion matrix of KnnTestPrediction_k4
table(testData$Species, KnnTestPrediction_k4)
## KnnTestPrediction_k4
## setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 0
## virginica 0 2 14
# Classification accuracy of KnnTestPrediction_k4
sum(KnnTestPrediction_k4 == testData$Species) / length(testData$Species) * 100
## [1] 94.73684
어떤 ’k’값이 가장 정확한 분류 결과를 가져오는지 plot을 통해 확인해보자.
KnnTestPrediction <- list()
accuracy <- numeric()
for(k in 1:100){
KnnTestPrediction[[k]] <- knn(trainData[,-5], testData[,-5], trainData$Species, k, prob=TRUE)
accuracy[k] <- sum(KnnTestPrediction[[k]]==testData$Species)/length(testData$Species)*100
}
plot(accuracy, type="b", col="dodgerblue", cex=1, pch=20,
xlab="k, number of neighbors", ylab="Classification accuracy",
main="Accuracy vs Neighbors")
# Add lines indicating k with best accuracy
abline(v=which(accuracy==max(accuracy)), col="darkorange", lwd=1.5)
# Add line for max accuracy seen
abline(h=max(accuracy), col="grey", lty=2)
# Add line for min accuracy seen
abline(h=min(accuracy), col="grey", lty=2)
plot에서 확인할 수 있는 점은 10개의 ’k’값이 가장 정확한 정확성을 가졌고 ’k’값이 증가할수록 정확성은 감소하는 것을 볼 수 있다.