# setting
PSDS_PATH <- file.path('/Users/jinameliachoi/Documents/statistics-for-data-scientists')

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
library(FNN)
library(rpart)
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
loan200 <- read.csv(file.path(PSDS_PATH, 'data', 'loan200.csv'))
loan200$outcome <- ordered(loan200$outcome, levels=c('paid off', 'default'))

loan3000 <- read.csv(file.path(PSDS_PATH, 'data', 'loan3000.csv'))
loan3000$outcome <- ordered(loan3000$outcome, levels=c('paid off', 'default'))

loan_data <- read.csv(file.path(PSDS_PATH, 'data', 'loan_data.csv'))
loan_data <- select(loan_data, -X, -status)

최근 통계학 분야는 회귀나 분류 같은 예측 모델링을 자동화하기 위한 더 강력한 기술을 개발하는 것을 집중해왔다. 이러한 방법들을 통계적 머신러닝 statistical machine learning 이라한다.

데이터에 기반하며 전체적인 구조를 가정하지 않는다는 고전적인 통계방법과 구분된다.


K 최근접 이웃


K 최근접 이웃 (K-Nearest Neighbors, KNN)은 가장 간단한 예측/분류 방법 중 하나이다.


KNN Process

  1. 특징들이 가장 유사한 (예측변수들이 유사한) K개의 레코드 찾음

  2. 분류: 이 유사한 레코드들 중에 다수에 속한 클래스가 무엇인지 찾은 후에 새로운 레코드를 그 클래스에 할당

  3. 예측 (KNN회귀) : 유사한 레코드들의 평균을 찾아서 새로운 레코드에 대한 예측값으로 사용함

모든 예측변수들은 수치형이어야 한다.


예제 : 대출 연체 예측


예측변수 두 가지만을 고려한 가장 간단한 모델 생각해보자.

변수 payment_inc_ratio는 소득에 대한 대출 상환 비율, dti는 소득에 대한 부채 비율 (모기지 제외)

200개의 대출만 뽑아서 만든 loan200을 이용하여,

K=20으로 할 때, payment_inc_ratio=9, dti=22.5인 새로운 대출에 대한 예측결과 newloan 구하기

newloan <- loan200[1, 2:3, drop=FALSE]

knn_pred <- knn(train=loan200[-1, 2:3], test=newloan, cl=loan200[-1, 1], k=20)
knn_pred == 'paid off'
## [1] TRUE

KNN 결과 이 새로운 대출은 상환될 것으로 예상한다.

loan200[attr(knn_pred, 'nn.index')-1, ]
##      outcome payment_inc_ratio   dti
## 34  paid off           3.84084  9.36
## 181 paid off          11.10790 15.33
## 180 paid off           5.00386  5.97
## 84  paid off           8.60830 16.17
## 8   paid off          13.85620 11.24
## 168  default          10.12410 16.67
## 20  paid off           5.92267 18.11
## 198  default           2.97641 16.41
## 76   default           3.86227 22.91
## 54  paid off           2.49545  2.40
## 140  default          10.02450 19.11
## 30   default          16.41910 26.08
## 65   default          10.85580  6.80
## 162 paid off          10.18890 25.47
## 160 paid off           1.65527 12.91
## 111 paid off           4.76040 25.95
## 77  paid off           4.39094  2.86
## 45   default           1.49472 23.79
## 40  paid off           4.27237 12.22
## 138 paid off           3.29013 24.39
dist <- attr(knn_pred, 'nn.dist')

circleFun <- function(center = c(0,0), r = 1, npoints = 100){
  tt <- seq(0, 2*pi, length.out = npoints-1)
  xx <- center[1] + r * cos(tt)
  yy <- center[2] + r * sin(tt)
  return(data.frame(x = c(xx, xx[1]), y = c(yy, yy[1])))
}

circle_df <- circleFun(center=unlist(newloan), r=max(dist), npoints=201)
loan200_df <- bind_cols(loan200, circle_df)

ggplot(data=loan200_df, aes(x=payment_inc_ratio, dti, color=outcome, shape=outcome)) +
  geom_point(size=2) +
  scale_shape_manual(values = c(1, 4, 15)) +
  geom_path(aes(x=x, y=y), color='black') +
  xlim(3, 15) + 
  ylim(17, 29) +
  theme_bw() 
## Warning: Removed 126 rows containing missing values (geom_point).

위 그림은 해당 예제 내용을 시각화한 모습이다.

동그라미(상환), 엑스(연체)를 나타나며 검은 실선은 가장 가까운 20개의 점들에 대한 경계선을 보여준다.

11번의 상환, 9번의 연체가 이뤄졌다는 것을 알 수 있다.


거리 지표


유사성(근접성)은 거리 지표를 통해 결정된다.

두 벡터 사이에 가장 많이 사용되는 지표는 유클리드 거리 (Euclidean distance)이다.

두 벡터 사이의 유클리드 거리를 구하려면 서로의 차이에 대한 제곱합을 구한 뒤 그 값의 제곱근을 취한다.


그 다음으로 많이 사용되는 거리 지표는 맨하탄 거리 (Manhattan distance)

유클리드 거리는 두 점 사이의 직선 거리라고 볼 수 있다.

반면, 맨하탄 거리는 한 번에 대각선이 아닌 한 축 방향으로만 움직일 수 있다고 할 때 두 점 사이의 거리이다.

따라서 점과 점 사이의 이동 시간으로 근접성을 따질 때 좋은 지표.


원-핫 인코더


대부분의 통계 모델이나 머신러닝 모델에서 요인(문자열) 변수는 이진 가변수의 집합으로 변환해야 한다.

이러한 기법은 원-핫 인코딩 (one-hot encoding)이라고 한다.


표준화 (정규화, z 점수)


표준화(정규화)는 모든 변수에서 평균을 빼고 표준편차를 나누는 과정을 통해 변수들을 비슷한 스케일에 두는 작업을 의미한다.

대출 데이터 예제에서 4가지 변수를 적용해본다.

달러로 신청할 수 있는 총 회전 신용 : revol_bal

이미 사용 중인 신용 비율 : revol_util

loan_df <- model.matrix(~ -1 + payment_inc_ratio + dti + revol_bal + revol_util, data=loan_data)

newloan = loan_df[1,,drop=FALSE]
loan_df = loan_df[-1,]
outcome <- loan_data[-1,1]

knn_pred <- knn(train=loan_df, test=newloan, cl=outcome, k=5)
knn_pred
## [1] 4000
## attr(,"nn.index")
##       [,1]  [,2]  [,3]  [,4]  [,5]
## [1,] 35536 33651 25863 42953 43599
## attr(,"nn.dist")
##          [,1]     [,2]     [,3]     [,4]     [,5]
## [1,] 1.555631 5.640407 7.138838 8.842243 8.972774
## Levels: 4000
loan_df[attr(knn_pred,"nn.index"),]
##       payment_inc_ratio  dti revol_bal revol_util
## 35537           1.47212 1.46      1686       10.0
## 33652           3.38178 6.37      1688        8.4
## 25864           2.36303 1.39      1691        3.5
## 42954           1.28160 7.14      1684        3.9
## 43600           4.12244 8.98      1684        7.2

