이 코드는 Titanic호에 승선했던 사람들의 데이터셋을 이용하여 데이터 탐색(EDA; Explorer Data Analysis)을 수행하고, 머신러닝(randomForest)을 이용하여 생존율을 예측한 후, kaggle에 제출하는 것이다.
1. 데이터: Titanic 데이터 셋
2. 데이터 탐색을 통해 통찰력(insight) 수행
3. 머신러닝(randomForester package)으로 생존율 예측 4. kaggle에 예측 데이터 제출
library(dplyr)
library(readr)
train <- read_csv("train.csv")
str(train)
tibble [891 x 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
$ PassengerId: num [1:891] 1 2 3 4 5 6 7 8 9 10 ...
$ Survived : num [1:891] 0 1 1 1 0 0 0 0 1 1 ...
$ Pclass : num [1:891] 3 1 3 1 3 3 1 3 3 2 ...
$ Name : chr [1:891] "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
$ Sex : chr [1:891] "male" "female" "female" "female" ...
$ Age : num [1:891] 22 38 26 35 35 NA 54 2 27 14 ...
$ SibSp : num [1:891] 1 1 0 1 0 0 0 3 0 1 ...
$ Parch : num [1:891] 0 0 0 0 0 0 0 1 2 0 ...
$ Ticket : chr [1:891] "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
$ Fare : num [1:891] 7.25 71.28 7.92 53.1 8.05 ...
$ Cabin : chr [1:891] NA "C85" NA "C123" ...
$ Embarked : chr [1:891] "S" "C" "S" "S" ...
- attr(*, "spec")=
.. cols(
.. PassengerId = col_double(),
.. Survived = col_double(),
.. Pclass = col_double(),
.. Name = col_character(),
.. Sex = col_character(),
.. Age = col_double(),
.. SibSp = col_double(),
.. Parch = col_double(),
.. Ticket = col_character(),
.. Fare = col_double(),
.. Cabin = col_character(),
.. Embarked = col_character()
.. )
PassengerId, Name, Age, Ticket, Fare, Cabin을 제외하고, 나머지 컬럼의 범주를 일괄적으로 확인.
train %>% select(-c(1,4,6,9,10,11)) %>% sapply(table)
$Survived
0 1
549 342
$Pclass
1 2 3
216 184 491
$Sex
female male
314 577
$SibSp
0 1 2 3 4 5 8
608 209 28 16 18 5 7
$Parch
0 1 2 3 4 5 6
678 118 80 5 4 5 1
$Embarked
C Q S
168 77 644
train %>% sapply(is.na) %>% colSums()
PassengerId Survived Pclass Name Sex Age
0 0 0 0 0 177
SibSp Parch Ticket Fare Cabin Embarked
0 0 0 0 687 2
library(ggplot2)
library(plotly) # 대화 형 웹 그래픽
rate<-train %>% group_by(Sex, Survived) %>%
summarise(n=n()) %>%
mutate(rate=n/sum(n))
rate
# A tibble: 4 x 4
# Groups: Sex [2]
Sex Survived n rate
<chr> <dbl> <int> <dbl>
1 female 0 81 0.258
2 female 1 233 0.742
3 male 0 468 0.811
4 male 1 109 0.189
p<-rate %>% ggplot()+
aes(x = Sex, y = rate, fill=as.factor(Survived))+
geom_bar(position = "dodge", stat = 'identity')+
labs(title="성별 사망&생존-비율-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
p<-train %>% group_by(Survived) %>%
ggplot()+
aes(x = Sex, fill=as.factor(Survived))+
geom_bar(position = "stack", show.legend = T, stat = 'count') +
labs(title="성별 사망&생존 수-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
rate<-train %>% group_by(Pclass, Survived) %>%
summarise(total=n()) %>%
mutate(rate=total/sum(total))
rate
# A tibble: 6 x 4
# Groups: Pclass [3]
Pclass Survived total rate
<dbl> <dbl> <int> <dbl>
1 1 0 80 0.370
2 1 1 136 0.630
3 2 0 97 0.527
4 2 1 87 0.473
5 3 0 372 0.758
6 3 1 119 0.242
p<-rate %>% ggplot()+
aes(x = Pclass, y = rate, fill=as.factor(Survived))+
geom_bar(position = "dodge", stat = 'identity')+
labs(title="등급별 사망&생존율-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
train %>% group_by(Pclass, Survived) %>%
ggplot()+
aes(x = Pclass, fill=as.factor(Survived))+
geom_bar(position = "stack", show.legend = T, stat = 'count')+
labs(title="등급별 사망&생존 수-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
rate<-train %>%
transform(age_group=cut(Age,
breaks = c(0,10,20,30,40,50,60,70,80))) %>%
group_by(age_group, Survived) %>%
summarise(total=n()) %>%
mutate(rate=total/sum(total))
rate
# A tibble: 18 x 4
# Groups: age_group [9]
age_group Survived total rate
<fct> <dbl> <int> <dbl>
1 (0,10] 0 26 0.406
2 (0,10] 1 38 0.594
3 (10,20] 0 71 0.617
4 (10,20] 1 44 0.383
5 (20,30] 0 146 0.635
6 (20,30] 1 84 0.365
7 (30,40] 0 86 0.555
8 (30,40] 1 69 0.445
9 (40,50] 0 53 0.616
10 (40,50] 1 33 0.384
11 (50,60] 0 25 0.595
12 (50,60] 1 17 0.405
13 (60,70] 0 13 0.765
14 (60,70] 1 4 0.235
15 (70,80] 0 4 0.8
16 (70,80] 1 1 0.2
17 <NA> 0 125 0.706
18 <NA> 1 52 0.294
p<-rate %>% ggplot()+
aes(x = age_group, y = rate, fill=as.factor(Survived))+
geom_bar(position = "dodge", stat = 'identity')+
labs(title="연령대별 사망&생존율-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
p<-train %>%
transform(age_group=cut(Age,
breaks = c(0,10,20,30,40,50,60,70,80))) %>%
ggplot()+
aes(age_group, group=as.factor(Survived), fill=as.factor(Survived))+
geom_bar(position = "dodge")+
labs(title="연령대별 사망&생존수-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
Warning: `group_by_()` is deprecated as of dplyr 0.7.0.
Please use `group_by()` instead.
See vignette('programming') for more help
This warning is displayed once every 8 hours.
Call `lifecycle::last_warnings()` to see where this warning was generated.
train %>%
ggplot()+
aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
geom_density(alpha=0.3)+
labs(title="사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_density).
train %>%
ggplot()+
aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
geom_histogram(position = "dodge", bins = 50)+
labs(title="사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_bin).
train %>% ggplot()+
aes(Age, group=as.factor(Survived), fill=as.factor(Survived))+
geom_boxplot()+
labs(title="사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
Warning: Removed 177 rows containing non-finite values (stat_boxplot).
train %>% ggplot()+
aes(x = Age, group=Sex, fill=Sex)+
geom_density(alpha=0.3)
Warning: Removed 177 rows containing non-finite values (stat_density).
train %>% ggplot()+
aes(Age, group=Sex, fill=Sex)+
geom_boxplot()
Warning: Removed 177 rows containing non-finite values (stat_boxplot).
train$Age %>% summary()
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
0.42 20.12 28.00 29.70 38.00 80.00 177
rate<-train %>% group_by(Embarked, Survived) %>%
summarise(total=n()) %>%
mutate(rate=total/sum(total))
rate
# A tibble: 7 x 4
# Groups: Embarked [4]
Embarked Survived total rate
<chr> <dbl> <int> <dbl>
1 C 0 75 0.446
2 C 1 93 0.554
3 Q 0 47 0.610
4 Q 1 30 0.390
5 S 0 427 0.663
6 S 1 217 0.337
7 <NA> 1 2 1
p<-rate %>% ggplot()+
aes(x = Embarked, y = rate, fill=as.factor(Survived))+
geom_bar(position = "dodge", stat = 'identity')+
labs(title="선착장 별 사망&생존율-사망:0, 생존:1", fill="Survived")+
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p)
train %>% ggplot()+
aes(Embarked, group=as.factor(Survived), fill=as.factor(Survived))+
geom_bar(position="dodge", na.rm = T)
train %>% ggplot()+
aes(x = Pclass, y = Fare, group=Embarked, col=Embarked)+
geom_jitter()
train %>%
ggplot()+
aes(x = Age, y = Fare, group=Survived, fill=as.factor(Survived), col=as.factor(Survived))+
geom_jitter()+
labs(title="요금 별 사망&생존율-사망:0, 생존:1")
Warning: Removed 177 rows containing missing values (geom_point).
#### 소결론: 요금이 낮은 사람들이 훨씬 더 많이 사망했음.
train %>% select(Pclass, Survived) %>% cor()
Pclass Survived
Pclass 1.000000 -0.338481
Survived -0.338481 1.000000
train %>%
mutate(Sex2=case_when(Sex=='male' ~ 1,Sex=='female' ~ 0),
Age2=case_when(Age=is.na(Age) ~ mean(Age, na.rm = T),
TRUE ~ Age),
Embarked2=case_when(Embarked=='C' ~ 1,
Embarked=='Q' ~ 2,
Embarked=='S' ~ 3,
TRUE ~ 0)) %>%
select(Survived, Fare, Pclass, Sex2, Embarked, Age2) %>%
psych::pairs.panels()
train %>% sapply(is.na) %>% colSums()
PassengerId Survived Pclass Name Sex Age
0 0 0 0 0 177
SibSp Parch Ticket Fare Cabin Embarked
0 0 0 0 687 2
Age의 경우 결측치의 값을 평균,분산을 이용하여 생성 후 대입하는 방법을 아래와 같이 시도해 봤으나, 이렇게 생성하고 보니까 음수 값(-값)이 나와서 안됨. 그리고, 의미도 없어 보이고.
# test$Age[which(is.na(test$Age))]<-
# rnorm(n =length(which(is.na(test$Age))),
# mean=mean(test$Age, na.rm = T),
# sd=sd(test$Age, na.rm = T))
Age 값을 density, median=28, mean=29.7, IQR=18, mode=24 등으로, 분석해보니, Age의 NA 값은 최빈값(mode=24)를 넣는것이 가장 좋을 것으로 판단됨.
train$Age[is.na(train$Age)]<- train$Age %>% table() %>% which.max() %>% names()
Embarked NA 값은 단 2개인데, 1개는 C, 1개는 S로 넣음.
train$Embarked %>% table()
.
C Q S
168 77 644
train$Embarked[is.na(train$Embarked)]<-c('C','S')
이렇게 바꾸고, 다시 상관관계를 확인해 보았으나, 거의 차이가 없음.
train %>%
mutate(Sex2=case_when(Sex=='male' ~ 1,Sex=='female' ~ 0),
Embarked2=case_when(Embarked=='C' ~ 1,
Embarked=='Q' ~ 2,
Embarked=='S' ~ 3,
TRUE ~ 0)) %>%
select(Survived, Fare, Pclass, Sex2, Embarked, Age) %>%
psych::pairs.panels()
library(randomForest)
library(caret) #성능 평가시 confusion matrix를 그리기 위한 라이브러리
library(e1071) #성능 평가시 confusion matrix를 그리기 위한 라이브러리
#출처: https://3months.tistory.com/215 [Deep Play]
train<-train[,-c(4,9,11)]
train$Survived<-as.factor(train$Survived)
# 다른 컬럼들도 범주형으로 변환해서 테스트 해 보았으나 차이가 없음.
# train$Pclass<-as.factor(train$Pclass)
# train$Sex<-as.factor(train$Sex)
# train$Embarked<-as.factor(train$Embarked)
fit<-randomForest(Survived ~ .-PassengerId,
data = train,
importance = T) # 예측 변수의 중요성을 평가 여부
fit
Call:
randomForest(formula = Survived ~ . - PassengerId, data = train, importance = T)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 2
OOB estimate of error rate: 17.73%
Confusion matrix:
0 1 class.error
0 499 50 0.09107468
1 108 234 0.31578947
결과를 보면 약 18.18%의 에러율(error rate) 과 confusion matrix 값을 보여준다. 즉, 이 모델이 약 88%의 정확도(accuracy)를 보여준다고 하겠다.
plot(fit, type="l") # 이 그림의 의미는 잘 모르겠음.
varImp(fit, scale=T) # caret package의 함수. 훈련된 개체의 변수의 중요도
0 1
Pclass 25.386378 25.386378
Sex 73.216621 73.216621
Age 11.687546 11.687546
SibSp 9.654512 9.654512
Parch 11.494003 11.494003
Fare 21.046498 21.046498
Embarked 10.162607 10.162607
varImpPlot(fit, color='red') # randomForest package 함수.
importance(fit, # 변수 중요도 측정 추출
type = NULL) # 1=mean decrease in accuracy, 2=mean decrease in node impurity)
0 1 MeanDecreaseAccuracy MeanDecreaseGini
Pclass 20.051295 30.721462 38.21768 33.97455
Sex 59.270765 87.162477 86.11129 104.45648
Age 7.885424 15.489668 17.77382 47.72601
SibSp 14.857460 4.451563 16.26461 15.02815
Parch 15.233187 7.754819 17.79139 14.17742
Fare 19.179458 22.913537 31.87635 66.58260
Embarked 8.270351 12.054864 14.77798 11.15142
이상과 같이 가장 큰 영향을 미치는 변수를 확인했다.
test<-read_csv('test.csv')
test %>% str()
tibble [418 x 11] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
$ PassengerId: num [1:418] 892 893 894 895 896 897 898 899 900 901 ...
$ Pclass : num [1:418] 3 3 2 3 3 3 3 2 3 3 ...
$ Name : chr [1:418] "Kelly, Mr. James" "Wilkes, Mrs. James (Ellen Needs)" "Myles, Mr. Thomas Francis" "Wirz, Mr. Albert" ...
$ Sex : chr [1:418] "male" "female" "male" "male" ...
$ Age : num [1:418] 34.5 47 62 27 22 14 30 26 18 21 ...
$ SibSp : num [1:418] 0 1 0 0 1 0 0 1 0 2 ...
$ Parch : num [1:418] 0 0 0 0 1 0 0 1 0 0 ...
$ Ticket : chr [1:418] "330911" "363272" "240276" "315154" ...
$ Fare : num [1:418] 7.83 7 9.69 8.66 12.29 ...
$ Cabin : chr [1:418] NA NA NA NA ...
$ Embarked : chr [1:418] "Q" "S" "Q" "S" ...
- attr(*, "spec")=
.. cols(
.. PassengerId = col_double(),
.. Pclass = col_double(),
.. Name = col_character(),
.. Sex = col_character(),
.. Age = col_double(),
.. SibSp = col_double(),
.. Parch = col_double(),
.. Ticket = col_character(),
.. Fare = col_double(),
.. Cabin = col_character(),
.. Embarked = col_character()
.. )
train 데이터 훈련 시 NA 처리 및 PassengerId, Name, Ticket, Cabin 컬럼 제거했음.
pred <- predict(fit, test, type="class")
pred
1 2 3 4 5 6 7 8 9 10 <NA> 12 13 14 15 16
0 0 0 0 1 0 1 0 1 0 <NA> 0 1 0 1 1
17 18 19 20 21 22 <NA> 24 25 26 27 28 29 <NA> 31 32
0 0 0 1 1 0 <NA> 0 1 0 1 0 0 <NA> 0 0
33 <NA> 35 36 <NA> 38 39 <NA> 41 <NA> 43 44 45 46 47 <NA>
0 <NA> 1 0 <NA> 0 0 <NA> 0 <NA> 0 1 1 0 0 <NA>
49 50 51 52 53 54 <NA> 56 57 58 <NA> 60 61 62 63 64
1 1 0 0 1 1 <NA> 0 0 0 <NA> 1 0 0 0 1
65 <NA> 67 68 69 70 71 72 73 74 75 76 <NA> 78 79 80
1 <NA> 1 0 0 1 1 0 1 0 1 0 <NA> 1 0 1
81 82 83 <NA> <NA> <NA> 87 88 <NA> 90 91 <NA> 93 <NA> 95 96
1 0 0 <NA> <NA> <NA> 1 0 <NA> 1 0 <NA> 1 <NA> 0 0
97 98 99 100 101 102 <NA> 104 105 106 107 <NA> <NA> 110 111 <NA>
1 0 1 0 1 0 <NA> 0 1 0 0 <NA> <NA> 0 0 <NA>
113 114 115 116 <NA> 118 119 120 121 <NA> 123 124 <NA> 126 127 <NA>
1 1 1 0 <NA> 1 0 1 1 <NA> 1 0 <NA> 1 0 <NA>
129 130 131 132 <NA> <NA> 135 136 137 138 139 140 141 142 143 144
0 0 0 0 <NA> <NA> 0 0 0 0 0 0 0 1 0 0
145 146 <NA> 148 <NA> 150 151 <NA> <NA> 154 155 156 157 158 159 160
1 0 <NA> 0 <NA> 1 1 <NA> <NA> 1 0 0 1 1 1 1
<NA> 162 163 <NA> 165 166 167 168 <NA> 170 <NA> 172 173 <NA> 175 176
<NA> 1 1 <NA> 0 0 0 0 <NA> 0 <NA> 0 0 <NA> 0 1
177 178 179 180 181 182 183 <NA> 185 186 187 188 <NA> 190 191 <NA>
1 1 1 1 0 0 1 <NA> 1 0 1 0 <NA> 0 0 <NA>
193 194 195 196 197 198 199 <NA> <NA> 202 203 204 205 <NA> 207 208
0 0 1 0 1 1 0 <NA> <NA> 1 0 1 0 <NA> 1 0
209 210 211 <NA> 213 214 215 216 <NA> 218 219 <NA> 221 222 223 224
1 0 0 <NA> 0 1 0 0 <NA> 0 1 <NA> 1 0 1 0
225 <NA> 227 <NA> 229 230 231 232 233 <NA> 235 236 237 238 239 240
1 <NA> 0 <NA> 0 0 0 1 0 <NA> 1 0 1 0 1 1
241 242 243 <NA> <NA> 246 247 248 249 <NA> 251 252 253 254 255 <NA>
1 1 0 <NA> <NA> 0 1 0 1 <NA> 1 0 0 0 0 <NA>
<NA> 258 259 260 261 262 263 264 265 <NA> <NA> <NA> <NA> 270 271 <NA>
<NA> 0 1 0 0 0 1 1 0 <NA> <NA> <NA> <NA> 0 0 <NA>
273 <NA> <NA> 276 277 278 279 280 281 282 <NA> 284 285 286 <NA> 288
1 <NA> <NA> 1 0 0 0 0 0 1 <NA> 1 1 0 <NA> 0
<NA> <NA> <NA> 292 <NA> 294 295 296 297 <NA> 299 300 301 <NA> 303 304
<NA> <NA> <NA> 1 <NA> 0 0 0 1 <NA> 0 0 0 <NA> 0 0
<NA> 306 307 308 309 310 311 312 <NA> 314 315 316 317 318 319 320
<NA> 1 1 1 0 0 0 0 <NA> 1 1 1 0 0 0 0
321 322 323 324 325 326 327 328 329 330 331 332 <NA> 334 335 336
0 0 0 0 1 0 1 0 0 0 1 0 <NA> 1 0 0
337 338 339 <NA> 341 342 <NA> 344 <NA> 346 347 348 349 350 351 352
0 0 0 <NA> 0 0 <NA> 1 <NA> 1 0 1 0 1 1 0
353 354 355 356 357 <NA> <NA> 360 361 362 363 364 365 <NA> <NA> 368
0 0 1 0 1 <NA> <NA> 0 0 1 1 0 1 <NA> <NA> 0
369 370 371 372 373 374 375 376 377 378 379 380 <NA> 382 <NA> 384
1 0 0 1 0 0 1 1 0 0 0 0 <NA> 0 <NA> 0
<NA> 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
<NA> 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0
401 402 403 404 405 406 407 408 <NA> 410 <NA> 412 413 <NA> 415 416
1 0 1 0 0 0 0 0 <NA> 1 <NA> 1 1 <NA> 1 0
<NA> <NA>
<NA> <NA>
Levels: 0 1
pred %>% is.na() %>% sum()
[1] 87
확인해 보니, 87개로 NA 값을 가진 Age:86, Fare:1의 개수만큼 NA로 채워졌음. 그래서 test 데이터에 대해 Age, Fare의 NA 값을 채워 넣음.
#### 6.8 데이터 전처리: Age
# Age는 train과 같이 최빈값을 채워 넣음.
test$Age[is.na(test$Age)]<- test$Age %>% table() %>% which.max() %>% names()
어떤 값을 넣는 것이 가장 좋을 지 확인하기 위해 데이터 분포 확인.
test %>% ggplot()+
aes(Fare)+
geom_density()
Warning: Removed 1 rows containing non-finite values (stat_density).
test %>% ggplot()+
aes(Fare)+
geom_boxplot()
Warning: Removed 1 rows containing non-finite values (stat_boxplot).
test$Fare %>% table() %>% which.max() %>% names()
[1] "7.75"
test$Fare %>% summary()
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
0.000 7.896 14.454 35.627 31.500 512.329 1
test$Fare %>% IQR(na.rm = T)
[1] 23.6042
확인 후 median 값을 넣는 것이 가장 좋을 것 같음.
test$Fare[is.na(test$Fare)]<-test$Fare %>% median(na.rm = T)
test %>% sapply(is.na) %>% colSums()
PassengerId Pclass Name Sex Age SibSp
0 0 0 0 0 0
Parch Ticket Fare Cabin Embarked
0 0 0 327 0
pred <- predict(fit, test, type="class")
pred
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 1 0 0 0 1
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
1 0 1 0 1 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
0 0 0 1 1 0 0 0 1 1 0 0 1 1 0 0 0 0 0 1
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
0 0 0 1 1 1 1 0 0 1 1 0 0 0 1 0 0 1 0 1
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
1 0 0 0 0 0 1 0 1 1 0 0 1 0 0 0 1 0 1 0
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
1 0 0 0 1 0 0 0 0 0 0 1 1 1 1 0 0 1 0 1
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
1 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 0 1 1 0 1
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
1 1 1 0 0 0 0 0 1 0 0 0 0 0 0 1 1 1 1 1
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
0 0 1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
1 1 0 1 0 0 1 0 1 0 0 0 0 1 0 0 1 0 1 0
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
1 0 1 0 1 1 0 1 0 0 0 1 0 0 1 0 1 0 1 1
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
1 1 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 1 0
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 0 0
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
0 0 0 1 0 1 0 1 0 1 1 0 0 0 1 0 1 0 0 0
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
0 1 1 0 1 0 0 0 1 0 0 1 0 0 1 1 0 0 0 0
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
1 0 1 0 0 0 0 0 1 0 1 1 0 0 1 0 0 1
Levels: 0 1
pred %>% class() # factor 형임.
[1] "factor"
ori <- read_csv('titanic/gender_submission.csv')
confusionMatrix(pred, as.factor(ori$Survived))
Confusion Matrix and Statistics
Reference
Prediction 0 1
0 250 31
1 16 121
Accuracy : 0.8876
95% CI : (0.8533, 0.9162)
No Information Rate : 0.6364
P-Value [Acc > NIR] : < 2e-16
Kappa : 0.7518
Mcnemar's Test P-Value : 0.04114
Sensitivity : 0.9398
Specificity : 0.7961
Pos Pred Value : 0.8897
Neg Pred Value : 0.8832
Prevalence : 0.6364
Detection Rate : 0.5981
Detection Prevalence : 0.6722
Balanced Accuracy : 0.8680
'Positive' Class : 0
gender_submission <- read_csv('gender_submission.csv')
gender_submission <- cbind(gender_submission, pred) %>%
select(1,3) %>%
rename(Survived=pred)
gender_submission %>% write_csv('gender_submission.csv')