iris 데이터를 가지고 몇 개의 모델을 돌려볼 것이다. ‘Species’ 분류에 있어서 정확도가 가장 높은 모델을 채택할 것이다.
부가적인 기술설명은 중간중간 달아놓은 블로그 링크를 통해 확인할 수 있다.
분석에 앞서 가장 먼저 해야할 것은 데이터의 구조를 파악하는 것이다.
여기서 데이터는 tranining_set을 기준으로 파악할 것이며, 먼저 ’Species’를 분류하는데 어떤 데이터가 필요한지 확인해보자.
library(kernlab) # Support Vector Machine
library(nnet) # Logistic regression, Neural Network
library(rpart) # Decision Tree
library(randomForest) # Random Forest
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
library(dplyr) # Data Handling
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:randomForest':
##
## combine
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(caret) # Confusion Matrix
## Loading required package: lattice
## Loading required package: ggplot2
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
## The following object is masked from 'package:kernlab':
##
## alpha
library(DT) # Data visualize
library(class) # KNN
library(ggvis) # Data visualize
##
## Attaching package: 'ggvis'
## The following object is masked from 'package:ggplot2':
##
## resolution
시각화를 통해 ‘Petal.Length’ 와 ‘Petal.Width’ 가 선형성 구조를 보이고 있으므로 이를 통해 분석에 적합성을 가지고 있다는 사실을 알 수 있다.
plot(iris)
그래프를 통해 보는 바와 같이 ’Setosa’의 데이터가 가장 정확하게 분류되는 것을 확인할 수 있다.
iris %>% ggvis(~Petal.Length, ~Petal.Width, fill = ~factor(Species)) %>%
layer_points()
str() 함수를 통해 iris 데이터의 구조를 확인해보자.
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
iris 데이터의 통계적 수치를 확인해보자.
summary(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.350 Median :1.300
## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
## Species
## setosa :50
## versicolor:50
## virginica :50
##
##
##
결측치 유무를 확인해보자.
sum(is.na(iris))
## [1] 0
seed 값을 설정해주고, sampling 과 set 데이터 만들어주자.
# 데이터 할당
df <- iris
# seed값 설정
set.seed(919)
# training / test sampling
training_sampling <- sort(sample(1:nrow(df), nrow(df) * 0.7 ))
test_sampling <- setdiff(1: nrow(df), training_sampling)
# training / test set
training_set <- df[training_sampling,]
test_set <- df [test_sampling,]
Training / Test set을 나눴으니 이제 학습을 시켜보자. 모델은 Logistic regression, Decision Tree, Random Forest 를 사용했다.
이번에는 로지스틱 회귀분석을 해보려 한다. 간단히 말하자면 데이터를 분류할 때 선형분류기를 사용하여 이진분류를 하는 것이 로지스틱 분석의 핵심이다. 링크를 통하여 참고할 수 있다.
https://liujingjun.tistory.com/25
multi_logit_m <- multinom(Species ~ Petal.Length + Petal.Width, data = training_set)
## # weights: 12 (6 variable)
## initial value 115.354290
## iter 10 value 8.103704
## iter 20 value 6.085925
## iter 30 value 6.068446
## iter 40 value 6.061376
## iter 50 value 6.058734
## iter 60 value 6.055210
## iter 70 value 6.053575
## iter 80 value 6.048157
## iter 90 value 6.047643
## iter 100 value 6.047423
## final value 6.047423
## stopped after 100 iterations
multi_logit_p <- predict(multi_logit_m, newdata = test_set, type = "class")
의사결정 나무 기법을 사용한 iris 분류를 해보자. 의사결정 나무란, 데이터를 나무가 가지 치듯이 차례차례 분류하여 최종적으로 분류하는 모델을 의미한다. rpart패키지 불러와 rpart 함수를 사용하여 Decision tree를 생성해보자.
https://liujingjun.tistory.com/19
rpart_m <- rpart(Species ~ Petal.Length + Petal.Width, data = training_set)
rpart_p <- predict(rpart_m, newdata = test_set, type = "class")
앙상블 기법의 일종으로 여러가지 기술을 가진 의사결정 나무들이 모여있는 형태라고 볼 수 있다. randomForest패키지의 randomForest함수를 사용하여 Random Forest 모델을 만들어 보자.
https://liujingjun.tistory.com/27
rf_m <- randomForest(Species ~ Petal.Length + Petal.Width, data = training_set)
rf_p <- predict(rf_m, newdata = test_set, type = "class")
이번엔 Support Vector Machine 을 통해 분류를 해볼 것이다. SVM(Support Vector Machine)이란 데이터 상에 있는 각점들의 거리를 분석해 가장 먼 거리에 있는 점들을 기준으로 support vector를 형성하여 두개의 support vector 중간에 초평면을 만들어 분류를 하는 방법이다. 쉽게말하면 두점 사이의 거리가 최대가 되는 지점을 찾는 것이다. https://liujingjun.tistory.com/42
svm_m <- ksvm(Species ~ Petal.Length + Petal.Width, data = training_set)
svm_p <- predict(svm_m, newdata = test_set)
이번엔 KNN으로 분류를 해보자. KNN이란 K-Nearest Neighbor의 점들에 주어진 가장 근접해있는 K근접이웃을 알아내는 과정이다. 자세한 내용은 기술블로그를 참고하도록 하자. https://liujingjun.tistory.com/29
normalizer <- function(x) {
return_value <- (x - min(x)) / (max(x) - min(x))
return(return_value)
}
normal_iris <- sapply(iris[,1:4], normalizer) %>%
as.data.frame()
# 데이터 생성
df <- cbind(normal_iris, "Species" = iris[,5])
# training / test sampling
training_sampling <- sort(sample(1:nrow(df), nrow(df)* 0.7))
test_sampling <- setdiff(1:nrow(df), training_sampling)
# training_set, test_set
training_set <- df[training_sampling,]
test_set <- df[test_sampling,]
training_set_unlable <- training_set[,1:4]
training_set_lable <- training_set[,5]
test_set_unlable <- test_set[,1:4]
test_set_lable <- test_set[,5]
knn_p <- knn(train = training_set_unlable, test = test_set_unlable, cl = training_set_lable, k =3)
각각의 모델을 평가해 보자. 평가항목은 정확도로 할 것이고 이중에서 가장 좋은 모델을 채택한다.
model_list <- cbind(
as.character(multi_logit_p),
as.character(rpart_p),
as.character(rf_p),
as.character(svm_p),
as.character(knn_p) %>%
as.data.frame()
)
# str(model_list)
total_model_accuracy <- data.frame()
for (model in model_list[, 1:ncol(model_list)]) {
model_cm <- confusionMatrix(model, test_set$Species)
model_cm_class <- model_cm$byClass %>% as.data.frame()
model_accuracy <- model_cm_class$'Balanced Accuracy'
total_model_accuracy <- rbind(total_model_accuracy, model_accuracy)
}
colnames(total_model_accuracy) <- levels(test_set$Species)
rownames(total_model_accuracy) <- c("Logistic Regression", "Decision Tree",
"Random Forest", "Support Vector Machine","KNN")
예측값, 실측값의 비교를 위해 각 모델별 분류에 대한 정확도 비교테이블이다.
datatable(total_model_accuracy)
결과