달러로 표시된 revol_bal의 값 크기가 다른 변수들보다 값이 큰 것을 알 수 있다.

scale 함수를 이용하여 데이터를 표준화한 뒤 KNN과 비교해 보기

loan_df <- model.matrix(~ -1 + payment_inc_ratio + dti + revol_bal + revol_util, data=loan_data)
loan_std <- scale(loan_df)

target_std = loan_std[1,, drop=FALSE]
loan_std = loan_std[-1,]
outcome <- loan_data[-1,1]
knn_pred <- knn(train=loan_std, test=target_std, cl=outcome, k=5)
knn_pred
## [1] 2000
## attr(,"nn.index")
##      [,1] [,2]  [,3]  [,4]  [,5]
## [1,] 2080 1438 30215 28542 44737
## attr(,"nn.dist")
##           [,1]       [,2]       [,3]      [,4]     [,5]
## [1,] 0.0575066 0.09801921 0.09886893 0.1054015 0.116448
## Levels: 2000
loan_df[attr(knn_pred,"nn.index"),]
##       payment_inc_ratio   dti revol_bal revol_util
## 2080           10.04400 19.89      9179       51.5
## 1438            3.87890  5.31      1687       51.1
## 30215           6.71820 15.44      4295       26.0
## 28542           6.93816 20.31     11182       76.1
## 44737           8.20170 16.65      5244       73.9

feature들의 scale이 변환한 것을 확인할 수 있다.


K 선택하기


K를 잘 선택하는 것은 KNN의 성능을 결정하는 아주 중요한 요소이다.

일반적으로 K가 너무 작으면 데이터의 노이즈 성분까지 분석하는 overfitting 문제가 발생한다.

이와 반대로, K가 너무 크면 결정 함수가 너무 과하게 평탄화되어 underfitting(oversmoothing) 문제가 발생한다.


최적의 K값을 찾기 위해 정확도 지표를 활용한다.

특히 홀드아웃 데이터 또는 타당성검사를 위해 따로 떼놓은 데이터에서의 정확도를 가지고 K값을 결정하는 것이 좋다.


KNN을 통한 피처 엔지니어링


KNN은 실용적인 측면에서 다른 분류 방법들의 특정 단계에 사용할 수 있게 모델에 지역적 정보(local knowledge)를 추가하기 위해 사용되는 경우도 있다.

  1. KNN은 데이터에 기반하여 분류 결과를 얻는다.

  2. 이 결과는 해당 레코드에 새로운 특징(feature)로 추가된다. 이 결과를 다른 분류 방법에 사용한다. 원래의 예측변수들을 두 번씩 사용하는 셈이 된다.

다중공선성 문제는 없을까? NO. 소수의 근접한 레코드들로부터 얻는 지협적인 정보이기 때문에.


킹 카운티 주택 데이터를 예를 들 때,

주택 가격을 산정할 때, 부동산 중개업자들은 최근에 팔린 비슷한 집들의 가격(compos)을 기준으로 삼을 것이다.

중개업자들은 비슷한 주택의 매매 가격을 일일 확인하면서 일종의 수동식 KNN을 진행중에 있음. (이 집이 얼마에 팔릴 지 예측하는 것)

우리는 최근 거래 정보에 KNN을 적용해 사로운 예측 변수 - 각 레코드에 대한 (중계업자의 compos와 유사한) KNN 예측변수 - 를 추가한다.

예측결과가 수치형이기 때문에 다수결 결과가 아닌 K 최근접 이웃값의 평균을 사용한다 (KNN 회귀)

borrow_df <- model.matrix(~ -1 + dti + revol_bal + revol_util + open_acc +
                            delinq_2yrs_zero + pub_rec_zero, data=loan_data)
borrow_knn <- knn(borrow_df, test=borrow_df, cl=loan_data[, 'outcome'], prob=TRUE, k=20)
prob <- attr(borrow_knn, "prob")
borrow_feature <- ifelse(borrow_knn=='default', 1-prob, prob)
summary(borrow_feature)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   0.050   0.400   0.500   0.499   0.600   1.000

트리 모델

트리모델은 회귀 및 분석 트리(classification and regression tree), 의사 결정 트리(decision tree), 혹은 단순히 그냥 tree라고 불리며 대중적인 분류(및 회귀) 방법이다.

트리 모델들과 여기서 파생된 강력한 랜덤 포레스트 random forest와 부스팅 boosting 같은 방법들이 회귀나 분류 문제를 위해 데이터 과학에서 가장 널리 사용되는 강력한 예측 모델링 기법들의 기초이다.


an simple tree model example

3,000개의 대출 데이터에 적합한 트리 모델을 만들어 보자.

payment_inc_ratioborrower_score 변수를 고려한다.

loan_tree <- rpart(outcome ~ borrower_score + payment_inc_ratio,
                   data=loan_data, control=rpart.control(cp=.005))
plot(loan_tree, uniform=TRUE, margin=.05)
text(loan_tree)

일반적으로 그림을 그릴 때는 트리를 거꾸로 나타낸다.

즉, 뿌리가 위로 가고 잎 부분이 아래로 간다.

’손실’의 경우에는 임시 분할에서 발생하는 오분류의 개수를 의미한다.


재귀 분할 알고리즘


의사 결정 트리를 만들 때는 재귀 분할이라고 하는 알고리즘을 사용한다.

예측변수 값을 기준으로 데이터를 반복적으로 분할해나간다.

분할할 때에는 상대적으로 같은 클래스의 데이터들끼리 구분되도록 한다.

r_tree <- data_frame(x1 = c(0.575, 0.375, 0.375, 0.375, 0.475),
                     x2 = c(0.575, 0.375, 0.575, 0.575, 0.475),
                     y1 = c(0,         0, 10.42, 4.426, 4.426),
                     y2 = c(25,       25, 10.42, 4.426, 10.42),
                     rule_number = factor(c(1, 2, 3, 4, 5)))
## Warning: `data_frame()` is deprecated, use `tibble()`.
## This warning is displayed once per session.
r_tree <- as.data.frame(r_tree)

labs <- data.frame(x=c(.575 + (1-.575)/2, 
                       .375/2, 
                       (.375 + .575)/2,
                       (.375 + .575)/2, 
                       (.475 + .575)/2, 
                       (.375 + .475)/2
                       ),
                   y=c(12.5, 
                       12.5,
                       10.42 + (25-10.42)/2,
                       4.426/2, 
                       4.426 + (10.42-4.426)/2,
                       4.426 + (10.42-4.426)/2
                       ),
                   decision = factor(c('paid off', 'default', 'default', 'paid off', 'paid off', 'default')))

