2. 지도학습 알고리즘
2.1 의사결정나무 모형(Decision Tree)
- 의사결정나무 모형이란 특정 항목에 대한 의사결정규칙을 나무 형태로 분류해 나가는 분석 기법을 말한다.
- 조건문 형식을 가지는 것으로서 조건에 맞는지 여부에 따라 가지를 반복 분할하면서 모델을 만든다.
- 결과를 해석하고 이해하기 쉽다.
- 업무처리를 지원해 주는 시스템으로 비지니스 모델 ERP 안에서 존재한다.
의사결정나무의 구조
- 뿌리 마디 : 나무 구조가 시작되는 노드로, 분석 대상의 모든 데이터로 구성된다.
- 자식 마디 : 하나의 노드로부터 분기되어 나간 두 개 이상의 노드로, 분석 대상 데이터는 노드 특성에 따라 분리된다.
- 부모 마디 : 자식 마디의 상위 노드이다.
- 끝 마디 : 각 나뭇가지의 끝에 있는 노드로, 나무 모형에서 분류의 규칙은 끝마디의 개수만큼 생성된다.
학습방법
- 의사결정나무의 학습은 학습에 사용되는 데이터 집합을 적절한 분할 기준 또는, 분할 테스트에 따라 나누는 과정이다.
- 순환 분할 방식으로서 더이상 순수도를 높일 수 없거나, 말단 노드에 포함된 개체의 수가 사전에 정한 최솟값에 도달하였거나, 노드의 깊이가 사전에 정해놓은 한계에 이를 때까지 재귀적으로 분할이 반복된다.
- 이를 하향식 귀납법이라고 한다.
- 가지치기의 기준은 순수도를 가장 높여줄 수 있는 변수를 먼저 선택해 진행한다.
- 순수도를 측정하는 척도로는 지니 척도, 정보 이익(= 엔트로피, 정보의 양)이 많이 사용된다.
지니 척도
- 지니 척도는 두 번을 복원추출 했을 때, 동일범주 개체가 선택될 확률이다.
- 지니 척도는 순수도가 높을수록 1에, 균등할수록 0.5에 가까워진다.
- 순수도가 높은 변수로 가지치기를 한다.
- 불순도 : f(p) = p(1-p), p : 특정 종류 변수가 있을 확률
- 순수도를 구할 때는 1 - 불순도
- 정보 이익은 엔트로피의 개념을 사용한다,
- 어떤 그룹에 두 개의 종류가 섞여있고, 각 종류가 존재할 확률이 p1, p2 일 때 엔트로피식은 다음과 같다
- E = -(p1log_2p1 + p1log_2p1)
- 두 범주의 확률이 모두 0.5이면, 엔트로피는 1이 되고, 한 범주만 있는 완전히 순수한 상태이면 0이 된다.
2.1.1 데이터준비
autoparts <- read.csv("autoparts.csv", header = TRUE)
autoparts1 <- autoparts[autoparts$prod_no == "90784-76001", c(2:11)]
autoparts2 <- autoparts1[autoparts1$c_thickness < 1000, ]
autoparts2$y_faulty <- ifelse((autoparts2$c_thickness < 20) | (autoparts2$c_thickness > 32), 1, 0)
t_index <- sample(1:nrow(autoparts2), size = nrow(autoparts2) * 0.7)
train <- autoparts2[t_index, ] # 훈련데이터 (70%)
test <- autoparts2[-t_index, ] # 검증데이터 (30%)2.1.2 모델생성
library(tree)
m <- tree(factor(y_faulty)~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train)
plot(m)
text(m)2.1.3 가지치기
- 가지가 많을수록 정답률은 높아지나, 모델은 복잡해진다.
prune.m=prune.tree(m,method = "misclass")
# 잘못된 분류를 기준으로 가지치기한다.
# 기본값으로 두어도 큰 차이는 없다.
plot(prune.m)- 가지 9개
prune.m=prune.tree(m,best=9)
plot(prune.m)
text(prune.m)- 가지 3
prune.m2=prune.tree(m,best=3)
plot(prune.m2)
text(prune.m2)2.1.4 검증 데이터 예측하기
yhat_test=predict(prune.m,test,type="class")
yhat_test2=predict(prune.m2,test,type="class")- 가지 9개
#install.packages("caret")
library(caret)## Loading required package: lattice
## Warning: package 'lattice' was built under R version 3.5.1
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.5.1
confusionMatrix(yhat_test,as.factor(test$y_faulty))## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 5448 238
## 1 268 577
##
## Accuracy : 0.9225
## 95% CI : (0.9158, 0.9289)
## No Information Rate : 0.8752
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.6508
## Mcnemar's Test P-Value : 0.1973
##
## Sensitivity : 0.9531
## Specificity : 0.7080
## Pos Pred Value : 0.9581
## Neg Pred Value : 0.6828
## Prevalence : 0.8752
## Detection Rate : 0.8342
## Detection Prevalence : 0.8706
## Balanced Accuracy : 0.8305
##
## 'Positive' Class : 0
##
- 가지 3
confusionMatrix(yhat_test2,as.factor(test$y_faulty))## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 5612 680
## 1 104 135
##
## Accuracy : 0.88
## 95% CI : (0.8718, 0.8877)
## No Information Rate : 0.8752
## P-Value [Acc > NIR] : 0.1264
##
## Kappa : 0.2115
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.9818
## Specificity : 0.1656
## Pos Pred Value : 0.8919
## Neg Pred Value : 0.5649
## Prevalence : 0.8752
## Detection Rate : 0.8593
## Detection Prevalence : 0.9634
## Balanced Accuracy : 0.5737
##
## 'Positive' Class : 0
##
2.1.5 ROC, AUC
- 가지 9개
#install.packages("ROCR")
#install.packages("pROC")
library(pROC)## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(Epi)
ROC(test = yhat_test,stat=test$y_faulty,plot="ROC",AUC=T,main="Tree")- 가지 3개
ROC(test=yhat_test2,stat=test$y_faulty,plot="ROC",AUC=T,main="Tree")2.1.6 예측 (1) 한개의 데이터 예측
new.data=data.frame(fix_time=87,a_speed=0.609,b_speed=1.715,separation=242.7,s_separation=657.5,rate_terms=95,mpa=78,load_time=18.1,highpressure_time=82)
predict(prune.m,newdata = new.data,type="class")## [1] 0
## Levels: 0 1
predict(prune.m2,newdata = new.data,type="class")## [1] 0
## Levels: 0 1
2.1.7 예측 (2) 여러개의 데이터 예측
- 벡터
new.data=data.frame(fix_time=c(87,85.6),a_speed=c(0.609,0.472),b_speed=c(1.715,1.685),separation=c(242.7,243.4),s_separation=c(657.5,657.9),rate_terms=c(95,95),mpa=c(78,28.8),load_time=c(18.1,18.2),highpressure_time=c(82,60))
predict(prune.m,newdata = new.data,type="class")## [1] 0 1
## Levels: 0 1
predict(prune.m2,newdata = new.data,type="class")## [1] 0 1
## Levels: 0 1
2.데이터 프레임
new.data=data.frame(fix_time=test$fix_time,a_speed=test$a_speed,b_speed=test$b_speed,separation=test$separation,s_separation=test$s_separation,rate_terms=test$rate_terms,mpa=test$mpa,load_time=test$load_time,highpressure_time=test$highpressure_time)
head(predict(prune.m,newdata = new.data,type="class"))## [1] 0 0 0 0 0 1
## Levels: 0 1
2.1.8 다항 종속변수에 대하여 의사결정나무 모형 만들기
- 데이터 준비
autoparts2$g_class=as.factor(ifelse(autoparts2$c_thickness<20,1,ifelse(autoparts2$c_thickness<32,2,3)))
t_index=sample(1:nrow(autoparts2),size=nrow(autoparts2)*0.7) # 행 인덱스 추출 (70%)
train=autoparts2[t_index,]
test=autoparts2[-t_index,]- 모델생성
m=tree(g_class~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train)
plot(m)
text(m)- 검증데이터 예측
yhat_test=predict(m,test,type="class")
confusionMatrix(yhat_test,as.factor(test$g_class))## Confusion Matrix and Statistics
##
## Reference
## Prediction 1 2 3
## 1 451 163 0
## 2 188 5357 5
## 3 4 139 224
##
## Overall Statistics
##
## Accuracy : 0.9236
## 95% CI : (0.9169, 0.9299)
## No Information Rate : 0.8665
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6973
## Mcnemar's Test P-Value : < 2.2e-16
##
## Statistics by Class:
##
## Class: 1 Class: 2 Class: 3
## Sensitivity 0.70140 0.9466 0.97817
## Specificity 0.97232 0.7787 0.97731
## Pos Pred Value 0.73453 0.9652 0.61035
## Neg Pred Value 0.96755 0.6922 0.99919
## Prevalence 0.09845 0.8665 0.03506
## Detection Rate 0.06906 0.8202 0.03430
## Detection Prevalence 0.09401 0.8498 0.05619
## Balanced Accuracy 0.83686 0.8627 0.97774
2.1.9 연속형 종속변수에 대하여 의사결정나무 모형 만들기
- 종속변수가 연속형이므로 type=“class” 사용하지 않음
- 또한 table함수 사용 불가능
- 모델생성
m=tree(c_thickness~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train)
plot(m)
text(m)- 검증데이터 예측
yhat_test=predict(m,test)
head(yhat_test)## 2 4 7 8 13 16
## 22.84106 22.84106 22.48722 22.84106 22.48722 33.13135
- 예측
new.data=data.frame(fix_time=c(87,85.6),a_speed=c(0.609,0.472),b_speed=c(1.715,1.685),separation=c(242.7,243.4),s_separation=c(657.5,657.9),rate_terms=c(95,95),mpa=c(78,28.8),load_time=c(18.1,18.2),highpressure_time=c(82,60))
predict(prune.m,newdata = new.data)## [,1] [,2]
## 0 0.9328996 0.06710045
## 1 0.3870403 0.61295972
연습문제
- rpart 패키지
library(rpart)
rpart=rpart(as.factor(y_faulty)~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train)
plot(rpart)
text(rpart)printcp(rpart)##
## Classification tree:
## rpart(formula = as.factor(y_faulty) ~ fix_time + a_speed + b_speed +
## separation + s_separation + rate_terms + mpa + load_time +
## highpressure_time, data = train)
##
## Variables actually used in tree construction:
## [1] a_speed b_speed load_time mpa rate_terms
## [6] s_separation separation
##
## Root node error: 1971/15236 = 0.12936
##
## n= 15236
##
## CP nsplit rel error xerror xstd
## 1 0.118721 0 1.00000 1.00000 0.021017
## 2 0.055302 2 0.76256 0.76560 0.018707
## 3 0.030441 3 0.70726 0.71182 0.018108
## 4 0.028919 4 0.67681 0.69406 0.017903
## 5 0.019280 5 0.64789 0.64587 0.017329
## 6 0.017504 8 0.59006 0.61187 0.016908
## 7 0.013191 10 0.55505 0.56824 0.016343
## 8 0.011669 12 0.52867 0.51091 0.015559
## 9 0.010000 20 0.41197 0.45510 0.014741
plotcp(rpart)rpart.prune=rpart(as.factor(y_faulty)~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train,control=rpart.control(cp=0.02))
plot(rpart.prune)
text(rpart.prune,pretty = 2)2.2 KNN (K-Nearest Neighbors classfier)
- 새로운 점이 주어졌을 때, 그 점으로부터 가까운 점 K개를 이용하여 분류하는 머신러닝 기법
- 작은 K를 선정하면 주변 소수의 데이터에 너무 큰 영향을 받는다.
- 큰 K를 선정하면 관련이 없는 먼 곳의 데이터까지 분류에 영향을 끼치고, 정작 중요한 주변의 데이터 영향력은 작아진다.
- Cross Validation(교차검증)을 통해 오분류율이 낮은 K를 선정한다.
- K값은 일반적으로 홀 수를 취한다.
- 범주형 자료, 연속형 데이터 모두 사용 가능하다.
2.2.1 자료 준비
데이터
autoparts=read.csv("autoparts.csv",header = TRUE)
autoparts1=autoparts[autoparts$prod_no=="90784-76001",c(2:11)]
autoparts2=autoparts1[autoparts1$c_thickness<1000,]
autoparts2$y_faulty=ifelse((autoparts2$c_thickness<20)|(autoparts2$c_thickness>32),1,0)데이터셋 나누기
t_index=sample(1:nrow(autoparts2),size=nrow(autoparts2)*0.7) # 행 인덱스 추출 (70%)
train=autoparts2[t_index,]
test=autoparts2[-t_index,]2.2.2 Argument 준비
xmat.train=as.matrix(train[1:9])
y_faulty.train=train$y_faulty
head(xmat.train)## fix_time a_speed b_speed separation s_separation rate_terms mpa
## 16462 80.8 0.645 1.681 188.5 711.8 86 75.6
## 17082 80.7 0.651 1.788 185.0 710.7 84 73.9
## 25302 80.8 0.657 1.467 179.3 717.8 84 76.1
## 9139 86.4 0.601 1.666 242.9 654.1 81 77.5
## 3313 85.1 0.586 1.693 250.7 648.9 81 79.9
## 25678 80.8 0.645 1.653 189.1 711.9 84 76.9
## load_time highpressure_time
## 16462 19.3 71
## 17082 19.2 66
## 25302 19.2 76
## 9139 18.2 61
## 3313 18.1 56
## 25678 19.2 69
검증 데이터 행렬
xmat.test=as.matrix(test[1:9])
head(xmat.test)## fix_time a_speed b_speed separation s_separation rate_terms mpa
## 7 86.5 0.606 1.701 243.1 656.9 95 78.2
## 14 85.2 0.606 1.682 249.0 657.4 95 76.8
## 17 85.4 0.470 1.681 243.6 657.1 95 28.1
## 23 85.4 0.474 1.706 242.8 656.9 95 28.1
## 28 86.9 0.606 1.703 243.2 656.9 96 77.6
## 30 86.4 0.609 1.712 248.3 657.1 96 77.8
## load_time highpressure_time
## 7 18.1 55
## 14 18.0 72
## 17 18.1 60
## 23 18.1 60
## 28 18.2 67
## 30 18.1 53
2.2.3 예측 값 생성
library(class)## Warning: package 'class' was built under R version 3.5.1
yhat_test=knn(xmat.train,xmat.test,as.factor(y_faulty.train),k=3)
head(yhat_test,100)## [1] 0 1 1 1 0 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
## [36] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0
## [71] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0
## Levels: 0 1
table=table(real=test$y_faulty,predict=yhat_test)
table## predict
## real 0 1
## 0 5556 115
## 1 188 672
confusionMatrix(yhat_test,as.factor(test$y_faulty))## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 5556 188
## 1 115 672
##
## Accuracy : 0.9536
## 95% CI : (0.9482, 0.9586)
## No Information Rate : 0.8683
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7895
## Mcnemar's Test P-Value : 3.53e-05
##
## Sensitivity : 0.9797
## Specificity : 0.7814
## Pos Pred Value : 0.9673
## Neg Pred Value : 0.8539
## Prevalence : 0.8683
## Detection Rate : 0.8507
## Detection Prevalence : 0.8795
## Balanced Accuracy : 0.8806
##
## 'Positive' Class : 0
##
2.2.4 최적 k값 찾기
library(e1071)
tune.out=tune.knn(x=xmat.train,y=as.factor(y_faulty.train),k=1:10)
tune.out##
## Parameter tuning of 'knn.wrapper':
##
## - sampling method: 10-fold cross validation
##
## - best parameters:
## k
## 5
##
## - best performance: 0.04463101
plot(tune.out)최적 K=5로 knn 반복
library(class)
yhat_test=knn(xmat.train,xmat.test,as.factor(y_faulty.train),k=5)
head(yhat_test,100)## [1] 0 1 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
## [36] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0
## [71] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0
## Levels: 0 1
confusionMatrix(yhat_test,as.factor(test$y_faulty))## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 5565 197
## 1 106 663
##
## Accuracy : 0.9536
## 95% CI : (0.9482, 0.9586)
## No Information Rate : 0.8683
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7876
## Mcnemar's Test P-Value : 2.336e-07
##
## Sensitivity : 0.9813
## Specificity : 0.7709
## Pos Pred Value : 0.9658
## Neg Pred Value : 0.8622
## Prevalence : 0.8683
## Detection Rate : 0.8521
## Detection Prevalence : 0.8823
## Balanced Accuracy : 0.8761
##
## 'Positive' Class : 0
##
2.2.5 ROC,AUC
library(Epi)
ROC(test=yhat_test,stat=test$y_faulty,plot="ROC",AUC=T,main="KNN")2.2.6 데이터 예측
new.data=data.frame(fix_time=c(87,85.6),a_speed=c(0.609,0.472),b_speed=c(1.715,1.685),separation=c(242.7,243.4),s_separation=c(657.5,657.9),rate_terms=c(95,95),mpa=c(78,28.8),load_time=c(18.1,18.2),highpressure_time=c(82,60))
knn(xmat.train,new.data,y_faulty.train,k=5)## [1] 0 1
## Levels: 0 1
2.2.7 종속변수가 다항인 경우
- 이항인 경우와 같음.
2.2.8 종속변수가 연속형인 경우 FNN::knn.reg()
autoparts=read.csv("autoparts.csv",header = TRUE)
autoparts1=autoparts[autoparts$prod_no=="90784-76001",c(2:11)]
autoparts2=autoparts1[autoparts1$c_thickness<1000,]
t_index=sample(1:nrow(autoparts2),size=nrow(autoparts2)*0.7) # 행 인덱스 추출 (70%)
train=autoparts2[t_index,]
test=autoparts2[-t_index,]library(FNN)
yhat_test <- knn.reg(xmat.train,xmat.test,c_thickness.train,k=3)
mse=mean((yhat_test$pred-test$c_thickness)^2)
mse2.3 신경망(Neural Network)
- 인간의 뇌와 같은 형태의 모형
- 입력층 / 은닉층 / 출력층으로 구성
- 입력 - 반응을 결정하는 대표적인 함수는 sigmoid 함수이다.
- 은닉층에서 정보의 조합이 어떻게 이루어지는지 실행 중에 파악하기가 어려워 결과 도출 과정을 설명하기 어렵고 모델 수정도 어렵다.
- 변수 선택에 매우 민감.
- 인공 신경망에서의 학습이란 노드와 노드 사이의 링크에 부여된 가중치를 조절하는 과정이다.
- 가중치를 계속 조절해가며 오차를 줄여 나가도록 한다.
- 가중치의 조정은 출력 노드로 부터 역방향으로 이루어지므로 역전파 알고리즘이라고 부른다.
- 입력 변수가 많아지면 입력 노드가 많아지고, 노드가 많아지면 추정해야하는 가중치의 수가 늘어나게 된다. - 추정해야 할 가중치의 수가 늘어나게 되면 과적합이 발생할 가능성이 높아져 train 데이터의 예측력은 높더라도 test 데이터의 예측력이 떨어지게 된다.
- 따라서 종속변수와의 관계가 깊은 주요 변수를 최소한으로 선택하는 것이 필요하다.
2.3.1 자료 준비
autoparts=read.csv("autoparts.csv",header = TRUE)
autoparts1=autoparts[autoparts$prod_no=="90784-76001",c(2:11)]
autoparts2=autoparts1[autoparts1$c_thickness<1000,]
autoparts2$g_class=as.factor(ifelse(autoparts2$c_thickness<20,1,ifelse(autoparts2$c_thickness<32,2,3)))
t_index=sample(1:nrow(autoparts2),size=nrow(autoparts2)*0.7) # 행 인덱스 추출 (70%)
train=autoparts2[t_index,]
test=autoparts2[-t_index,]2.3.2 모델 생성
library(nnet)## Warning: package 'nnet' was built under R version 3.5.1
m=nnet(g_class~fix_time+a_speed+b_speed+separation+s_separation+rate_terms+mpa
+load_time+highpressure_time,data=train,size=10)## # weights: 133
## initial value 50253.935684
## iter 10 value 7015.347014
## iter 20 value 7007.682857
## final value 7007.682288
## converged
2.3.3 성능 평가
yhat_test=predict(m,test,type="class")
table=table(real=test$g_class,predict=yhat_test)
table## predict
## real 2
## 1 639
## 2 5669
## 3 223
2.3.4 시각화
library(reshape)##
## Attaching package: 'reshape'
## The following object is masked from 'package:class':
##
## condense
library(scales)
library(devtools)
source_url("https://gist.githubusercontent.com/fawda123/7471137/raw/466c1474d0a505ff044412703516c34f1a4684a5/nnet_plot_updata.r")## SHA-1 hash of file is 74c80bd5ddbc17ab3ae5ece9c0ed9beb612e87ef
plot(m)