1. Opis problemu

Celem projektu jest zbudowanie modelu predykcyjnego, który na podstawie cech demograficznych i zawodowych osoby przewiduje, czy jej roczny dochód przekracza 50 000 USD. Dane pochodzą ze zbioru Adult Census Income (UCI Machine Learning Repository) i zawierają ponad 30 000 obserwacji.

Zmienna docelowa to class (wartości <=50K / >50K), więc jest to problem klasyfikacji binarnej.

W projekcie porównujemy pięć modeli: regresję logistyczną jako punkt odniesienia, drzewo klasyfikacyjne (CART), Random Forest (z tuningiem), XGBoost oraz sieć neuronową (nnet).

library(tidyverse)
library(caret)
library(pROC)
library(rpart)
library(rpart.plot)
library(randomForest)
library(ranger)
library(xgboost)
library(nnet)

2. Eksploracja danych

adult.raw <- read.csv("k1_adult.csv", stringsAsFactors = FALSE, sep = ",")
glimpse(adult.raw)
## Rows: 30,162
## Columns: 18
## $ class           <chr> "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", …
## $ age             <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num   <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status  <chr> "Never-married", "Married-civ-spouse", "Divorced", "Ma…
## $ occupation      <chr> "Adm-clerical", "Exec-managerial", "Handlers-cleaners"…
## $ relationship    <chr> "Not-in-family", "Husband", "Not-in-family", "Husband"…
## $ race            <chr> "White", "White", "White", "Black", "Black", "White", …
## $ sex             <chr> "Male", "Male", "Male", "Male", "Female", "Female", "F…
## $ hours_per_week  <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain    <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss    <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass       <chr> "State-gov", "Self-emp-not-inc", "Private", "Private",…
## $ native_country  <chr> "United-States", "United-States", "United-States", "Un…
## $ net_capital     <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag    <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu     <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age         <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …
summary(adult.raw)
##        class            age        education_num     marital_status 
##  Length   :30162   Min.   :17.00   Min.   : 1.00   Length   :30162  
##  N.unique :    2   1st Qu.:28.00   1st Qu.: 9.00   N.unique :    7  
##  N.blank  :    0   Median :37.00   Median :10.00   N.blank  :    0  
##  Min.nchar:    4   Mean   :38.44   Mean   :10.12   Min.nchar:    7  
##  Max.nchar:    5   3rd Qu.:47.00   3rd Qu.:13.00   Max.nchar:   21  
##                    Max.   :90.00   Max.   :16.00                    
##      occupation       relationship          race              sex       
##  Length   :30162   Length   :30162   Length   :30162   Length   :30162  
##  N.unique :   14   N.unique :    6   N.unique :    5   N.unique :    2  
##  N.blank  :    0   N.blank  :    0   N.blank  :    0   N.blank  :    0  
##  Min.nchar:    5   Min.nchar:    4   Min.nchar:    5   Min.nchar:    4  
##  Max.nchar:   17   Max.nchar:   14   Max.nchar:   18   Max.nchar:    6  
##                                                                         
##  hours_per_week   capital_gain    capital_loss         workclass    
##  Min.   : 1.00   Min.   :    0   Min.   :   0.00   Length   :30162  
##  1st Qu.:40.00   1st Qu.:    0   1st Qu.:   0.00   N.unique :    7  
##  Median :40.00   Median :    0   Median :   0.00   N.blank  :    0  
##  Mean   :40.93   Mean   : 1092   Mean   :  88.37   Min.nchar:    7  
##  3rd Qu.:45.00   3rd Qu.:    0   3rd Qu.:   0.00   Max.nchar:   16  
##  Max.   :99.00   Max.   :99999   Max.   :4356.00                    
##    native_country   net_capital     capital_flag      hours_x_edu    
##  Length   :30162   Min.   :-4356   Min.   :0.00000   Min.   :   6.0  
##  N.unique :   41   1st Qu.:    0   1st Qu.:0.00000   1st Qu.: 342.0  
##  N.blank  :    0   Median :    0   Median :0.00000   Median : 400.0  
##  Min.nchar:    4   Mean   : 1004   Mean   :0.08415   Mean   : 418.9  
##  Max.nchar:   26   3rd Qu.:    0   3rd Qu.:0.00000   3rd Qu.: 520.0  
##                    Max.   :99999   Max.   :1.00000   Max.   :1584.0  
##     mid_age       high_work_hours 
##  Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :1.0000   Median :0.0000  
##  Mean   :0.5948   Mean   :0.3049  
##  3rd Qu.:1.0000   3rd Qu.:1.0000  
##  Max.   :1.0000   Max.   :1.0000
cat("Liczba obserwacji:", nrow(adult.raw), "\n")
## Liczba obserwacji: 30162
cat("Duplikaty:", sum(duplicated(adult.raw)), "\n")
## Duplikaty: 3258
cat("Braki danych:\n")
## Braki danych:
colSums(is.na(adult.raw))
##           class             age   education_num  marital_status      occupation 
##               0               0               0               0               0 
##    relationship            race             sex  hours_per_week    capital_gain 
##               0               0               0               0               0 
##    capital_loss       workclass  native_country     net_capital    capital_flag 
##               0               0               0               0               0 
##     hours_x_edu         mid_age high_work_hours 
##               0               0               0