ggplot(data=loan3000, aes(x=borrower_score, y=payment_inc_ratio)) +
  geom_point( aes(color=outcome, shape=outcome), alpha=.5) +
  scale_color_manual(values=c('blue', 'red')) +
  scale_shape_manual(values = c(1, 46)) +
  # scale_shape_discrete(solid=FALSE) +
  geom_segment(data=r_tree, aes(x=x1, y=y1, xend=x2, yend=y2, linetype=rule_number), size=1.5, alpha=.7) +
  guides(colour = guide_legend(override.aes = list(size=1.5)),
         linetype = guide_legend(keywidth=3, override.aes = list(size=1))) +
  scale_x_continuous(expand=c(0,0)) + 
  scale_y_continuous(expand=c(0,0), limits=c(0, 25)) + 
  geom_label(data=labs, aes(x=x, y=y, label=decision)) +
  #theme(legend.position='bottom') +
  theme_bw()
## Warning: Removed 2 rows containing missing values (geom_point).

알고리즘의 재귀 process

  1. 전체 데이터를 가지고 A를 초기화한다.

  2. A를 두 부분 A1과 A2로 나누기 위해 분할 알고리즘을 적용한다.

  3. A1과 A2 각각에서 2번 과정을 반복한다.

  4. 분할을 해도 더 이상 하위 분할 영역의 동질성이 개선되지 않을 정도로 충분히 분할을 진행했을 때, 알고리즘을 종료한다.


동질성과 불순도 측정하기


트리 모델링은 분할 영역 A(기록 집합)를 재귀적으로 만드는 과정.

이렇게 만들어진 분할 영역을 통해 Y=0 혹은 Y=1의 결과를 예측한다.

=> 각 분할 영역에 대한 동질성, 즉 클래스 순도 class purity를 측정하는 방법이 필요하다.


해당 파티션 내에서 오분류된 레코드의 비율 p로 예측의 정확도를 표시할 수 있으며,

이는 0 (완전)에서 0.5(순수 랜덤 추측) 사이의 값을 갖는다.


지니 불순도 Gini impurity와 엔트로피 entropy가 대표적인 불순도 측정 지표


Gini impurity I(A) = p(1 - p)

Entropy I(A) = -plog(p) - (1-p)log(1-p)

info <- function(x){
  info <- ifelse(x==0, 0, -x * log2(x) - (1-x) * log2(1-x))
  return(info)
}
x <- 0:50/100
plot(x, info(x) + info(1-x))

gini <- function(x){
  return(x * (1-x))
}
plot(x, gini(x))

impure <- data.frame(p = rep(x, 3),
                     impurity = c(2*x,
                                  gini(x)/gini(.5)*info(.5),
                                  info(x)),
                     type = rep(c('Accuracy', 'Gini', 'Entropy'), rep(51,3)))

ggplot(data=impure, aes(x=p, y=impurity, linetype=type, color=type)) + 
  geom_line(size=1.5) +
  guides( linetype = guide_legend( keywidth=3, override.aes = list(size=1))) +
  scale_x_continuous(expand=c(0,0.01)) + 
  scale_y_continuous(expand=c(0,0.01)) + 
  theme_bw() +
  theme( legend.title=element_blank())


트리 형성 중지하기


잎에서의 순도가 완전히 100%가 될 때까지 다 자란 트리는 학습한 데이터에 대해 100%의 정확도를 갖게 된다.

=> 물론 이 정확도는 오버피팅에 의해 얻은 허황된 것이다.


모델에 새로 들어오는 데이터에 대해 좋은 일반화 성능을 얻기 위해 언제 트리 성장을 멈춰야하는지 결정하는 방법이 필요하다.


가지 분할을 멈추는 대표적인 두 가지 방법이 있다.

rpart의 함수에서 minsplit이나 minbucket 같은 파라미터를 이용해 최소 분할 영역 크기나 말단 잎의 크기를 조절할 수 있다.

rpart 함수에서 트리의 복잡도를 의미하는 복잡도 파라미터 complexity parameter인 cp를 이용해 이를 조절한다. (트리가 복잡할 수록 cp의 값 증가)

=> 다소 임의적이라고 할 수 있는 첫번째 방법은 탐색 작업에서는 유용하지만 최적값 결정은 어려움.

복잡도 파라미터 cp를 이용한 어떤 크기의 트리가 새로운 데이터에 대해 가장 좋은 성능을 보일지 추정할 수 있다.


cp가 매우 작다면 트리는 실제 의미 있는 신호뿐 아니라 노이즈까지 학습하여 오버피팅되는 문제가 발생하게 될 것이다. 반면, cp가 너무 크다면 트리가 너무 작아 예측 능력을 거의 갖지 못할 것이다.


최적의 cp를 결정하는 것은 편향-분산 트레이드오프를 보여주는 하나의 대표적인 예이다.

cp를 추정하는 가장 일반적인 방법은 교차타당성 검정을 이용하는 것이다.

  1. 데이터를 학습용 데이터와 타당성검사용 (holdout) 데이터로 나눈다.

  2. 학습 데이터를 이용해 트리를 키운다.

  3. 트리를 단계적으로 계속해서 가지치기한다. 매 단계마다 학습 데이터를 이용해 cp를 기록한다.

  4. 타당성검사 데이터에 대해 최소 에러(손실)를 보이는 cp를 기록한다.

  5. 데이터를 다시 학습용 데이터와 타당성검사용 데이터로 나누고, 마찬가지로 트리를 만들고 가지치기하고 cp를 기록하는 과정을 반복한다.

  6. 이를 여러 번 반복한 후 각 트리에서 최소 에러를 보이는 cp 값의 평균을 구한다.

  7. 원래 데이터를 이용해 위에서 구한 cp의 최적값을 가지고 트리를 만든다.


연속값 예측하기


트리 모델을 이용해 연속값을 예측하는 방법(회귀분석)은,

각 하위 분할 영역에서 평균으로부터의 편차들을 제곱한 값을 이용해 불순도를 측정하는 점과

제곱근 평균제곱오차 (RMSE)를 이용해 예측 성능을 평가한다는 점에서 차이가 있다.


트리 활용하기

예측변수들 간의 비선형 관계를 담아낼 수 있다.

=> 하지만 예측에 관해서는, 다중 트리에서 나온 결과를 이용하는 것이 단일 트리를 이용하는 것보다 보통은 훨씬 강력하다.

특히 랜덤 포레스트와 부스팅 트리 알고리즘은 거의 항상 우수한 예측 정확도나 성능을 보여준다.


배깅과 랜덤 포레스트


다중 모델의 평균을 취하는 방식(혹은 다수결 투표), 다른 말로는 앙상블 모델은 단일 모델을 사용하는 것보다 더 나은 성능을 보인다.


앙상블 방법의 가장 간단한 버전은 다음과 같다.

  1. 주어진 데이터에 대해 예측 모델을 만들고 예측 결과를 기록한다.

  2. 같은 데이터에 대해 여러 모델을 만들고 결과를 기록한다.

  3. 각 레코드에 대해 예측된 결과들의 평균(또는 가중평균, 다수결 투표)을 구한다.


앙상블 기법은 상대적으로 적은 노력만으로도 좋은 예측 모델을 만들 수 있다는 점에서 정말 파워풀하다.

가장 많이 사용되는 배깅과 부스팅이라는 앙상블 기법이 있다.

