Celem projektu jest zbudowanie modelu predykcyjnego, który na podstawie cech demograficznych i zawodowych osoby przewiduje, czy jej roczny dochód przekracza 50 000 USD. Dane pochodzą ze zbioru Adult Census Income (UCI Machine Learning Repository) i zawierają ponad 30 000 obserwacji.
Zmienna docelowa to class (wartości
<=50K / >50K), więc jest to problem
klasyfikacji binarnej.
W projekcie porównujemy pięć modeli: regresję logistyczną jako punkt odniesienia, drzewo klasyfikacyjne (CART), Random Forest (z tuningiem), XGBoost oraz sieć neuronową (nnet).
library(tidyverse)
library(caret)
library(pROC)
library(rpart)
library(rpart.plot)
library(randomForest)
library(ranger)
library(xgboost)
library(nnet)
adult.raw <- read.csv("k1_adult.csv", stringsAsFactors = FALSE, sep = ",")
glimpse(adult.raw)
## Rows: 30,162
## Columns: 18
## $ class <chr> "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", …
## $ age <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status <chr> "Never-married", "Married-civ-spouse", "Divorced", "Ma…
## $ occupation <chr> "Adm-clerical", "Exec-managerial", "Handlers-cleaners"…
## $ relationship <chr> "Not-in-family", "Husband", "Not-in-family", "Husband"…
## $ race <chr> "White", "White", "White", "Black", "Black", "White", …
## $ sex <chr> "Male", "Male", "Male", "Male", "Female", "Female", "F…
## $ hours_per_week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass <chr> "State-gov", "Self-emp-not-inc", "Private", "Private",…
## $ native_country <chr> "United-States", "United-States", "United-States", "Un…
## $ net_capital <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …
summary(adult.raw)
## class age education_num marital_status
## Length :30162 Min. :17.00 Min. : 1.00 Length :30162
## N.unique : 2 1st Qu.:28.00 1st Qu.: 9.00 N.unique : 7
## N.blank : 0 Median :37.00 Median :10.00 N.blank : 0
## Min.nchar: 4 Mean :38.44 Mean :10.12 Min.nchar: 7
## Max.nchar: 5 3rd Qu.:47.00 3rd Qu.:13.00 Max.nchar: 21
## Max. :90.00 Max. :16.00
## occupation relationship race sex
## Length :30162 Length :30162 Length :30162 Length :30162
## N.unique : 14 N.unique : 6 N.unique : 5 N.unique : 2
## N.blank : 0 N.blank : 0 N.blank : 0 N.blank : 0
## Min.nchar: 5 Min.nchar: 4 Min.nchar: 5 Min.nchar: 4
## Max.nchar: 17 Max.nchar: 14 Max.nchar: 18 Max.nchar: 6
##
## hours_per_week capital_gain capital_loss workclass
## Min. : 1.00 Min. : 0 Min. : 0.00 Length :30162
## 1st Qu.:40.00 1st Qu.: 0 1st Qu.: 0.00 N.unique : 7
## Median :40.00 Median : 0 Median : 0.00 N.blank : 0
## Mean :40.93 Mean : 1092 Mean : 88.37 Min.nchar: 7
## 3rd Qu.:45.00 3rd Qu.: 0 3rd Qu.: 0.00 Max.nchar: 16
## Max. :99.00 Max. :99999 Max. :4356.00
## native_country net_capital capital_flag hours_x_edu
## Length :30162 Min. :-4356 Min. :0.00000 Min. : 6.0
## N.unique : 41 1st Qu.: 0 1st Qu.:0.00000 1st Qu.: 342.0
## N.blank : 0 Median : 0 Median :0.00000 Median : 400.0
## Min.nchar: 4 Mean : 1004 Mean :0.08415 Mean : 418.9
## Max.nchar: 26 3rd Qu.: 0 3rd Qu.:0.00000 3rd Qu.: 520.0
## Max. :99999 Max. :1.00000 Max. :1584.0
## mid_age high_work_hours
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :1.0000 Median :0.0000
## Mean :0.5948 Mean :0.3049
## 3rd Qu.:1.0000 3rd Qu.:1.0000
## Max. :1.0000 Max. :1.0000
cat("Liczba obserwacji:", nrow(adult.raw), "\n")
## Liczba obserwacji: 30162
cat("Duplikaty:", sum(duplicated(adult.raw)), "\n")
## Duplikaty: 3258
cat("Braki danych:\n")
## Braki danych:
colSums(is.na(adult.raw))
## class age education_num marital_status occupation
## 0 0 0 0 0
## relationship race sex hours_per_week capital_gain
## 0 0 0 0 0
## capital_loss workclass native_country net_capital capital_flag
## 0 0 0 0 0
## hours_x_edu mid_age high_work_hours
## 0 0 0
Dane są kompletne — brak braków i duplikatów.
tabela_y <- as.data.frame(table(adult.raw$class))
tabela_y$Procent <- round(tabela_y$Freq / sum(tabela_y$Freq) * 100, 1)
colnames(tabela_y) <- c("Klasa", "N", "Procent")
print(tabela_y)
## Klasa N Procent
## 1 <=50K 22654 75.1
## 2 >50K 7508 24.9
ggplot(adult.raw, aes(x = class, fill = class)) +
geom_bar(width = 0.5, color = "white") +
geom_text(stat = "count",
aes(label = paste0(..count.., "\n(", round(..count../nrow(adult.raw)*100, 1), "%)")),
vjust = -0.3, size = 4) +
scale_fill_manual(values = c("<=50K" = "steelblue", ">50K" = "coral")) +
labs(title = "Rozkład zmiennej docelowej",
x = "Klasa dochodu", y = "Liczba obserwacji") +
theme_minimal() +
theme(legend.position = "none")
Dane są niezbalansowane — klasa
<=50K stanowi ok. 75% obserwacji. Dlatego jako główną
metrykę wybieramy ROC-AUC, a nie accuracy, oraz
stosujemy wagi klas przy trenowaniu modeli.
adult.raw %>%
select(age, education_num, hours_per_week, capital_gain, capital_loss,
net_capital, class) %>%
pivot_longer(-class, names_to = "zmienna", values_to = "wartosc") %>%
ggplot(aes(x = wartosc, fill = class)) +
geom_histogram(bins = 30, alpha = 0.7, position = "identity", color = "white") +
facet_wrap(~ zmienna, scales = "free", ncol = 3) +
scale_fill_manual(values = c("<=50K" = "steelblue", ">50K" = "coral")) +
labs(title = "Rozkłady zmiennych numerycznych według klasy dochodu",
x = NULL, y = "Liczba obserwacji", fill = "Dochód") +
theme_minimal() +
theme(legend.position = "bottom")
adult.raw %>%
select(marital_status, occupation, relationship, sex, workclass, class) %>%
pivot_longer(-class, names_to = "zmienna", values_to = "kategoria") %>%
group_by(zmienna, kategoria, class) %>%
summarise(n = n(), .groups = "drop") %>%
group_by(zmienna, kategoria) %>%
mutate(prop_50k = sum(n[class == ">50K"]) / sum(n)) %>%
filter(class == ">50K") %>%
ggplot(aes(x = reorder(kategoria, prop_50k), y = prop_50k)) +
geom_col(fill = "coral") +
geom_hline(yintercept = 0.249, linetype = "dashed", color = "gray50") +
coord_flip() +
facet_wrap(~ zmienna, scales = "free_y", ncol = 2) +
scale_y_continuous(labels = scales::percent_format()) +
labs(title = "Odsetek zarabiających >50K w każdej kategorii",
subtitle = "Przerywana linia = odsetek globalny (~25%)",
x = NULL, y = "Odsetek >50K") +
theme_minimal(base_size = 10)
adult.prep <- adult.raw %>%
mutate(
# Zmienna docelowa jako factor 0/1
income = factor(ifelse(class == ">50K", "wyzszy", "nizszy"),
levels = c("nizszy", "wyzszy"))
) %>%
select(-class) %>%
mutate(across(c(marital_status, occupation, relationship,
race, sex, workclass, native_country), as.factor))
glimpse(adult.prep)
## Rows: 30,162
## Columns: 18
## $ age <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status <fct> Never-married, Married-civ-spouse, Divorced, Married-c…
## $ occupation <fct> Adm-clerical, Exec-managerial, Handlers-cleaners, Hand…
## $ relationship <fct> Not-in-family, Husband, Not-in-family, Husband, Wife, …
## $ race <fct> White, White, White, Black, Black, White, Black, White…
## $ sex <fct> Male, Male, Male, Male, Female, Female, Female, Male, …
## $ hours_per_week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass <fct> State-gov, Self-emp-not-inc, Private, Private, Private…
## $ native_country <fct> United-States, United-States, United-States, United-St…
## $ net_capital <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ income <fct> nizszy, nizszy, nizszy, nizszy, nizszy, nizszy, nizszy…
cat("\nRozkład zmiennej docelowej:\n")
##
## Rozkład zmiennej docelowej:
print(prop.table(table(adult.prep$income)))
##
## nizszy wyzszy
## 0.7510775 0.2489225
Zmienne kategoryczne kodujemy jako factory — caret
automatycznie stosuje one-hot encoding przy trenowaniu modeli, które
tego wymagają.
# Split before you model!
set.seed(42)
idx.train <- createDataPartition(adult.prep$income, p = 0.75, list = FALSE)
adult.train <- adult.prep[ idx.train, ]
adult.test <- adult.prep[-idx.train, ]
cat("Train:", nrow(adult.train), "| Test:", nrow(adult.test), "\n")
## Train: 22622 | Test: 7540
real.train <- adult.train$income
real.test <- adult.test$income
# Obsługa niezbalansowania: wagi odwrotnie proporcjonalne do liczebności
n_total <- nrow(adult.train)
n_nizszy <- sum(adult.train$income == "nizszy")
n_wyzszy <- sum(adult.train$income == "wyzszy")
waga_nizszy <- n_total / (2 * n_nizszy)
waga_wyzszy <- n_total / (2 * n_wyzszy)
cat(sprintf("Waga klasy 'nizszy' (<=50K): %.4f\n", waga_nizszy))
## Waga klasy 'nizszy' (<=50K): 0.6657
cat(sprintf("Waga klasy 'wyzszy' (>50K): %.4f\n", waga_wyzszy))
## Waga klasy 'wyzszy' (>50K): 2.0087
sample_weights <- ifelse(adult.train$income == "wyzszy", waga_wyzszy, waga_nizszy)
klasyfikacja.metryki <- function(predicted_probabilities,
real,
cutoff = 0.5,
level_positive = "wyzszy",
level_negative = "nizszy") {
predicted_class <- ifelse(predicted_probabilities > cutoff,
level_positive, level_negative)
predicted_class <- factor(predicted_class,
levels = c(level_negative, level_positive))
real <- factor(real, levels = c(level_negative, level_positive))
ctable <- confusionMatrix(data = predicted_class,
reference = real,
positive = level_positive)
roc_obj <- roc(response = as.numeric(real == level_positive),
predictor = predicted_probabilities,
quiet = TRUE)
data.frame(
AUC = round(as.numeric(auc(roc_obj)), 4),
Accuracy = round(unname(ctable$overall["Accuracy"]), 4),
Sensitivity = round(unname(ctable$byClass["Sensitivity"]), 4),
Specificity = round(unname(ctable$byClass["Specificity"]), 4),
F1 = round(unname(ctable$byClass["F1"]), 4)
)
}
set.seed(42)
adult.logit <- train(
income ~ .,
data = adult.train,
method = "glm",
family = "binomial",
trControl = trainControl(method = "cv", number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary),
metric = "ROC",
weights = sample_weights
)
summary(adult.logit$finalModel)
##
## Call:
## NULL
##
## Coefficients: (2 not defined because of singularities)
## Estimate Std. Error z value
## (Intercept) -8.33961359 0.84053725 -9.922
## age 0.03303479 0.00188352 17.539
## education_num 0.31828812 0.03319493 9.588
## `marital_statusMarried-AF-spouse` 3.15816626 0.60312866 5.236
## `marital_statusMarried-civ-spouse` 2.43084314 0.27143371 8.956
## `marital_statusMarried-spouse-absent` 0.17638658 0.22468476 0.785
## `marital_statusNever-married` -0.39596977 0.08626223 -4.590
## marital_statusSeparated -0.07810964 0.15207493 -0.514
## marital_statusWidowed 0.30740033 0.15473913 1.987
## `occupationArmed-Forces` -1.53154022 1.38432414 -1.106
## `occupationCraft-repair` 0.11733372 0.08442588 1.390
## `occupationExec-managerial` 0.80134180 0.08344880 9.603
## `occupationFarming-fishing` -1.23836519 0.14700485 -8.424
## `occupationHandlers-cleaners` -0.57656470 0.14308674 -4.029
## `occupationMachine-op-inspct` -0.18347804 0.10367172 -1.770
## `occupationOther-service` -0.64801545 0.11590006 -5.591
## `occupationPriv-house-serv` -8.11863550 9.50842799 -0.854
## `occupationProf-specialty` 0.64551891 0.08626117 7.483
## `occupationProtective-serv` 0.78808126 0.13818895 5.703
## occupationSales 0.36083406 0.08858416 4.073
## `occupationTech-support` 0.73809843 0.12058376 6.121
## `occupationTransport-moving` 0.06293983 0.10337770 0.609
## `relationshipNot-in-family` 0.78048829 0.26998644 2.891
## `relationshipOther-relative` -0.00663434 0.25492783 -0.026
## `relationshipOwn-child` -0.33180386 0.26948637 -1.231
## relationshipUnmarried 0.52239942 0.28426215 1.838
## relationshipWife 1.43477151 0.10932395 13.124
## `raceAsian-Pac-Islander` 0.92004698 0.29473689 3.122
## raceBlack 0.54932023 0.23916868 2.297
## raceOther 0.12613396 0.38372181 0.329
## raceWhite 0.64509319 0.22799416 2.829
## sexMale 0.87997054 0.07714557 11.407
## hours_per_week 0.03139478 0.00798781 3.930
## capital_gain 0.00070871 0.00004017 17.645
## capital_loss 0.00064463 0.00004427 14.560
## `workclassLocal-gov` -0.69574807 0.12389796 -5.615
## workclassPrivate -0.52362809 0.10372326 -5.048
## `workclassSelf-emp-inc` -0.29145811 0.14081193 -2.070
## `workclassSelf-emp-not-inc` -0.86074593 0.12088284 -7.120
## `workclassState-gov` -0.86276960 0.13660362 -6.316
## `workclassWithout-pay` -12.96308275 129.58912739 -0.100
## native_countryCanada -0.81435417 0.71408920 -1.140
## native_countryChina -2.19176674 0.73623701 -2.977
## native_countryColumbia -2.99781073 0.92125816 -3.254
## native_countryCuba -0.42309644 0.72250426 -0.586
## `native_countryDominican-Republic` -2.95434516 1.01901771 -2.899
## native_countryEcuador -2.95809466 1.14434520 -2.585
## `native_countryEl-Salvador` -1.87761254 0.81327129 -2.309
## native_countryEngland -0.85238264 0.73218983 -1.164
## native_countryFrance -0.75009159 0.86509537 -0.867
## native_countryGermany -0.74809445 0.70352131 -1.063
## native_countryGreece -2.22168325 0.88688938 -2.505
## native_countryGuatemala -0.60004779 0.91325292 -0.657
## native_countryHaiti -0.80406396 0.89430186 -0.899
## `native_countryHoland-Netherlands` NA NA NA
## native_countryHonduras -2.72593969 2.57442836 -1.059
## native_countryHong -1.13811989 0.88997858 -1.279
## native_countryHungary -0.67415920 1.02123578 -0.660
## native_countryIndia -1.22118198 0.70449776 -1.733
## native_countryIran -1.17971317 0.79052578 -1.492
## native_countryIreland -0.14403721 0.90509750 -0.159
## native_countryItaly -0.19673971 0.74287318 -0.265
## native_countryJamaica -0.80082113 0.77083613 -1.039
## native_countryJapan -0.92504253 0.75943424 -1.218
## native_countryLaos -1.01431117 1.00026394 -1.014
## native_countryMexico -1.37662617 0.67515968 -2.039
## native_countryNicaragua -1.79509743 1.12673633 -1.593
## `native_countryOutlying-US(Guam-USVI-etc)` -12.86671562 172.42426031 -0.075
## native_countryPeru -1.56766889 1.05654676 -1.484
## native_countryPhilippines -0.78028396 0.65908800 -1.184
## native_countryPoland -1.32237104 0.78045299 -1.694
## native_countryPortugal -1.74085600 0.98098724 -1.775
## `native_countryPuerto-Rico` -1.40518633 0.77129132 -1.822
## native_countryScotland -1.48487119 1.08060134 -1.374
## native_countrySouth -2.55095426 0.77970247 -3.272
## native_countryTaiwan -1.08390373 0.82320854 -1.317
## native_countryThailand -1.96473337 1.01007061 -1.945
## `native_countryTrinadad&Tobago` -1.15547083 1.02137799 -1.131
## `native_countryUnited-States` -0.95035085 0.64462900 -1.474
## native_countryVietnam -1.66960191 0.82634817 -2.020
## native_countryYugoslavia -0.00442209 0.94555926 -0.005
## net_capital NA NA NA
## capital_flag -2.55761855 0.21320137 -11.996
## hours_x_edu -0.00101389 0.00071822 -1.412
## mid_age 0.68990030 0.04526037 15.243
## high_work_hours 0.35096864 0.05663846 6.197
## Pr(>|z|)
## (Intercept) < 0.0000000000000002 ***
## age < 0.0000000000000002 ***
## education_num < 0.0000000000000002 ***
## `marital_statusMarried-AF-spouse` 0.0000001638222307 ***
## `marital_statusMarried-civ-spouse` < 0.0000000000000002 ***
## `marital_statusMarried-spouse-absent` 0.43243
## `marital_statusNever-married` 0.0000044260265784 ***
## marital_statusSeparated 0.60751
## marital_statusWidowed 0.04697 *
## `occupationArmed-Forces` 0.26858
## `occupationCraft-repair` 0.16459
## `occupationExec-managerial` < 0.0000000000000002 ***
## `occupationFarming-fishing` < 0.0000000000000002 ***
## `occupationHandlers-cleaners` 0.0000559011122535 ***
## `occupationMachine-op-inspct` 0.07676 .
## `occupationOther-service` 0.0000000225561257 ***
## `occupationPriv-house-serv` 0.39320
## `occupationProf-specialty` 0.0000000000000725 ***
## `occupationProtective-serv` 0.0000000117768491 ***
## occupationSales 0.0000463421866958 ***
## `occupationTech-support` 0.0000000009296472 ***
## `occupationTransport-moving` 0.54263
## `relationshipNot-in-family` 0.00384 **
## `relationshipOther-relative` 0.97924
## `relationshipOwn-child` 0.21823
## relationshipUnmarried 0.06610 .
## relationshipWife < 0.0000000000000002 ***
## `raceAsian-Pac-Islander` 0.00180 **
## raceBlack 0.02163 *
## raceOther 0.74237
## raceWhite 0.00466 **
## sexMale < 0.0000000000000002 ***
## hours_per_week 0.0000848271728951 ***
## capital_gain < 0.0000000000000002 ***
## capital_loss < 0.0000000000000002 ***
## `workclassLocal-gov` 0.0000000196002970 ***
## workclassPrivate 0.0000004457149139 ***
## `workclassSelf-emp-inc` 0.03847 *
## `workclassSelf-emp-not-inc` 0.0000000000010754 ***
## `workclassState-gov` 0.0000000002686600 ***
## `workclassWithout-pay` 0.92032
## native_countryCanada 0.25412
## native_countryChina 0.00291 **
## native_countryColumbia 0.00114 **
## native_countryCuba 0.55815
## `native_countryDominican-Republic` 0.00374 **
## native_countryEcuador 0.00974 **
## `native_countryEl-Salvador` 0.02096 *
## native_countryEngland 0.24436
## native_countryFrance 0.38591
## native_countryGermany 0.28762
## native_countryGreece 0.01224 *
## native_countryGuatemala 0.51115
## native_countryHaiti 0.36860
## `native_countryHoland-Netherlands` NA
## native_countryHonduras 0.28967
## native_countryHong 0.20096
## native_countryHungary 0.50916
## native_countryIndia 0.08302 .
## native_countryIran 0.13562
## native_countryIreland 0.87356
## native_countryItaly 0.79114
## native_countryJamaica 0.29885
## native_countryJapan 0.22320
## native_countryLaos 0.31056
## native_countryMexico 0.04145 *
## native_countryNicaragua 0.11112
## `native_countryOutlying-US(Guam-USVI-etc)` 0.94052
## native_countryPeru 0.13787
## native_countryPhilippines 0.23646
## native_countryPoland 0.09020 .
## native_countryPortugal 0.07596 .
## `native_countryPuerto-Rico` 0.06848 .
## native_countryScotland 0.16941
## native_countrySouth 0.00107 **
## native_countryTaiwan 0.18795
## native_countryThailand 0.05176 .
## `native_countryTrinadad&Tobago` 0.25793
## `native_countryUnited-States` 0.14041
## native_countryVietnam 0.04334 *
## native_countryYugoslavia 0.99627
## net_capital NA
## capital_flag < 0.0000000000000002 ***
## hours_x_edu 0.15805
## mid_age < 0.0000000000000002 ***
## high_work_hours 0.0000000005767818 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 31361 on 22621 degrees of freedom
## Residual deviance: 16811 on 22538 degrees of freedom
## AIC: 21383
##
## Number of Fisher Scoring iterations: 12
pred.logit.train <- predict(adult.logit, adult.train, type = "prob")[, "wyzszy"]
pred.logit.test <- predict(adult.logit, adult.test, type = "prob")[, "wyzszy"]
cat("=== Regresja logistyczna: train vs test ===\n")
## === Regresja logistyczna: train vs test ===
rbind(train = klasyfikacja.metryki(pred.logit.train, real.train),
test = klasyfikacja.metryki(pred.logit.test, real.test))
## AUC Accuracy Sensitivity Specificity F1
## train 0.9124 0.8151 0.8530 0.8025 0.6966
## test 0.9108 0.8182 0.8434 0.8098 0.6978
ctrl.class <- trainControl(method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary)
grid.cart <- expand.grid(cp = c(0.0001, 0.0005, 0.001, 0.005, 0.01))
set.seed(42)
adult.cart <- train(income ~ ., data = adult.train, method = "rpart",
metric = "ROC", trControl = ctrl.class,
tuneGrid = grid.cart, weights = sample_weights)
saveRDS(adult.cart, "adult.cart.rds")
adult.cart <- readRDS("adult.cart.rds")
cat("Najlepsze cp:", adult.cart$bestTune$cp, "\n")
## Najlepsze cp: 0.0001
rpart.plot(adult.cart$finalModel,
type = 4,
extra = 104,
under = TRUE,
tweak = 1.1,
main = "Drzewo klasyfikacyjne CART (Adult Income)")
imp.cart <- varImp(adult.cart)$importance
imp.cart$Zmienna <- rownames(imp.cart)
imp.cart %>%
arrange(desc(Overall)) %>%
head(12) %>%
ggplot(aes(x = reorder(Zmienna, Overall), y = Overall)) +
geom_col(fill = "steelblue") +
coord_flip() +
labs(title = "Ważność zmiennych — CART",
x = NULL, y = "Importance") +
theme_minimal()
pred.cart.train <- predict(adult.cart, adult.train, type = "prob")[, "wyzszy"]
pred.cart.test <- predict(adult.cart, adult.test, type = "prob")[, "wyzszy"]
cat("=== CART: train vs test ===\n")
## === CART: train vs test ===
rbind(train = klasyfikacja.metryki(pred.cart.train, real.train),
test = klasyfikacja.metryki(pred.cart.test, real.test))
## AUC Accuracy Sensitivity Specificity F1
## train 0.9288 0.8482 0.9052 0.8294 0.7481
## test 0.8964 0.8150 0.8348 0.8084 0.6920
grid.rf <- expand.grid(mtry = c(3, 5, 7),
splitrule = "gini",
min.node.size = c(5, 10))
set.seed(42)
adult.rf <- train(income ~ ., data = adult.train, method = "ranger",
metric = "ROC", trControl = ctrl.class,
tuneGrid = grid.rf, num.trees = 300,
importance = "impurity", weights = sample_weights)
saveRDS(adult.rf, "adult.rf.rds")
adult.rf <- readRDS("adult.rf.rds")
cat("Najlepsze parametry RF:\n")
## Najlepsze parametry RF:
print(adult.rf$bestTune)
## mtry splitrule min.node.size
## 5 7 gini 5
imp.rf <- varImp(adult.rf)$importance
imp.rf$Zmienna <- rownames(imp.rf)
imp.rf %>%
arrange(desc(Overall)) %>%
head(12) %>%
ggplot(aes(x = reorder(Zmienna, Overall), y = Overall)) +
geom_col(fill = "#2ECC71") +
coord_flip() +
labs(title = "Ważność zmiennych — Random Forest",
x = NULL, y = "Importance (Gini)") +
theme_minimal()
pred.rf.train <- predict(adult.rf, adult.train, type = "prob")[, "wyzszy"]
pred.rf.test <- predict(adult.rf, adult.test, type = "prob")[, "wyzszy"]
cat("=== Random Forest: train vs test ===\n")
## === Random Forest: train vs test ===
rbind(train = klasyfikacja.metryki(pred.rf.train, real.train),
test = klasyfikacja.metryki(pred.rf.test, real.test))
## AUC Accuracy Sensitivity Specificity F1
## train 0.9600 0.8669 0.9343 0.8446 0.7775
## test 0.9206 0.8281 0.8418 0.8236 0.7092
grid.gbm <- expand.grid(
interaction.depth = c(4, 6),
n.trees = c(100, 200),
shrinkage = c(0.05, 0.1),
n.minobsinnode = 10
)
set.seed(42)
adult.gbm <- train(income ~ .,
data = adult.train,
method = "gbm",
metric = "ROC",
trControl = ctrl.class,
tuneGrid = grid.gbm,
verbose = FALSE)
saveRDS(adult.gbm, "adult.gbm.rds")
cat("Najlepsze parametry GBM:\n")
## Najlepsze parametry GBM:
print(adult.gbm$bestTune)
## n.trees interaction.depth shrinkage n.minobsinnode
## 8 200 6 0.1 10
grid.cart <- expand.grid(cp = c(0.0001, 0.0005, 0.001, 0.005, 0.01))
set.seed(42)
adult.cart <- train(income ~ .,
data = adult.train,
method = "rpart",
metric = "ROC",
trControl = ctrl.class,
tuneGrid = grid.cart)
saveRDS(adult.cart, "adult.cart.rds")
adult.cart <- readRDS("adult.cart.rds")
cat("Najlepsze cp:", adult.cart$bestTune$cp, "\n")
## Najlepsze cp: 0.0005
pred.gbm.train <- predict(adult.gbm, adult.train, type = "prob")[, "wyzszy"]
pred.gbm.test <- predict(adult.gbm, adult.test, type = "prob")[, "wyzszy"]
cat("=== GBM: train vs test ===\n")
## === GBM: train vs test ===
rbind(train = klasyfikacja.metryki(pred.gbm.train, real.train),
test = klasyfikacja.metryki(pred.gbm.test, real.test))
## AUC Accuracy Sensitivity Specificity F1
## train 0.9321 0.8730 0.6567 0.9447 0.7202
## test 0.9242 0.8715 0.6516 0.9444 0.7163
# Standaryzacja: parametry wyłącznie ze zbioru treningowego!
adult.vars.num <- c("age", "education_num", "hours_per_week",
"capital_gain", "capital_loss", "net_capital",
"hours_x_edu", "capital_flag", "mid_age", "high_work_hours")
mn.train <- colMeans(adult.train[, adult.vars.num])
sd.train <- apply(adult.train[, adult.vars.num], 2, sd)
sd.train[sd.train == 0] <- 1
adult.train.s <- adult.train
adult.test.s <- adult.test
adult.train.s[, adult.vars.num] <- scale(adult.train[, adult.vars.num],
center = mn.train, scale = sd.train)
adult.test.s[, adult.vars.num] <- scale(adult.test[, adult.vars.num],
center = mn.train, scale = sd.train)
set.seed(42)
adult.nn <- nnet(income ~ .,
data = adult.train.s,
size = 16,
decay = 0.001,
maxit = 200,
trace = FALSE,
MaxNWts = 5000)
saveRDS(adult.nn, "adult.nn.rds")
saveRDS(list(mean = mn.train, sd = sd.train), "scaler.nn.rds")
cat("Gotowe. Liczba wag:", length(adult.nn$wts), "\n")
## Gotowe. Liczba wag: 1393
print(adult.nn$bestTune)
## NULL
pred.nn.train <- predict(adult.nn, adult.train.s, type = "raw")[, 1]
pred.nn.test <- predict(adult.nn, adult.test.s, type = "raw")[, 1]
cat("=== Sieć neuronowa: train vs test ===\n")
## === Sieć neuronowa: train vs test ===
rbind(train = klasyfikacja.metryki(pred.nn.train, real.train),
test = klasyfikacja.metryki(pred.nn.test, real.test))
## AUC Accuracy Sensitivity Specificity F1
## train 0.9421 0.8795 0.7075 0.9365 0.7451
## test 0.8983 0.8463 0.6335 0.9168 0.6723
porownanie.train <- rbind(
Logit = klasyfikacja.metryki(pred.logit.train, real.train),
CART = klasyfikacja.metryki(pred.cart.train, real.train),
RandomForest = klasyfikacja.metryki(pred.rf.train, real.train),
GBM = klasyfikacja.metryki(pred.gbm.train, real.train),
NeuralNet = klasyfikacja.metryki(pred.nn.train, real.train)
)
knitr::kable(porownanie.train, caption = "Metryki — zbiór treningowy")
| AUC | Accuracy | Sensitivity | Specificity | F1 | |
|---|---|---|---|---|---|
| Logit | 0.9124 | 0.8151 | 0.8530 | 0.8025 | 0.6966 |
| CART | 0.9288 | 0.8482 | 0.9052 | 0.8294 | 0.7481 |
| RandomForest | 0.9600 | 0.8669 | 0.9343 | 0.8446 | 0.7775 |
| GBM | 0.9321 | 0.8730 | 0.6567 | 0.9447 | 0.7202 |
| NeuralNet | 0.9421 | 0.8795 | 0.7075 | 0.9365 | 0.7451 |
porownanie.test <- rbind(
Logit = klasyfikacja.metryki(pred.logit.test, real.test),
CART = klasyfikacja.metryki(pred.cart.test, real.test),
RandomForest = klasyfikacja.metryki(pred.rf.test, real.test),
XGBoost = klasyfikacja.metryki(pred.gbm.test, real.test),
NeuralNet = klasyfikacja.metryki(pred.nn.test, real.test)
)
knitr::kable(porownanie.test, caption = "Metryki — zbiór testowy")
| AUC | Accuracy | Sensitivity | Specificity | F1 | |
|---|---|---|---|---|---|
| Logit | 0.9108 | 0.8182 | 0.8434 | 0.8098 | 0.6978 |
| CART | 0.8964 | 0.8150 | 0.8348 | 0.8084 | 0.6920 |
| RandomForest | 0.9206 | 0.8281 | 0.8418 | 0.8236 | 0.7092 |
| XGBoost | 0.9242 | 0.8715 | 0.6516 | 0.9444 | 0.7163 |
| NeuralNet | 0.8983 | 0.8463 | 0.6335 | 0.9168 | 0.6723 |
auc.test <- porownanie.test$AUC
names(auc.test) <- rownames(porownanie.test)
par(mar = c(9, 4, 3, 1))
barplot(sort(auc.test, decreasing = TRUE),
main = "AUC na zbiorze testowym (wyżej = lepiej)",
ylab = "AUC", las = 2, ylim = c(0.8, 1),
col = "steelblue", border = NA, cex.names = 0.9)
par(mar = c(5, 4, 4, 2))
roc.logit <- roc(as.numeric(real.test == "wyzszy"), pred.logit.test, quiet = TRUE)
roc.cart <- roc(as.numeric(real.test == "wyzszy"), pred.cart.test, quiet = TRUE)
roc.rf <- roc(as.numeric(real.test == "wyzszy"), pred.rf.test, quiet = TRUE)
roc.xgb <- roc(as.numeric(real.test == "wyzszy"), pred.gbm.test, quiet = TRUE)
roc.nn <- roc(as.numeric(real.test == "wyzszy"), pred.nn.test, quiet = TRUE)
plot(roc.logit, col = "gray50", lwd = 1.5, main = "Krzywe ROC — zbiór testowy")
plot(roc.cart, col = "steelblue", lwd = 1.5, add = TRUE)
plot(roc.rf, col = "#2ECC71", lwd = 1.5, add = TRUE)
plot(roc.xgb, col = "coral", lwd = 1.5, add = TRUE)
plot(roc.nn, col = "#9B59B6", lwd = 1.5, add = TRUE)
legend("bottomright", bty = "n",
legend = c(paste0("Logit (AUC=", round(auc(roc.logit), 3), ")"),
paste0("CART (AUC=", round(auc(roc.cart), 3), ")"),
paste0("Random Forest (AUC=", round(auc(roc.rf), 3), ")"),
paste0("XGBoost (AUC=", round(auc(roc.xgb), 3), ")"),
paste0("nnet (AUC=", round(auc(roc.nn), 3), ")")),
col = c("gray50", "steelblue", "#2ECC71", "coral", "#9B59B6"),
lwd = 2, cex = 0.85)
cat("Różnica AUC (train - test):\n")
## Różnica AUC (train - test):
print(round(porownanie.train$AUC - porownanie.test$AUC, 4))
## [1] 0.0016 0.0324 0.0394 0.0079 0.0438
best.idx <- which.max(porownanie.test$AUC)
cat("Najlepszy model:", rownames(porownanie.test)[best.idx], "\n")
## Najlepszy model: XGBoost
cat("AUC:", porownanie.test$AUC[best.idx], "\n")
## AUC: 0.9242
cat("F1: ", porownanie.test$F1[best.idx], "\n")
## F1: 0.7163
Ranking modeli i uzasadnienie:
Najlepsze wyniki osiągnęły XGBoost i Random Forest (AUC ok. 0.93). Obie metody budują wiele drzew i redukują wariancję — XGBoost sekwencyjnie naprawia błędy poprzednich drzew, RF uśrednia niezależnie zbudowane drzewa. W praktyce dają bardzo zbliżone wyniki. XGBoost jest nieco lepszy dzięki mechanizmowi boostingu, który skupia się na trudniejszych do klasyfikacji obserwacjach.
Dlaczego nie accuracy?
Dane są niezbalansowane (75%/25%). Model zawsze przewidujący
<=50K osiągałby ~75% accuracy bez żadnej użyteczności.
Dlatego jako główną metrykę wybieramy ROC-AUC, która ocenia zdolność
modelu do rozróżniania klas niezależnie od progu decyzyjnego.
Najważniejsze zmienne:
We wszystkich modelach dominują net_capital,
capital_gain i education_num — kapitał
finansowy i poziom wykształcenia są głównymi predyktorami dochodu
powyżej 50K. marital_status i relationship są
istotne ze względu na silną korelację z innymi zmiennymi
socjoekonomicznymi (np. Married-civ-spouse silnie koreluje z typem
zatrudnienia i wiekiem).
Analiza niezbalansowania:
Zastosowanie wag klas poprawiło Sensitivity (wykrywalność klasy
>50K) kosztem Specificity. Jest to pożądane w kontekście
biznesowym — koszt przeoczenia osoby o wysokim dochodzie jest zwykle
wyższy niż fałszywy alarm.
Sieć neuronowa:
nnet z jedną warstwą ukrytą daje wynik zbliżony do
regresji logistycznej, wyraźnie gorszy od metod zespołowych. Wynika to z
ograniczonej architektury — głębsze sieci (keras3) dawałyby lepszy
rezultat, ale wymagają oddzielnego środowiska Python/TensorFlow.
Kontekst biznesowy:
Model XGBoost z AUC ~0.93 jest wystarczający do:
Przy cutoff 0.35–0.40 (niższym niż domyślne 0.5) zwiększamy Sensitivity — bardziej opłacalne przy asymetrycznych kosztach błędów.
saveRDS(adult.logit, "adult.logit.rds")
saveRDS(adult.cart, "adult.cart.rds")
saveRDS(adult.rf, "adult.rf.rds")
saveRDS(adult.gbm, "adult.gbm.rds")
saveRDS(adult.nn, "adult.nn.rds")
saveRDS(list(mean = mn.train, sd = sd.train), "scaler.nn.rds")
cat("Wszystkie modele zapisane.\n")
## Wszystkie modele zapisane.
cat("Aby wczytać: model <- readRDS('adult.gbm.rds')\n")
## Aby wczytać: model <- readRDS('adult.gbm.rds')