Dane są kompletne — brak braków i duplikatów.

Rozkład zmiennej docelowej

tabela_y <- as.data.frame(table(adult.raw$class))
tabela_y$Procent <- round(tabela_y$Freq / sum(tabela_y$Freq) * 100, 1)
colnames(tabela_y) <- c("Klasa", "N", "Procent")
print(tabela_y)
##   Klasa     N Procent
## 1 <=50K 22654    75.1
## 2  >50K  7508    24.9
ggplot(adult.raw, aes(x = class, fill = class)) +
  geom_bar(width = 0.5, color = "white") +
  geom_text(stat = "count",
            aes(label = paste0(..count.., "\n(", round(..count../nrow(adult.raw)*100, 1), "%)")),
            vjust = -0.3, size = 4) +
  scale_fill_manual(values = c("<=50K" = "steelblue", ">50K" = "coral")) +
  labs(title = "Rozkład zmiennej docelowej",
       x = "Klasa dochodu", y = "Liczba obserwacji") +
  theme_minimal() +
  theme(legend.position = "none")

Dane są niezbalansowane — klasa <=50K stanowi ok. 75% obserwacji. Dlatego jako główną metrykę wybieramy ROC-AUC, a nie accuracy, oraz stosujemy wagi klas przy trenowaniu modeli.

Zmienne numeryczne

adult.raw %>%
  select(age, education_num, hours_per_week, capital_gain, capital_loss,
         net_capital, class) %>%
  pivot_longer(-class, names_to = "zmienna", values_to = "wartosc") %>%
  ggplot(aes(x = wartosc, fill = class)) +
  geom_histogram(bins = 30, alpha = 0.7, position = "identity", color = "white") +
  facet_wrap(~ zmienna, scales = "free", ncol = 3) +
  scale_fill_manual(values = c("<=50K" = "steelblue", ">50K" = "coral")) +
  labs(title = "Rozkłady zmiennych numerycznych według klasy dochodu",
       x = NULL, y = "Liczba obserwacji", fill = "Dochód") +
  theme_minimal() +
  theme(legend.position = "bottom")

Zmienne kategoryczne

adult.raw %>%
  select(marital_status, occupation, relationship, sex, workclass, class) %>%
  pivot_longer(-class, names_to = "zmienna", values_to = "kategoria") %>%
  group_by(zmienna, kategoria, class) %>%
  summarise(n = n(), .groups = "drop") %>%
  group_by(zmienna, kategoria) %>%
  mutate(prop_50k = sum(n[class == ">50K"]) / sum(n)) %>%
  filter(class == ">50K") %>%
  ggplot(aes(x = reorder(kategoria, prop_50k), y = prop_50k)) +
  geom_col(fill = "coral") +
  geom_hline(yintercept = 0.249, linetype = "dashed", color = "gray50") +
  coord_flip() +
  facet_wrap(~ zmienna, scales = "free_y", ncol = 2) +
  scale_y_continuous(labels = scales::percent_format()) +
  labs(title    = "Odsetek zarabiających >50K w każdej kategorii",
       subtitle = "Przerywana linia = odsetek globalny (~25%)",
       x = NULL, y = "Odsetek >50K") +
  theme_minimal(base_size = 10)

3. Preprocessing i przygotowanie danych

Czyszczenie i kodowanie

adult.prep <- adult.raw %>%
  mutate(
    # Zmienna docelowa jako factor 0/1
    income = factor(ifelse(class == ">50K", "wyzszy", "nizszy"),
                    levels = c("nizszy", "wyzszy"))
  ) %>%
  select(-class) %>%
  mutate(across(c(marital_status, occupation, relationship,
                  race, sex, workclass, native_country), as.factor))

glimpse(adult.prep)
## Rows: 30,162
## Columns: 18
## $ age             <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num   <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status  <fct> Never-married, Married-civ-spouse, Divorced, Married-c…
## $ occupation      <fct> Adm-clerical, Exec-managerial, Handlers-cleaners, Hand…
## $ relationship    <fct> Not-in-family, Husband, Not-in-family, Husband, Wife, …
## $ race            <fct> White, White, White, Black, Black, White, Black, White…
## $ sex             <fct> Male, Male, Male, Male, Female, Female, Female, Male, …
## $ hours_per_week  <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain    <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss    <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass       <fct> State-gov, Self-emp-not-inc, Private, Private, Private…
## $ native_country  <fct> United-States, United-States, United-States, United-St…
## $ net_capital     <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag    <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu     <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age         <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …
## $ income          <fct> nizszy, nizszy, nizszy, nizszy, nizszy, nizszy, nizszy…
cat("\nRozkład zmiennej docelowej:\n")
## 
## Rozkład zmiennej docelowej:
print(prop.table(table(adult.prep$income)))
## 
##    nizszy    wyzszy 
## 0.7510775 0.2489225

Zmienne kategoryczne kodujemy jako factory — caret automatycznie stosuje one-hot encoding przy trenowaniu modeli, które tego wymagają.