=> 이는 트리 모델에 적용될 경우, 랜덤 포레스트와 부스팅 트리가 각각 이에 해당한다.


배깅


배깅이란 부트스트랩 종합 bootstrap aggregating의 줄임말

다양한 모델들을 정확히 같은 데이터에 대해 구하는 대신, 매번 부트스트랩 재표본에 대해 새로운 모델을 만든다.

이 부분만 빼면 앞에서 설명한 기본 앙상블 기법과 동일하다.


랜덤 포레스트


랜덤 포레스트는 의사 결정 트리 모델에 한가지 중요한 요소가 추가된 배깅 방법을 적용한 모델이다. 바로 레코드를 표본추출할 떄, 변수 역시 샘플링하는 것이다.

랜덤 포레스트에서는 알고리즘의 각 단계마다, 고를 수 있는 변수가 랜덤하게 결정된 전체 변수들의 부분집합에 한정된다.

=> 보통은 전체 변수의 개수가 P일 때, root P개 정도를 선택한다.

library(randomForest)
rf <- randomForest(outcome ~ borrower_score + payment_inc_ratio,
                   data = loan3000)
rf
## 
## Call:
##  randomForest(formula = outcome ~ borrower_score + payment_inc_ratio,      data = loan3000) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 1
## 
##         OOB estimate of  error rate: 38.6%
## Confusion matrix:
##          paid off default class.error
## paid off      964     591   0.3800643
## default       567     878   0.3923875

주머니 외부 out-of-bag(OOB) 추정 에러는 트리 모델을 만들 때 사용했던 학습 데이터에 속하지 않는 데이터를 사용해 구한 학습된 모델의 오차율을 말한다.

error_df = data.frame(error_rate = rf$err.rate[, 'OOB'],
                      num_trees = 1:rf$ntree)

ggplot(error_df, aes(x=num_trees, y=error_rate)) +
  geom_line()

=> 오차율이 0.44에서 0.385 정도가 될 때까지 빠르게 감소한 후 비슷한 수준을 유지하는 것을 볼 수 있다.


predict 함수를 이용해 예측값을 구하고 이를 다음과 같이 그래프로 만들 수 있다.

pred <- predict(rf, prob=TRUE)
rf_df <- cbind(loan3000, pred=pred)
ggplot(data=rf_df, aes(x=borrower_score, y=payment_inc_ratio,
                       color=pred, shape=pred)) +
  geom_point(alpha=.6, size=2) +
  scale_shape_manual(values=c(46, 4)) +
  scale_x_continuous(expand=c(0,0)) + 
  scale_y_continuous(expand=c(0,0), lim=c(0, 20)) + 
  theme_bw()
## Warning: Removed 18 rows containing missing values (geom_point).

랜덤 포레스트는 일종의 블랙박스 모델이다.

단순한 단일 트리보다 훨씬 정확한 예측 성능을 보이지만 간단한 트리를 통해 얻을 수 있는 직관적인 해석은 불가능하다.

또한 예외 사항까지 학습해서 생기는 결과로 인해, 랜덤 포레스트에 의한 오버피팅의 위험성을 보여준다.


변수 중요도

랜덤 포레스트는 피처와 레코드의 개수가 많은 데이터에 대해 예측 모델을 만들 때 장점을 발휘한다.

다수의 예측 변수 중에서 어떤 것이 중요한지, 그리고 이들 사이에 존재하는 상관관계 항들에 대응되는 복잡한 관계들을 자동으로 결정하는 능력이 있다.

rf_all <- randomForest(outcome ~ ., data=loan_data, importance=TRUE)
rf_all
## 
## Call:
##  randomForest(formula = outcome ~ ., data = loan_data, importance = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 4
## 
##         OOB estimate of  error rate: 33.93%
## Confusion matrix:
##          default paid off class.error
## default    15141     7530   0.3321424
## paid off    7853    14818   0.3463897

importance = TRUE 설정은 randomForest 함수에 다른 변수들의 중요도에 관한 정보를 추가적으로 저장하도록 요청한다.

varImpPlot(rf_all, type=1)

imp1 <- importance(rf_all, type=1)
imp2 <- importance(rf_all, type=2)
idx <- order(imp1[,1])
nms <- factor(row.names(imp1)[idx], levels=row.names(imp1)[idx])
imp <- data.frame(Predictor = rep(nms, 2),
                  Importance = c(imp1[idx, 1], imp2[idx, 1]),
                  Type = rep( c('Accuracy Decrease', 'Gini Decrease'), rep(nrow(imp1), 2)))

ggplot(imp) + 
  geom_point(aes(y=Predictor, x=Importance), size=2, stat="identity") + 
  facet_wrap(~Type, ncol=1, scales="free_x") + 
  theme(
    panel.grid.major.x = element_blank() ,
    panel.grid.major.y = element_line(linetype=3, color="darkgray") ) +
  theme_bw()

varImPlot 함수는 변수들의 상대적인 성능을 그래프를 통해 보여준다.


변수 중요도를 측정하는 데에는 두가지 방법이 있다.


하이퍼파라미터


하이퍼파라미터 hyperparameter는 모델을 학습하기 전에 미리 정해야 한다. 이 값들은 학습 과정 중에 최적화되지 않는다. (랜덤 포레스트의 성능을 조절할 수 있는 손잡이가 달린 블랙박스 알고리즘).

특히 오버피팅을 피하기 위해 매우 중요하다.


랜덤 포레스트를 노이즈가 많은 데이터에 적용할 때, 기본 설정으로는 오버피팅에 빠질 수 있다.

nodesize와 maxnodes를 크게 하면 더 작은 트리를 얻게 되고 거짓 예측 규칙들을 만드는 것을 피할 수 있게 된다.


부스팅

선형회귀 모델에서는 피팅이 더 개선될 수 있는지 알아보기 위해 잔차를 종종 사용했다.

부스팅은 바로 이러한 개념을 더 발전시켜, 이전 모델이 갖는 오차를 줄이는 방향으로 다음 모델을 연속적으로 생성한다.

에이다부스트, 그레이디언트 부스팅, 확률적 그레이디언트 부스팅은 가장 자주 사용되는 변형된 형태의 부스팅 알고리즘이다. 이 중에서도 확률적 그레이디언트 부스팅이 일반적으로 가장 널리 사용된다.


부스팅 알고리즘

잘못 분류된 관측치에 대해 가중치를 적용한 합을 의미하는 가중오차가 최소화되도록 학습

=> 모델의 오차가 낮을수록 더 큰 가중치를 부여한다.


XG부스트


XG부스트는 부스팅 방법 가운데 대중적으로 가장 많이 사용되는 오픈소스 소프트웨어

XG부스트에서 가장 중요한 파라미터 두 가지는 subsample과 eta

library(xgboost)
## 
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
## 
##     slice
predictors <- data.matrix(loan3000[, c('borrower_score','payment_inc_ratio')])
label <- as.numeric(loan3000[, 'outcome'])-1
xgb <- xgboost(data=predictors, label=label,
               objective='binary:logistic',
               params=list(subsample=.63, eta=0.1), nrounds=100)
