收入群体分类
乐
初始设置
初步数据检查
变量说明
预处理
names(data) <- c("age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "sex", "capital_gain",
"capital_loss", "hours_per_week", "native_country", "class")
edu_level <- function(x) {
if (x == "HS-grad") {
"High_School_grad"
} else if (x == "Bachelors" | x == "Some-college") {
"Bachelors"
} else if (x %in% c("11th", "9th", "7th-8th", "5th-6th", "10th", "1st-4th",
"12th", "preschool")) {
"compulsory"
} else if (x == "Assoc-acdm" | x == "Assoc-voc") {
"Associate"
} else {
x
}
}
education_level <- apply(data["education"], 1, edu_level)
data <- cbind(data, education_level)
data$class <- factor(data$class)连续变量分析
Age
ggplot(data, aes(x = age, fill = class)) + geom_density(alpha = 0.8) + ggtitle("高收入群体年龄中位数较高",
subtitle = "低收入群体年龄分布呈右偏") + labs(y = NULL)fnlwgt
ggplot(data, aes(x = fnlwgt, fill = class)) + geom_density(alpha = 0.8) + ggtitle("两种收入水平群体的fnlwgt无明显差别") +
labs(y = NULL)Education
ggplot(data, aes(x = education_num, fill = class)) + geom_density(alpha = 0.8) +
ggtitle("低收入群体受教育年限多为9-10年") + labs(x = "education/year", y = NULL)data %>% ggplot(aes(x = class, y = education_num, fill = class)) + geom_boxplot() +
ggtitle("高收入群体的受教育年限长于低收入群体") + labs(y = "education/year")Capital
capital_gain <- data %>% filter(capital_gain < 20000) %>% ggplot(aes(x = capital_gain,
fill = class)) + geom_density(alpha = 0.8) + ggtitle("两种收入群体capital_gain相差悬殊") +
labs(y = NULL)
capital_loss <- data %>% ggplot(aes(x = capital_loss, fill = class)) + geom_density(alpha = 0.8) +
ggtitle("高收入群体capital_loss在1900处出现峰值") + labs(y = NULL)
Rmisc::multiplot(capital_gain, capital_loss)非连续变量分析
Workclass
Var1 Freq
1 Private 22696
2 Self-emp-not-inc 2541
3 Local-gov 2093
4 ? 1836
5 State-gov 1298
6 Self-emp-inc 1116
7 Federal-gov 960
8 Without-pay 14
9 Never-worked 7
data <- data %>% mutate(workclass = ifelse(workclass == "?", "unknown", workclass)) %>%
filter(!workclass %in% c("Without-pay", "Never-worked"))
data %>% group_by(class, workclass) %>% summarise(n = n()) %>% ggplot(aes(workclass,
n, fill = class)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.ticks.length = unit(0.5,
"cm"), axis.text.x = element_text(angle = 330)) + guides(fill = guide_legend(title = NULL)) +
coord_flip() + labs(x = NULL, y = NULL) + ggtitle("Private职业低收入群体远大于高收入群体")Marital Status & Relationship
Var1 Freq
1 Married-civ-spouse 14967
2 Never-married 10674
3 Divorced 4442
4 Separated 1025
5 Widowed 992
6 Married-spouse-absent 417
7 Married-AF-spouse 23
data %>% filter(marital_status != "Married-AF-spouse", age > 20) %>% group_by(class,
marital_status) %>% summarise(n = n()) %>% ggplot(aes(marital_status, n,
fill = class)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.ticks.length = unit(0.5,
"cm"), axis.text.x = element_text(angle = 330)) + guides(fill = guide_legend(title = NULL)) +
coord_flip() + labs(x = NULL, y = NULL) + ggtitle("大部分高收入人群已婚") Var1 Freq
1 Husband 13189
2 Not-in-family 8304
3 Own-child 5058
4 Unmarried 3444
5 Wife 1564
6 Other-relative 981
data %>% group_by(class, relationship) %>% summarise(n = n()) %>% ggplot(aes(relationship,
n, fill = class)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.ticks.length = unit(0.5,
"cm"), axis.text.x = element_text(angle = 330)) + guides(fill = guide_legend(title = NULL)) +
coord_flip() + labs(x = NULL, y = NULL) + ggtitle("拥有孩子的家庭绝大部分属于低收入人群!")Occupation
Var1 Freq
1 Prof-specialty 4140
2 Craft-repair 4098
3 Exec-managerial 4066
4 Adm-clerical 3767
5 Sales 3650
6 Other-service 3294
7 Machine-op-inspct 2001
8 ? 1836
9 Transport-moving 1596
10 Handlers-cleaners 1369
11 Farming-fishing 988
12 Tech-support 928
13 Protective-serv 649
14 Priv-house-serv 149
15 Armed-Forces 9
data <- data %>% mutate(occupation = ifelse(occupation == "?", "unknown", occupation))
occ <- as.data.frame(table(data$occupation)) %>% arrange(desc(Freq))
data %>% filter(occupation %in% head(occ, 7)$Var1) %>% group_by(class, occupation) %>%
summarise(n = n()) %>% ggplot(aes(occupation, n, fill = class)) + geom_bar(stat = "identity",
position = "dodge") + theme(axis.ticks.length = unit(0.5, "cm"), axis.text.x = element_text(angle = 330)) +
guides(fill = guide_legend(title = NULL)) + coord_flip() + labs(x = NULL,
y = NULL) + ggtitle("Prof-specialty & Exec-managerial 盛产高收入人员")Race & Native Country
Var1 Freq
1 White 27799
2 Black 3121
3 Asian-Pac-Islander 1038
4 Amer-Indian-Eskimo 311
5 Other 271
data %>% filter(!race %in% c("Amer-Indian-Eskimo", "Other")) %>% group_by(class,
race) %>% summarise(n = n()) %>% ggplot(aes(race, n, fill = class)) + geom_bar(stat = "identity",
position = "dodge") + theme(axis.ticks.length = unit(0.5, "cm"), axis.text.x = element_text(angle = 330)) +
guides(fill = guide_legend(title = NULL)) + coord_flip() + labs(x = NULL,
y = NULL) + ggtitle("收入情况与种族并无明显差别", subtitle = "数据内种族比例失调") Var1 Freq
1 United-States 29150
2 Mexico 643
3 ? 583
4 Philippines 197
5 Germany 137
6 Canada 121
7 Puerto-Rico 114
8 El-Salvador 106
9 India 100
10 Cuba 95
11 England 90
12 Jamaica 81
13 South 80
14 China 75
15 Italy 73
16 Dominican-Republic 70
17 Vietnam 67
18 Guatemala 64
19 Japan 62
20 Poland 60
21 Columbia 59
22 Taiwan 51
23 Haiti 44
24 Iran 43
25 Portugal 37
26 Nicaragua 34
27 Peru 31
28 France 29
29 Greece 29
30 Ecuador 28
31 Ireland 24
32 Hong 20
33 Cambodia 19
34 Trinadad&Tobago 19
35 Laos 18
36 Thailand 18
37 Yugoslavia 16
[ reached 'max' / getOption("max.print") -- omitted 5 rows ]
data %>% filter(native_country == "United-States") %>% group_by(class) %>% summarise(n = n()) %>%
ggplot(aes(class, n, fill = class)) + geom_bar(stat = "identity", position = "dodge") +
theme(axis.ticks.length = unit(0.5, "cm")) + guides(fill = guide_legend(title = NULL)) +
labs(x = NULL, y = NULL) + ggtitle("美国国籍收入情况", subtitle = "低收入群体占据大多数")data %>% filter(!native_country == "United-States") %>% group_by(class, native_country) %>%
summarise(n = n()) %>% arrange(desc(n)) %>% ggplot(aes(native_country, n,
fill = class)) + geom_bar(stat = "identity", position = "dodge") + theme(axis.ticks.length = unit(0.5,
"cm"), axis.text.x = element_text(angle = 330)) + guides(fill = guide_legend(title = NULL)) +
coord_flip() + labs(x = NULL, y = NULL)Sex
data %>% group_by(sex) %>% summarise(n = n()) %>% ggplot(aes(x = "", y = n,
fill = sex)) + geom_bar(stat = "identity", width = 2) + coord_polar("y") +
geom_text(aes(label = n), position = position_stack(vjust = 0.5), check_overlap = T,
size = 5) + labs(x = NULL, y = NULL, fill = NULL, title = "男女占比") +
theme(axis.line = element_blank(), axis.text = element_blank(), axis.ticks = element_blank(),
plot.title = element_text(size = 14))male <- data %>% filter(sex == "Male") %>% group_by(class) %>% summarise(n = n()) %>%
ggplot(aes(x = "", y = n, fill = class)) + geom_bar(stat = "identity", width = 2) +
coord_polar("y") + geom_text(aes(label = paste(100 * round(n/21776, 2),
"%")), position = position_stack(vjust = 0.5), check_overlap = T, size = 5) +
labs(x = NULL, y = NULL, fill = NULL, title = "男性收入分布结构") + theme(axis.line = element_blank(),
axis.text = element_blank(), axis.ticks = element_blank(), plot.title = element_text(size = 14))
female <- data %>% filter(!sex == "Male") %>% group_by(class) %>% summarise(n = n()) %>%
ggplot(aes(x = "", y = n, fill = class)) + geom_bar(stat = "identity", width = 2) +
coord_polar("y") + geom_text(aes(label = paste(100 * round(n/10764, 2),
"%")), position = position_stack(vjust = 0.5), check_overlap = T, size = 5) +
labs(x = NULL, y = NULL, fill = NULL, title = "女性收入分布结构") + theme(axis.line = element_blank(),
axis.text = element_blank(), axis.ticks = element_blank(), plot.title = element_text(size = 14))
Rmisc::multiplot(male, female, cols = 2)模型处理
训练集与测试集(未做平衡处理)
set.seed(1234)
index <- createDataPartition(data$class, p = 0.8, list = F)
# char to factor
data$workclass <- factor(data$workclass)
data$education <- factor(data$education)
data$marital_status <- factor(data$marital_status)
data$occupation <- factor(data$occupation)
data$relationship <- factor(data$relationship)
data$race <- factor(data$race)
data$sex <- factor(data$sex)
data$native_country <- factor(data$native_country)
traindata <- data[index, ]
testdata <- data[-index, ]随机森林
Model
Call:
randomForest(x = traindata[, -15], y = traindata$class, ntree = 300)
Type of random forest: classification
Number of trees: 300
No. of variables tried at each split: 3
OOB estimate of error rate: 13.69%
Confusion matrix:
<=50K >50K class.error
<=50K 18438 1322 0.06690283
>50K 2241 4032 0.35724534
Confusion Matrix & ROC Curve
Confusion Matrix and Statistics
Reference
Prediction <=50K >50K
<=50K 4650 577
>50K 289 991
Accuracy : 0.8669
95% CI : (0.8584, 0.8751)
No Information Rate : 0.759
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.6119
Mcnemar's Test P-Value : < 2.2e-16
Sensitivity : 0.9415
Specificity : 0.6320
Pos Pred Value : 0.8896
Neg Pred Value : 0.7742
Prevalence : 0.7590
Detection Rate : 0.7146
Detection Prevalence : 0.8033
Balanced Accuracy : 0.7868
'Positive' Class : <=50K
pred_rf_roc <- ifelse(pred_rf == "<=50K", 1, 0)
roc_rf <- roc(testdata$class, pred_rf_roc)
Specificity_rf <- roc_rf$specificities # 为后续的横纵坐标轴奠基,真反例率
Sensitivity_rf <- roc_rf$sensitivities
p_rf <- ggplot(data = NULL, aes(x = 1 - Specificity_rf, y = Sensitivity_rf)) +
geom_line(size = 1) + geom_abline(size = 1) + annotate("text", x = 0.4,
y = 0.5, label = paste("AUC = ", round(roc_rf$auc, 3))) + labs(x = "1 - Specificity",
y = "Sensitivities")
p_rf生成平衡数据
# =======================================================
low <- data %>% filter(class == "<=50K")
high <- data %>% filter(!class == "<=50K")
set.seed(1234)
index_low <- sample(x = 1:nrow(low), size = nrow(high))
low_new <- low[index_low, ]
data_balance <- rbind(low_new, high)
set.seed(1234)
index2 <- createDataPartition(data_balance$class, p = 0.8, list = F)
traindata2 <- data_balance[index2, ]
testdata2 <- data_balance[-index2, ]模型改进
# =======================================================
set.seed(1234)
model_rf2 <- randomForest(traindata2[, -15], traindata2$class, ntree = 500)
model_rf2
Call:
randomForest(x = traindata2[, -15], y = traindata2$class, ntree = 500)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 3
OOB estimate of error rate: 16.85%
Confusion matrix:
<=50K >50K class.error
<=50K 4927 1346 0.2145704
>50K 768 5505 0.1224295
Confusion Matrix and Statistics
Reference
Prediction <=50K >50K
<=50K 1207 187
>50K 361 1381
Accuracy : 0.8253
95% CI : (0.8115, 0.8384)
No Information Rate : 0.5
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.6505
Mcnemar's Test P-Value : 1.466e-13
Sensitivity : 0.7698
Specificity : 0.8807
Pos Pred Value : 0.8659
Neg Pred Value : 0.7928
Prevalence : 0.5000
Detection Rate : 0.3849
Detection Prevalence : 0.4445
Balanced Accuracy : 0.8253
'Positive' Class : <=50K
pred_rf2_roc <- ifelse(pred_rf2 == "<=50K", 1, 0)
roc_rf2 <- roc(testdata2$class, pred_rf2_roc)
Specificity_rf2 <- roc_rf2$specificities # 为后续的横纵坐标轴奠基,真反例率
Sensitivity_rf2 <- roc_rf2$sensitivities
p_rf2 <- ggplot(data = NULL, aes(x = 1 - Specificity_rf2, y = Sensitivity_rf2)) +
geom_line(size = 1) + geom_abline(size = 1) + annotate("text", x = 0.4,
y = 0.5, label = paste("AUC = ", round(roc_rf2$auc, 3))) + labs(x = "1 - Specificity",
y = "Sensitivities")
p_rf2Tune RF
tunerf <- tuneRF(traindata2[, -15], traindata2[, 15], stepFactor = 0.5, plot = T,
ntreeTry = 300, trace = T, improve = 0.05)mtry = 3 OOB error = 16.83%
Searching left ...
mtry = 6 OOB error = 17.87%
-0.0620559 0.05
Searching right ...
mtry = 1 OOB error = 16.98%
-0.009000474 0.05
XGBoost
# ===============================XGBoost========================================
dataset <- one_hot(as.data.table(data_balance[-15]))
dataset <- dataset %>% mutate(class = data_balance$class) %>% select(class,
everything())
label <- dataset$class
dataset$class <- NULL
label <- as.integer(label) - 1
n = nrow(dataset)
set.seed(1234)
train.index <- sample(n, floor(0.8 * n))
train.data <- as.matrix(dataset[train.index, ])
train.label <- label[train.index]
test.data <- as.matrix(dataset[-train.index, ])
test.label <- label[-train.index]
# Create the xgb.DMatrix objects
xgb.train = xgb.DMatrix(data = train.data, label = train.label)
xgb.test = xgb.DMatrix(data = test.data, label = test.label)
# Train
params = list(booster = "gbtree", eta = 0.001, max_depth = 5, gamma = 3, subsample = 0.75,
colsample_bytree = 1, objective = "multi:softprob", eval_metric = "mlogloss",
num_class = 2)
xgb.fit <- xgb.train(params = params, data = xgb.train, nrounds = 10000, nthreads = 1,
early_stopping_rounds = 10, watchlist = list(val1 = xgb.train, val2 = xgb.test),
verbose = 0)
# Predict
xgb.pred <- predict(xgb.fit, test.data, reshape = T)
xgb.pred <- as.data.frame(xgb.pred)
colnames(xgb.pred) <- c("<=50K", ">50K")
# identify
xgb.pred$prediction = apply(xgb.pred, 1, function(x) colnames(xgb.pred)[which.max(x)])
xgb.pred$prediction <- ifelse(xgb.pred$prediction == "<=50K", 0, 1)
sum(xgb.pred$prediction == test.label)/nrow(xgb.pred)[1] 0.8383806
Confusion Matrix and Statistics
Reference
Prediction 0 1
0 1294 192
1 315 1336
Accuracy : 0.8384
95% CI : (0.825, 0.8511)
No Information Rate : 0.5129
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.6772
Mcnemar's Test P-Value : 6.02e-08
Sensitivity : 0.8042
Specificity : 0.8743
Pos Pred Value : 0.8708
Neg Pred Value : 0.8092
Prevalence : 0.5129
Detection Rate : 0.4125
Detection Prevalence : 0.4737
Balanced Accuracy : 0.8393
'Positive' Class : 0
roc_xgb <- roc(test.label, xgb.pred$prediction)
# plot
Specificity_xgb <- roc_xgb$specificities # 为后续的横纵坐标轴奠基,真反例率
Sensitivity_xgb <- roc_xgb$sensitivities
p_xgb <- ggplot(data = NULL, aes(x = 1 - Specificity_xgb, y = Sensitivity_xgb)) +
geom_line(size = 1) + geom_abline(size = 1) + annotate("text", x = 0.4,
y = 0.5, label = paste("AUC = ", round(roc_xgb$auc, 3))) + labs(x = "1 - Specificity",
y = "Sensitivities")
p_xgb