Podział train / test

# Split before you model!
set.seed(42)
idx.train <- createDataPartition(adult.prep$income, p = 0.75, list = FALSE)
adult.train <- adult.prep[ idx.train, ]
adult.test  <- adult.prep[-idx.train, ]

cat("Train:", nrow(adult.train), "| Test:", nrow(adult.test), "\n")
## Train: 22622 | Test: 7540
real.train <- adult.train$income
real.test  <- adult.test$income

Wagi klas

# Obsługa niezbalansowania: wagi odwrotnie proporcjonalne do liczebności
n_total  <- nrow(adult.train)
n_nizszy <- sum(adult.train$income == "nizszy")
n_wyzszy <- sum(adult.train$income == "wyzszy")

waga_nizszy <- n_total / (2 * n_nizszy)
waga_wyzszy <- n_total / (2 * n_wyzszy)

cat(sprintf("Waga klasy 'nizszy' (<=50K): %.4f\n", waga_nizszy))
## Waga klasy 'nizszy' (<=50K): 0.6657
cat(sprintf("Waga klasy 'wyzszy' (>50K):  %.4f\n", waga_wyzszy))
## Waga klasy 'wyzszy' (>50K):  2.0087
sample_weights <- ifelse(adult.train$income == "wyzszy", waga_wyzszy, waga_nizszy)

Funkcja do metryk

klasyfikacja.metryki <- function(predicted_probabilities,
                                  real,
                                  cutoff         = 0.5,
                                  level_positive = "wyzszy",
                                  level_negative = "nizszy") {

  predicted_class <- ifelse(predicted_probabilities > cutoff,
                            level_positive, level_negative)
  predicted_class <- factor(predicted_class,
                            levels = c(level_negative, level_positive))
  real <- factor(real, levels = c(level_negative, level_positive))

  ctable <- confusionMatrix(data      = predicted_class,
                            reference = real,
                            positive  = level_positive)

  roc_obj <- roc(response  = as.numeric(real == level_positive),
                 predictor = predicted_probabilities,
                 quiet     = TRUE)

  data.frame(
    AUC         = round(as.numeric(auc(roc_obj)), 4),
    Accuracy    = round(unname(ctable$overall["Accuracy"]), 4),
    Sensitivity = round(unname(ctable$byClass["Sensitivity"]), 4),
    Specificity = round(unname(ctable$byClass["Specificity"]), 4),
    F1          = round(unname(ctable$byClass["F1"]), 4)
  )
}

4. Model benchmarkowy: regresja logistyczna

set.seed(42)
adult.logit <- train(
  income ~ .,
  data      = adult.train,
  method    = "glm",
  family    = "binomial",
  trControl = trainControl(method = "cv", number = 5,
                           classProbs      = TRUE,
                           summaryFunction = twoClassSummary),
  metric    = "ROC",
  weights   = sample_weights
)