## [1]  train-error:0.366000 
## [2]  train-error:0.361667 
## [3]  train-error:0.346000 
## [4]  train-error:0.346667 
## [5]  train-error:0.341333 
## [6]  train-error:0.335667 
## [7]  train-error:0.334333 
## [8]  train-error:0.334667 
## [9]  train-error:0.333667 
## [10] train-error:0.326667 
## [11] train-error:0.328000 
## [12] train-error:0.325667 
## [13] train-error:0.323667 
## [14] train-error:0.319667 
## [15] train-error:0.318000 
## [16] train-error:0.318333 
## [17] train-error:0.311667 
## [18] train-error:0.307667 
## [19] train-error:0.304333 
## [20] train-error:0.306000 
## [21] train-error:0.305000 
## [22] train-error:0.301667 
## [23] train-error:0.299333 
## [24] train-error:0.300667 
## [25] train-error:0.301000 
## [26] train-error:0.301000 
## [27] train-error:0.303000 
## [28] train-error:0.301667 
## [29] train-error:0.300000 
## [30] train-error:0.305333 
## [31] train-error:0.302000 
## [32] train-error:0.302333 
## [33] train-error:0.300000 
## [34] train-error:0.298333 
## [35] train-error:0.298667 
## [36] train-error:0.297333 
## [37] train-error:0.296000 
## [38] train-error:0.294333 
## [39] train-error:0.294000 
## [40] train-error:0.293000 
## [41] train-error:0.291667 
## [42] train-error:0.291333 
## [43] train-error:0.291333 
## [44] train-error:0.289333 
## [45] train-error:0.290000 
## [46] train-error:0.290000 
## [47] train-error:0.288333 
## [48] train-error:0.286333 
## [49] train-error:0.283667 
## [50] train-error:0.281000 
## [51] train-error:0.279333 
## [52] train-error:0.278333 
## [53] train-error:0.276333 
## [54] train-error:0.277000 
## [55] train-error:0.275333 
## [56] train-error:0.278667 
## [57] train-error:0.278000 
## [58] train-error:0.272667 
## [59] train-error:0.271333 
## [60] train-error:0.269667 
## [61] train-error:0.268667 
## [62] train-error:0.268333 
## [63] train-error:0.264000 
## [64] train-error:0.264333 
## [65] train-error:0.266000 
## [66] train-error:0.263333 
## [67] train-error:0.263333 
## [68] train-error:0.261000 
## [69] train-error:0.261667 
## [70] train-error:0.260000 
## [71] train-error:0.257667 
## [72] train-error:0.254000 
## [73] train-error:0.253333 
## [74] train-error:0.252667 
## [75] train-error:0.251667 
## [76] train-error:0.248333 
## [77] train-error:0.249333 
## [78] train-error:0.252000 
## [79] train-error:0.254000 
## [80] train-error:0.253333 
## [81] train-error:0.251667 
## [82] train-error:0.248667 
## [83] train-error:0.249000 
## [84] train-error:0.249333 
## [85] train-error:0.245667 
## [86] train-error:0.248667 
## [87] train-error:0.248667 
## [88] train-error:0.246667 
## [89] train-error:0.246000 
## [90] train-error:0.243000 
## [91] train-error:0.243000 
## [92] train-error:0.241000 
## [93] train-error:0.242000 
## [94] train-error:0.243000 
## [95] train-error:0.240667 
## [96] train-error:0.240333 
## [97] train-error:0.240333 
## [98] train-error:0.241000 
## [99] train-error:0.239000 
## [100]    train-error:0.241333
pred <- predict(xgb, newdata=predictors)
xgb_df <- cbind(loan3000, pred_default=pred>.5, prob_default=pred)
ggplot(data=xgb_df, aes(x=borrower_score, y=payment_inc_ratio,
                        color=pred_default, shape=pred_default)) +
  geom_point(alpha=.6, size=2) +
  scale_shape_manual( values=c( 46, 4)) +
  scale_x_continuous(expand=c(.03, 0)) + 
  scale_y_continuous(expand=c(0,0), lim=c(0, 20)) + 
  theme_bw()
## Warning: Removed 18 rows containing missing values (geom_point).


정규화 : 오버피팅 피하기


xgboost 함수를 무작정 사용할 경우, 학습 데이터에 오버피팅되는 불안정한 모델을 얻을 수 있다. 오버피팅은 다음 두 가지 문제를 일으킬 수 있다.


xgboost의 경우를 살펴보자.

predictors <- data.matrix(loan_data[, -which(names(loan_data) %in% 'outcome')])
label <- as.numeric(loan_data$outcome)-1
test_idx <- sample(nrow(loan_data), 10000)
xgb_default <- xgboost(data=predictors[-test_idx,],
                       label=label[-test_idx],
                       objective='binary:logistic', nrounds=250)
