0. 목적

이 코드는 Titanic호에 승선했던 사람들의 데이터셋을 이용하여 데이터 탐색(EDA; Explorer Data Analysis)을 수행하고, 머신러닝(randomForest)을 이용하여 생존율을 예측한 후, kaggle에 제출하는 것이다.
1. 데이터: Titanic 데이터 셋
2. 데이터 탐색을 통해 통찰력(insight) 수행
3. 머신러닝(randomForester package)으로 생존율 예측 4. kaggle에 예측 데이터 제출

Data Dictionary

  1. PassengerId = col_double()
  2. survived : 생존=1, 죽음=0
  3. pclass : 승객 등급. 1등급=1, 2등급=2, 3등급=3
  4. Name = col_character()
  5. Sex = col_character()
  6. Age = col_double()
  7. sibsp : 함께 탑승한 형제 또는 배우자 수
  8. parch : 함께 탑승한 부모 또는 자녀 수
  9. ticket : 티켓 번호
  10. Fare = col_double()
  11. cabin : 선실 번호
  12. embarked : 탑승장소 S=Southhampton, C=Cherbourg, Q=Queenstown

1. 데이터의 구조 확인

library(dplyr)
library(readr)
train <- read_csv("train.csv")

str(train)
tibble [891 x 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
 $ PassengerId: num [1:891] 1 2 3 4 5 6 7 8 9 10 ...
 $ Survived   : num [1:891] 0 1 1 1 0 0 0 0 1 1 ...
 $ Pclass     : num [1:891] 3 1 3 1 3 3 1 3 3 2 ...
 $ Name       : chr [1:891] "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
 $ Sex        : chr [1:891] "male" "female" "female" "female" ...
 $ Age        : num [1:891] 22 38 26 35 35 NA 54 2 27 14 ...
 $ SibSp      : num [1:891] 1 1 0 1 0 0 0 3 0 1 ...
 $ Parch      : num [1:891] 0 0 0 0 0 0 0 1 2 0 ...
 $ Ticket     : chr [1:891] "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
 $ Fare       : num [1:891] 7.25 71.28 7.92 53.1 8.05 ...
 $ Cabin      : chr [1:891] NA "C85" NA "C123" ...
 $ Embarked   : chr [1:891] "S" "C" "S" "S" ...
 - attr(*, "spec")=
  .. cols(
  ..   PassengerId = col_double(),
  ..   Survived = col_double(),
  ..   Pclass = col_double(),
  ..   Name = col_character(),
  ..   Sex = col_character(),
  ..   Age = col_double(),
  ..   SibSp = col_double(),
  ..   Parch = col_double(),
  ..   Ticket = col_character(),
  ..   Fare = col_double(),
  ..   Cabin = col_character(),
  ..   Embarked = col_character()
  .. )

2. 각 컬럼의 범주(factor) 확인:

PassengerId, Name, Age, Ticket, Fare, Cabin을 제외하고, 나머지 컬럼의 범주를 일괄적으로 확인.

train %>% select(-c(1,4,6,9,10,11)) %>% sapply(table)
$Survived

  0   1 
549 342 

$Pclass

  1   2   3 
216 184 491 

$Sex

female   male 
   314    577 

$SibSp

  0   1   2   3   4   5   8 
608 209  28  16  18   5   7 

$Parch

  0   1   2   3   4   5   6 
678 118  80   5   4   5   1 

$Embarked

  C   Q   S 
168  77 644 

3. 결측치 확인: 결측치를 어떻게 처리할 것인지는 데이터 탐색을 더 해본 후 결정

train %>% sapply(is.na) %>% colSums()
PassengerId    Survived      Pclass        Name         Sex         Age 
          0           0           0           0           0         177 
      SibSp       Parch      Ticket        Fare       Cabin    Embarked 
          0           0           0           0         687           2 

4. 그래프로 데이터 탐색

library(ggplot2)
library(plotly) # 대화 형 웹 그래픽 

4.1 남,녀 생존자 비율

rate<-train %>% group_by(Sex, Survived) %>% 
  summarise(n=n()) %>% 
  mutate(rate=n/sum(n))
rate
# A tibble: 4 x 4
# Groups:   Sex [2]
  Sex    Survived     n  rate
  <chr>     <dbl> <int> <dbl>
1 female        0    81 0.258
2 female        1   233 0.742
3 male          0   468 0.811
4 male          1   109 0.189
p<-rate %>% ggplot()+
  aes(x = Sex, y = rate, fill=as.factor(Survived))+
  geom_bar(position = "dodge", stat = 'identity')+
  labs(title="성별 사망&생존-비율-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
p<-train %>% group_by(Survived) %>% 
  ggplot()+
  aes(x = Sex, fill=as.factor(Survived))+
  geom_bar(position = "stack", show.legend = T, stat = 'count') + 
  labs(title="성별 사망&생존 수-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)

소결론: 여자의 생존률이 훨씬 높음

4.2 클라스별 생존자 비율

rate<-train %>% group_by(Pclass, Survived) %>% 
  summarise(total=n()) %>% 
  mutate(rate=total/sum(total))
rate
# A tibble: 6 x 4
# Groups:   Pclass [3]
  Pclass Survived total  rate
   <dbl>    <dbl> <int> <dbl>
1      1        0    80 0.370
2      1        1   136 0.630
3      2        0    97 0.527
4      2        1    87 0.473
5      3        0   372 0.758
6      3        1   119 0.242
p<-rate %>% ggplot()+
  aes(x = Pclass, y = rate, fill=as.factor(Survived))+
  geom_bar(position = "dodge", stat = 'identity')+
  labs(title="등급별 사망&생존율-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
train %>% group_by(Pclass, Survived) %>% 
  ggplot()+
  aes(x = Pclass, fill=as.factor(Survived))+
  geom_bar(position = "stack", show.legend = T, stat = 'count')+
  labs(title="등급별 사망&생존 수-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))

소결론: 1등급의 생존률이 훨씬 높고, 3등급은 사망률이 훨씬 높음

4.3 나이대별 생존율

rate<-train %>% 
  transform(age_group=cut(Age,
                    breaks = c(0,10,20,30,40,50,60,70,80))) %>% 
  group_by(age_group, Survived) %>% 
  summarise(total=n()) %>% 
  mutate(rate=total/sum(total))
rate
# A tibble: 18 x 4
# Groups:   age_group [9]
   age_group Survived total  rate
   <fct>        <dbl> <int> <dbl>
 1 (0,10]           0    26 0.406
 2 (0,10]           1    38 0.594
 3 (10,20]          0    71 0.617
 4 (10,20]          1    44 0.383
 5 (20,30]          0   146 0.635
 6 (20,30]          1    84 0.365
 7 (30,40]          0    86 0.555
 8 (30,40]          1    69 0.445
 9 (40,50]          0    53 0.616
10 (40,50]          1    33 0.384
11 (50,60]          0    25 0.595
12 (50,60]          1    17 0.405
13 (60,70]          0    13 0.765
14 (60,70]          1     4 0.235
15 (70,80]          0     4 0.8  
16 (70,80]          1     1 0.2  
17 <NA>             0   125 0.706
18 <NA>             1    52 0.294
p<-rate %>% ggplot()+
  aes(x = age_group, y = rate, fill=as.factor(Survived))+
  geom_bar(position = "dodge", stat = 'identity')+
  labs(title="연령대별 사망&생존율-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
p<-train %>% 
  transform(age_group=cut(Age,
                    breaks = c(0,10,20,30,40,50,60,70,80))) %>% 
  ggplot()+
  aes(age_group, group=as.factor(Survived), fill=as.factor(Survived))+
  geom_bar(position = "dodge")+
  labs(title="연령대별 사망&생존수-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
Warning: `group_by_()` is deprecated as of dplyr 0.7.0.
Please use `group_by()` instead.
See vignette('programming') for more help
This warning is displayed once every 8 hours.
Call `lifecycle::last_warnings()` to see where this warning was generated.
train %>% 
  ggplot()+
  aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
  geom_density(alpha=0.3)+
  labs(title="사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_density).

train %>% 
  ggplot()+
  aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
  geom_histogram(position = "dodge", bins = 50)+
  labs(title="사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_bin).

train %>% ggplot()+
  aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
  geom_boxplot()+
  labs(title="사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_boxplot).

train %>% ggplot()+
  aes(x = Age, group=Sex, fill=Sex)+
  geom_density(alpha=0.3)
Warning: Removed 177 rows containing non-finite values (stat_density).

train %>% ggplot()+
  aes(Age, group=Sex, fill=Sex)+
  geom_boxplot()
Warning: Removed 177 rows containing non-finite values (stat_boxplot).

train$Age %>% summary()
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   0.42   20.12   28.00   29.70   38.00   80.00     177 

소결론: 10세 이하는 생존율이 훨씬 높고, 또 30~40세가 생존율이 다른 세대에 비해 높은 것으로 봐서 아마 10세 이하의 부모로 판단됨. 50세 이상은 사망율이 매우 높음.

4.4 탑승 선착장에 따른 생존률 분석

rate<-train %>% group_by(Embarked, Survived) %>% 
  summarise(total=n()) %>% 
  mutate(rate=total/sum(total))
rate
# A tibble: 7 x 4
# Groups:   Embarked [4]
  Embarked Survived total  rate
  <chr>       <dbl> <int> <dbl>
1 C               0    75 0.446
2 C               1    93 0.554
3 Q               0    47 0.610
4 Q               1    30 0.390
5 S               0   427 0.663
6 S               1   217 0.337
7 <NA>            1     2 1    
p<-rate %>% ggplot()+
  aes(x = Embarked, y = rate, fill=as.factor(Survived))+
  geom_bar(position = "dodge", stat = 'identity')+
  labs(title="선착장 별 사망&생존율-사망:0, 생존:1", fill="Survived")+
  theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
train %>% ggplot()+
  aes(Embarked, group=as.factor(Survived), fill=as.factor(Survived))+
  geom_bar(position="dodge", na.rm = T)

소결론: C 선착장은 생존율이 더 높은데 비해서 S 선착장은 사망률이 훨씬 낮음.

4.4.1 이유를 알아보기 위해서 선착장과 요금, Class의 관계를 좀 더 살펴보자.

train %>% ggplot()+
  aes(x = Pclass, y = Fare, group=Embarked, col=Embarked)+
  geom_jitter()

소결론: 이 그림을 보니 이해가 됨. 요금이 낮고, 등급이 낮은 승객의 다수가 S 선착장에서 탑승했음.

4.5 요금과 생존률 탐색(Fare vs Survived)

train %>% 
  ggplot()+
  aes(x = Age, y = Fare, group=Survived, fill=as.factor(Survived), col=as.factor(Survived))+
  geom_jitter()+
  labs(title="요금 별 사망&생존율-사망:0, 생존:1")
Warning: Removed 177 rows containing missing values (geom_point).

#### 소결론: 요금이 낮은 사람들이 훨씬 더 많이 사망했음.

4.6 상관분석: Fare, Age, Sex, Pclass와 Survived의 상관관계

train %>% select(Pclass, Survived) %>% cor()
            Pclass  Survived
Pclass    1.000000 -0.338481
Survived -0.338481  1.000000
train %>% 
  mutate(Sex2=case_when(Sex=='male' ~ 1,Sex=='female' ~ 0),
         Age2=case_when(Age=is.na(Age) ~ mean(Age, na.rm = T),
                        TRUE ~ Age),
         Embarked2=case_when(Embarked=='C' ~ 1,
                             Embarked=='Q' ~ 2,
                             Embarked=='S' ~ 3,
                             TRUE ~ 0)) %>% 
  select(Survived, Fare, Pclass, Sex2, Embarked, Age2) %>% 
  psych::pairs.panels()

소결론: 이 그림을 보니 확실히 생존율이 Sex, Pclass, Fare와 큰 상관이 있음. 나이와는 그리 큰 상관은 없음.

결론 종합: 여기까지의 결론은 가난한(3등급) 남자(male)의 죽을 확률이 매우 높다고 할 수 있다.

5.이제 결측치(Age, Embarked)를 채워넣고 다시 탐색

train %>% sapply(is.na) %>% colSums()
PassengerId    Survived      Pclass        Name         Sex         Age 
          0           0           0           0           0         177 
      SibSp       Parch      Ticket        Fare       Cabin    Embarked 
          0           0           0           0         687           2 

Age의 경우 결측치의 값을 평균,분산을 이용하여 생성 후 대입하는 방법을 아래와 같이 시도해 봤으나, 이렇게 생성하고 보니까 음수 값(-값)이 나와서 안됨. 그리고, 의미도 없어 보이고.

# test$Age[which(is.na(test$Age))]<- 
#   rnorm(n =length(which(is.na(test$Age))),
#         mean=mean(test$Age, na.rm = T),
#         sd=sd(test$Age, na.rm = T))

Age 값을 density, median=28, mean=29.7, IQR=18, mode=24 등으로, 분석해보니, Age의 NA 값은 최빈값(mode=24)를 넣는것이 가장 좋을 것으로 판단됨.

train$Age[is.na(train$Age)]<- train$Age %>% table() %>% which.max() %>% names()

Embarked NA 값은 단 2개인데, 1개는 C, 1개는 S로 넣음.

train$Embarked %>% table()
.
  C   Q   S 
168  77 644 
train$Embarked[is.na(train$Embarked)]<-c('C','S')

이렇게 바꾸고, 다시 상관관계를 확인해 보았으나, 거의 차이가 없음.

train %>% 
  mutate(Sex2=case_when(Sex=='male' ~ 1,Sex=='female' ~ 0),
         Embarked2=case_when(Embarked=='C' ~ 1,
                             Embarked=='Q' ~ 2,
                             Embarked=='S' ~ 3,
                             TRUE ~ 0)) %>% 
  select(Survived, Fare, Pclass, Sex2, Embarked, Age) %>% 
  psych::pairs.panels()

EDA 최종 결론: 결측치를 채우나 혹은 채우지 않으나 거의 차이가 없음. 따라서 최종 결론은 “가난한(3등급) 남자(male)은 죽을 확률이 매우 높다고 할 수 있다.”

6. 머신러닝으로 생존확률 구하기. 생존 확률을 randomForest 머신러닝 패키지로 구해보자.

library(randomForest)
library(caret) #성능 평가시 confusion matrix를 그리기 위한 라이브러리
library(e1071) #성능 평가시 confusion matrix를 그리기 위한 라이브러리
#출처: https://3months.tistory.com/215 [Deep Play]  

6.1 필요 없는 컬럼(Name, Ticket, Cabin)을 제외하고 다시 구성.

train<-train[,-c(4,9,11)]

6.2 예측할(predictor) Survived는 범주형이므로 factor로 변환

train$Survived<-as.factor(train$Survived)
# 다른 컬럼들도 범주형으로 변환해서 테스트 해 보았으나 차이가 없음.
# train$Pclass<-as.factor(train$Pclass)
# train$Sex<-as.factor(train$Sex)
# train$Embarked<-as.factor(train$Embarked)

6.3 모델 훈련: 다양한 옵션이 있음.

fit<-randomForest(Survived ~ .-PassengerId,
                  data = train,
                  importance = T) # 예측 변수의 중요성을 평가 여부 

6.4 결과 확인

fit

Call:
 randomForest(formula = Survived ~ . - PassengerId, data = train,      importance = T) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 2

        OOB estimate of  error rate: 17.73%
Confusion matrix:
    0   1 class.error
0 499  50  0.09107468
1 108 234  0.31578947

결과를 보면 약 18.18%의 에러율(error rate) 과 confusion matrix 값을 보여준다. 즉, 이 모델이 약 88%의 정확도(accuracy)를 보여준다고 하겠다.

plot(fit, type="l") # 이 그림의 의미는 잘 모르겠음.

6.5 어떤 변수(feature)가 가장 중요한지 확인. 숫자가 클수록 가장 큰 영향을 미치는 변수임

varImp(fit, scale=T)  # caret package의 함수. 훈련된 개체의 변수의 중요도  
                 0         1
Pclass   25.386378 25.386378
Sex      73.216621 73.216621
Age      11.687546 11.687546
SibSp     9.654512  9.654512
Parch    11.494003 11.494003
Fare     21.046498 21.046498
Embarked 10.162607 10.162607
varImpPlot(fit, color='red') # randomForest package 함수.

importance(fit,              # 변수 중요도 측정 추출
           type = NULL)      # 1=mean decrease in accuracy, 2=mean decrease in node impurity)
                 0         1 MeanDecreaseAccuracy MeanDecreaseGini
Pclass   20.051295 30.721462             38.21768         33.97455
Sex      59.270765 87.162477             86.11129        104.45648
Age       7.885424 15.489668             17.77382         47.72601
SibSp    14.857460  4.451563             16.26461         15.02815
Parch    15.233187  7.754819             17.79139         14.17742
Fare     19.179458 22.913537             31.87635         66.58260
Embarked  8.270351 12.054864             14.77798         11.15142

이상과 같이 가장 큰 영향을 미치는 변수를 확인했다.

6.6 예측 수행: 이제 test 데이터에 대해 예측을 수행해보자.

test<-read_csv('test.csv')
test %>% str()
tibble [418 x 11] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
 $ PassengerId: num [1:418] 892 893 894 895 896 897 898 899 900 901 ...
 $ Pclass     : num [1:418] 3 3 2 3 3 3 3 2 3 3 ...
 $ Name       : chr [1:418] "Kelly, Mr. James" "Wilkes, Mrs. James (Ellen Needs)" "Myles, Mr. Thomas Francis" "Wirz, Mr. Albert" ...
 $ Sex        : chr [1:418] "male" "female" "male" "male" ...
 $ Age        : num [1:418] 34.5 47 62 27 22 14 30 26 18 21 ...
 $ SibSp      : num [1:418] 0 1 0 0 1 0 0 1 0 2 ...
 $ Parch      : num [1:418] 0 0 0 0 1 0 0 1 0 0 ...
 $ Ticket     : chr [1:418] "330911" "363272" "240276" "315154" ...
 $ Fare       : num [1:418] 7.83 7 9.69 8.66 12.29 ...
 $ Cabin      : chr [1:418] NA NA NA NA ...
 $ Embarked   : chr [1:418] "Q" "S" "Q" "S" ...
 - attr(*, "spec")=
  .. cols(
  ..   PassengerId = col_double(),
  ..   Pclass = col_double(),
  ..   Name = col_character(),
  ..   Sex = col_character(),
  ..   Age = col_double(),
  ..   SibSp = col_double(),
  ..   Parch = col_double(),
  ..   Ticket = col_character(),
  ..   Fare = col_double(),
  ..   Cabin = col_character(),
  ..   Embarked = col_character()
  .. )

6.7 데이터 전처리 없이 예측.

train 데이터 훈련 시 NA 처리 및 PassengerId, Name, Ticket, Cabin 컬럼 제거했음.

pred <- predict(fit, test, type="class")
pred
   1    2    3    4    5    6    7    8    9   10 <NA>   12   13   14   15   16 
   0    0    0    0    1    0    1    0    1    0 <NA>    0    1    0    1    1 
  17   18   19   20   21   22 <NA>   24   25   26   27   28   29 <NA>   31   32 
   0    0    0    1    1    0 <NA>    0    1    0    1    0    0 <NA>    0    0 
  33 <NA>   35   36 <NA>   38   39 <NA>   41 <NA>   43   44   45   46   47 <NA> 
   0 <NA>    1    0 <NA>    0    0 <NA>    0 <NA>    0    1    1    0    0 <NA> 
  49   50   51   52   53   54 <NA>   56   57   58 <NA>   60   61   62   63   64 
   1    1    0    0    1    1 <NA>    0    0    0 <NA>    1    0    0    0    1 
  65 <NA>   67   68   69   70   71   72   73   74   75   76 <NA>   78   79   80 
   1 <NA>    1    0    0    1    1    0    1    0    1    0 <NA>    1    0    1 
  81   82   83 <NA> <NA> <NA>   87   88 <NA>   90   91 <NA>   93 <NA>   95   96 
   1    0    0 <NA> <NA> <NA>    1    0 <NA>    1    0 <NA>    1 <NA>    0    0 
  97   98   99  100  101  102 <NA>  104  105  106  107 <NA> <NA>  110  111 <NA> 
   1    0    1    0    1    0 <NA>    0    1    0    0 <NA> <NA>    0    0 <NA> 
 113  114  115  116 <NA>  118  119  120  121 <NA>  123  124 <NA>  126  127 <NA> 
   1    1    1    0 <NA>    1    0    1    1 <NA>    1    0 <NA>    1    0 <NA> 
 129  130  131  132 <NA> <NA>  135  136  137  138  139  140  141  142  143  144 
   0    0    0    0 <NA> <NA>    0    0    0    0    0    0    0    1    0    0 
 145  146 <NA>  148 <NA>  150  151 <NA> <NA>  154  155  156  157  158  159  160 
   1    0 <NA>    0 <NA>    1    1 <NA> <NA>    1    0    0    1    1    1    1 
<NA>  162  163 <NA>  165  166  167  168 <NA>  170 <NA>  172  173 <NA>  175  176 
<NA>    1    1 <NA>    0    0    0    0 <NA>    0 <NA>    0    0 <NA>    0    1 
 177  178  179  180  181  182  183 <NA>  185  186  187  188 <NA>  190  191 <NA> 
   1    1    1    1    0    0    1 <NA>    1    0    1    0 <NA>    0    0 <NA> 
 193  194  195  196  197  198  199 <NA> <NA>  202  203  204  205 <NA>  207  208 
   0    0    1    0    1    1    0 <NA> <NA>    1    0    1    0 <NA>    1    0 
 209  210  211 <NA>  213  214  215  216 <NA>  218  219 <NA>  221  222  223  224 
   1    0    0 <NA>    0    1    0    0 <NA>    0    1 <NA>    1    0    1    0 
 225 <NA>  227 <NA>  229  230  231  232  233 <NA>  235  236  237  238  239  240 
   1 <NA>    0 <NA>    0    0    0    1    0 <NA>    1    0    1    0    1    1 
 241  242  243 <NA> <NA>  246  247  248  249 <NA>  251  252  253  254  255 <NA> 
   1    1    0 <NA> <NA>    0    1    0    1 <NA>    1    0    0    0    0 <NA> 
<NA>  258  259  260  261  262  263  264  265 <NA> <NA> <NA> <NA>  270  271 <NA> 
<NA>    0    1    0    0    0    1    1    0 <NA> <NA> <NA> <NA>    0    0 <NA> 
 273 <NA> <NA>  276  277  278  279  280  281  282 <NA>  284  285  286 <NA>  288 
   1 <NA> <NA>    1    0    0    0    0    0    1 <NA>    1    1    0 <NA>    0 
<NA> <NA> <NA>  292 <NA>  294  295  296  297 <NA>  299  300  301 <NA>  303  304 
<NA> <NA> <NA>    1 <NA>    0    0    0    1 <NA>    0    0    0 <NA>    0    0 
<NA>  306  307  308  309  310  311  312 <NA>  314  315  316  317  318  319  320 
<NA>    1    1    1    0    0    0    0 <NA>    1    1    1    0    0    0    0 
 321  322  323  324  325  326  327  328  329  330  331  332 <NA>  334  335  336 
   0    0    0    0    1    0    1    0    0    0    1    0 <NA>    1    0    0 
 337  338  339 <NA>  341  342 <NA>  344 <NA>  346  347  348  349  350  351  352 
   0    0    0 <NA>    0    0 <NA>    1 <NA>    1    0    1    0    1    1    0 
 353  354  355  356  357 <NA> <NA>  360  361  362  363  364  365 <NA> <NA>  368 
   0    0    1    0    1 <NA> <NA>    0    0    1    1    0    1 <NA> <NA>    0 
 369  370  371  372  373  374  375  376  377  378  379  380 <NA>  382 <NA>  384 
   1    0    0    1    0    0    1    1    0    0    0    0 <NA>    0 <NA>    0 
<NA>  386  387  388  389  390  391  392  393  394  395  396  397  398  399  400 
<NA>    1    0    0    0    0    0    1    0    0    0    1    0    1    0    0 
 401  402  403  404  405  406  407  408 <NA>  410 <NA>  412  413 <NA>  415  416 
   1    0    1    0    0    0    0    0 <NA>    1 <NA>    1    1 <NA>    1    0 
<NA> <NA> 
<NA> <NA> 
Levels: 0 1
pred %>% is.na() %>% sum() 
[1] 87

확인해 보니, 87개로 NA 값을 가진 Age:86, Fare:1의 개수만큼 NA로 채워졌음. 그래서 test 데이터에 대해 Age, Fare의 NA 값을 채워 넣음.
#### 6.8 데이터 전처리: Age

# Age는 train과 같이 최빈값을 채워 넣음.
test$Age[is.na(test$Age)]<- test$Age %>% table() %>% which.max() %>% names()

6.8 데이터 전처리: Fare

어떤 값을 넣는 것이 가장 좋을 지 확인하기 위해 데이터 분포 확인.

test %>% ggplot()+
  aes(Fare)+
  geom_density()
Warning: Removed 1 rows containing non-finite values (stat_density).

test %>% ggplot()+
  aes(Fare)+
  geom_boxplot()
Warning: Removed 1 rows containing non-finite values (stat_boxplot).

test$Fare %>% table() %>% which.max() %>% names()
[1] "7.75"
test$Fare %>% summary()
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
  0.000   7.896  14.454  35.627  31.500 512.329       1 
test$Fare %>% IQR(na.rm = T)
[1] 23.6042

확인 후 median 값을 넣는 것이 가장 좋을 것 같음.

test$Fare[is.na(test$Fare)]<-test$Fare %>% median(na.rm = T)
test %>% sapply(is.na) %>% colSums()
PassengerId      Pclass        Name         Sex         Age       SibSp 
          0           0           0           0           0           0 
      Parch      Ticket        Fare       Cabin    Embarked 
          0           0           0         327           0 

6.9 데이터 전처리 후 다시 예측.

pred <- predict(fit, test, type="class")
pred
  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20 
  0   0   0   0   1   0   1   0   1   0   0   0   1   0   1   1   0   0   0   1 
 21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40 
  1   0   1   0   1   0   1   0   0   0   0   0   1   0   1   0   0   0   0   0 
 41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60 
  0   0   0   1   1   0   0   0   1   1   0   0   1   1   0   0   0   0   0   1 
 61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80 
  0   0   0   1   1   1   1   0   0   1   1   0   0   0   1   0   0   1   0   1 
 81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 
  1   0   0   0   0   0   1   0   1   1   0   0   1   0   0   0   1   0   1   0 
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 
  1   0   0   0   1   0   0   0   0   0   0   1   1   1   1   0   0   1   0   1 
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 
  1   0   1   0   0   1   0   1   0   0   0   0   0   0   0   0   0   0   0   0 
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 
  0   1   0   0   0   0   0   0   0   1   1   0   0   1   0   0   1   1   0   1 
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 
  1   1   1   0   0   0   0   0   1   0   0   0   0   0   0   1   1   1   1   1 
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 
  0   0   1   0   1   0   1   0   0   0   0   0   0   0   0   0   0   0   0   0 
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 
  1   1   0   1   0   0   1   0   1   0   0   0   0   1   0   0   1   0   1   0 
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 
  1   0   1   0   1   1   0   1   0   0   0   1   0   0   1   0   1   0   1   1 
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 
  1   1   0   0   0   0   1   0   1   0   1   0   0   0   0   0   0   0   1   0 
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 
  0   0   1   1   0   0   0   0   0   0   0   0   1   1   0   1   0   0   0   0 
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 
  0   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0 
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 
  0   0   0   0   1   1   1   1   0   0   0   0   0   1   1   1   0   0   0   0 
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 
  0   0   0   1   1   0   1   0   0   0   1   0   0   1   0   0   0   0   0   0 
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 
  0   0   0   1   0   1   0   1   0   1   1   0   0   0   1   0   1   0   0   0 
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 
  0   1   1   0   1   0   0   0   1   0   0   1   0   0   1   1   0   0   0   0 
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 
  0   0   0   1   0   1   0   0   0   0   0   1   0   0   0   1   0   1   0   0 
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 
  1   0   1   0   0   0   0   0   1   0   1   1   0   0   1   0   0   1 
Levels: 0 1

6.10 예측 값 비교

pred %>% class()   # factor 형임.
[1] "factor"
ori <- read_csv('titanic/gender_submission.csv')

confusionMatrix(pred, as.factor(ori$Survived)) 
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 250  31
         1  16 121
                                          
               Accuracy : 0.8876          
                 95% CI : (0.8533, 0.9162)
    No Information Rate : 0.6364          
    P-Value [Acc > NIR] : < 2e-16         
                                          
                  Kappa : 0.7518          
                                          
 Mcnemar's Test P-Value : 0.04114         
                                          
            Sensitivity : 0.9398          
            Specificity : 0.7961          
         Pos Pred Value : 0.8897          
         Neg Pred Value : 0.8832          
             Prevalence : 0.6364          
         Detection Rate : 0.5981          
   Detection Prevalence : 0.6722          
      Balanced Accuracy : 0.8680          
                                          
       'Positive' Class : 0               
                                          

결과값은 훈련 때와 거의 유사하게 약 88%의 Accuracy를 보여준다.

7. kaggle에 제출될 form으로 변경해서 제출

gender_submission <- read_csv('gender_submission.csv')
gender_submission <- cbind(gender_submission, pred) %>% 
  select(1,3) %>%
  rename(Survived=pred)
gender_submission %>% write_csv('gender_submission.csv')