Predict if people in the data set belong in a certain class by salary, either making <=50k or >50k per year by comparing three models : Logistics Regression, Decision Tree and RandomForest.
suppressWarnings(library(readr))
adult <- read_csv("C:/Users/dannyhuang/Desktop/adult.csv")
## Parsed with column specification:
## cols(
## age = col_integer(),
## workclass = col_character(),
## fnlwgt = col_integer(),
## education = col_character(),
## `educational-num` = col_integer(),
## `marital-status` = col_character(),
## occupation = col_character(),
## relationship = col_character(),
## race = col_character(),
## gender = col_character(),
## `capital-gain` = col_integer(),
## `capital-loss` = col_integer(),
## `hours-per-week` = col_integer(),
## `native-country` = col_character(),
## income = col_character()
## )
head(adult,10)
## # A tibble: 10 × 15
## age workclass fnlwgt education `educational-num`
## <int> <chr> <int> <chr> <int>
## 1 25 Private 226802 11th 7
## 2 38 Private 89814 HS-grad 9
## 3 28 Local-gov 336951 Assoc-acdm 12
## 4 44 Private 160323 Some-college 10
## 5 18 ? 103497 Some-college 10
## 6 34 Private 198693 10th 6
## 7 29 ? 227026 HS-grad 9
## 8 63 Self-emp-not-inc 104626 Prof-school 15
## 9 24 Private 369667 Some-college 10
## 10 55 Private 104996 7th-8th 4
## # ... with 10 more variables: `marital-status` <chr>, occupation <chr>,
## # relationship <chr>, race <chr>, gender <chr>, `capital-gain` <int>,
## # `capital-loss` <int>, `hours-per-week` <int>, `native-country` <chr>,
## # income <chr>
str(adult)
## Classes 'tbl_df', 'tbl' and 'data.frame': 48842 obs. of 15 variables:
## $ age : int 25 38 28 44 18 34 29 63 24 55 ...
## $ workclass : chr "Private" "Private" "Local-gov" "Private" ...
## $ fnlwgt : int 226802 89814 336951 160323 103497 198693 227026 104626 369667 104996 ...
## $ education : chr "11th" "HS-grad" "Assoc-acdm" "Some-college" ...
## $ educational-num: int 7 9 12 10 10 6 9 15 10 4 ...
## $ marital-status : chr "Never-married" "Married-civ-spouse" "Married-civ-spouse" "Married-civ-spouse" ...
## $ occupation : chr "Machine-op-inspct" "Farming-fishing" "Protective-serv" "Machine-op-inspct" ...
## $ relationship : chr "Own-child" "Husband" "Husband" "Husband" ...
## $ race : chr "Black" "White" "White" "Black" ...
## $ gender : chr "Male" "Male" "Male" "Male" ...
## $ capital-gain : int 0 0 0 7688 0 0 0 3103 0 0 ...
## $ capital-loss : int 0 0 0 0 0 0 0 0 0 0 ...
## $ hours-per-week : int 40 50 40 40 30 30 40 32 40 10 ...
## $ native-country : chr "United-States" "United-States" "United-States" "United-States" ...
## $ income : chr "<=50K" "<=50K" ">50K" ">50K" ...
## - attr(*, "spec")=List of 2
## ..$ cols :List of 15
## .. ..$ age : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ workclass : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ fnlwgt : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ education : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ educational-num: list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ marital-status : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ occupation : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ relationship : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ race : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ gender : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ capital-gain : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ capital-loss : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ hours-per-week : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ native-country : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ income : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## ..$ default: list()
## .. ..- attr(*, "class")= chr "collector_guess" "collector"
## ..- attr(*, "class")= chr "col_spec"
table(adult$workclass)
##
## ? Federal-gov Local-gov Never-worked
## 2799 1432 3136 10
## Private Self-emp-inc Self-emp-not-inc State-gov
## 33906 1695 3862 1981
## Without-pay
## 21
### 2799 NULL Values
unemployed <- function(job){
job <- as.character(job)
if (job=='Never-worked' | job=='Without-pay'){
return('Unemployed')
}else{
return(job)
}
}
adult$workclass <- sapply(adult$workclass,unemployed)
table(adult$workclass)
##
## ? Federal-gov Local-gov Private
## 2799 1432 3136 33906
## Self-emp-inc Self-emp-not-inc State-gov Unemployed
## 1695 3862 1981 31
group_emp <- function(job){
if (job=='Local-gov' | job=='State-gov'){
return('SL-gov')
}else if (job=='Self-emp-inc' | job=='Self-emp-not-inc'){
return('self-emp')
}else{
return(job)
}
}
adult$workclass <- sapply(adult$workclass,group_emp)
table(adult$workclass)
##
## ? Federal-gov Private self-emp SL-gov Unemployed
## 2799 1432 33906 5557 5117 31
table(adult$`marital-status`)
##
## Divorced Married-AF-spouse Married-civ-spouse
## 6633 37 22379
## Married-spouse-absent Never-married Separated
## 628 16117 1530
## Widowed
## 1518
group_marital <- function(mar){
mar <- as.character(mar)
# Not-Married
if (mar=='Separated' | mar=='Divorced' | mar=='Widowed'){
return('Not-Married')
# Never-Married
}else if(mar=='Never-married'){
return(mar)
#Married
}else{
return('Married')
}
}
adult$`marital-status` <- sapply(adult$`marital-status` ,group_marital)
table(adult$`marital-status`)
##
## Married Never-married Not-Married
## 23044 16117 9681
table(adult$country)
## Warning: Unknown column 'country'
## < table of extent 0 >
unique(adult$`native-country`)
## [1] "United-States" "?"
## [3] "Peru" "Guatemala"
## [5] "Mexico" "Dominican-Republic"
## [7] "Ireland" "Germany"
## [9] "Philippines" "Thailand"
## [11] "Haiti" "El-Salvador"
## [13] "Puerto-Rico" "Vietnam"
## [15] "South" "Columbia"
## [17] "Japan" "India"
## [19] "Cambodia" "Poland"
## [21] "Laos" "England"
## [23] "Cuba" "Taiwan"
## [25] "Italy" "Canada"
## [27] "Portugal" "China"
## [29] "Nicaragua" "Honduras"
## [31] "Iran" "Scotland"
## [33] "Jamaica" "Ecuador"
## [35] "Yugoslavia" "Hungary"
## [37] "Hong" "Greece"
## [39] "Trinadad&Tobago" "Outlying-US(Guam-USVI-etc)"
## [41] "France" "Holand-Netherlands"
Asia <- c('China','Hong','India','Iran','Cambodia','Japan', 'Laos' ,
'Philippines' ,'Vietnam' ,'Taiwan', 'Thailand')
North.America <- c('Canada','United-States','Puerto-Rico' )
Europe <- c('England' ,'France', 'Germany' ,'Greece','Holand-Netherlands','Hungary',
'Ireland','Italy','Poland','Portugal','Scotland','Yugoslavia')
Latin.and.South.America <- c('Columbia','Cuba','Dominican-Republic','Ecuador',
'El-Salvador','Guatemala','Haiti','Honduras',
'Mexico','Nicaragua','Outlying-US(Guam-USVI-etc)','Peru',
'Jamaica','Trinadad&Tobago')
Other <- c('South')
group_country <- function(ctry){
if (ctry %in% Asia){
return('Asia')
}else if (ctry %in% North.America){
return('North.America')
}else if (ctry %in% Europe){
return('Europe')
}else if (ctry %in% Latin.and.South.America){
return('Latin.and.South.America')
}else{
return('Other')
}
}
adult$`native-country` <- sapply(adult$`native-country`,group_country)
table(adult$`native-country`)
##
## Asia Europe Latin.and.South.America
## 981 780 1911
## North.America Other
## 44198 972
str(adult)
## Classes 'tbl_df', 'tbl' and 'data.frame': 48842 obs. of 15 variables:
## $ age : int 25 38 28 44 18 34 29 63 24 55 ...
## $ workclass : chr "Private" "Private" "SL-gov" "Private" ...
## $ fnlwgt : int 226802 89814 336951 160323 103497 198693 227026 104626 369667 104996 ...
## $ education : chr "11th" "HS-grad" "Assoc-acdm" "Some-college" ...
## $ educational-num: int 7 9 12 10 10 6 9 15 10 4 ...
## $ marital-status : chr "Never-married" "Married" "Married" "Married" ...
## $ occupation : chr "Machine-op-inspct" "Farming-fishing" "Protective-serv" "Machine-op-inspct" ...
## $ relationship : chr "Own-child" "Husband" "Husband" "Husband" ...
## $ race : chr "Black" "White" "White" "Black" ...
## $ gender : chr "Male" "Male" "Male" "Male" ...
## $ capital-gain : int 0 0 0 7688 0 0 0 3103 0 0 ...
## $ capital-loss : int 0 0 0 0 0 0 0 0 0 0 ...
## $ hours-per-week : int 40 50 40 40 30 30 40 32 40 10 ...
## $ native-country : chr "North.America" "North.America" "North.America" "North.America" ...
## $ income : chr "<=50K" "<=50K" ">50K" ">50K" ...
## - attr(*, "spec")=List of 2
## ..$ cols :List of 15
## .. ..$ age : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ workclass : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ fnlwgt : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ education : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ educational-num: list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ marital-status : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ occupation : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ relationship : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ race : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ gender : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ capital-gain : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ capital-loss : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ hours-per-week : list()
## .. .. ..- attr(*, "class")= chr "collector_integer" "collector"
## .. ..$ native-country : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## .. ..$ income : list()
## .. .. ..- attr(*, "class")= chr "collector_character" "collector"
## ..$ default: list()
## .. ..- attr(*, "class")= chr "collector_guess" "collector"
## ..- attr(*, "class")= chr "col_spec"
adult$workclass <- sapply(adult$workclass,factor)
adult$`native-country` <- sapply(adult$`native-country`,factor)
adult$`marital-status` <- sapply(adult$`marital-status`,factor)
adult[adult == '4'] <- NA
library(Amelia)
## Warning: package 'Amelia' was built under R version 3.3.2
## Loading required package: Rcpp
## ##
## ## Amelia II: Multiple Imputation
## ## (Version 1.7.4, built: 2015-12-05)
## ## Copyright (C) 2005-2017 James Honaker, Gary King and Matthew Blackwell
## ## Refer to http://gking.harvard.edu/amelia/ for more information
## ##
table(adult$workclass)
##
## Private SL-gov ? self-emp Federal-gov Unemployed
## 33906 5117 2799 5557 1432 31
adult$workclass <- sapply(adult$workclass,factor)
adult$`native-country` <- sapply(adult$`native-country`,factor)
adult$`marital-status` <- sapply(adult$`marital-status`,factor)
adult$occupation <- sapply(adult$occupation,factor)
adult$relationship <- sapply(adult$relationship,factor)
adult$income <- sapply(adult$income,factor)
adult$education<- sapply(adult$education,factor)
adult$gender<- sapply(adult$gender,factor)
adult$race<- sapply(adult$race,factor)
adult <- na.omit(adult)
missmap(adult)
## Warning in if (class(obj) == "amelia") {: 條件的長度 > 1,因此只能用其第一
## 元素
## Warning: Unknown column 'arguments'
## Warning: Unknown column 'arguments'
## Warning: Unknown column 'imputations'
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.3.2
library(dplyr)
## Warning: package 'dplyr' was built under R version 3.3.2
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
ggplot(adult,aes(age)) + geom_histogram(aes(fill=income),color='black',binwidth=1) + theme_bw()
###A Quick View about hours-per-week by histogram
library(caTools)
## Warning: package 'caTools' was built under R version 3.3.2
sample <- sample.split(adult$income, SplitRatio = 0.70) # SplitRatio = percent of sample==TRUE
# Training Data
train = subset(adult, sample == TRUE)
# Testing Data
test = subset(adult, sample == FALSE) #for logistics model
test2= test # for Desicison Tree
test3=test # for RandomForst
suppressWarnings(library(caret))
## Loading required package: lattice
glm_model = glm(income ~ ., family = binomial(logit), data = train)
test$predicted.income = predict(glm_model, newdata=test, type="response")
set.seed(1)
test$income_class <- ifelse(test$predicted.income > 0.5, ">50K","<=50K")
glm_con <-confusionMatrix(test$income_class,test$income)
glm_con
## Confusion Matrix and Statistics
##
## Reference
## Prediction <=50K >50K
## <=50K 10117 1389
## >50K 739 2097
##
## Accuracy : 0.8516
## 95% CI : (0.8457, 0.8574)
## No Information Rate : 0.7569
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5695
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9319
## Specificity : 0.6015
## Pos Pred Value : 0.8793
## Neg Pred Value : 0.7394
## Prevalence : 0.7569
## Detection Rate : 0.7054
## Detection Prevalence : 0.8023
## Balanced Accuracy : 0.7667
##
## 'Positive' Class : <=50K
##
suppressWarnings(library(caTools))
suppressWarnings(library(caTools))
colAUC(test$predicted.income,test$income, plotROC = TRUE)
## [,1]
## <=50K vs. >50K 0.9033706
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 3.3.2
## Loading required package: rpart
## Warning: package 'rpart' was built under R version 3.3.2
suppressWarnings(library(rpart))
suppressWarnings(library(ROCR))
## Loading required package: gplots
##
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
##
## lowess
names(train) <- make.names(names(train))
names(test2) <- make.names(names(test2))
tree_model <- rpart(income ~ ., train, method = "class")
all_probs <- predict(tree_model, test2, type = "prob")
test2$income_class <- ifelse(all_probs[,1]>0.5,"<=50K",">50K")
dt_con <- confusionMatrix(test2$income_class,test2$income)
dt_con
## Confusion Matrix and Statistics
##
## Reference
## Prediction <=50K >50K
## <=50K 10303 1728
## >50K 553 1758
##
## Accuracy : 0.841
## 95% CI : (0.8349, 0.8469)
## No Information Rate : 0.7569
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5119
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9491
## Specificity : 0.5043
## Pos Pred Value : 0.8564
## Neg Pred Value : 0.7607
## Prevalence : 0.7569
## Detection Rate : 0.7184
## Detection Prevalence : 0.8389
## Balanced Accuracy : 0.7267
##
## 'Positive' Class : <=50K
##
suppressWarnings(library(randomForest))
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
names(train) <- make.names(names(train))
names(test3) <- make.names(names(test3))
train$`educational-num` <- NULL
set.seed(32423)
rfFit<- randomForest(income~.,data= train)
print(rfFit)
##
## Call:
## randomForest(formula = income ~ ., data = train)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 3
##
## OOB estimate of error rate: 13.85%
## Confusion matrix:
## <=50K >50K class.error
## <=50K 23732 1599 0.06312424
## >50K 3036 5097 0.37329399
rf_pred <- predict(rfFit,test3,type = "class")
rf_con<- confusionMatrix(rf_pred, test3$income)
rf_con
## Confusion Matrix and Statistics
##
## Reference
## Prediction <=50K >50K
## <=50K 10220 1288
## >50K 636 2198
##
## Accuracy : 0.8658
## 95% CI : (0.8602, 0.8714)
## No Information Rate : 0.7569
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6107
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9414
## Specificity : 0.6305
## Pos Pred Value : 0.8881
## Neg Pred Value : 0.7756
## Prevalence : 0.7569
## Detection Rate : 0.7126
## Detection Prevalence : 0.8024
## Balanced Accuracy : 0.7860
##
## 'Positive' Class : <=50K
##
glmAcu <- glm_con$overall[1]
dtAcu<- dt_con$overall[1]
rfAcu<- rf_con$overall[1]
ACU <- data.frame(Model=c("Decision Tree","Logistic Regression","Random Forest"),Accuracy=c(dtAcu,glmAcu,rfAcu))
ggplot(ACU,aes(x=Model,y=Accuracy,fill=Model))+geom_bar(stat = 'identity')+theme_bw()+ggtitle('Accuracies of Models')
#It shows that, based on Accuracy, Random Forest Model is a relatively better.