## [1]  train-error:0.342284 
## [2]  train-error:0.330909 
## [3]  train-error:0.328985 
## [4]  train-error:0.325477 
## [5]  train-error:0.321459 
## [6]  train-error:0.318997 
## [7]  train-error:0.317413 
## [8]  train-error:0.315234 
## [9]  train-error:0.313253 
## [10] train-error:0.310877 
## [11] train-error:0.310367 
## [12] train-error:0.308217 
## [13] train-error:0.305585 
## [14] train-error:0.304595 
## [15] train-error:0.303378 
## [16] train-error:0.302218 
## [17] train-error:0.300464 
## [18] train-error:0.299021 
## [19] train-error:0.298993 
## [20] train-error:0.298172 
## [21] train-error:0.296503 
## [22] train-error:0.294777 
## [23] train-error:0.294466 
## [24] train-error:0.292145 
## [25] train-error:0.291296 
## [26] train-error:0.291042 
## [27] train-error:0.289457 
## [28] train-error:0.289259 
## [29] train-error:0.288014 
## [30] train-error:0.287137 
## [31] train-error:0.286288 
## [32] train-error:0.285411 
## [33] train-error:0.283798 
## [34] train-error:0.282412 
## [35] train-error:0.279639 
## [36] train-error:0.278224 
## [37] train-error:0.277347 
## [38] train-error:0.276130 
## [39] train-error:0.274999 
## [40] train-error:0.273895 
## [41] train-error:0.273216 
## [42] train-error:0.272056 
## [43] train-error:0.271349 
## [44] train-error:0.270471 
## [45] train-error:0.269538 
## [46] train-error:0.267698 
## [47] train-error:0.267076 
## [48] train-error:0.265095 
## [49] train-error:0.262719 
## [50] train-error:0.261983 
## [51] train-error:0.259946 
## [52] train-error:0.258531 
## [53] train-error:0.256946 
## [54] train-error:0.256409 
## [55] train-error:0.255871 
## [56] train-error:0.255022 
## [57] train-error:0.254117 
## [58] train-error:0.252872 
## [59] train-error:0.252278 
## [60] train-error:0.251684 
## [61] train-error:0.250863 
## [62] train-error:0.249844 
## [63] train-error:0.249222 
## [64] train-error:0.248571 
## [65] train-error:0.247015 
## [66] train-error:0.245600 
## [67] train-error:0.244185 
## [68] train-error:0.243761 
## [69] train-error:0.243195 
## [70] train-error:0.242742 
## [71] train-error:0.241045 
## [72] train-error:0.239800 
## [73] train-error:0.239828 
## [74] train-error:0.239715 
## [75] train-error:0.237961 
## [76] train-error:0.236659 
## [77] train-error:0.236631 
## [78] train-error:0.235074 
## [79] train-error:0.233716 
## [80] train-error:0.232528 
## [81] train-error:0.231934 
## [82] train-error:0.230660 
## [83] train-error:0.229868 
## [84] train-error:0.228708 
## [85] train-error:0.227406 
## [86] train-error:0.226048 
## [87] train-error:0.226048 
## [88] train-error:0.225426 
## [89] train-error:0.224690 
## [90] train-error:0.224634 
## [91] train-error:0.223417 
## [92] train-error:0.222653 
## [93] train-error:0.221804 
## [94] train-error:0.221549 
## [95] train-error:0.220276 
## [96] train-error:0.219739 
## [97] train-error:0.219342 
## [98] train-error:0.218607 
## [99] train-error:0.218409 
## [100]    train-error:0.217984 
## [101]    train-error:0.216117 
## [102]    train-error:0.215947 
## [103]    train-error:0.214928 
## [104]    train-error:0.213825 
## [105]    train-error:0.212495 
## [106]    train-error:0.212212 
## [107]    train-error:0.211392 
## [108]    train-error:0.210628 
## [109]    train-error:0.210147 
## [110]    train-error:0.209609 
## [111]    train-error:0.208619 
## [112]    train-error:0.207147 
## [113]    train-error:0.206553 
## [114]    train-error:0.204997 
## [115]    train-error:0.204969 
## [116]    train-error:0.204233 
## [117]    train-error:0.204063 
## [118]    train-error:0.203752 
## [119]    train-error:0.203695 
## [120]    train-error:0.203384 
## [121]    train-error:0.202960 
## [122]    train-error:0.201517 
## [123]    train-error:0.200555 
## [124]    train-error:0.200300 
## [125]    train-error:0.199310 
## [126]    train-error:0.198489 
## [127]    train-error:0.197923 
## [128]    train-error:0.197697 
## [129]    train-error:0.197499 
## [130]    train-error:0.197442 
## [131]    train-error:0.196282 
## [132]    train-error:0.195716 
## [133]    train-error:0.195094 
## [134]    train-error:0.193254 
## [135]    train-error:0.192349 
## [136]    train-error:0.191217 
## [137]    train-error:0.190765 
## [138]    train-error:0.190623 
## [139]    train-error:0.189972 
## [140]    train-error:0.189152 
## [141]    train-error:0.188812 
## [142]    train-error:0.188303 
## [143]    train-error:0.188133 
## [144]    train-error:0.187539 
## [145]    train-error:0.187199 
## [146]    train-error:0.186492 
## [147]    train-error:0.186294 
## [148]    train-error:0.185898 
## [149]    train-error:0.185417 
## [150]    train-error:0.184596 
## [151]    train-error:0.184200 
## [152]    train-error:0.184002 
## [153]    train-error:0.183408 
## [154]    train-error:0.182248 
## [155]    train-error:0.182078 
## [156]    train-error:0.182078 
## [157]    train-error:0.181116 
## [158]    train-error:0.179786 
## [159]    train-error:0.179447 
## [160]    train-error:0.178966 
## [161]    train-error:0.178966 
## [162]    train-error:0.178315 
## [163]    train-error:0.177664 
## [164]    train-error:0.177409 
## [165]    train-error:0.177240 
## [166]    train-error:0.176221 
## [167]    train-error:0.175089 
## [168]    train-error:0.174891 
## [169]    train-error:0.174834 
## [170]    train-error:0.174325 
## [171]    train-error:0.174184 
## [172]    train-error:0.173816 
## [173]    train-error:0.172882 
## [174]    train-error:0.172345 
## [175]    train-error:0.171892 
## [176]    train-error:0.171467 
## [177]    train-error:0.170788 
## [178]    train-error:0.170477 
## [179]    train-error:0.170138 
## [180]    train-error:0.169798 
## [181]    train-error:0.169657 
## [182]    train-error:0.169091 
## [183]    train-error:0.167987 
## [184]    train-error:0.167138 
## [185]    train-error:0.165893 
## [186]    train-error:0.165497 
## [187]    train-error:0.165186 
## [188]    train-error:0.163715 
## [189]    train-error:0.162753 
## [190]    train-error:0.162554 
## [191]    train-error:0.161564 
## [192]    train-error:0.160574 
## [193]    train-error:0.160319 
## [194]    train-error:0.159866 
## [195]    train-error:0.159753 
## [196]    train-error:0.159470 
## [197]    train-error:0.158395 
## [198]    train-error:0.158225 
## [199]    train-error:0.157348 
## [200]    train-error:0.156499 
## [201]    train-error:0.155509 
## [202]    train-error:0.155169 
## [203]    train-error:0.154349 
## [204]    train-error:0.154377 
## [205]    train-error:0.154292 
## [206]    train-error:0.154066 
## [207]    train-error:0.153925 
## [208]    train-error:0.153528 
## [209]    train-error:0.153161 
## [210]    train-error:0.152623 
## [211]    train-error:0.152340 
## [212]    train-error:0.151717 
## [213]    train-error:0.151152 
## [214]    train-error:0.149963 
## [215]    train-error:0.149963 
## [216]    train-error:0.149143 
## [217]    train-error:0.148294 
## [218]    train-error:0.147417 
## [219]    train-error:0.147275 
## [220]    train-error:0.146681 
## [221]    train-error:0.146766 
## [222]    train-error:0.145945 
## [223]    train-error:0.145238 
## [224]    train-error:0.145493 
## [225]    train-error:0.144814 
## [226]    train-error:0.144559 
## [227]    train-error:0.144276 
## [228]    train-error:0.143569 
## [229]    train-error:0.143059 
## [230]    train-error:0.142720 
## [231]    train-error:0.142522 
## [232]    train-error:0.141248 
## [233]    train-error:0.140569 
## [234]    train-error:0.140513 
## [235]    train-error:0.140456 
## [236]    train-error:0.140371 
## [237]    train-error:0.139862 
## [238]    train-error:0.139466 
## [239]    train-error:0.139353 
## [240]    train-error:0.138787 
## [241]    train-error:0.138277 
## [242]    train-error:0.137768 
## [243]    train-error:0.137655 
## [244]    train-error:0.136580 
## [245]    train-error:0.136438 
## [246]    train-error:0.136014 
## [247]    train-error:0.135618 
## [248]    train-error:0.135306 
## [249]    train-error:0.135137 
## [250]    train-error:0.134260
pred_default <- predict(xgb_default, predictors[test_idx,])
error_default <- abs(label[test_idx] - pred_default) > 0.5
xgb_default$evaluation_log[250,]
##    iter train_error
## 1:  250     0.13426
mean(error_default)
## [1] 0.3438