summary(adult.logit$finalModel)
## 
## Call:
## NULL
## 
## Coefficients: (2 not defined because of singularities)
##                                                Estimate   Std. Error z value
## (Intercept)                                 -8.33961359   0.84053725  -9.922
## age                                          0.03303479   0.00188352  17.539
## education_num                                0.31828812   0.03319493   9.588
## `marital_statusMarried-AF-spouse`            3.15816626   0.60312866   5.236
## `marital_statusMarried-civ-spouse`           2.43084314   0.27143371   8.956
## `marital_statusMarried-spouse-absent`        0.17638658   0.22468476   0.785
## `marital_statusNever-married`               -0.39596977   0.08626223  -4.590
## marital_statusSeparated                     -0.07810964   0.15207493  -0.514
## marital_statusWidowed                        0.30740033   0.15473913   1.987
## `occupationArmed-Forces`                    -1.53154022   1.38432414  -1.106
## `occupationCraft-repair`                     0.11733372   0.08442588   1.390
## `occupationExec-managerial`                  0.80134180   0.08344880   9.603
## `occupationFarming-fishing`                 -1.23836519   0.14700485  -8.424
## `occupationHandlers-cleaners`               -0.57656470   0.14308674  -4.029
## `occupationMachine-op-inspct`               -0.18347804   0.10367172  -1.770
## `occupationOther-service`                   -0.64801545   0.11590006  -5.591
## `occupationPriv-house-serv`                 -8.11863550   9.50842799  -0.854
## `occupationProf-specialty`                   0.64551891   0.08626117   7.483
## `occupationProtective-serv`                  0.78808126   0.13818895   5.703
## occupationSales                              0.36083406   0.08858416   4.073
## `occupationTech-support`                     0.73809843   0.12058376   6.121
## `occupationTransport-moving`                 0.06293983   0.10337770   0.609
## `relationshipNot-in-family`                  0.78048829   0.26998644   2.891
## `relationshipOther-relative`                -0.00663434   0.25492783  -0.026
## `relationshipOwn-child`                     -0.33180386   0.26948637  -1.231
## relationshipUnmarried                        0.52239942   0.28426215   1.838
## relationshipWife                             1.43477151   0.10932395  13.124
## `raceAsian-Pac-Islander`                     0.92004698   0.29473689   3.122
## raceBlack                                    0.54932023   0.23916868   2.297
## raceOther                                    0.12613396   0.38372181   0.329
## raceWhite                                    0.64509319   0.22799416   2.829
## sexMale                                      0.87997054   0.07714557  11.407
## hours_per_week                               0.03139478   0.00798781   3.930
## capital_gain                                 0.00070871   0.00004017  17.645
## capital_loss                                 0.00064463   0.00004427  14.560
## `workclassLocal-gov`                        -0.69574807   0.12389796  -5.615
## workclassPrivate                            -0.52362809   0.10372326  -5.048
## `workclassSelf-emp-inc`                     -0.29145811   0.14081193  -2.070
## `workclassSelf-emp-not-inc`                 -0.86074593   0.12088284  -7.120
## `workclassState-gov`                        -0.86276960   0.13660362  -6.316
## `workclassWithout-pay`                     -12.96308275 129.58912739  -0.100
## native_countryCanada                        -0.81435417   0.71408920  -1.140
## native_countryChina                         -2.19176674   0.73623701  -2.977
## native_countryColumbia                      -2.99781073   0.92125816  -3.254
## native_countryCuba                          -0.42309644   0.72250426  -0.586
## `native_countryDominican-Republic`          -2.95434516   1.01901771  -2.899
## native_countryEcuador                       -2.95809466   1.14434520  -2.585
## `native_countryEl-Salvador`                 -1.87761254   0.81327129  -2.309
## native_countryEngland                       -0.85238264   0.73218983  -1.164
## native_countryFrance                        -0.75009159   0.86509537  -0.867
## native_countryGermany                       -0.74809445   0.70352131  -1.063
## native_countryGreece                        -2.22168325   0.88688938  -2.505
## native_countryGuatemala                     -0.60004779   0.91325292  -0.657
## native_countryHaiti                         -0.80406396   0.89430186  -0.899
## `native_countryHoland-Netherlands`                   NA           NA      NA
## native_countryHonduras                      -2.72593969   2.57442836  -1.059
## native_countryHong                          -1.13811989   0.88997858  -1.279
## native_countryHungary                       -0.67415920   1.02123578  -0.660
## native_countryIndia                         -1.22118198   0.70449776  -1.733
## native_countryIran                          -1.17971317   0.79052578  -1.492
## native_countryIreland                       -0.14403721   0.90509750  -0.159
## native_countryItaly                         -0.19673971   0.74287318  -0.265
## native_countryJamaica                       -0.80082113   0.77083613  -1.039
## native_countryJapan                         -0.92504253   0.75943424  -1.218
## native_countryLaos                          -1.01431117   1.00026394  -1.014
## native_countryMexico                        -1.37662617   0.67515968  -2.039
## native_countryNicaragua                     -1.79509743   1.12673633  -1.593
## `native_countryOutlying-US(Guam-USVI-etc)` -12.86671562 172.42426031  -0.075
## native_countryPeru                          -1.56766889   1.05654676  -1.484
## native_countryPhilippines                   -0.78028396   0.65908800  -1.184
## native_countryPoland                        -1.32237104   0.78045299  -1.694
## native_countryPortugal                      -1.74085600   0.98098724  -1.775
## `native_countryPuerto-Rico`                 -1.40518633   0.77129132  -1.822
## native_countryScotland                      -1.48487119   1.08060134  -1.374
## native_countrySouth                         -2.55095426   0.77970247  -3.272
## native_countryTaiwan                        -1.08390373   0.82320854  -1.317
## native_countryThailand                      -1.96473337   1.01007061  -1.945
## `native_countryTrinadad&Tobago`             -1.15547083   1.02137799  -1.131
## `native_countryUnited-States`               -0.95035085   0.64462900  -1.474
## native_countryVietnam                       -1.66960191   0.82634817  -2.020
## native_countryYugoslavia                    -0.00442209   0.94555926  -0.005
## net_capital                                          NA           NA      NA
## capital_flag                                -2.55761855   0.21320137 -11.996
## hours_x_edu                                 -0.00101389   0.00071822  -1.412
## mid_age                                      0.68990030   0.04526037  15.243
## high_work_hours                              0.35096864   0.05663846   6.197
##                                                        Pr(>|z|)    
## (Intercept)                                < 0.0000000000000002 ***
## age                                        < 0.0000000000000002 ***
## education_num                              < 0.0000000000000002 ***
## `marital_statusMarried-AF-spouse`            0.0000001638222307 ***
## `marital_statusMarried-civ-spouse`         < 0.0000000000000002 ***
## `marital_statusMarried-spouse-absent`                   0.43243    
## `marital_statusNever-married`                0.0000044260265784 ***
## marital_statusSeparated                                 0.60751    
## marital_statusWidowed                                   0.04697 *  
## `occupationArmed-Forces`                                0.26858    
## `occupationCraft-repair`                                0.16459    
## `occupationExec-managerial`                < 0.0000000000000002 ***
## `occupationFarming-fishing`                < 0.0000000000000002 ***
## `occupationHandlers-cleaners`                0.0000559011122535 ***
## `occupationMachine-op-inspct`                           0.07676 .  
## `occupationOther-service`                    0.0000000225561257 ***
## `occupationPriv-house-serv`                             0.39320    
## `occupationProf-specialty`                   0.0000000000000725 ***
## `occupationProtective-serv`                  0.0000000117768491 ***
## occupationSales                              0.0000463421866958 ***
## `occupationTech-support`                     0.0000000009296472 ***
## `occupationTransport-moving`                            0.54263    
## `relationshipNot-in-family`                             0.00384 ** 
## `relationshipOther-relative`                            0.97924    
## `relationshipOwn-child`                                 0.21823    
## relationshipUnmarried                                   0.06610 .  
## relationshipWife                           < 0.0000000000000002 ***
## `raceAsian-Pac-Islander`                                0.00180 ** 
## raceBlack                                               0.02163 *  
## raceOther                                               0.74237    
## raceWhite                                               0.00466 ** 
## sexMale                                    < 0.0000000000000002 ***
## hours_per_week                               0.0000848271728951 ***
## capital_gain                               < 0.0000000000000002 ***
## capital_loss                               < 0.0000000000000002 ***
## `workclassLocal-gov`                         0.0000000196002970 ***
## workclassPrivate                             0.0000004457149139 ***
## `workclassSelf-emp-inc`                                 0.03847 *  
## `workclassSelf-emp-not-inc`                  0.0000000000010754 ***
## `workclassState-gov`                         0.0000000002686600 ***
## `workclassWithout-pay`                                  0.92032    
## native_countryCanada                                    0.25412    
## native_countryChina                                     0.00291 ** 
## native_countryColumbia                                  0.00114 ** 
## native_countryCuba                                      0.55815    
## `native_countryDominican-Republic`                      0.00374 ** 
## native_countryEcuador                                   0.00974 ** 
## `native_countryEl-Salvador`                             0.02096 *  
## native_countryEngland                                   0.24436    
## native_countryFrance                                    0.38591    
## native_countryGermany                                   0.28762    
## native_countryGreece                                    0.01224 *  
## native_countryGuatemala                                 0.51115    
## native_countryHaiti                                     0.36860    
## `native_countryHoland-Netherlands`                           NA    
## native_countryHonduras                                  0.28967    
## native_countryHong                                      0.20096    
## native_countryHungary                                   0.50916    
## native_countryIndia                                     0.08302 .  
## native_countryIran                                      0.13562    
## native_countryIreland                                   0.87356    
## native_countryItaly                                     0.79114    
## native_countryJamaica                                   0.29885    
## native_countryJapan                                     0.22320    
## native_countryLaos                                      0.31056    
## native_countryMexico                                    0.04145 *  
## native_countryNicaragua                                 0.11112    
## `native_countryOutlying-US(Guam-USVI-etc)`              0.94052    
## native_countryPeru                                      0.13787    
## native_countryPhilippines                               0.23646    
## native_countryPoland                                    0.09020 .  
## native_countryPortugal                                  0.07596 .  
## `native_countryPuerto-Rico`                             0.06848 .  
## native_countryScotland                                  0.16941    
## native_countrySouth                                     0.00107 ** 
## native_countryTaiwan                                    0.18795    
## native_countryThailand                                  0.05176 .  
## `native_countryTrinadad&Tobago`                         0.25793    
## `native_countryUnited-States`                           0.14041    
## native_countryVietnam                                   0.04334 *  
## native_countryYugoslavia                                0.99627    
## net_capital                                                  NA    
## capital_flag                               < 0.0000000000000002 ***
## hours_x_edu                                             0.15805    
## mid_age                                    < 0.0000000000000002 ***
## high_work_hours                              0.0000000005767818 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 31361  on 22621  degrees of freedom
## Residual deviance: 16811  on 22538  degrees of freedom
## AIC: 21383
## 
## Number of Fisher Scoring iterations: 12
pred.logit.train <- predict(adult.logit, adult.train, type = "prob")[, "wyzszy"]
pred.logit.test  <- predict(adult.logit, adult.test,  type = "prob")[, "wyzszy"]