부스팅 모델을 학습한 결과, 학습 데이터에 대한 오차율은 12.72%였지만, 테스트 데이터에 대한 오차율은 그것보다 훨씬 높은 수치로 나타난다 (mean(error_default)).


정규화 regularization 방법을 이용하면 모델의 복잡도에 따라 벌점을 추가하는 형태로 비용함수를 변경할 수 있다.

의사 결정 트리에서는 지니 불순도와 같은 비용 기준값을 최소화하는 쪽으로 모델을 피팅했다.


xgboost에서는 모델을 정규화하기 위한 두 파라미터 alpha와 lambda가 존재한다.

이 파라미터들을 크게 하면, 모델이 복잡해질수록 더 많은 벌점을 부여하게 되고 결과적으로 얻어지는 트리의 크기가 작아지게 된다.

xgb_penalty <- xgboost(data=predictors[-test_idx,],
                       label=label[-test_idx],
                       params=list(eta=.1, subsample=.63, lambda=1000),
                       objective='binary:logistic', nrounds=250)
## [1]  train-error:0.342029 
## [2]  train-error:0.340049 
## [3]  train-error:0.339624 
## [4]  train-error:0.339737 
## [5]  train-error:0.341463 
## [6]  train-error:0.341803 
## [7]  train-error:0.341294 
## [8]  train-error:0.340332 
## [9]  train-error:0.339285 
## [10] train-error:0.335974 
## [11] train-error:0.336484 
## [12] train-error:0.335974 
## [13] train-error:0.335663 
## [14] train-error:0.335154 
## [15] train-error:0.335040 
## [16] train-error:0.334305 
## [17] train-error:0.334503 
## [18] train-error:0.334842 
## [19] train-error:0.334135 
## [20] train-error:0.333852 
## [21] train-error:0.333513 
## [22] train-error:0.333286 
## [23] train-error:0.332777 
## [24] train-error:0.332494 
## [25] train-error:0.331787 
## [26] train-error:0.332494 
## [27] train-error:0.333116 
## [28] train-error:0.332268 
## [29] train-error:0.331787 
## [30] train-error:0.331475 
## [31] train-error:0.330825 
## [32] train-error:0.330994 
## [33] train-error:0.331079 
## [34] train-error:0.330768 
## [35] train-error:0.330485 
## [36] train-error:0.329976 
## [37] train-error:0.329240 
## [38] train-error:0.329466 
## [39] train-error:0.329155 
## [40] train-error:0.329014 
## [41] train-error:0.329099 
## [42] train-error:0.328872 
## [43] train-error:0.328872 
## [44] train-error:0.328108 
## [45] train-error:0.328023 
## [46] train-error:0.327910 
## [47] train-error:0.327627 
## [48] train-error:0.327995 
## [49] train-error:0.327769 
## [50] train-error:0.327316 
## [51] train-error:0.327373 
## [52] train-error:0.327005 
## [53] train-error:0.326778 
## [54] train-error:0.326439 
## [55] train-error:0.326948 
## [56] train-error:0.326665 
## [57] train-error:0.326043 
## [58] train-error:0.325873 
## [59] train-error:0.326382 
## [60] train-error:0.326071 
## [61] train-error:0.326184 
## [62] train-error:0.325901 
## [63] train-error:0.326269 
## [64] train-error:0.326128 
## [65] train-error:0.326014 
## [66] train-error:0.325562 
## [67] train-error:0.325533 
## [68] train-error:0.325222 
## [69] train-error:0.325505 
## [70] train-error:0.325420 
## [71] train-error:0.325392 
## [72] train-error:0.325335 
## [73] train-error:0.325024 
## [74] train-error:0.324996 
## [75] train-error:0.324769 
## [76] train-error:0.324883 
## [77] train-error:0.324656 
## [78] train-error:0.323921 
## [79] train-error:0.323411 
## [80] train-error:0.323751 
## [81] train-error:0.323751 
## [82] train-error:0.323638 
## [83] train-error:0.323524 
## [84] train-error:0.323468 
## [85] train-error:0.323751 
## [86] train-error:0.324034 
## [87] train-error:0.324119 
## [88] train-error:0.323807 
## [89] train-error:0.324062 
## [90] train-error:0.324147 
## [91] train-error:0.323864 
## [92] train-error:0.323977 
## [93] train-error:0.323722 
## [94] train-error:0.323836 
## [95] train-error:0.323355 
## [96] train-error:0.323524 
## [97] train-error:0.323383 
## [98] train-error:0.323270 
## [99] train-error:0.323128 
## [100]    train-error:0.323157 
## [101]    train-error:0.323241 
## [102]    train-error:0.323157 
## [103]    train-error:0.322902 
## [104]    train-error:0.323326 
## [105]    train-error:0.323100 
## [106]    train-error:0.323015 
## [107]    train-error:0.323213 
## [108]    train-error:0.323128 
## [109]    train-error:0.323213 
## [110]    train-error:0.322817 
## [111]    train-error:0.323213 
## [112]    train-error:0.322987 
## [113]    train-error:0.322619 
## [114]    train-error:0.322534 
## [115]    train-error:0.322506 
## [116]    train-error:0.322506 
## [117]    train-error:0.322449 
## [118]    train-error:0.322732 
## [119]    train-error:0.322647 
## [120]    train-error:0.322619 
## [121]    train-error:0.321940 
## [122]    train-error:0.321996 
## [123]    train-error:0.321940 
## [124]    train-error:0.321968 
## [125]    train-error:0.321883 
## [126]    train-error:0.321827 
## [127]    train-error:0.321402 
## [128]    train-error:0.321346 
## [129]    train-error:0.321063 
## [130]    train-error:0.321091 
## [131]    train-error:0.320893 
## [132]    train-error:0.320752 
## [133]    train-error:0.320553 
## [134]    train-error:0.320808 
## [135]    train-error:0.320808 
## [136]    train-error:0.320667 
## [137]    train-error:0.320214 
## [138]    train-error:0.320327 
## [139]    train-error:0.320440 
## [140]    train-error:0.320129 
## [141]    train-error:0.320214 
## [142]    train-error:0.320497 
## [143]    train-error:0.320327 
## [144]    train-error:0.320214 
## [145]    train-error:0.320214 
## [146]    train-error:0.320044 
## [147]    train-error:0.319903 
## [148]    train-error:0.319959 
## [149]    train-error:0.319761 
## [150]    train-error:0.319393 
## [151]    train-error:0.319450 
## [152]    train-error:0.319280 
## [153]    train-error:0.319337 
## [154]    train-error:0.319224 
## [155]    train-error:0.319054 
## [156]    train-error:0.319110 
## [157]    train-error:0.318601 
## [158]    train-error:0.318856 
## [159]    train-error:0.318743 
## [160]    train-error:0.318686 
## [161]    train-error:0.318743 
## [162]    train-error:0.318516 
## [163]    train-error:0.318714 
## [164]    train-error:0.318375 
## [165]    train-error:0.318346 
## [166]    train-error:0.318460 
## [167]    train-error:0.318092 
## [168]    train-error:0.317979 
## [169]    train-error:0.317979 
## [170]    train-error:0.317837 
## [171]    train-error:0.317384 
## [172]    train-error:0.317300 
## [173]    train-error:0.317356 
## [174]    train-error:0.317639 
## [175]    train-error:0.317498 
## [176]    train-error:0.317526 
## [177]    train-error:0.317130 
## [178]    train-error:0.317186 
## [179]    train-error:0.317243 
## [180]    train-error:0.317271 
## [181]    train-error:0.317300 
## [182]    train-error:0.317158 
## [183]    train-error:0.317215 
## [184]    train-error:0.316677 
## [185]    train-error:0.316677 
## [186]    train-error:0.316677 
## [187]    train-error:0.316224 
## [188]    train-error:0.316111 
## [189]    train-error:0.315941 
## [190]    train-error:0.315970 
## [191]    train-error:0.316139 
## [192]    train-error:0.315828 
## [193]    train-error:0.315743 
## [194]    train-error:0.315630 
## [195]    train-error:0.315828 
## [196]    train-error:0.315602 
## [197]    train-error:0.315687 
## [198]    train-error:0.315743 
## [199]    train-error:0.315687 
## [200]    train-error:0.315574 
## [201]    train-error:0.315517 
## [202]    train-error:0.315856 
## [203]    train-error:0.315574 
## [204]    train-error:0.316111 
## [205]    train-error:0.315574 
## [206]    train-error:0.315687 
## [207]    train-error:0.315262 
## [208]    train-error:0.315234 
## [209]    train-error:0.314979 
## [210]    train-error:0.314612 
## [211]    train-error:0.315008 
## [212]    train-error:0.314866 
## [213]    train-error:0.314979 
## [214]    train-error:0.314894 
## [215]    train-error:0.314866 
## [216]    train-error:0.314810 
## [217]    train-error:0.314753 
## [218]    train-error:0.314810 
## [219]    train-error:0.314442 
## [220]    train-error:0.314527 
## [221]    train-error:0.314470 
## [222]    train-error:0.314272 
## [223]    train-error:0.313961 
## [224]    train-error:0.314159 
## [225]    train-error:0.314046 
## [226]    train-error:0.314159 
## [227]    train-error:0.314017 
## [228]    train-error:0.313904 
## [229]    train-error:0.313989 
## [230]    train-error:0.313706 
## [231]    train-error:0.313508 
## [232]    train-error:0.313197 
## [233]    train-error:0.313310 
## [234]    train-error:0.313197 
## [235]    train-error:0.313225 
## [236]    train-error:0.313168 
## [237]    train-error:0.313197 
## [238]    train-error:0.313084 
## [239]    train-error:0.312942 
## [240]    train-error:0.313027 
## [241]    train-error:0.312970 
## [242]    train-error:0.313112 
## [243]    train-error:0.313055 
## [244]    train-error:0.312801 
## [245]    train-error:0.312886 
## [246]    train-error:0.312518 
## [247]    train-error:0.312801 
## [248]    train-error:0.312574 
## [249]    train-error:0.312461 
## [250]    train-error:0.312348
pred_penalty <- predict(xgb_penalty, predictors[test_idx,])
error_penalty <- abs(label[test_idx] - pred_penalty) > 0.5
xgb_penalty$evaluation_log[250,]
##    iter train_error
## 1:  250    0.312348
mean(error_penalty)
## [1] 0.3249