cat("=== Regresja logistyczna: train vs test ===\n")
## === Regresja logistyczna: train vs test ===
rbind(train = klasyfikacja.metryki(pred.logit.train, real.train),
      test  = klasyfikacja.metryki(pred.logit.test,  real.test))
##          AUC Accuracy Sensitivity Specificity     F1
## train 0.9124   0.8151      0.8530      0.8025 0.6966
## test  0.9108   0.8182      0.8434      0.8098 0.6978

5. Drzewo klasyfikacyjne (CART)

ctrl.class <- trainControl(method          = "cv",
                           number          = 5,
                           classProbs      = TRUE,
                           summaryFunction = twoClassSummary)

grid.cart <- expand.grid(cp = c(0.0001, 0.0005, 0.001, 0.005, 0.01))

set.seed(42)
adult.cart <- train(income ~ ., data = adult.train, method = "rpart",
                     metric = "ROC", trControl = ctrl.class,
                     tuneGrid = grid.cart, weights = sample_weights)
saveRDS(adult.cart, "adult.cart.rds")

adult.cart <- readRDS("adult.cart.rds")
cat("Najlepsze cp:", adult.cart$bestTune$cp, "\n")
## Najlepsze cp: 0.0001
rpart.plot(adult.cart$finalModel,
           type  = 4,
           extra = 104,
           under = TRUE,
           tweak = 1.1,
           main  = "Drzewo klasyfikacyjne CART (Adult Income)")

imp.cart <- varImp(adult.cart)$importance
imp.cart$Zmienna <- rownames(imp.cart)
imp.cart %>%
  arrange(desc(Overall)) %>%
  head(12) %>%
  ggplot(aes(x = reorder(Zmienna, Overall), y = Overall)) +
  geom_col(fill = "steelblue") +
  coord_flip() +
  labs(title = "Ważność zmiennych — CART",
       x = NULL, y = "Importance") +
  theme_minimal()

pred.cart.train <- predict(adult.cart, adult.train, type = "prob")[, "wyzszy"]
pred.cart.test  <- predict(adult.cart, adult.test,  type = "prob")[, "wyzszy"]

cat("=== CART: train vs test ===\n")
## === CART: train vs test ===
rbind(train = klasyfikacja.metryki(pred.cart.train, real.train),
      test  = klasyfikacja.metryki(pred.cart.test,  real.test))
##          AUC Accuracy Sensitivity Specificity     F1
## train 0.9288   0.8482      0.9052      0.8294 0.7481
## test  0.8964   0.8150      0.8348      0.8084 0.6920

6. Random Forest

grid.rf <- expand.grid(mtry          = c(3, 5, 7),
                       splitrule     = "gini",
                       min.node.size = c(5, 10))

set.seed(42)
adult.rf <- train(income ~ ., data = adult.train, method = "ranger",
                  metric = "ROC", trControl = ctrl.class,
                  tuneGrid = grid.rf, num.trees = 300,
                  importance = "impurity", weights = sample_weights)
saveRDS(adult.rf, "adult.rf.rds")

adult.rf <- readRDS("adult.rf.rds")
cat("Najlepsze parametry RF:\n")
## Najlepsze parametry RF:
print(adult.rf$bestTune)
##   mtry splitrule min.node.size
## 5    7      gini             5
imp.rf <- varImp(adult.rf)$importance
imp.rf$Zmienna <- rownames(imp.rf)
imp.rf %>%
  arrange(desc(Overall)) %>%
  head(12) %>%
  ggplot(aes(x = reorder(Zmienna, Overall), y = Overall)) +
  geom_col(fill = "#2ECC71") +
  coord_flip() +
  labs(title = "Ważność zmiennych — Random Forest",
       x = NULL, y = "Importance (Gini)") +
  theme_minimal()

pred.rf.train <- predict(adult.rf, adult.train, type = "prob")[, "wyzszy"]
pred.rf.test  <- predict(adult.rf, adult.test,  type = "prob")[, "wyzszy"]

cat("=== Random Forest: train vs test ===\n")
## === Random Forest: train vs test ===
rbind(train = klasyfikacja.metryki(pred.rf.train, real.train),
      test  = klasyfikacja.metryki(pred.rf.test,  real.test))
##          AUC Accuracy Sensitivity Specificity     F1
## train 0.9600   0.8669      0.9343      0.8446 0.7775
## test  0.9206   0.8281      0.8418      0.8236 0.7092

7. XGBoost

grid.gbm <- expand.grid(
  interaction.depth = c(4, 6),
  n.trees           = c(100, 200),
  shrinkage         = c(0.05, 0.1),
  n.minobsinnode    = 10
)

set.seed(42)
adult.gbm <- train(income ~ .,
                   data      = adult.train,
                   method    = "gbm",
                   metric    = "ROC",
                   trControl = ctrl.class,
                   tuneGrid  = grid.gbm,
                   verbose   = FALSE)

saveRDS(adult.gbm, "adult.gbm.rds")
cat("Najlepsze parametry GBM:\n")
## Najlepsze parametry GBM:
print(adult.gbm$bestTune)
##   n.trees interaction.depth shrinkage n.minobsinnode
## 8     200                 6       0.1             10
grid.cart <- expand.grid(cp = c(0.0001, 0.0005, 0.001, 0.005, 0.01))

set.seed(42)
adult.cart <- train(income ~ .,
                    data      = adult.train,
                    method    = "rpart",
                    metric    = "ROC",
                    trControl = ctrl.class,
                    tuneGrid  = grid.cart)
saveRDS(adult.cart, "adult.cart.rds")

adult.cart <- readRDS("adult.cart.rds")
cat("Najlepsze cp:", adult.cart$bestTune$cp, "\n")
## Najlepsze cp: 0.0005
pred.gbm.train <- predict(adult.gbm, adult.train, type = "prob")[, "wyzszy"]
pred.gbm.test  <- predict(adult.gbm, adult.test,  type = "prob")[, "wyzszy"]

cat("=== GBM: train vs test ===\n")
## === GBM: train vs test ===
rbind(train = klasyfikacja.metryki(pred.gbm.train, real.train),
      test  = klasyfikacja.metryki(pred.gbm.test,  real.test))