ntreelimit 를 이용할 때, 예측을 위해 사용하는 모델의 개수에 따른 표본 내 오차율과 표본 밖 오차율을 더 쉽게 비교할 수 있다.

error_default <- rep(0, 250)
error_penalty <- rep(0, 250)

for(i in 1:250){
  pred_def <- predict(xgb_default, predictors[test_idx,], ntreelimit=i)
  error_default[i] <- mean(abs(label[test_idx] - pred_def) >= 0.5)
  pred_pen <- predict(xgb_penalty, predictors[test_idx,], ntreelimit=i)
  error_penalty[i] <- mean(abs(label[test_idx] - pred_pen) >= 0.5)
}
errors <- rbind(xgb_default$evaluation_log,
                xgb_penalty$evaluation_log,
                data.frame(iter=1:250, train_error=error_default),
                data.frame(iter=1:250, train_error=error_penalty))

errors$type <- rep(c('default train', 'penalty train',
                     'default test', 'penalty test'), rep(250, 4))

ggplot(errors, aes(x=iter, y=train_error, group=type)) +
  geom_line(aes(linetype=type, color=type))

==> 기본 모형은 정확도가 학습 데이터에 대해서는 꾸준히 좋아지지만 테스트 데이터에 대해서는 나빠진다. 하지만 벌점을 추가한 모형에서는 그렇지 않다.


하이퍼파라미터와 교차타당성검사


교차타당성검사는 일단 데이터를 K개의 서로 다른 그룹(fold)으로 랜덤하게 나눈다. 각 폴드마다 해당 폴드에 속한 데이터를 제외한 나머지 데이터를 가지고 모델을 학습한 후, 폴드에 속한 데이터를 이용해 모델을 평가한다.

이는 결국 표본 밖 데이터에 대한 모델의 성능을 보여준다. 전체적으로 가장 낮은 평균 오차를 갖는 최적의 하이퍼파라미터 조합을 찾는다.

N <- nrow(loan_data)
fold_number <- sample(1:5, N, replace=TRUE)
params <- data.frame(eta = rep(c(.1, .5, .9), 3),
                     max_depth = rep(c(3, 6, 12), rep(3, 3)))

이제 각 5개의 fold를 이용해 각 모델에 대한 오차를 계산한다.

error <- matrix(0, nrow=9, ncol=5)
for(i in 1:nrow(params)){
  for(k in 1:5){
    fold_idx <- (1:N)[fold_number == k]
    xgb <- xgboost(data=predictors[-fold_idx,], label=label[-fold_idx],
                   params=list(eta = params[i, 'eta'],
                               max_depth = params[i, 'max_depth']),
                   objective = 'binary:logistic', nrounds=100, verbose=0)
    pred <- predict(xgb, predictors[fold_idx,])
    error[i, k] <- mean(abs(label[fold_idx] - pred) >= 0.5)
  }
}
avg_error <- 100 * rowMeans(error)
cbind(params, avg_error)
##   eta max_depth avg_error
## 1 0.1         3  32.99163
## 2 0.5         3  33.42027
## 3 0.9         3  34.41122
## 4 0.1         6  33.12896
## 5 0.5         6  35.23196
## 6 0.9         6  37.77478
## 7 0.1        12  34.55011
## 8 0.5        12  37.14070
## 9 0.9        12  38.05394

교차타당성검사를 통해 eta 값이 작으면서 깊이가 얕은 트리를 사용하는 것이 좀 더 정확한 성능을 보인다는 사실을 알게 되었다.

최적의 파라미터는 eta=0.1, max_depth=3 (또는 max_depth=6)이라고 할 수 있다.