##          AUC Accuracy Sensitivity Specificity     F1
## train 0.9321   0.8730      0.6567      0.9447 0.7202
## test  0.9242   0.8715      0.6516      0.9444 0.7163

8. Sieć neuronowa (nnet)

# Standaryzacja: parametry wyłącznie ze zbioru treningowego!
adult.vars.num <- c("age", "education_num", "hours_per_week",
                    "capital_gain", "capital_loss", "net_capital",
                    "hours_x_edu", "capital_flag", "mid_age", "high_work_hours")

mn.train <- colMeans(adult.train[, adult.vars.num])
sd.train <- apply(adult.train[, adult.vars.num], 2, sd)
sd.train[sd.train == 0] <- 1

adult.train.s <- adult.train
adult.test.s  <- adult.test
adult.train.s[, adult.vars.num] <- scale(adult.train[, adult.vars.num],
                                          center = mn.train, scale = sd.train)
adult.test.s[, adult.vars.num]  <- scale(adult.test[, adult.vars.num],
                                          center = mn.train, scale = sd.train)
set.seed(42)
adult.nn <- nnet(income ~ .,
                 data    = adult.train.s,
                 size    = 16,
                 decay   = 0.001,
                 maxit   = 200,
                 trace   = FALSE,
                 MaxNWts = 5000)

saveRDS(adult.nn, "adult.nn.rds")
saveRDS(list(mean = mn.train, sd = sd.train), "scaler.nn.rds")
cat("Gotowe. Liczba wag:", length(adult.nn$wts), "\n")
## Gotowe. Liczba wag: 1393
print(adult.nn$bestTune)
## NULL
pred.nn.train <- predict(adult.nn, adult.train.s, type = "raw")[, 1]
pred.nn.test  <- predict(adult.nn, adult.test.s,  type = "raw")[, 1]

cat("=== Sieć neuronowa: train vs test ===\n")
## === Sieć neuronowa: train vs test ===
rbind(train = klasyfikacja.metryki(pred.nn.train, real.train),
      test  = klasyfikacja.metryki(pred.nn.test,  real.test))
##          AUC Accuracy Sensitivity Specificity     F1
## train 0.9421   0.8795      0.7075      0.9365 0.7451
## test  0.8983   0.8463      0.6335      0.9168 0.6723

9. Porównanie wszystkich modeli

porownanie.train <- rbind(
  Logit        = klasyfikacja.metryki(pred.logit.train, real.train),
  CART         = klasyfikacja.metryki(pred.cart.train,  real.train),
  RandomForest = klasyfikacja.metryki(pred.rf.train,    real.train),
  GBM          = klasyfikacja.metryki(pred.gbm.train,   real.train),
  NeuralNet    = klasyfikacja.metryki(pred.nn.train,     real.train)
)
knitr::kable(porownanie.train, caption = "Metryki — zbiór treningowy")
Metryki — zbiór treningowy
AUC Accuracy Sensitivity Specificity F1
Logit 0.9124 0.8151 0.8530 0.8025 0.6966
CART 0.9288 0.8482 0.9052 0.8294 0.7481
RandomForest 0.9600 0.8669 0.9343 0.8446 0.7775
GBM 0.9321 0.8730 0.6567 0.9447 0.7202
NeuralNet 0.9421 0.8795 0.7075 0.9365 0.7451
porownanie.test <- rbind(
  Logit        = klasyfikacja.metryki(pred.logit.test, real.test),
  CART         = klasyfikacja.metryki(pred.cart.test,  real.test),
  RandomForest = klasyfikacja.metryki(pred.rf.test,    real.test),
  XGBoost      = klasyfikacja.metryki(pred.gbm.test,   real.test),
  NeuralNet    = klasyfikacja.metryki(pred.nn.test,     real.test)
)
knitr::kable(porownanie.test, caption = "Metryki — zbiór testowy")
Metryki — zbiór testowy
AUC Accuracy Sensitivity Specificity F1
Logit 0.9108 0.8182 0.8434 0.8098 0.6978
CART 0.8964 0.8150 0.8348 0.8084 0.6920
RandomForest 0.9206 0.8281 0.8418 0.8236 0.7092
XGBoost 0.9242 0.8715 0.6516 0.9444 0.7163
NeuralNet 0.8983 0.8463 0.6335 0.9168 0.6723
auc.test <- porownanie.test$AUC
names(auc.test) <- rownames(porownanie.test)

par(mar = c(9, 4, 3, 1))
barplot(sort(auc.test, decreasing = TRUE),
        main  = "AUC na zbiorze testowym (wyżej = lepiej)",
        ylab  = "AUC", las = 2, ylim = c(0.8, 1),
        col   = "steelblue", border = NA, cex.names = 0.9)

par(mar = c(5, 4, 4, 2))
roc.logit <- roc(as.numeric(real.test == "wyzszy"), pred.logit.test, quiet = TRUE)
roc.cart  <- roc(as.numeric(real.test == "wyzszy"), pred.cart.test,  quiet = TRUE)
roc.rf    <- roc(as.numeric(real.test == "wyzszy"), pred.rf.test,    quiet = TRUE)
roc.xgb   <- roc(as.numeric(real.test == "wyzszy"), pred.gbm.test,   quiet = TRUE)
roc.nn    <- roc(as.numeric(real.test == "wyzszy"), pred.nn.test,    quiet = TRUE)

plot(roc.logit, col = "gray50",  lwd = 1.5, main = "Krzywe ROC — zbiór testowy")
plot(roc.cart,  col = "steelblue", lwd = 1.5, add = TRUE)
plot(roc.rf,    col = "#2ECC71",  lwd = 1.5, add = TRUE)
plot(roc.xgb,   col = "coral",   lwd = 1.5, add = TRUE)
plot(roc.nn,    col = "#9B59B6", lwd = 1.5, add = TRUE)

legend("bottomright", bty = "n",
       legend = c(paste0("Logit (AUC=",       round(auc(roc.logit), 3), ")"),
                  paste0("CART (AUC=",        round(auc(roc.cart),  3), ")"),
                  paste0("Random Forest (AUC=", round(auc(roc.rf),  3), ")"),
                  paste0("XGBoost (AUC=",     round(auc(roc.xgb),  3), ")"),
                  paste0("nnet (AUC=",        round(auc(roc.nn),   3), ")")),
       col = c("gray50", "steelblue", "#2ECC71", "coral", "#9B59B6"),
       lwd = 2, cex = 0.85)

cat("Różnica AUC (train - test):\n")
## Różnica AUC (train - test):
print(round(porownanie.train$AUC - porownanie.test$AUC, 4))
## [1] 0.0016 0.0324 0.0394 0.0079 0.0438
best.idx <- which.max(porownanie.test$AUC)
cat("Najlepszy model:", rownames(porownanie.test)[best.idx], "\n")
## Najlepszy model: XGBoost
cat("AUC:", porownanie.test$AUC[best.idx], "\n")
## AUC: 0.9242
cat("F1: ", porownanie.test$F1[best.idx], "\n")
## F1:  0.7163

10. Wnioski i interpretacja wyników

Ranking modeli i uzasadnienie:

Najlepsze wyniki osiągnęły XGBoost i Random Forest (AUC ok. 0.93). Obie metody budują wiele drzew i redukują wariancję — XGBoost sekwencyjnie naprawia błędy poprzednich drzew, RF uśrednia niezależnie zbudowane drzewa. W praktyce dają bardzo zbliżone wyniki. XGBoost jest nieco lepszy dzięki mechanizmowi boostingu, który skupia się na trudniejszych do klasyfikacji obserwacjach.

Dlaczego nie accuracy?

Dane są niezbalansowane (75%/25%). Model zawsze przewidujący <=50K osiągałby ~75% accuracy bez żadnej użyteczności. Dlatego jako główną metrykę wybieramy ROC-AUC, która ocenia zdolność modelu do rozróżniania klas niezależnie od progu decyzyjnego.

Najważniejsze zmienne:

We wszystkich modelach dominują net_capital, capital_gain i education_num — kapitał finansowy i poziom wykształcenia są głównymi predyktorami dochodu powyżej 50K. marital_status i relationship są istotne ze względu na silną korelację z innymi zmiennymi socjoekonomicznymi (np. Married-civ-spouse silnie koreluje z typem zatrudnienia i wiekiem).

Analiza niezbalansowania:

Zastosowanie wag klas poprawiło Sensitivity (wykrywalność klasy >50K) kosztem Specificity. Jest to pożądane w kontekście biznesowym — koszt przeoczenia osoby o wysokim dochodzie jest zwykle wyższy niż fałszywy alarm.

Sieć neuronowa:

nnet z jedną warstwą ukrytą daje wynik zbliżony do regresji logistycznej, wyraźnie gorszy od metod zespołowych. Wynika to z ograniczonej architektury — głębsze sieci (keras3) dawałyby lepszy rezultat, ale wymagają oddzielnego środowiska Python/TensorFlow.

Kontekst biznesowy:

Model XGBoost z AUC ~0.93 jest wystarczający do:

  • Scoring kredytowy: wstępna kwalifikacja klientów premium,
  • Marketing: targetowanie kampanii produktów finansowych do grupy o wyższym dochodzie,
  • HR: benchmarking wynagrodzeń — identyfikacja pracowników z profilem >50K, którzy mogą być niedowynagradzani.

Przy cutoff 0.35–0.40 (niższym niż domyślne 0.5) zwiększamy Sensitivity — bardziej opłacalne przy asymetrycznych kosztach błędów.

11. Zapis modeli

saveRDS(adult.logit, "adult.logit.rds")
saveRDS(adult.cart,  "adult.cart.rds")
saveRDS(adult.rf,    "adult.rf.rds")
saveRDS(adult.gbm,   "adult.gbm.rds")
saveRDS(adult.nn,    "adult.nn.rds")
saveRDS(list(mean = mn.train, sd = sd.train), "scaler.nn.rds")

cat("Wszystkie modele zapisane.\n")
## Wszystkie modele zapisane.
cat("Aby wczytać: model <- readRDS('adult.gbm.rds')\n")
## Aby wczytać: model <- readRDS('adult.gbm.rds')