library(tidyverse)             # manipulacja i wizualizacja danych
library(skimr)                 # statystyki opisowe / overview danych
library(DataExplorer)          # eksploracja danych
library(corrplot)              # macierz korelacji
library(PerformanceAnalytics)  # dodatkowe wykresy korelacji
library(janitor)               # tabele częstości i czyszczenie nazw
library(modeldata)             # zbiory danych
library(sf)                    # dane przestrzenne, jeśli używane
library(caret)                 # train/test, confusionMatrix, metryki
library(olsrr)                 # selekcja zmiennych
library(verification)          # narzędzia do oceny klasyfikacji
library(glmnet)                # modele regularyzowane
library(pROC)                  # krzywe ROC i AUC
library(rpart)                 # drzewa decyzyjne
library(rpart.plot)            # wizualizacja drzew
library(cutpointr)             # dobór punktu odcięcia
library(randomForest)          # las losowy
library(ranger)                # szybszy random forest
library(xgboost)               #pamiętajmy o starszej wersji pakietu!
library(neuralnet)             # proste sieci neuronowe
library(NeuralNetTools)        # wizualizacja sieci neuronowych

Wstęp

Poniższe badanie stanowi jeden z dwóch projektów zaliczeniowych przygotowanych w ramach przedmiotu - Uczenie maszynowe nadzorowane (2400-M1ABUMN) - prowadzonego przez mgr Monikę Kot w semestrze letnim 2026. Projekt powstał w ramach współpracy z Małgorzatą Hodurek.

Niniejsza praca skupiona będzie na analizie klasyfikacyjnej zarobków osób zarabiających powyżej 50k USD od cech ekonomiczno-społecznych. Jednak równie istotny aspekt analizy stanowić będzie porównanie modeli nadzorowanego uczenia maszynowego dla zmiennej binarnej

Struktura pracy jest następująca: najpierw przygotowujemy dane do dalszej analizy, następnie szacujemy modele podstawowe - benchmarkowe, względem, których będziemy porównywać modele nadzorowanego uczenia maszynowego. W dalszej sekcji trenowane oraz dostrajane są, kolejno: drzewa klasyfikacyjne, modele bagging, random forest oraz boosting a na końcu sieci neuronowe. Następnie wszystkie modele są ze sobą porównane i wyłonione zostają najlepsze podejścia do analizy i predykcji klasyfikacji zaróbków.

Dane

W badaniu wykorzystujemy próbkę danych na temat zarobków danych osób ze strony [archive.ics.uci.edu] [https://archive.ics.uci.edu/dataset/2/adult]. Zbiór ten opiera się na danych pobranych zebranych w 1994 roku.

# ustawienie ścieżki
getwd()
## [1] "C:/Users/aleks/OneDrive - SGH/Pulpit/R/ML"
# ostrzeżenia w języku polskim
Sys.setenv(LANG = "pl")

# wyłączenie notacji naukowej
options(scipen = 999)

# wczytanie danych 
salary <- read.csv("02_dane_w_SML-20260311/dane/k1_adult.csv")
glimpse(salary)
## Rows: 30,162
## Columns: 18
## $ class           <chr> "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", …
## $ age             <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num   <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status  <chr> "Never-married", "Married-civ-spouse", "Divorced", "Ma…
## $ occupation      <chr> "Adm-clerical", "Exec-managerial", "Handlers-cleaners"…
## $ relationship    <chr> "Not-in-family", "Husband", "Not-in-family", "Husband"…
## $ race            <chr> "White", "White", "White", "Black", "Black", "White", …
## $ sex             <chr> "Male", "Male", "Male", "Male", "Female", "Female", "F…
## $ hours_per_week  <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain    <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss    <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass       <chr> "State-gov", "Self-emp-not-inc", "Private", "Private",…
## $ native_country  <chr> "United-States", "United-States", "United-States", "Un…
## $ net_capital     <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag    <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu     <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age         <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …

W kolejnym etapie przeprowadzono przygotowanie danych do modelowania. Na początku zmieniono nazwy zmiennych na bardziej czytelne oraz przekodowano wartości kategorii na polskie, opisowe etykiety. Dzięki temu dalsza analiza i interpretacja wyników są bardziej intuicyjne.

Następnie zmienne tekstowe zostały przekształcone do typu factor, ponieważ reprezentują cechy kategoryczne, takie jak płeć, stan cywilny, zawód, rasa, typ zatrudnienia czy kraj pochodzenia. Przed dalszym modelowaniem sprawdzono również występowanie braków danych.

W ramach feature engineeringu utworzono kilka nowych zmiennych. Pierwszą z nich była zmienna poziom_edukacji_grupa, która grupuje szczegółowe poziomy edukacji w szersze kategorie: wykształcenie podstawowe, średnie lub niepełne wyższe, associate/licencjackie oraz magisterskie lub wyższe. Takie przekształcenie pozwala uprościć strukturę danych i ułatwia interpretację wpływu edukacji na poziom dochodu. W kodzie widać, że grupowanie zostało wykonane na podstawie zmiennej liczbowej poziom_edukacji za pomocą instrukcji case_when() .

Drugą utworzoną zmienną była kraj_pochodzenia_grupa, która upraszcza zmienną kraju pochodzenia do dwóch kategorii: USA oraz Poza_USA. Zastosowano takie rozwiązanie, ponieważ zdecydowana większość obserwacji pochodzi ze Stanów Zjednoczonych, a pozostałe kraje występują znacznie rzadziej. Połączenie ich w jedną kategorię ogranicza problem małych liczebności w poszczególnych grupach oraz poprawia czytelność modelu .

Dodatkowo zmienna objaśniana klasa_dochodu została przekształcona do postaci binarnej klasa_dochodu_01, gdzie wartość tak oznacza dochód powyżej 50 tys., a wartość nie oznacza dochód nieprzekraczający tego progu. Na potrzeby regresji logistycznej zmienna ta została następnie zakodowana jako 0/1, co jest wymagane przy modelowaniu binarnym za pomocą funkcji glm(…, family = binomial) .

Na końcu przygotowano finalny zbiór danych salary_prep, z którego dla modelu logitowego usunięto zmienne zbędne lub zastąpione przez nowe konstrukcje, takie jak oryginalny kraj_pochodzenia, poziom_edukacji, klasa_dochodu, a także składowe kapitału, które zostały zastąpione zmienną syntetyczną dotyczącą dodatniego kapitału netto.

# Wczytanie słowników z nazwami zmiennych i wartościami kategorii
source("02_dane_w_SML-20260311/slowniki_funkcje/slowniki_wartosci_adult.R")

# Zmiana nazw zmiennych na polskie
salary_pl <- salary %>% 
  rename(any_of(nazwy_zmiennych_pl))

# Przekodowanie wartości kategorii na bardziej opisowe etykiety
salary_pl_2 <- salary_pl %>%
  mutate(
    klasa_dochodu = recode(trimws(as.character(klasa_dochodu)), !!!dict_klasa_dochodu),
    stan_cywilny = recode(trimws(as.character(stan_cywilny)), !!!dict_stan_cywilny),
    zawod = recode(trimws(as.character(zawod)), !!!dict_zawod),
    relacja_w_rodzinie = recode(trimws(as.character(relacja_w_rodzinie)), !!!dict_relacja_w_rodzinie),
    rasa = recode(trimws(as.character(rasa)), !!!dict_rasa),
    plec = recode(trimws(as.character(plec)), !!!dict_plec),
    typ_zatrudnienia = recode(trimws(as.character(typ_zatrudnienia)), !!!dict_typ_zatrudnienia),
    kraj_pochodzenia = recode(trimws(as.character(kraj_pochodzenia)), !!!dict_kraj_pochodzenia),
    czy_kapital_netto_dodatni = recode(trimws(as.character(czy_kapital_netto_dodatni)), !!!dict_czy_kapital_netto_dodatni),
    wiek_sredni = recode(trimws(as.character(wiek_sredni)), !!!dict_wiek_sredni),
    dlugie_godziny_pracy = recode(trimws(as.character(dlugie_godziny_pracy)), !!!dict_dlugie_godziny_pracy)
  )


# Zamiana zmiennych tekstowych na zmienne kategoryczne
salary_pl_3 <- salary_pl_2 %>% 
  mutate(across(where(is.character), as.factor))

glimpse(salary_pl_3)
## Rows: 30,162
## Columns: 18
## $ klasa_dochodu             <fct> dochód_do_50k, dochód_do_50k, dochód_do_50k,…
## $ wiek                      <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, …
## $ poziom_edukacji           <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, …
## $ stan_cywilny              <fct> nigdy_niezamężny_niezonaty, małżeństwo_cywil…
## $ zawod                     <fct> administracja_biurowa, kadra_kierownicza, pr…
## $ relacja_w_rodzinie        <fct> poza_rodziną, mąż, poza_rodziną, mąż, żona, …
## $ rasa                      <fct> biała, biała, biała, czarna, czarna, biała, …
## $ plec                      <fct> mężczyzna, mężczyzna, mężczyzna, mężczyzna, …
## $ godziny_pracy_tygodniowo  <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, …
## $ zysk_kapitalowy           <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0…
## $ strata_kapitalowa         <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ typ_zatrudnienia          <fct> administracja_stanowa, samozatrudniony_bez_o…
## $ kraj_pochodzenia          <fct> Stany_Zjednoczone, Stany_Zjednoczone, Stany_…
## $ kapital_netto             <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0…
## $ czy_kapital_netto_dodatni <fct> kapital_netto_dodatni, kapital_netto_niedoda…
## $ godziny_razy_edukacja     <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, …
## $ wiek_sredni               <fct> w_grupie_wieku_średniego, w_grupie_wieku_śre…
## $ dlugie_godziny_pracy      <fct> nie, nie, nie, nie, nie, nie, nie, tak, tak,…

Feature engineerig

W kolejnym kroku przeanalizowano zależności pomiędzy zmiennymi numerycznymi. W tym celu utworzono macierz korelacji, która pozwala sprawdzić, czy między wybranymi cechami występują silne zależności liniowe. Jest to istotne, ponieważ bardzo silnie skorelowane zmienne mogą powielać podobną informację w modelu.

# Wybór zmiennych numerycznych
num_data_salary <- salary_pl_3 %>% 
  dplyr::select(where(is.numeric))

glimpse(num_data_salary)
## Rows: 30,162
## Columns: 7
## $ wiek                     <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 3…
## $ poziom_edukacji          <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 1…
## $ godziny_pracy_tygodniowo <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 4…
## $ zysk_kapitalowy          <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0,…
## $ strata_kapitalowa        <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ kapital_netto            <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0,…
## $ godziny_razy_edukacja    <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 5…
# Macierz korelacji
cor_mat_salary <- cor(num_data_salary, use = "pairwise.complete.obs")

# Wizualizacja macierzy korelacji
corrplot(
  cor_mat_salary,
  type = "lower",
  tl.cex = 0.7
)

Warto również przeanalizować częstości występowania poszczególnych poziomów zmiennych kategorycznych. W przypadku zmiennej kraju pochodzenia widoczne jest bardzo silne skupienie obserwacji w jednej kategorii, czyli w Stanach Zjednoczonych. Pozostałe kraje występują znacznie rzadziej. Dlatego dla logitu zdecydowano się na połączenie tych kategorii

salary_pl_3 %>%
  count(kraj_pochodzenia, sort = TRUE) %>%
  mutate(kraj_pochodzenia = fct_reorder(kraj_pochodzenia, n)) %>%
  ggplot(aes(x = kraj_pochodzenia, y = n)) +
  geom_col() +
  coord_flip() +
  labs(
    title = "Liczba obserwacji według kraju pochodzenia",
    x = "Kraj pochodzenia",
    y = "Liczba obserwacji"
  ) +
  theme_minimal()

# Uproszczenie kraju pochodzenia do dwóch kategorii
salary_pl_3 <- salary_pl_3 %>%
  mutate(
    kraj_pochodzenia_grupa = if_else(
      kraj_pochodzenia == "Stany_Zjednoczone",
      "USA",
      "Poza_USA",
      missing = NA_character_
    ),
    kraj_pochodzenia_grupa = as.factor(kraj_pochodzenia_grupa)
  )

# Grupowanie poziomu edukacji do szerszych kategorii
salary_pl_3 <- salary_pl_3 %>%
  mutate(
    poziom_edukacji_grupa = case_when(
      poziom_edukacji <= 8 ~ "podstawowe",
      poziom_edukacji %in% c(9, 10) ~ "średnie_lub_niepełne_wyższe",
      poziom_edukacji %in% c(11, 12, 13) ~ "associate_lub_licencjackie",
      poziom_edukacji >= 14 ~ "magisterskie_lub_wyższe",
      TRUE ~ NA_character_
    ),
    poziom_edukacji_grupa = factor(
      poziom_edukacji_grupa,
      levels = c(
        "podstawowe",
        "średnie_lub_niepełne_wyższe",
        "associate_lub_licencjackie",
        "magisterskie_lub_wyższe"
      ),
      ordered = FALSE
    )
  )


# Utworzenie binarnej zmiennej celu
salary_pl_3 <- salary_pl_3 %>%
  mutate(
    klasa_dochodu_01 = if_else(
      klasa_dochodu == "dochód_powyżej_50k",
      "tak",
      "nie"
    ),
    klasa_dochodu_01 = as.factor(klasa_dochodu_01)
  )

# Finalny zbiór do modelu logitowego
salary_prep <- salary_pl_3 %>% 
  dplyr::select(-c(
    kraj_pochodzenia,
    poziom_edukacji,
    klasa_dochodu,
    kapital_netto,
    zysk_kapitalowy,
    strata_kapitalowa
  ))

# Kodowanie zmiennej celu jako 0/1 do regresji logistycznej
salary_prep <- salary_prep %>%
  mutate(
    klasa_dochodu_01 = if_else(klasa_dochodu_01 == "tak", 1, 0)
  )

glimpse(salary_prep)
## Rows: 30,162
## Columns: 15
## $ wiek                      <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, …
## $ stan_cywilny              <fct> nigdy_niezamężny_niezonaty, małżeństwo_cywil…
## $ zawod                     <fct> administracja_biurowa, kadra_kierownicza, pr…
## $ relacja_w_rodzinie        <fct> poza_rodziną, mąż, poza_rodziną, mąż, żona, …
## $ rasa                      <fct> biała, biała, biała, czarna, czarna, biała, …
## $ plec                      <fct> mężczyzna, mężczyzna, mężczyzna, mężczyzna, …
## $ godziny_pracy_tygodniowo  <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, …
## $ typ_zatrudnienia          <fct> administracja_stanowa, samozatrudniony_bez_o…
## $ czy_kapital_netto_dodatni <fct> kapital_netto_dodatni, kapital_netto_niedoda…
## $ godziny_razy_edukacja     <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, …
## $ wiek_sredni               <fct> w_grupie_wieku_średniego, w_grupie_wieku_śre…
## $ dlugie_godziny_pracy      <fct> nie, nie, nie, nie, nie, nie, nie, tak, tak,…
## $ kraj_pochodzenia_grupa    <fct> USA, USA, USA, USA, Poza_USA, USA, Poza_USA,…
## $ poziom_edukacji_grupa     <fct> associate_lub_licencjackie, associate_lub_li…
## $ klasa_dochodu_01          <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,…

Sprawdźmy jeszcze czy w naszym zbiorze nie ma pustych wartości. Z poniższego outputu wynika, że nie mamy do czynienia z żadnymi brakami w danych. Zbiór do dalszej analizy jest gotowy, jak widać usunięto w nim również zmienne które powielały informacje jak kapitał_netto, zysk czy strata netto a pozostawiono zmienną czy kapitał netto dodatni.

colSums(is.na(salary_prep)) %>% sort() # nie ma braków
##                      wiek              stan_cywilny                     zawod 
##                         0                         0                         0 
##        relacja_w_rodzinie                      rasa                      plec 
##                         0                         0                         0 
##  godziny_pracy_tygodniowo          typ_zatrudnienia czy_kapital_netto_dodatni 
##                         0                         0                         0 
##     godziny_razy_edukacja               wiek_sredni      dlugie_godziny_pracy 
##                         0                         0                         0 
##    kraj_pochodzenia_grupa     poziom_edukacji_grupa          klasa_dochodu_01 
##                         0                         0                         0

Podział na zbiór treningowy i testowy

Kluczowym etapem przygotowania zbioru do analizy przy wykorzystaniu technik uczenia maszynowego jest podział całego zbioru na zbiór treningowy (służące do estymacji modelu, wyboru zmiennych oraz strojenia hiperparametrów) oraz zbiór testowy (służące jedynie końcowej ocenie jakości modelu).

W przypadku naszego zbioru, który nie kieruje się żadną chronologią, możemy posłużyć się podziałem losowym. W taki przyapdku zakładamy, że obserwacje są niezależne i pochodzą z tego samego rozkładu (i.i.d.). Jest to prosty sposób podziału, który jest równocześnie zgodny z klasycznym podejściem Uczenia Maszynowego, estymacja błędu w tym przypadku jest stabilna a dane są wykorzystane efektywnie. Natomiast jeśli warunki rynkowe silnie zmieniały się w czasie model uwzględnia te zmiany zarówno na etapie treningu, jak i testowania.

W naszje analizie posłużymy się podziałem 70/30.

set.seed(123)
trening <- createDataPartition(salary_prep$klasa_dochodu_01,
                               p = 0.7,
                               list = FALSE) 

salary.train <- salary_prep[c(trening),]
salary.test <- salary_prep[-c(trening),]

Modele Benchmarkowe

Model Logit

Jako pierwszy model benchmarkowy zastosowano logit, czyli klasyczny model wykorzystywany wtedy, gdy zmienna objaśniana ma charakter binarny. W naszym przypadku zmienna klasa_dochodu_01 przyjmuje wartość 1, gdy dana osoba osiąga dochód powyżej 50 tys. USD rocznie, oraz wartość 0 w przeciwnym przypadku.

Model logitowy nie przewiduje bezpośrednio klasy, lecz prawdopodobieństwo przynależności do klasy pozytywnej. Oznacza to, że dla każdej obserwacji model zwraca wartość z przedziału od 0 do 1, którą można interpretować jako prawdopodobieństwo osiągania dochodu powyżej 50 tys. USD.

Budowa modelu

W pierwszym kroku oszacowano pełny model regresji logistycznej, wykorzystując wszystkie dostępne zmienne objaśniające znajdujące się w przygotowanym zbiorze danych.

glm.full.salary <- glm(
  klasa_dochodu_01 ~ .,
  data = salary.train,
  family = binomial
)

summary(glm.full.salary)
## 
## Call:
## glm(formula = klasa_dochodu_01 ~ ., family = binomial, data = salary.train)
## 
## Coefficients:
##                                                           Estimate  Std. Error
## (Intercept)                                             -2.6754588   0.3807457
## wiek                                                     0.0294486   0.0020325
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych           1.2242191   0.5656714
## stan_cywilnymałżonek_nieobecny                          -2.1825524   0.4047527
## stan_cywilnynigdy_niezamężny_niezonaty                  -2.4892457   0.3120541
## stan_cywilnyrozwiedziony_rozwiedziona                   -2.2613339   0.3189205
## stan_cywilnyw_separacji                                 -2.2664291   0.3536429
## stan_cywilnywdowiec_wdowa                               -1.9068868   0.3539218
## zawodinne_usługi                                        -0.9122487   0.1384000
## zawodkadra_kierownicza                                   0.8144842   0.0895568
## zawodoperatorzy_maszyn_i_inspektorzy                    -0.3997932   0.1179369
## zawodpracownicy_fizyczni_i_sprzątający                  -0.8191188   0.1708117
## zawodprywatne_usługi_domowe                            -12.6204938 127.4793512
## zawodrolnictwo_i_rybołówstwo                            -0.8856543   0.1583252
## zawodrzemiosło_i_naprawy                                -0.0157938   0.0933926
## zawodsiły_zbrojne                                       -0.3556816   1.4053650
## zawodspecjaliści                                         0.5813813   0.0946097
## zawodsprzedaż                                            0.3530584   0.0958303
## zawodtransport_i_przemieszczanie                        -0.0892285   0.1154147
## zawodusługi_ochrony                                      0.5981511   0.1467605
## zawodwsparcie_techniczne                                 0.5923707   0.1301816
## relacja_w_rodziniemąż                                    0.2829615   0.2793491
## relacja_w_rodzinieniezamężny_niezonaty                   0.4970554   0.3075370
## relacja_w_rodziniepoza_rodziną                           0.8819446   0.2914316
## relacja_w_rodziniewłasne_dziecko                        -0.1253493   0.3071854
## relacja_w_rodzinieżona                                   1.5351340   0.2913368
## rasabiała                                               -0.0635356   0.1394250
## rasaczarna                                              -0.1422131   0.1611755
## rasainna                                                -0.6782135   0.3413477
## rasardzenni_amerykanie_lub_eskimosi                     -0.5841652   0.2987833
## plecmężczyzna                                            0.7672716   0.0903320
## godziny_pracy_tygodniowo                                -0.0256287   0.0061979
## typ_zatrudnieniaadministracja_lokalna                   -0.6376198   0.1326488
## typ_zatrudnieniaadministracja_stanowa                   -0.7654349   0.1455334
## typ_zatrudnieniabez_wynagrodzenia                      -13.7862623 401.4844552
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej  -0.8419172   0.1279671
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną    -0.2362106   0.1426511
## typ_zatrudnieniasektor_prywatny                         -0.4448359   0.1102385
## czy_kapital_netto_dodatnikapital_netto_niedodatni       -1.5661187   0.0672901
## godziny_razy_edukacja                                    0.0039352   0.0005241
## wiek_sredniw_grupie_wieku_średniego                      0.7070402   0.0505085
## dlugie_godziny_pracytak                                  0.3014832   0.0609898
## kraj_pochodzenia_grupaUSA                                0.2247643   0.0921824
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe         0.3165938   0.1228674
## poziom_edukacji_grupaassociate_lub_licencjackie          0.5127905   0.1777039
## poziom_edukacji_grupamagisterskie_lub_wyższe             0.8351007   0.2232604
##                                                        z value
## (Intercept)                                             -7.027
## wiek                                                    14.489
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych           2.164
## stan_cywilnymałżonek_nieobecny                          -5.392
## stan_cywilnynigdy_niezamężny_niezonaty                  -7.977
## stan_cywilnyrozwiedziony_rozwiedziona                   -7.091
## stan_cywilnyw_separacji                                 -6.409
## stan_cywilnywdowiec_wdowa                               -5.388
## zawodinne_usługi                                        -6.591
## zawodkadra_kierownicza                                   9.095
## zawodoperatorzy_maszyn_i_inspektorzy                    -3.390
## zawodpracownicy_fizyczni_i_sprzątający                  -4.795
## zawodprywatne_usługi_domowe                             -0.099
## zawodrolnictwo_i_rybołówstwo                            -5.594
## zawodrzemiosło_i_naprawy                                -0.169
## zawodsiły_zbrojne                                       -0.253
## zawodspecjaliści                                         6.145
## zawodsprzedaż                                            3.684
## zawodtransport_i_przemieszczanie                        -0.773
## zawodusługi_ochrony                                      4.076
## zawodwsparcie_techniczne                                 4.550
## relacja_w_rodziniemąż                                    1.013
## relacja_w_rodzinieniezamężny_niezonaty                   1.616
## relacja_w_rodziniepoza_rodziną                           3.026
## relacja_w_rodziniewłasne_dziecko                        -0.408
## relacja_w_rodzinieżona                                   5.269
## rasabiała                                               -0.456
## rasaczarna                                              -0.882
## rasainna                                                -1.987
## rasardzenni_amerykanie_lub_eskimosi                     -1.955
## plecmężczyzna                                            8.494
## godziny_pracy_tygodniowo                                -4.135
## typ_zatrudnieniaadministracja_lokalna                   -4.807
## typ_zatrudnieniaadministracja_stanowa                   -5.260
## typ_zatrudnieniabez_wynagrodzenia                       -0.034
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej  -6.579
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną    -1.656
## typ_zatrudnieniasektor_prywatny                         -4.035
## czy_kapital_netto_dodatnikapital_netto_niedodatni      -23.274
## godziny_razy_edukacja                                    7.509
## wiek_sredniw_grupie_wieku_średniego                     13.998
## dlugie_godziny_pracytak                                  4.943
## kraj_pochodzenia_grupaUSA                                2.438
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe         2.577
## poziom_edukacji_grupaassociate_lub_licencjackie          2.886
## poziom_edukacji_grupamagisterskie_lub_wyższe             3.740
##                                                                    Pr(>|z|)    
## (Intercept)                                              0.0000000000021118 ***
## wiek                                                   < 0.0000000000000002 ***
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych                     0.030450 *  
## stan_cywilnymałżonek_nieobecny                           0.0000000695572516 ***
## stan_cywilnynigdy_niezamężny_niezonaty                   0.0000000000000015 ***
## stan_cywilnyrozwiedziony_rozwiedziona                    0.0000000000013354 ***
## stan_cywilnyw_separacji                                  0.0000000001466635 ***
## stan_cywilnywdowiec_wdowa                                0.0000000712951907 ***
## zawodinne_usługi                                         0.0000000000435724 ***
## zawodkadra_kierownicza                                 < 0.0000000000000002 ***
## zawodoperatorzy_maszyn_i_inspektorzy                               0.000699 ***
## zawodpracownicy_fizyczni_i_sprzątający                   0.0000016231096794 ***
## zawodprywatne_usługi_domowe                                        0.921138    
## zawodrolnictwo_i_rybołówstwo                             0.0000000222033528 ***
## zawodrzemiosło_i_naprawy                                           0.865709    
## zawodsiły_zbrojne                                                  0.800200    
## zawodspecjaliści                                         0.0000000007993977 ***
## zawodsprzedaż                                                      0.000229 ***
## zawodtransport_i_przemieszczanie                                   0.439456    
## zawodusługi_ochrony                                      0.0000458768836760 ***
## zawodwsparcie_techniczne                                 0.0000053559060923 ***
## relacja_w_rodziniemąż                                              0.311093    
## relacja_w_rodzinieniezamężny_niezonaty                             0.106041    
## relacja_w_rodziniepoza_rodziną                                     0.002476 ** 
## relacja_w_rodziniewłasne_dziecko                                   0.683232    
## relacja_w_rodzinieżona                                   0.0000001369627385 ***
## rasabiała                                                          0.648607    
## rasaczarna                                                         0.377588    
## rasainna                                                           0.046937 *  
## rasardzenni_amerykanie_lub_eskimosi                                0.050566 .  
## plecmężczyzna                                          < 0.0000000000000002 ***
## godziny_pracy_tygodniowo                                 0.0000354877337519 ***
## typ_zatrudnieniaadministracja_lokalna                    0.0000015334482761 ***
## typ_zatrudnieniaadministracja_stanowa                    0.0000001444358345 ***
## typ_zatrudnieniabez_wynagrodzenia                                  0.972607    
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej   0.0000000000473082 ***
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną               0.097750 .  
## typ_zatrudnieniasektor_prywatny                          0.0000545525206521 ***
## czy_kapital_netto_dodatnikapital_netto_niedodatni      < 0.0000000000000002 ***
## godziny_razy_edukacja                                    0.0000000000000596 ***
## wiek_sredniw_grupie_wieku_średniego                    < 0.0000000000000002 ***
## dlugie_godziny_pracytak                                  0.0000007685947367 ***
## kraj_pochodzenia_grupaUSA                                          0.014758 *  
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe                   0.009974 ** 
## poziom_edukacji_grupaassociate_lub_licencjackie                    0.003906 ** 
## poziom_edukacji_grupamagisterskie_lub_wyższe                       0.000184 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 23695  on 21113  degrees of freedom
## Residual deviance: 14413  on 21068  degrees of freedom
## AIC: 14505
## 
## Number of Fisher Scoring iterations: 14

Predykcja

Następnie wyznaczono przewidywane prawdopodobieństwa dla zbioru treningowego. W funkcji predict() zastosowano argument type = “response”, ponieważ dla modelu logistycznego pozwala on otrzymać prawdopodobieństwa przynależności do klasy pozytywnej.

pred.glm.train.prob <- predict( glm.full.salary, newdata = salary.train, type = "response" ) 

summary(pred.glm.train.prob) 
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## 0.00000 0.01913 0.12163 0.24889 0.42026 0.99238
head(pred.glm.train.prob)
##          2          3          4          5          7          9 
## 0.49164947 0.03160871 0.11627241 0.41251300 0.00678185 0.56088961

Krzywa ROC i optymalny próg odcięcia

W klasyfikacji binarnej standardowo przyjmuje się próg odcięcia równy 0.5. W przypadku danych niezbalansowanych taki próg może jednak prowadzić do zbyt słabego wykrywania klasy mniejszościowej. Ponieważ osoby zarabiające powyżej 50 tys. USD stanowią mniejszą część obserwacji, wyznaczono optymalny próg odcięcia na podstawie krzywej ROC.

krzywa_roc <- pROC::roc(
  response = salary.train$klasa_dochodu_01,
  predictor = pred.glm.train.prob
)

plot(
  krzywa_roc,
  main = "Krzywa ROC - zbiór treningowy",
  col = "blue",
  lwd = 2,
  print.auc = TRUE
)

optymalny_prog <- pROC::coords(
  krzywa_roc,
  x = "best",
  best.method = "youden",
  ret = c("threshold", "specificity", "sensitivity")
)

optymalny_prog
##   threshold specificity sensitivity
## 1 0.2346138   0.7711079   0.8614653
optymalny_prog_glm <- optymalny_prog$threshold[1]

Optymalnym tresholdem= został punkt 0,2346

Klasyfikacja na zbiorze treningowym

Po wyznaczeniu optymalnego progu odcięcia przewidywane prawdopodobieństwa dla zbioru treningowego zostały przekształcone na klasy. Jeżeli przewidywane prawdopodobieństwo było większe lub równe progowi optymalny_prog_glm, obserwacja została przypisana do klasy 1, czyli do grupy osób zarabiających powyżej 50 tys. USD. W przeciwnym przypadku obserwacja została przypisana do klasy 0.

pred.glm.train.class <- ifelse(
  pred.glm.train.prob >= optymalny_prog_glm,
  1,
  0
)

pred.glm.train.class <- factor(pred.glm.train.class, levels = c(0, 1))
salary.train$klasa_dochodu_01 <- factor(salary.train$klasa_dochodu_01, levels = c(0, 1))

cm.train <- confusionMatrix(
  pred.glm.train.class,
  salary.train$klasa_dochodu_01,
  positive = "1"
)

cm.train
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction     0     1
##          0 12229   728
##          1  3630  4527
##                                                
##                Accuracy : 0.7936               
##                  95% CI : (0.7881, 0.799)      
##     No Information Rate : 0.7511               
##     P-Value [Acc > NIR] : < 0.00000000000000022
##                                                
##                   Kappa : 0.534                
##                                                
##  Mcnemar's Test P-Value : < 0.00000000000000022
##                                                
##             Sensitivity : 0.8615               
##             Specificity : 0.7711               
##          Pos Pred Value : 0.5550               
##          Neg Pred Value : 0.9438               
##              Prevalence : 0.2489               
##          Detection Rate : 0.2144               
##    Detection Prevalence : 0.3863               
##       Balanced Accuracy : 0.8163               
##                                                
##        'Positive' Class : 1                    
## 
cm.train$table
##           Reference
## Prediction     0     1
##          0 12229   728
##          1  3630  4527

Macierz pomyłek pozwala sprawdzić, ile obserwacji zostało sklasyfikowanych poprawnie oraz jakie rodzaje błędów popełnia model. W przypadku tego problemu szczególnie ważna jest czułość, czyli zdolność modelu do wykrywania osób rzeczywiście zarabiających powyżej 50 tys. USD.

ggplot(
  data.frame(
    p = pred.glm.train.prob,
    klasa = salary.train$klasa_dochodu_01
  ),
  aes(x = p, fill = klasa)
) +
  geom_histogram(alpha = 0.5, bins = 30, position = "identity") +
  labs(
    title = "Rozkład przewidywanych prawdopodobieństw - zbiór treningowy",
    x = "P(dochód powyżej 50 tys. USD)",
    y = "Liczba obserwacji",
    fill = "Klasa"
  ) +
  theme_minimal()

Powyższy histogram przedstawia rozkład przewidywanych prawdopodobieństw w podziale na rzeczywiste klasy. Im bardziej rozkłady dla klas 0 i 1 są od siebie oddzielone, tym lepsza jest zdolność modelu do rozróżniania osób o niższych i wyższych dochodach.

Predykcja na zbiorze testowym

Następnie model zastosowano do zbioru testowego, który nie był wykorzystywany podczas estymacji. Dzięki temu można ocenić, jak model radzi sobie na nowych danych.

pred.glm.test.prob <- predict(
  glm.full.salary,
  newdata = salary.test,
  type = "response"
)

head(pred.glm.test.prob)
##          1          6          8         12         24         25 
## 0.32515206 0.83682007 0.55955963 0.43759321 0.03727325 0.52690441
summary(pred.glm.test.prob)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## 0.00000 0.01922 0.11915 0.24933 0.42175 0.99545

Po otrzymaniu prawdopodobieństw dla zbioru testowego zastosowano ten sam próg odcięcia, który został wyznaczony na podstawie zbioru treningowego. Takie podejście jest poprawne, ponieważ próg nie powinien być dobierany na podstawie danych testowych.

pred.glm.test.class <- ifelse(
  pred.glm.test.prob >= optymalny_prog_glm,
  1,
  0
)

pred.glm.test.class <- factor(pred.glm.test.class, levels = c(0, 1))
salary.test$klasa_dochodu_01 <- factor(salary.test$klasa_dochodu_01, levels = c(0, 1))

cm.test <- confusionMatrix(
  pred.glm.test.class,
  salary.test$klasa_dochodu_01,
  positive = "1"
)

cm.test
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    0    1
##          0 5227  296
##          1 1568 1957
##                                                
##                Accuracy : 0.794                
##                  95% CI : (0.7855, 0.8023)     
##     No Information Rate : 0.751                
##     P-Value [Acc > NIR] : < 0.00000000000000022
##                                                
##                   Kappa : 0.5366               
##                                                
##  Mcnemar's Test P-Value : < 0.00000000000000022
##                                                
##             Sensitivity : 0.8686               
##             Specificity : 0.7692               
##          Pos Pred Value : 0.5552               
##          Neg Pred Value : 0.9464               
##              Prevalence : 0.2490               
##          Detection Rate : 0.2163               
##    Detection Prevalence : 0.3896               
##       Balanced Accuracy : 0.8189               
##                                                
##        'Positive' Class : 1                    
## 
cm.test$table
##           Reference
## Prediction    0    1
##          0 5227  296
##          1 1568 1957

Ocena na zbiorze testowym jest kluczowa, ponieważ pozwala sprawdzić zdolność generalizacji modelu. Jeżeli wyniki na zbiorze treningowym i testowym są do siebie zbliżone, można uznać, że model nie wykazuje silnych oznak przeuczenia.

ggplot(
  data.frame(
    p = pred.glm.test.prob,
    klasa = salary.test$klasa_dochodu_01
  ),
  aes(x = p, fill = klasa)
) +
  geom_histogram(alpha = 0.5, bins = 30, position = "identity") +
  labs(
    title = "Rozkład przewidywanych prawdopodobieństw - zbiór testowy",
    x = "P(dochód powyżej 50 tys. USD)",
    y = "Liczba obserwacji",
    fill = "Klasa"
  ) +
  theme_minimal()

Funkcja do oceny jakości klasyfikacji

W celu uporządkowania oceny jakości modeli przygotowano funkcję, która dla zadanych prawdopodobieństw, wartości rzeczywistych oraz progu odcięcia oblicza najważniejsze metryki klasyfikacji binarnej. Funkcja zwraca między innymi Accuracy, Sensitivity, Specificity, Balanced Accuracy, PPV, NPV, F1 oraz Kappa.

klasyfikacja.metryki <- function(predicted_probabilities, 
                                 real, 
                                 cutoff = 0.5,
                                 level_positive, 
                                 level_negative) {
  
  predicted_class <- ifelse(
    predicted_probabilities >= cutoff,
    level_positive,
    level_negative
  )
  
  predicted_class <- factor(
    predicted_class,
    levels = c(level_negative, level_positive)
  )
  
  real <- factor(
    real,
    levels = c(level_negative, level_positive)
  )
  
  ctable <- confusionMatrix(
    data = predicted_class,
    reference = real,
    positive = as.character(level_positive)
  )
  
  statystyki <- c(
    Accuracy = unname(ctable$overall["Accuracy"]),
    Sensitivity = unname(ctable$byClass["Sensitivity"]),
    Specificity = unname(ctable$byClass["Specificity"]),
    BalancedAccuracy = unname(ctable$byClass["Balanced Accuracy"]),
    PPV = unname(ctable$byClass["Pos Pred Value"]),
    NPV = unname(ctable$byClass["Neg Pred Value"]),
    F1 = unname(ctable$byClass["F1"]),
    Kappa = unname(ctable$overall["Kappa"])
  )
  
  wynik <- round(statystyki, 4)
  
  return(wynik)
}

Ocena jakości modelu na zbiorze treningowym i testowym

Następnie obliczono metryki jakości dla zbioru treningowego i testowego przy wykorzystaniu optymalnego progu odcięcia.

metryki.glm.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.glm.train.prob,
  real = salary.train$klasa_dochodu_01,
  cutoff = optymalny_prog_glm,
  level_positive = "1",
  level_negative = "0"
)

metryki.glm.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.glm.test.prob,
  real = salary.test$klasa_dochodu_01,
  cutoff = optymalny_prog_glm,
  level_positive = "1",
  level_negative = "0"
)

rbind(
  train = metryki.glm.train,
  test = metryki.glm.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.7936      0.8615      0.7711           0.8163 0.5550 0.9438 0.6751
## test    0.7940      0.8686      0.7692           0.8189 0.5552 0.9464 0.6774
##        Kappa
## train 0.5340
## test  0.5366

Porównanie wyników dla zbioru treningowego i testowego pozwala ocenić stabilność modelu. Co ważne niewidoczny jest overfiting, model przyjmuje również całkiem zadowaloające wartości predyckyjne. A co zaskakujące minimalnie lepsze Sensivity przyjmuje nawet dla danych testowych w porównaniu z treningowymi.

Porównanie progu 0.5 i progu optymalnego

Na końcu porównano wyniki modelu dla standardowego progu 0.5 oraz progu optymalnego wyznaczonego na podstawie krzywej ROC.

metryki.glm.test.standard <- klasyfikacja.metryki(
  predicted_probabilities = pred.glm.test.prob,
  real = salary.test$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "1",
  level_negative = "0"
)

rbind(
  logit.cutoff.0.5 = metryki.glm.test.standard,
  logit.cutoff.optymalny = metryki.glm.test
)
##                        Accuracy Sensitivity Specificity BalancedAccuracy    PPV
## logit.cutoff.0.5         0.8443      0.5992      0.9255           0.7624 0.7274
## logit.cutoff.optymalny   0.7940      0.8686      0.7692           0.8189 0.5552
##                           NPV     F1  Kappa
## logit.cutoff.0.5       0.8744 0.6571 0.5576
## logit.cutoff.optymalny 0.9464 0.6774 0.5366

Zauważmy, jak zmiana progu z domyślnego 0.5 na wyliczone 0.2246 zmieniła wyniki dla modelu Ogólne Accuracy minimalnie spadło (z około 84,4% na 79,40%), ale za to Sensitivity (czułość) drastycznie wzrosła! Model z nowym progiem poprawnie wykrywa ponad 86% osób zarabiających powyżej 50k (TP: 4527), podczas gdy na progu 0.5 wykrywał ich tylko około 59% (wtedy TP wynosiło 3107). To świetny dowód na to, dlaczego szukanie optymalnego progu odcięcia ma tak ogromne znaczenie na niezbalansowanych zbiorach.

Drzewa klasyfikacyjne

Drugim modelem benchmarkowym wykorzystanym w analizie jest drzewo klasyfikacyjne CART. W przeciwieństwie do modelu logitowego, drzewo decyzyjne nie zakłada liniowej zależności między zmiennymi objaśniającymi a zmienną celu. Model ten dzieli obserwacje na coraz bardziej jednorodne grupy, wykorzystując kolejne reguły decyzyjne. Jego zaletą jest wysoka interpretowalność, ponieważ wyniki można przedstawić w postaci graficznego drzewa.

Przygotowanie danych do modeli drzewiastych

Na potrzeby modeli uczenia maszynowego przygotowano osobny zbiór danych. W przypadku drzew decyzyjnych zmienna celu pozostaje faktorem z poziomami nie oraz tak, ponieważ modele klasyfikacyjne w pakiecie rpart oraz caret wymagają, aby zmienna objaśniana była zmienną kategoryczną.

salary_ML <- salary_pl_3 %>% 
  dplyr::select(-c(kraj_pochodzenia_grupa, poziom_edukacji_grupa, klasa_dochodu))

Ponadto zmienne dla edukacji postanowiono odgrupować, aby modele ML miały cały zakres zmiennych do dyspozycji

salary_ML <- salary_ML %>%
  mutate(
    poziom_edukacji_factor = case_when(
      poziom_edukacji == 1  ~ "przedszkole",
      poziom_edukacji == 2  ~ "klasy_1_4",
      poziom_edukacji == 3  ~ "klasy_5_6",
      poziom_edukacji == 4  ~ "klasy_7_8",
      poziom_edukacji == 5  ~ "klasa_9",
      poziom_edukacji == 6  ~ "klasa_10",
      poziom_edukacji == 7  ~ "klasa_11",
      poziom_edukacji == 8  ~ "klasa_12",
      poziom_edukacji == 9  ~ "szkoła_średnia",
      poziom_edukacji == 10 ~ "część_studiów",
      poziom_edukacji == 11 ~ "associate_zawodowe",
      poziom_edukacji == 12 ~ "associate_akademickie",
      poziom_edukacji == 13 ~ "licencjat",
      poziom_edukacji == 14 ~ "magister",
      poziom_edukacji == 15 ~ "szkoła_profesjonalna",
      poziom_edukacji == 16 ~ "doktorat",
      TRUE ~ NA_character_
    ),
    
    poziom_edukacji_factor = factor(
      poziom_edukacji_factor,
      levels = c(
        "przedszkole",
        "klasy_1_4",
        "klasy_5_6",
        "klasy_7_8",
        "klasa_9",
        "klasa_10",
        "klasa_11",
        "klasa_12",
        "szkoła_średnia",
        "część_studiów",
        "associate_zawodowe",
        "associate_akademickie",
        "licencjat",
        "magister",
        "szkoła_profesjonalna",
        "doktorat"
      ),
      ordered = FALSE
    )
  )

salary_ML <- salary_ML %>% 
  dplyr::select(-poziom_edukacji)

Nie potraktowano poziomu edukacji jako zmiennej porządkowej, ponieważ niektóre poziomy nie muszą tworzyć jednoznacznej hierarchii jakościowej. Przykładowo poziom szkoła_profesjonalna nie musi być wprost porównywalny z poziomem licencjat, dlatego zmienną potraktowano jako zwykły factor.

Końcowe przygotowanie danych dla ML

salary_ML$klasa_dochodu_01 <- factor(
  salary_ML$klasa_dochodu_01,
  levels = c("nie", "tak")
)

set.seed(123)

trening_ML <- createDataPartition(
  salary_ML$klasa_dochodu_01,
  p = 0.7,
  list = FALSE
) 

salary.train.ML <- salary_ML[c(trening_ML), ]
salary.test.ML <- salary_ML[-c(trening_ML), ]

Drzewo domyślne

W pierwszym kroku oszacowano podstawowe drzewo klasyfikacyjne z domyślnymi ustawieniami funkcji rpart(). Model ten stanowi punkt odniesienia dla późniejszego strojenia hiperparametrów.

tree.salary.default <- rpart(
  klasa_dochodu_01 ~ .,
  data = salary.train.ML,
  method = "class"
)

tree.salary.default
## n= 21114 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 21114 5256 nie (0.75106564 0.24893436)  
##    2) relacja_w_rodzinie=inny_krewny,niezamężny_niezonaty,poza_rodziną,własne_dziecko 11367  766 nie (0.93261195 0.06738805)  
##      4) zysk_kapitalowy< 7073.5 11166  570 nie (0.94895218 0.05104782) *
##      5) zysk_kapitalowy>=7073.5 201    5 tak (0.02487562 0.97512438) *
##    3) relacja_w_rodzinie=mąż,żona 9747 4490 nie (0.53934544 0.46065456)  
##      6) poziom_edukacji_factor=przedszkole,klasy_1_4,klasy_5_6,klasy_7_8,klasa_9,klasa_10,klasa_11,klasa_12,szkoła_średnia,część_studiów,associate_zawodowe,associate_akademickie 6856 2355 nie (0.65650525 0.34349475)  
##       12) zysk_kapitalowy< 5095.5 6500 2006 nie (0.69138462 0.30861538) *
##       13) zysk_kapitalowy>=5095.5 356    7 tak (0.01966292 0.98033708) *
##      7) poziom_edukacji_factor=licencjat,magister,szkoła_profesjonalna,doktorat 2891  756 tak (0.26150121 0.73849879) *
summary(tree.salary.default)
## Call:
## rpart(formula = klasa_dochodu_01 ~ ., data = salary.train.ML, 
##     method = "class")
##   n= 21114 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.13118341      0 1.0000000 1.0000000 0.01195395
## 2 0.06506849      2 0.7376332 0.7376332 0.01070380
## 3 0.03633942      3 0.6725647 0.6725647 0.01032170
## 4 0.01000000      4 0.6362253 0.6362253 0.01009337
## 
## Variable importance
##        relacja_w_rodzinie              stan_cywilny             kapital_netto 
##                        22                        22                         9 
##           zysk_kapitalowy    poziom_edukacji_factor                      plec 
##                         9                         9                         7 
##                     zawod     godziny_razy_edukacja                      wiek 
##                         6                         5                         5 
##               wiek_sredni czy_kapital_netto_dodatni 
##                         4                         1 
## 
## Node number 1: 21114 observations,    complexity param=0.1311834
##   predicted class=nie  expected loss=0.2489344  P(node) =1
##     class counts: 15858  5256
##    probabilities: 0.751 0.249 
##   left son=2 (11367 obs) right son=3 (9747 obs)
##   Primary splits:
##       relacja_w_rodzinie    splits as  LRLLLR,     improve=1623.1180, (0 missing)
##       stan_cywilny          splits as  RRLLLLL,    improve=1597.2620, (0 missing)
##       zysk_kapitalowy       < 5119   to the left,  improve=1081.9280, (0 missing)
##       kapital_netto         < 5119   to the left,  improve=1081.9280, (0 missing)
##       godziny_razy_edukacja < 488    to the left,  improve= 933.1938, (0 missing)
##   Surrogate splits:
##       stan_cywilny splits as  RRLLLLL, agree=0.993, adj=0.986, (0 split)
##       plec         splits as  LR, agree=0.690, adj=0.327, (0 split)
##       wiek         < 33.5   to the left,  agree=0.645, adj=0.231, (0 split)
##       zawod        splits as  LLRLLLRRLRLRRL, agree=0.619, adj=0.174, (0 split)
##       wiek_sredni  splits as  LR, agree=0.615, adj=0.166, (0 split)
## 
## Node number 2: 11367 observations,    complexity param=0.03633942
##   predicted class=nie  expected loss=0.06738805  P(node) =0.5383632
##     class counts: 10601   766
##    probabilities: 0.933 0.067 
##   left son=4 (11166 obs) right son=5 (201 obs)
##   Primary splits:
##       zysk_kapitalowy           < 7073.5 to the left,  improve=337.20480, (0 missing)
##       kapital_netto             < 7073.5 to the left,  improve=337.20480, (0 missing)
##       godziny_razy_edukacja     < 559.5  to the left,  improve=128.56960, (0 missing)
##       czy_kapital_netto_dodatni splits as  RL, improve=106.51450, (0 missing)
##       poziom_edukacji_factor    splits as  LLLLLLLLLLLLLRRR, improve= 95.72629, (0 missing)
##   Surrogate splits:
##       kapital_netto < 7073.5 to the left,  agree=1, adj=1, (0 split)
## 
## Node number 3: 9747 observations,    complexity param=0.1311834
##   predicted class=nie  expected loss=0.4606546  P(node) =0.4616368
##     class counts:  5257  4490
##    probabilities: 0.539 0.461 
##   left son=6 (6856 obs) right son=7 (2891 obs)
##   Primary splits:
##       poziom_edukacji_factor splits as  LLLLLLLLLLLLRRRR, improve=634.5721, (0 missing)
##       zawod                  splits as  RLRLLLLLRRRLRR, improve=600.1650, (0 missing)
##       godziny_razy_edukacja  < 482.5  to the left,  improve=559.9230, (0 missing)
##       zysk_kapitalowy        < 5095.5 to the left,  improve=492.1049, (0 missing)
##       kapital_netto          < 5095.5 to the left,  improve=492.1049, (0 missing)
##   Surrogate splits:
##       godziny_razy_edukacja < 517    to the left,  agree=0.880, adj=0.597, (0 split)
##       zawod                 splits as  LLRLLLLLRRLLLL, agree=0.792, adj=0.300, (0 split)
##       zysk_kapitalowy       < 7493   to the left,  agree=0.720, adj=0.055, (0 split)
##       kapital_netto         < 7493   to the left,  agree=0.720, adj=0.055, (0 split)
##       kraj_pochodzenia      splits as  RRLLRLLLL-RLRRLLRLLLLRLLLLLLLLLLLLLRLLLLL, agree=0.713, adj=0.031, (0 split)
## 
## Node number 4: 11166 observations
##   predicted class=nie  expected loss=0.05104782  P(node) =0.5288434
##     class counts: 10596   570
##    probabilities: 0.949 0.051 
## 
## Node number 5: 201 observations
##   predicted class=tak  expected loss=0.02487562  P(node) =0.00951975
##     class counts:     5   196
##    probabilities: 0.025 0.975 
## 
## Node number 6: 6856 observations,    complexity param=0.06506849
##   predicted class=nie  expected loss=0.3434947  P(node) =0.3247135
##     class counts:  4501  2355
##    probabilities: 0.657 0.343 
##   left son=12 (6500 obs) right son=13 (356 obs)
##   Primary splits:
##       zysk_kapitalowy        < 5095.5 to the left,  improve=304.5799, (0 missing)
##       kapital_netto          < 5095.5 to the left,  improve=304.5799, (0 missing)
##       zawod                  splits as  RLRLLLLL-RRLRR, improve=174.9527, (0 missing)
##       godziny_razy_edukacja  < 347    to the left,  improve=143.5213, (0 missing)
##       poziom_edukacji_factor splits as  LLLLLLLRRRRR----, improve=121.2879, (0 missing)
##   Surrogate splits:
##       kapital_netto             < 5095.5 to the left,  agree=1.000, adj=1.000, (0 split)
##       czy_kapital_netto_dodatni splits as  RL,         agree=0.956, adj=0.143, (0 split)
## 
## Node number 7: 2891 observations
##   predicted class=tak  expected loss=0.2615012  P(node) =0.1369234
##     class counts:   756  2135
##    probabilities: 0.262 0.738 
## 
## Node number 12: 6500 observations
##   predicted class=nie  expected loss=0.3086154  P(node) =0.3078526
##     class counts:  4494  2006
##    probabilities: 0.691 0.309 
## 
## Node number 13: 356 observations
##   predicted class=tak  expected loss=0.01966292  P(node) =0.01686085
##     class counts:     7   349
##    probabilities: 0.020 0.980

Poniżej przedstawiono podstawową wizualizację drzewa decyzyjnego.

rpart.plot(tree.salary.default)

Dla większej czytelności przygotowano również bardziej rozbudowaną wizualizację drzewa, pokazującą dodatkowe informacje w węzłach końcowych.

rpart.plot(
  tree.salary.default,
  type = 2,
  extra = 104,
  fallen.leaves = TRUE,
  nn = TRUE,
  box.palette = "Blues",
  branch.lty = 2,
  shadow.col = "gray"
)

Predykcja dla drzewa domyślnego

Następnie wyznaczono przewidywane prawdopodobieństwa przynależności do klasy tak, osobno dla zbioru treningowego i testowego.

pred.tree.default.train.prob <- predict(
  tree.salary.default,
  newdata = salary.train.ML,
  type = "prob"
)[, "tak"]

pred.tree.default.test.prob <- predict(
  tree.salary.default,
  newdata = salary.test.ML,
  type = "prob"
)[, "tak"]

Ocena drzewa domyślnego została przeprowadzona przy domyślnym progu odcięcia równym 0.5.

metryki.tree.default.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.default.train.prob,
  real = salary.train.ML$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.tree.default.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.default.test.prob,
  real = salary.test.ML$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.tree.default.train,
  test = metryki.tree.default.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.8416      0.5099      0.9516           0.7307 0.7773 0.8542 0.6158
## test    0.8397      0.5115      0.9485           0.7300 0.7670 0.8542 0.6137
##        Kappa
## train 0.5214
## test  0.5177

Wyniki drzewa domyślnego są bardzo podobne do wyników modelu logitowego przy standardowym progu odcięcia 0.5. Oznacza to, że proste drzewo jest w stanie uchwycić podstawowe zależności w danych, jednak podobnie jak logit przy progu 0.5 może mieć problem z odpowiednio dobrym wykrywaniem klasy mniejszościowej.

Optymalny próg odcięcia dla drzewa domyślnego

Podobnie jak w przypadku modelu logitowego, dla drzewa decyzyjnego wyznaczono optymalny próg odcięcia na podstawie krzywej ROC i metody Youdena.

krzywa_roc_tree <- pROC::roc(
  response = salary.train.ML$klasa_dochodu_01,
  predictor = pred.tree.default.train.prob
)

optymalny_prog_tree <- pROC::coords(
  krzywa_roc_tree,
  x = "best",
  best.method = "youden",
  ret = c("threshold", "specificity", "sensitivity")
)

prog_tree <- optymalny_prog_tree$threshold[1]

print(paste("Optymalny próg dla drzewa to:", round(prog_tree, 4)))
## [1] "Optymalny próg dla drzewa to: 0.1798"

Optymalny próg dla drzewa domyślnego wyniósł 0,1798. Jest to wartość niższa od standardowego progu 0.5, co wynika z niezbalansowania klas. Obniżenie progu zwiększa skłonność modelu do klasyfikowania obserwacji jako tak, a więc poprawia wykrywanie osób zarabiających powyżej 50 tys. USD.

metryki.tree.optimal.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.default.train.prob,
  real = salary.train.ML$klasa_dochodu_01,
  cutoff = prog_tree,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.tree.optimal.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.default.test.prob,
  real = salary.test.ML$klasa_dochodu_01,
  cutoff = prog_tree,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.tree.optimal.train,
  test = metryki.tree.optimal.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV   NPV     F1
## train   0.7238      0.8916      0.6682           0.7799 0.4710 0.949 0.6164
## test    0.7216      0.8779      0.6698           0.7738 0.4684 0.943 0.6108
##        Kappa
## train 0.4311
## test  0.4238

Porównanie prostego drzewa z modelem logitowym

Przeprowadzona analiza klasyfikacyjna na zbiorze danych Adult Census Income pozwoliła na porównanie dwóch odmiennych podejść algorytmicznych: tradycyjnego modelu regresji logistycznej oraz nieparametrycznego drzewa decyzyjnego CART. Ze względu na silne niezbalansowanie klas, w proporcji około 75% do 25%, kluczowym elementem podnoszącym jakość obu modeli była rezygnacja z domyślnego progu klasyfikacji 0.5 na rzecz progów optymalizowanych metodą Youdena. Progi te wyniosły około 0,2346 dla logitu oraz około 0,18 dla drzewa.

W ostatecznym porównaniu na zbiorze testowym model regresji logistycznej okazał się silniejszy ze względu na zbalansowaną dokładność. Model logitowy osiągnął Balanced Accuracy około 81,9%, podczas gdy proste drzewo uzyskało około 72,16%.

Jednocześnie drzewo decyzyjne osiągnęło bardzo wysoką czułość. Czułość modelu logitowego wyniosła około 86,8%, natomiast dla drzewa było to około 87,79%. Oznacza to, że jeżeli najważniejszym celem byłoby wykrywanie klasy pozytywnej, czyli osób zarabiających powyżej 50 tys. USD, drzewo może być atrakcyjnym rozwiązaniem.

Warto jednak podkreślić brak wyraźnego zjawiska przeuczenia w obu modelach, ponieważ metryki na zbiorach treningowych i testowych są do siebie zbliżone. Oznacza to, że modele zachowują zdolność generalizacji na nowych danych. Logit osiąga lepszą skuteczność predykcyjną w sensie zbalansowanej dokładności, natomiast drzewo decyzyjne oferuje większą interpretowalność i możliwość prześledzenia ścieżki decyzyjnej.

Drzewo klasyfikacyjne z kontrolą parametrów

Domyślne drzewo okazało się dość proste, dlatego w kolejnym kroku zbudowano model z kontrolą parametrów wzrostu drzewa. Celem było sprawdzenie, czy odpowiednie strojenie hiperparametrów pozwoli poprawić jakość klasyfikacji.

W modelach drzewiastych szczególne znaczenie mają następujące parametry:

  • cp, czyli complexity parameter, określający minimalną poprawę jakości modelu wymaganą do dodania kolejnego podziału,
  • minsplit, czyli minimalna liczba obserwacji w węźle potrzebna do rozważenia dalszego podziału,
  • minbucket, czyli minimalna liczba obserwacji w liściu końcowym,
  • maxdepth, czyli maksymalna głębokość drzewa,
  • xval, czyli liczba części wykorzystywanych w wewnętrznej walidacji krzyżowej.

W niniejszej pracy zastosowano kontrolę złożoności drzewa. Parametr cp dobrano automatycznie przy użyciu walidacji krzyżowej.

salary.train.caret <- salary.train.ML
salary.test.caret <- salary.test.ML

salary.train.caret$klasa_dochodu_01 <- factor(
  salary.train.caret$klasa_dochodu_01,
  labels = c("nie", "tak")
)

salary.test.caret$klasa_dochodu_01 <- factor(
  salary.test.caret$klasa_dochodu_01,
  labels = c("nie", "tak")
)

ctrl <- trainControl(
  method = "cv",
  number = 5,
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  savePredictions = "final"
)

grid.cp <- expand.grid(
  cp = seq(0.0000, 0.001, length = 50)
)

Strojenie drzewa decyzyjnego

Następnie przeprowadzono strojenie drzewa decyzyjnego. Jako główną metrykę optymalizacji przyjęto ROC, ponieważ w przypadku niezbalansowanych danych jest ona bardziej informacyjna niż sama ogólna trafność klasyfikacji.

set.seed(123)

tree.salary.tuned <- train(
  klasa_dochodu_01 ~ .,
  data = salary.train.caret,
  method = "rpart",
  trControl = ctrl,
  tuneGrid = grid.cp,
  metric = "ROC",
  control = rpart.control(
    minsplit = 30,
    minbucket = 10,
    maxdepth = 10
  )
)

tree.salary.tuned$bestTune
##              cp
## 3 0.00004081633
plot(tree.salary.tuned)

W celu maksymalizacji skuteczności drzewa decyzyjnego przeprowadzono proces strojenia hiperparametrów metodą walidacji krzyżowej. Szukano optymalnej wartości parametru złożoności cp, który stanowi karę za nadmierny rozrost drzewa. Im wyższy parametr cp, tym silniej drzewo jest przycinane.

Z analizy wynika, że najwyższą skuteczność rozróżniania klas, mierzoną wartością ROC na poziomie około 0,89, osiągnięto dla bardzo niskiej wartości parametru: cp = 0.00004081633. Zwiększanie tej wartości, czyli silniejsze obcinanie gałęzi, prowadziło do widocznego spadku wydajności modelu. Oznacza to, że struktura danych jest dość złożona i wymaga zbudowania bardziej rozbudowanego drzewa, aby precyzyjnie uchwycić wzorce decydujące o wysokich zarobkach.

Wizualizacja dostrojonego drzewa

Jednak wizualizacja w tym przypadku nie przynosi tak naprawdę żadnych rezultatów. Obraz jest po prostu niewidoczny.

tree.salary.tuned$finalModel
## n= 21114 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##    1) root 21114 5256 nie (0.751065644 0.248934356)  
##      2) relacja_w_rodziniemąż< 0.5 12368 1269 nie (0.897396507 0.102603493)  
##        4) zysk_kapitalowy< 7073.5 12087  994 nie (0.917762886 0.082237114)  
##          8) relacja_w_rodzinieżona< 0.5 11166  570 nie (0.948952176 0.051047824)  
##           16) godziny_razy_edukacja< 558.5 9823  285 nie (0.970986460 0.029013540)  
##             32) strata_kapitalowa< 2218.5 9779  263 nie (0.973105635 0.026894365) *
##             33) strata_kapitalowa>=2218.5 44   22 nie (0.500000000 0.500000000)  
##               66) godziny_razy_edukacja< 370 21    6 nie (0.714285714 0.285714286) *
##               67) godziny_razy_edukacja>=370 23    7 tak (0.304347826 0.695652174) *
##           17) godziny_razy_edukacja>=558.5 1343  285 nie (0.787788533 0.212211467)  
##             34) wiek< 27.5 262   10 nie (0.961832061 0.038167939) *
##             35) wiek>=27.5 1081  275 nie (0.745605920 0.254394080)  
##               70) strata_kapitalowa< 2351 1068  262 nie (0.754681648 0.245318352)  
##                140) zawodkadra_kierownicza< 0.5 835  170 nie (0.796407186 0.203592814)  
##                  280) poziom_edukacji_factorszkoła_profesjonalna< 0.5 766  140 nie (0.817232376 0.182767624) *
##                  281) poziom_edukacji_factorszkoła_profesjonalna>=0.5 69   30 nie (0.565217391 0.434782609)  
##                    562) wiek< 32.5 20    3 nie (0.850000000 0.150000000) *
##                    563) wiek>=32.5 49   22 tak (0.448979592 0.551020408)  
##                     1126) godziny_pracy_tygodniowo< 44 23    8 nie (0.652173913 0.347826087) *
##                     1127) godziny_pracy_tygodniowo>=44 26    7 tak (0.269230769 0.730769231) *
##                141) zawodkadra_kierownicza>=0.5 233   92 nie (0.605150215 0.394849785)  
##                  282) poziom_edukacji_factormagister< 0.5 167   57 nie (0.658682635 0.341317365)  
##                    564) wiek< 39.5 87   22 nie (0.747126437 0.252873563) *
##                    565) wiek>=39.5 80   35 nie (0.562500000 0.437500000)  
##                     1130) godziny_pracy_tygodniowo>=59 31    6 nie (0.806451613 0.193548387) *
##                     1131) godziny_pracy_tygodniowo< 59 49   20 tak (0.408163265 0.591836735) *
##                  283) poziom_edukacji_factormagister>=0.5 66   31 tak (0.469696970 0.530303030)  
##                    566) typ_zatrudnieniasektor_prywatny< 0.5 26    9 nie (0.653846154 0.346153846) *
##                    567) typ_zatrudnieniasektor_prywatny>=0.5 40   14 tak (0.350000000 0.650000000)  
##                     1134) wiek< 36 12    4 nie (0.666666667 0.333333333) *
##                     1135) wiek>=36 28    6 tak (0.214285714 0.785714286) *
##               71) strata_kapitalowa>=2351 13    0 tak (0.000000000 1.000000000) *
##          9) relacja_w_rodzinieżona>=0.5 921  424 nie (0.539630836 0.460369164)  
##           18) godziny_razy_edukacja< 374 455  151 nie (0.668131868 0.331868132)  
##             36) strata_kapitalowa< 1794 438  136 nie (0.689497717 0.310502283)  
##               72) wiek< 34.5 139   23 nie (0.834532374 0.165467626) *
##               73) wiek>=34.5 299  113 nie (0.622073579 0.377926421)  
##                146) zawodinne_usługi>=0.5 61    9 nie (0.852459016 0.147540984) *
##                147) zawodinne_usługi< 0.5 238  104 nie (0.563025210 0.436974790)  
##                  294) wiek>=45.5 126   41 nie (0.674603175 0.325396825)  
##                    588) zawodkadra_kierownicza< 0.5 110   30 nie (0.727272727 0.272727273) *
##                    589) zawodkadra_kierownicza>=0.5 16    5 tak (0.312500000 0.687500000) *
##                  295) wiek< 45.5 112   49 tak (0.437500000 0.562500000)  
##                    590) zawodoperatorzy_maszyn_i_inspektorzy>=0.5 13    4 nie (0.692307692 0.307692308) *
##                    591) zawodoperatorzy_maszyn_i_inspektorzy< 0.5 99   40 tak (0.404040404 0.595959596)  
##                     1182) zawodspecjaliści>=0.5 21    9 nie (0.571428571 0.428571429) *
##                     1183) zawodspecjaliści< 0.5 78   28 tak (0.358974359 0.641025641) *
##             37) strata_kapitalowa>=1794 17    2 tak (0.117647059 0.882352941) *
##           19) godziny_razy_edukacja>=374 466  193 tak (0.414163090 0.585836910)  
##             38) zawodspecjaliści< 0.5 331  161 tak (0.486404834 0.513595166)  
##               76) zawodkadra_kierownicza< 0.5 222   94 nie (0.576576577 0.423423423)  
##                152) godziny_pracy_tygodniowo>=43 64   14 nie (0.781250000 0.218750000) *
##                153) godziny_pracy_tygodniowo< 43 158   78 tak (0.493670886 0.506329114)  
##                  306) wiek< 33.5 61   22 nie (0.639344262 0.360655738)  
##                    612) typ_zatrudnieniasektor_prywatny>=0.5 51   16 nie (0.686274510 0.313725490)  
##                     1224) godziny_razy_edukacja< 500 40   10 nie (0.750000000 0.250000000) *
##                     1225) godziny_razy_edukacja>=500 11    5 tak (0.454545455 0.545454545) *
##                    613) typ_zatrudnieniasektor_prywatny< 0.5 10    4 tak (0.400000000 0.600000000) *
##                  307) wiek>=33.5 97   39 tak (0.402061856 0.597938144)  
##                    614) rasabiała< 0.5 14    6 nie (0.571428571 0.428571429) *
##                    615) rasabiała>=0.5 83   31 tak (0.373493976 0.626506024) *
##               77) zawodkadra_kierownicza>=0.5 109   33 tak (0.302752294 0.697247706)  
##                154) wiek_sredniw_grupie_wieku_średniego< 0.5 30   13 nie (0.566666667 0.433333333)  
##                  308) poziom_edukacji_factorlicencjat< 0.5 14    4 nie (0.714285714 0.285714286) *
##                  309) poziom_edukacji_factorlicencjat>=0.5 16    7 tak (0.437500000 0.562500000) *
##                155) wiek_sredniw_grupie_wieku_średniego>=0.5 79   16 tak (0.202531646 0.797468354) *
##             39) zawodspecjaliści>=0.5 135   32 tak (0.237037037 0.762962963) *
##        5) zysk_kapitalowy>=7073.5 281    6 tak (0.021352313 0.978647687) *
##      3) relacja_w_rodziniemąż>=0.5 8746 3987 nie (0.544134461 0.455865539)  
##        6) godziny_razy_edukacja< 488 5152 1587 nie (0.691964286 0.308035714)  
##         12) zysk_kapitalowy< 5095.5 4916 1356 nie (0.724165989 0.275834011)  
##           24) godziny_razy_edukacja< 307 1134  122 nie (0.892416226 0.107583774)  
##             48) strata_kapitalowa< 1780 1104  107 nie (0.903079710 0.096920290) *
##             49) strata_kapitalowa>=1780 30   15 nie (0.500000000 0.500000000)  
##               98) kapital_netto< -1989.5 17    4 nie (0.764705882 0.235294118) *
##               99) kapital_netto>=-1989.5 13    2 tak (0.153846154 0.846153846) *
##           25) godziny_razy_edukacja>=307 3782 1234 nie (0.673717610 0.326282390)  
##             50) wiek< 35.5 1234  237 nie (0.807941653 0.192058347)  
##              100) wiek< 26.5 272   21 nie (0.922794118 0.077205882) *
##              101) wiek>=26.5 962  216 nie (0.775467775 0.224532225)  
##                202) strata_kapitalowa< 1794 935  199 nie (0.787165775 0.212834225)  
##                  404) poziom_edukacji_factorszkoła_średnia>=0.5 559   93 nie (0.833631485 0.166368515) *
##                  405) poziom_edukacji_factorszkoła_średnia< 0.5 376  106 nie (0.718085106 0.281914894)  
##                    810) zawodspecjaliści< 0.5 353   93 nie (0.736543909 0.263456091) *
##                    811) zawodspecjaliści>=0.5 23   10 tak (0.434782609 0.565217391) *
##                203) strata_kapitalowa>=1794 27   10 tak (0.370370370 0.629629630) *
##             51) wiek>=35.5 2548  997 nie (0.608712716 0.391287284)  
##              102) strata_kapitalowa< 1794 2435  911 nie (0.625872690 0.374127310)  
##                204) godziny_razy_edukacja< 376.5 1117  345 nie (0.691136974 0.308863026)  
##                  408) zawodkadra_kierownicza< 0.5 1017  297 nie (0.707964602 0.292035398)  
##                    816) wiek< 45.5 415  100 nie (0.759036145 0.240963855) *
##                    817) wiek>=45.5 602  197 nie (0.672757475 0.327242525)  
##                     1634) zawodspecjaliści< 0.5 579  183 nie (0.683937824 0.316062176) *
##                     1635) zawodspecjaliści>=0.5 23    9 tak (0.391304348 0.608695652) *
##                  409) zawodkadra_kierownicza>=0.5 100   48 nie (0.520000000 0.480000000)  
##                    818) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej>=0.5 19    3 nie (0.842105263 0.157894737) *
##                    819) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej< 0.5 81   36 tak (0.444444444 0.555555556)  
##                     1638) wiek< 59.5 65   32 nie (0.507692308 0.492307692) *
##                     1639) wiek>=59.5 16    3 tak (0.187500000 0.812500000) *
##                205) godziny_razy_edukacja>=376.5 1318  566 nie (0.570561457 0.429438543) *
##              103) strata_kapitalowa>=1794 113   27 tak (0.238938053 0.761061947)  
##                206) strata_kapitalowa>=1989.5 34   10 nie (0.705882353 0.294117647) *
##                207) strata_kapitalowa< 1989.5 79    3 tak (0.037974684 0.962025316) *
##         13) zysk_kapitalowy>=5095.5 236    5 tak (0.021186441 0.978813559) *
##        7) godziny_razy_edukacja>=488 3594 1194 tak (0.332220367 0.667779633)  
##         14) zysk_kapitalowy< 5095.5 3115 1193 tak (0.382985554 0.617014446)  
##           28) kapital_netto>=-1846 2817 1177 tak (0.417820376 0.582179624)  
##             56) poziom_edukacji_factorszkoła_średnia>=0.5 357  127 nie (0.644257703 0.355742297)  
##              112) wiek< 31.5 66   11 nie (0.833333333 0.166666667) *
##              113) wiek>=31.5 291  116 nie (0.601374570 0.398625430)  
##                226) zysk_kapitalowy>=3120 11    0 nie (1.000000000 0.000000000) *
##                227) zysk_kapitalowy< 3120 280  116 nie (0.585714286 0.414285714)  
##                  454) zawodrolnictwo_i_rybołówstwo>=0.5 40    9 nie (0.775000000 0.225000000) *
##                  455) zawodrolnictwo_i_rybołówstwo< 0.5 240  107 nie (0.554166667 0.445833333)  
##                    910) zawodsprzedaż< 0.5 190   78 nie (0.589473684 0.410526316) *
##                    911) zawodsprzedaż>=0.5 50   21 tak (0.420000000 0.580000000)  
##                     1822) wiek< 45.5 24   11 nie (0.541666667 0.458333333) *
##                     1823) wiek>=45.5 26    8 tak (0.307692308 0.692307692) *
##             57) poziom_edukacji_factorszkoła_średnia< 0.5 2460  947 tak (0.384959350 0.615040650)  
##              114) wiek< 33.5 507  229 nie (0.548323471 0.451676529)  
##                228) wiek< 29.5 209   74 nie (0.645933014 0.354066986) *
##                229) wiek>=29.5 298  143 tak (0.479865772 0.520134228)  
##                  458) godziny_razy_edukacja< 862.5 278  137 nie (0.507194245 0.492805755)  
##                    916) zawodkadra_kierownicza< 0.5 218   94 nie (0.568807339 0.431192661)  
##                     1832) godziny_pracy_tygodniowo>=45.5 113   38 nie (0.663716814 0.336283186) *
##                     1833) godziny_pracy_tygodniowo< 45.5 105   49 tak (0.466666667 0.533333333) *
##                    917) zawodkadra_kierownicza>=0.5 60   17 tak (0.283333333 0.716666667) *
##                  459) godziny_razy_edukacja>=862.5 20    2 tak (0.100000000 0.900000000) *
##              115) wiek>=33.5 1953  669 tak (0.342549923 0.657450077)  
##                230) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej>=0.5 263  117 nie (0.555133080 0.444866920)  
##                  460) poziom_edukacji_factorszkoła_profesjonalna< 0.5 228   88 nie (0.614035088 0.385964912)  
##                    920) zawodrolnictwo_i_rybołówstwo>=0.5 45    7 nie (0.844444444 0.155555556) *
##                    921) zawodrolnictwo_i_rybołówstwo< 0.5 183   81 nie (0.557377049 0.442622951)  
##                     1842) godziny_pracy_tygodniowo>=49 126   48 nie (0.619047619 0.380952381) *
##                     1843) godziny_pracy_tygodniowo< 49 57   24 tak (0.421052632 0.578947368) *
##                  461) poziom_edukacji_factorszkoła_profesjonalna>=0.5 35    6 tak (0.171428571 0.828571429) *
##                231) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej< 0.5 1690  523 tak (0.309467456 0.690532544) *
##           29) kapital_netto< -1846 298   16 tak (0.053691275 0.946308725)  
##             58) strata_kapitalowa>=1989.5 51   15 tak (0.294117647 0.705882353)  
##              116) strata_kapitalowa< 2384.5 20    5 nie (0.750000000 0.250000000) *
##              117) strata_kapitalowa>=2384.5 31    0 tak (0.000000000 1.000000000) *
##             59) strata_kapitalowa< 1989.5 247    1 tak (0.004048583 0.995951417) *
##         15) zysk_kapitalowy>=5095.5 479    1 tak (0.002087683 0.997912317) *
rpart.plot(
  tree.salary.tuned$finalModel,
  type = 2,
  extra = 104,
  fallen.leaves = TRUE,
  nn = FALSE,
  box.palette = "Blues",
  shadow.col = 0,
  cex = 0.8
)

Predykcja dla dostrojonego drzewa

pred.tree.tuned.train.prob <- predict(
  tree.salary.tuned,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.tree.tuned.test.prob <- predict(
  tree.salary.tuned,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

Ocena dostrojonego drzewa przy progu 0.5

metryki.tree.tuned.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.tuned.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.tree.tuned.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.tuned.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.tree.tuned.train,
  test = metryki.tree.tuned.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.8616      0.6056      0.9465           0.7760 0.7894 0.8786 0.6854
## test    0.8457      0.5799      0.9338           0.7569 0.7437 0.8703 0.6517
##        Kappa
## train 0.5987
## test  0.5545

Optymalny cutoff dla dostrojonego drzewa

Dla dostrojonego drzewa ponownie wyznaczono optymalny próg klasyfikacji. Tym razem zastosowano funkcję cutpointr, wykorzystując kryterium roc01.

cut.tree.tuned <- cutpointr(
  data = data.frame(
    pred = pred.tree.tuned.train.prob,
    real = salary.train.caret$klasa_dochodu_01
  ),
  x = pred,
  class = real,
  pos_class = "tak",
  neg_class = "nie",
  direction = ">=",
  method = minimize_metric,
  metric = roc01
)

cutoff.tree.tuned <- cut.tree.tuned$optimal_cutpoint
cutoff.tree.tuned
## [1] 0.2634561

Optymalny cutoff dla dostrojonego drzewa wyniósł około 0,26.

metryki.tree.tuned.train.opt <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.tuned.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.tree.tuned,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.tree.tuned.test.opt <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.tuned.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.tree.tuned,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.tree.tuned.train.opt,
  test = metryki.tree.tuned.test.opt
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.8219      0.8364      0.8171           0.8267 0.6024 0.9378 0.7004
## test    0.8147      0.8099      0.8162           0.8131 0.5936 0.9284 0.6851
##        Kappa
## train 0.5784
## test  0.5581

Porównanie cutoffów dla dostrojonego drzewa

rbind(
  train.0.5 = metryki.tree.tuned.train,
  test.0.5 = metryki.tree.tuned.test,
  train.opt = metryki.tree.tuned.train.opt,
  test.opt = metryki.tree.tuned.test.opt
)
##           Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV
## train.0.5   0.8616      0.6056      0.9465           0.7760 0.7894 0.8786
## test.0.5    0.8457      0.5799      0.9338           0.7569 0.7437 0.8703
## train.opt   0.8219      0.8364      0.8171           0.8267 0.6024 0.9378
## test.opt    0.8147      0.8099      0.8162           0.8131 0.5936 0.9284
##               F1  Kappa
## train.0.5 0.6854 0.5987
## test.0.5  0.6517 0.5545
## train.opt 0.7004 0.5784
## test.opt  0.6851 0.5581

Zastosowanie optymalnego progu odcięcia poprawiło wartość Balanced Accuracy w porównaniu z progiem 0.5. Jednak największa zmiana widoczna jest dla specifity.

Porównanie modeli logitowych i drzew klasyfikacyjnych

porownanie.control.test <- rbind(
  logit.0.5 = metryki.glm.test.standard,
  logit.opt = metryki.glm.test,
  tree.default.0.5 = metryki.tree.default.test,
  tree.default.opt = metryki.tree.optimal.test,
  tree.tuned.0.5 = metryki.tree.tuned.test,
  tree.tuned.opt = metryki.tree.tuned.test.opt
)

porownanie.control.test
##                  Accuracy Sensitivity Specificity BalancedAccuracy    PPV
## logit.0.5          0.8443      0.5992      0.9255           0.7624 0.7274
## logit.opt          0.7940      0.8686      0.7692           0.8189 0.5552
## tree.default.0.5   0.8397      0.5115      0.9485           0.7300 0.7670
## tree.default.opt   0.7216      0.8779      0.6698           0.7738 0.4684
## tree.tuned.0.5     0.8457      0.5799      0.9338           0.7569 0.7437
## tree.tuned.opt     0.8147      0.8099      0.8162           0.8131 0.5936
##                     NPV     F1  Kappa
## logit.0.5        0.8744 0.6571 0.5576
## logit.opt        0.9464 0.6774 0.5366
## tree.default.0.5 0.8542 0.6137 0.5177
## tree.default.opt 0.9430 0.6108 0.4238
## tree.tuned.0.5   0.8703 0.6517 0.5545
## tree.tuned.opt   0.9284 0.6851 0.5581

Porównanie modeli pokazuje, że domyślny próg 0.5 może tworzyć złudzenie wysokiej skuteczności. Taki model osiąga wysoką ogólną trafność, ale głównie dlatego, że dobrze rozpoznaje klasę większościową. Jednocześnie może ignorować znaczną część osób zarabiających powyżej 50 tys. USD, czyli klasę mniejszościową.

Przejście na próg optymalny, mieszczący się w przybliżeniu w przedziale 0,18–0,26, pozwala znacząco zwiększyć czułość modelu. Oznacza to lepszą wykrywalność osób z wyższym dochodem, przy akceptowalnym spadku specyficzności. W efekcie rośnie również Balanced Accuracy.

Strojenie drzewa poprzez dobór parametru cp sprawiło, że model stał się stabilniejszy. Po tuningu czułość i specyficzność wyrównały się na poziomie około 80%, co czyni model bardziej rzetelnym.

Ostatecznie zawsze powinniśmy wybrac model w zależności od decyzji biznesowej czy najbardziej zależy nam na wykrywaniu grupy pozytynwej czy negatywnej a może ogólnej Accuracy.

Porównanie modeli za pomocą krzywej ROC

W kolejnym kroku porównano modele za pomocą krzywych ROC oraz indeksu Giniego. Dla każdego modelu wyznaczono przewidywane prawdopodobieństwa na zbiorze testowym, a następnie obliczono odpowiadające im krzywe ROC.

pred.test.logit <- predict(
  glm.full.salary,
  newdata = salary.test,
  type = "response"
)

pred.test.tree.default <- predict(
  tree.salary.default,
  newdata = salary.test.ML,
  type = "prob"
)[, "tak"]

pred.test.tree.tuned <- predict(
  tree.salary.tuned,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

ROC.test.logit <- pROC::roc(
  as.numeric(salary.test$klasa_dochodu_01) - 1,
  pred.test.logit
)

ROC.test.tree.default <- pROC::roc(
  salary.test.ML$klasa_dochodu_01,
  pred.test.tree.default
)

ROC.test.tree.tuned <- pROC::roc(
  salary.test.caret$klasa_dochodu_01,
  pred.test.tree.tuned
)

list(
  logit = ROC.test.logit,
  tree.default = ROC.test.tree.default,
  tree.tuned = ROC.test.tree.tuned
) %>%
  ggroc(alpha = 0.6, linewidth = 1) + 
  geom_segment(
    aes(x = 1, xend = 0, y = 0, yend = 1),
    color = "grey",
    linetype = "dashed"
  ) +
  labs(
    title = "ROC – zbiór testowy",
    subtitle = paste0(
      "Gini TEST: ",
      "logit = ", round(100 * (2 * pROC::auc(ROC.test.logit) - 1), 1), "%, ",
      "default = ", round(100 * (2 * pROC::auc(ROC.test.tree.default) - 1), 1), "%, ",
      "tuned = ", round(100 * (2 * pROC::auc(ROC.test.tree.tuned) - 1), 1), "%"
    )
  ) +
  theme_bw() +
  coord_fixed()

Na podstawie wartości indeksu Giniego można zauważyć, że model logitowy osiągnął bardzo wysoką zdolność rozróżniania klas, równą około 80,1%. Drzewo domyślne uzyskało wynik około 67,6%, natomiast drzewo dostrojone około 77,8%. Oznacza to, że w tym porównaniu logit charakteryzuje się najlepszą zdolnością separacji klas na zbiorze testowym.

Downsampling

Ostatnim krokiem w tej części analizy było sprawdzenie, czy zrównoważenie klas w zbiorze treningowym poprawi jakość drzewa klasyfikacyjnego. W tym celu zastosowano technikę downsamplingu, która polega na zmniejszeniu liczebności klasy większościowej tak, aby odpowiadała liczebności klasy mniejszościowej.

set.seed(123)

salary.train.down <- downSample(
  x = salary.train.caret %>% dplyr::select(-klasa_dochodu_01),
  y = salary.train.caret$klasa_dochodu_01,
  yname = "klasa_dochodu_01"
)

table(salary.train.down$klasa_dochodu_01)
## 
##  nie  tak 
## 5256 5256
prop.table(table(salary.train.down$klasa_dochodu_01))
## 
## nie tak 
## 0.5 0.5

Po zastosowaniu downsamplingu klasy w zbiorze treningowym są zbalansowane. Następnie na tak przygotowanych danych ponownie oszacowano drzewo decyzyjne.

tree.downsampled <- train(
  klasa_dochodu_01 ~ .,
  data = salary.train.down,
  method = "rpart",
  trControl = ctrl,
  tuneGrid = grid.cp,
  metric = "ROC",
  control = rpart.control(
    minsplit = 30,
    minbucket = 10,
    maxdepth = 10
  )
)

rpart.plot(
  tree.downsampled$finalModel,
  type = 2,
  extra = 104,
  fallen.leaves = TRUE,
  nn = FALSE,
  box.palette = "Blues",
  shadow.col = 0,
  cex = 0.8
)

Predykcja i ocena modelu po downsamplingu

pred.tree.down.train.prob <- predict(
  tree.downsampled,
  newdata = salary.train.down,
  type = "prob"
)[, "tak"]

pred.tree.down.test.prob <- predict(
  tree.downsampled,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

metryki.tree.down.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.down.train.prob,
  real = salary.train.down$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.tree.down.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.tree.down.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = 0.5,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.tree.down.train,
  test = metryki.tree.down.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.8394      0.8824      0.7964           0.8394 0.8125 0.8714 0.8460
## test    0.7928      0.8535      0.7727           0.8131 0.5544 0.9409 0.6721
##        Kappa
## train 0.6788
## test  0.5304

Zastosowanie techniki downsamplingu poprawiło czułość modelu, czyli jego zdolność do wykrywania klasy pozytywnej. Jednocześnie spowodowało spadek zdolności generalizacyjnych, co objawia się rozbieżnością wyników między zbiorem treningowym a testowym. Można więc mówić o lekkim przeuczeniu modelu.

Dlatego też następnym krokiem analizy było skupienie się na Baggingu oraz Random Forescie.

Wejdźmy do lasu: Bagging i RF

W kolejnej części analizy wykorzystano modele zespołowe oparte na drzewach decyzyjnych: bagging oraz Random Forest. Pojedyncze drzewo decyzyjne jest łatwe do interpretacji, ale może być niestabilne — niewielka zmiana danych może prowadzić do powstania zupełnie innej struktury drzewa.

Bagging, czyli Bootstrap Aggregating, polega na losowaniu wielu prób bootstrapowych ze zbioru treningowego, budowaniu osobnego drzewa na każdej z nich, a następnie łączeniu predykcji wielu modeli. W przypadku klasyfikacji końcowa predykcja może być wyznaczana przez głosowanie większościowe albo przez uśrednienie przewidywanych prawdopodobieństw.

Random Forest jest rozwinięciem baggingu. Również wykorzystuje wiele drzew budowanych na próbach bootstrapowych, ale dodatkowo przy każdym podziale drzewa bierze pod uwagę tylko losowy podzbiór zmiennych. Dzięki temu drzewa są bardziej zróżnicowane, co zwykle poprawia stabilność i jakość predykcji.

Kluczowa różnica między tymi metodami polega więc na tym, że w baggingu każde drzewo przy każdym podziale może korzystać ze wszystkich zmiennych, natomiast w Random Forest dostępny jest tylko losowy podzbiór predyktorów.

Bagging

W pierwszym kroku zbudowano model baggingowy. W tym celu w funkcji randomForest() ustawiono parametr mtry równy liczbie wszystkich zmiennych objaśniających. Oznacza to, że każde drzewo przy każdym podziale mogło korzystać ze wszystkich dostępnych predyktorów.

p.salary <- ncol(salary.train.caret) - 1
p.salary
## [1] 17
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.

# set.seed(123)
# salary.bag <- randomForest(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   mtry = p.salary,
#   importance = TRUE,
#   ntree = 500
# )

#salary.bag %>% saveRDS("trenowane_modele/salary.bagging.rds")

salary.bag <- readRDS("trenowane_modele/salary.bagging.rds")

Następnie przeanalizowano błąd OOB, czyli Out-of-Bag. Jest to wewnętrzna miara jakości modelu, liczona na obserwacjach, które nie zostały wykorzystane do budowy danego drzewa.

print(salary.bag)
## 
## Call:
##  randomForest(formula = klasa_dochodu_01 ~ ., data = salary.train.caret,      mtry = p.salary, importance = TRUE, ntree = 500) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 17
## 
##         OOB estimate of  error rate: 16.26%
## Confusion matrix:
##       nie  tak class.error
## nie 14339 1519  0.09578762
## tak  1915 3341  0.36434551
plot(salary.bag)

Wykres przedstawia przebieg błędu OOB w funkcji liczby drzew. Widoczna jest szybka stabilizacja błędu już po osiągnięciu około 100 drzew, co potwierdza zasadność przyjęcia parametru ntree = 500. Jednocześnie wyraźna różnica między błędem dla klas wskazuje, że model nadal ma większą trudność w klasyfikowaniu klasy tak, czyli osób zarabiających powyżej 50 tys. USD.

Predykcja i ocena modelu baggingowego

W kolejnym kroku wyznaczono przewidywane prawdopodobieństwa dla zbioru treningowego i testowego. Ponieważ model został zbudowany przy użyciu funkcji randomForest(), do uzyskania prawdopodobieństw zastosowano argument type = "prob".

pred.salary.bag.train.prob <- predict(
  salary.bag,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.salary.bag.test.prob <- predict(
  salary.bag,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

cutoff.bag <- prop.table(table(salary.train.caret$klasa_dochodu_01))["tak"]
cutoff.bag
##       tak 
## 0.2489344

Jako próg odcięcia przyjęto udział klasy tak w zbiorze treningowym. Jest to rozwiązanie dopasowane do niezbalansowanych danych, ponieważ próg ten znajduje się bliżej rzeczywistego udziału klasy pozytywnej niż standardowe 0.5.

metryki.salary.bag.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.bag.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.salary.bag.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.bag.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.salary.bag.train,
  test = metryki.salary.bag.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.9517      0.9812      0.9420           0.9616 0.8486 0.9934 0.9101
## test    0.8015      0.8215      0.7949           0.8082 0.5703 0.9307 0.6732
##        Kappa
## train 0.8773
## test  0.5373

Wyniki wskazują na widoczne przeuczenie modelu baggingowego. Model bardzo dobrze przewiduje wyniki dla danych treningowych, ale osiąga wyraźnie słabsze rezultaty na zbiorze testowym. Oznacza to, że konieczne jest zastosowanie bardziej kontrolowanego modelu, dlatego w kolejnym kroku przeanalizowano Random Forest.

Random Forest

Random Forest bez strojenia

Random Forest jest rozwinięciem baggingu. W klasyfikacji domyślna wartość parametru mtry jest zwykle zbliżona do pierwiastka z liczby zmiennych objaśniających. Dzięki temu każde drzewo korzysta z innego losowego podzbioru zmiennych, co zwiększa różnorodność drzew i może ograniczać przeuczenie.

# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.

# set.seed(123)
# salary.rf <- randomForest(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   importance = TRUE
# )

# saveRDS(
#   salary.rf,
#   "trenowane_modele/salary.rf.rds"
# )

salary.rf <- readRDS("trenowane_modele/salary.rf.rds")
salary.rf
## 
## Call:
##  randomForest(formula = klasa_dochodu_01 ~ ., data = salary.train.caret,      importance = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 4
## 
##         OOB estimate of  error rate: 14.12%
## Confusion matrix:
##       nie  tak class.error
## nie 14694 1164  0.07340144
## tak  1817 3439  0.34570015
plot(salary.rf)

salary.rf$err.rate[10, ]
##        OOB        nie        tak 
## 0.16334147 0.09610853 0.36634615
salary.rf$err.rate[50, ]
##        OOB        nie        tak 
## 0.14630103 0.08084248 0.34379756
salary.rf$err.rate[100, ]
##        OOB        nie        tak 
## 0.14298570 0.07554547 0.34646119
plot(salary.rf)
abline(h = salary.rf$err.rate[50, "OOB"], col = "red")

Na wykresie błędu OOB widać, że po przekroczeniu około 50 drzew dalsze zwiększanie liczby drzew daje już niewielką poprawę. Oznacza to, że model dość szybko stabilizuje swoje działanie.

Predykcja i ocena Random Forest bez strojenia

pred.salary.rf.train.prob <- predict(
  salary.rf,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.salary.rf.test.prob <- predict(
  salary.rf,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]
metryki.salary.rf.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.rf.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.salary.rf.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.rf.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.salary.rf.train,
  test = metryki.salary.rf.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.9299      0.9553      0.9215           0.9384 0.8013 0.9842 0.8716
## test    0.8196      0.8139      0.8215           0.8177 0.6018 0.9302 0.6920
##        Kappa
## train 0.8239
## test  0.5685

Wyniki wskazują, że w modelu Random Forest bez strojenia nadal widoczne jest przeuczenie. Accuracy wynosi około 0,93 dla zbioru treningowego oraz około 0,82 dla zbioru testowego. Podobną różnicę widać dla czułości, która wynosi około 0,95 na zbiorze treningowym oraz około 0,81 na zbiorze testowym. Z tego powodu w kolejnym kroku zastosowano strojenie hiperparametrów.

Strojenie Random Forest

W przypadku klasyfikacji jako główną metrykę strojenia wykorzystano ROC. Jest to szczególnie ważne przy niezbalansowanych danych, ponieważ sama Accuracy może dawać mylący obraz jakości modelu.

W procesie strojenia sprawdzano różne wartości parametrów mtry oraz min.node.size. Parametr mtry określa liczbę zmiennych losowanych przy każdym podziale drzewa, natomiast min.node.size określa minimalną liczebność węzła końcowego.

ctrl.salary.rf <- trainControl(
  method = "cv",
  number = 5,
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  savePredictions = "final"
)

grid.salary.rf <- expand.grid(
  mtry = seq(from = 2, to = 10, by = 2),
  min.node.size = c(5, 10, 25, 50),
  splitrule = "gini"
)

grid.salary.rf
##    mtry min.node.size splitrule
## 1     2             5      gini
## 2     4             5      gini
## 3     6             5      gini
## 4     8             5      gini
## 5    10             5      gini
## 6     2            10      gini
## 7     4            10      gini
## 8     6            10      gini
## 9     8            10      gini
## 10   10            10      gini
## 11    2            25      gini
## 12    4            25      gini
## 13    6            25      gini
## 14    8            25      gini
## 15   10            25      gini
## 16    2            50      gini
## 17    4            50      gini
## 18    6            50      gini
## 19    8            50      gini
## 20   10            50      gini
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.

# set.seed(123)
# salary.rf.tuned <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "ranger",
#   trControl = ctrl.salary.rf,
#   tuneGrid = grid.salary.rf,
#   metric = "ROC",
#   num.trees = 50
# )

# saveRDS(
#   salary.rf.tuned,
#   "trenowane_modele/salary.rf.tuned.rds"
# )

salary.rf.tuned <- readRDS("trenowane_modele/salary.rf.tuned.rds")
salary.rf.tuned
## Random Forest 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891 
## Resampling results across tuning parameters:
## 
##   mtry  min.node.size  ROC        Sens       Spec     
##    2     5             0.9079080  0.9872620  0.3544547
##    2    10             0.9074837  0.9883341  0.3464620
##    2    25             0.9076500  0.9871358  0.3529315
##    2    50             0.9074293  0.9877664  0.3485534
##    4     5             0.9140599  0.9610922  0.5340574
##    4    10             0.9140022  0.9610921  0.5340585
##    4    25             0.9136749  0.9609659  0.5336768
##    4    50             0.9136811  0.9610290  0.5321535
##    6     5             0.9157692  0.9510658  0.5810506
##    6    10             0.9160098  0.9511919  0.5774345
##    6    25             0.9155840  0.9511919  0.5766737
##    6    50             0.9151235  0.9516333  0.5745826
##    8     5             0.9160909  0.9463996  0.6031182
##    8    10             0.9161528  0.9464624  0.6027378
##    8    25             0.9158625  0.9474715  0.5983628
##    8    50             0.9156369  0.9480391  0.5939858
##   10     5             0.9157612  0.9429312  0.6139632
##   10    10             0.9154328  0.9432464  0.6114901
##   10    25             0.9159407  0.9453275  0.6073041
##   10    50             0.9155349  0.9458319  0.6025478
## 
## Tuning parameter 'splitrule' was held constant at a value of gini
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 8, splitrule = gini
##  and min.node.size = 10.
plot(salary.rf.tuned)

salary.rf.tuned$bestTune
##    mtry splitrule min.node.size
## 14    8      gini            10

W wyniku strojenia jako najlepszy wybrano model o parametrach: mtry = 8, splitrule = gini oraz min.node.size = 10. Do wyboru modelu wykorzystano największą wartość ROC.

Predykcja i ocena dostrojonego Random Forest

pred.salary.rf.tuned.train.prob <- predict(
  salary.rf.tuned,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.salary.rf.tuned.test.prob <- predict(
  salary.rf.tuned,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]
metryki.salary.rf.tuned.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.rf.tuned.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.salary.rf.tuned.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.rf.tuned.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.salary.rf.tuned.train,
  test = metryki.salary.rf.tuned.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.8500      0.9488      0.8173           0.8830 0.6325 0.9797 0.7590
## test    0.8112      0.8752      0.7900           0.8326 0.5800 0.9503 0.6977
##        Kappa
## train 0.6563
## test  0.5685

Dostrojony Random Forest osiąga większą czułość, wysoką wartość Balanced Accuracy oraz nie wykazuje tak wyraźnego przeuczenia jak wcześniejsze modele zespołowe. Oznacza to, że odpowiednie dobranie hiperparametrów znacząco poprawia stabilność modelu.

Porównanie krzywych ROC dla baggingu i Random Forest

ROC.bagging.train <- pROC::roc(
  as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
  pred.salary.bag.train.prob
)

ROC.bagging.test <- pROC::roc(
  as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
  pred.salary.bag.test.prob
)

ROC.rf.train <- pROC::roc(
  as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
  pred.salary.rf.train.prob
)

ROC.rf.test <- pROC::roc(
  as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
  pred.salary.rf.test.prob
)

ROC.rf.tuned.train <- pROC::roc(
  as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
  pred.salary.rf.tuned.train.prob
)

ROC.rf.tuned.test <- pROC::roc(
  as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
  pred.salary.rf.tuned.test.prob
)

list(
  bagging.train = ROC.bagging.train,
  bagging.test = ROC.bagging.test,
  rf.train = ROC.rf.train,
  rf.test = ROC.rf.test,
  rf.tuned.train = ROC.rf.tuned.train,
  rf.tuned.test = ROC.rf.tuned.test
) %>%
  ggroc(alpha = 0.5, linewidth = 1) + 
  geom_segment(
    aes(x = 1, xend = 0, y = 0, yend = 1),
    color = "grey",
    linetype = "dashed"
  ) +
  labs(
    title = paste0(
      "Gini TEST: ",
      "bagging = ", round(100 * (2 * pROC::auc(ROC.bagging.test) - 1), 1), "%, ",
      "RF = ", round(100 * (2 * pROC::auc(ROC.rf.test) - 1), 1), "%, ",
      "RF tuned = ", round(100 * (2 * pROC::auc(ROC.rf.tuned.test) - 1), 1), "%"
    ),
    subtitle = paste0(
      "Gini TRAIN: ",
      "bagging = ", round(100 * (2 * pROC::auc(ROC.bagging.train) - 1), 1), "%, ",
      "RF = ", round(100 * (2 * pROC::auc(ROC.rf.train) - 1), 1), "%, ",
      "RF tuned = ", round(100 * (2 * pROC::auc(ROC.rf.tuned.train) - 1), 1), "%"
    )
  ) +
  theme_bw() +
  coord_fixed() +
  scale_color_brewer(palette = "Paired")

Na wykresie dobrze widoczne jest przeuczenie modelu baggingowego. Dla danych treningowych bagging osiąga indeks Giniego na poziomie około 99%, natomiast dla danych testowych tylko około 79. Najlepszy wynik dla danych testowych uzyskuje dostrojony Random Forest, dla którego indeks Giniego wynosi około 83%.

Porównanie metryk modeli

Na końcu porównano najważniejsze metryki dla modeli oszacowanych do tej pory. W przypadku danych niezbalansowanych progi odcięcia znajdują się zwykle w okolicach udziału klasy pozytywnej, czyli około 25%. Dla modeli trenowanych na danych zbalansowanych naturalnym progiem jest 0.5.

porownanie.salary.train <- rbind(
  logit = metryki.glm.train,
  decision.tree = metryki.tree.default.train,
  decision.tree.downsample = metryki.tree.down.train,
  bagging = metryki.salary.bag.train,
  rf = metryki.salary.rf.train,
  rf.tuned = metryki.salary.rf.tuned.train
)

porownanie.salary.test <- rbind(
  logit = metryki.glm.test,
  decision.tree = metryki.tree.default.test,
  decision.tree.downsample = metryki.tree.down.test,
  bagging = metryki.salary.bag.test,
  rf = metryki.salary.rf.test,
  rf.tuned = metryki.salary.rf.tuned.test
)

porownanie.salary.train
##                          Accuracy Sensitivity Specificity BalancedAccuracy
## logit                      0.7936      0.8615      0.7711           0.8163
## decision.tree              0.8416      0.5099      0.9516           0.7307
## decision.tree.downsample   0.8394      0.8824      0.7964           0.8394
## bagging                    0.9517      0.9812      0.9420           0.9616
## rf                         0.9299      0.9553      0.9215           0.9384
## rf.tuned                   0.8500      0.9488      0.8173           0.8830
##                             PPV    NPV     F1  Kappa
## logit                    0.5550 0.9438 0.6751 0.5340
## decision.tree            0.7773 0.8542 0.6158 0.5214
## decision.tree.downsample 0.8125 0.8714 0.8460 0.6788
## bagging                  0.8486 0.9934 0.9101 0.8773
## rf                       0.8013 0.9842 0.8716 0.8239
## rf.tuned                 0.6325 0.9797 0.7590 0.6563
porownanie.salary.test
##                          Accuracy Sensitivity Specificity BalancedAccuracy
## logit                      0.7940      0.8686      0.7692           0.8189
## decision.tree              0.8397      0.5115      0.9485           0.7300
## decision.tree.downsample   0.7928      0.8535      0.7727           0.8131
## bagging                    0.8015      0.8215      0.7949           0.8082
## rf                         0.8196      0.8139      0.8215           0.8177
## rf.tuned                   0.8112      0.8752      0.7900           0.8326
##                             PPV    NPV     F1  Kappa
## logit                    0.5552 0.9464 0.6774 0.5366
## decision.tree            0.7670 0.8542 0.6137 0.5177
## decision.tree.downsample 0.5544 0.9409 0.6721 0.5304
## bagging                  0.5703 0.9307 0.6732 0.5373
## rf                       0.6018 0.9302 0.6920 0.5685
## rf.tuned                 0.5800 0.9503 0.6977 0.5685

Przeprowadzona analiza porównawcza wskazuje, że wybór optymalnego modelu predykcyjnego dla dochodów w zbiorze Adult zależy od priorytetów analizy. Model logitowy oferuje wysoką stabilność i dobrą zdolność generalizacji, podczas gdy dostrojony Random Forest wykazuje szczególnie dobrą czułość w identyfikacji klasy mniejszościowej. Mimo zaobserwowanego przeuczenia w bazowych wariantach modeli zespołowych, takich jak bagging, proces strojenia hiperparametrów pozwolił uzyskać wyniki o wysokiej wartości predykcyjnej przy zachowaniu akceptowalnego poziomu dopasowania.

Warto podkreślić, iż dla baggingu metryki na zbiorze treningowym są bardzo wysokie: Accuracy wynosi 0.9517, a Balanced Accuracy 0.9616. Na zbiorze testowym wartości te spadają odpowiednio do 0.8015 oraz 0.8082. Oznacza to, że model bardzo dobrze dopasował się do danych treningowych, ale jego zdolność generalizacji na nowych danych jest wyraźnie słabsza.

Podobne, choć mniej silne zjawisko widać dla domyślnego Random Forest. Model osiąga Accuracy 0.9299 i Balanced Accuracy 0.9384 na zbiorze treningowym, natomiast na zbiorze testowym odpowiednio 0.8196 oraz 0.8177. Różnica ta sugeruje, że model również częściowo przeuczył się na danych treningowych.

Najlepszy kompromis między skutecznością a stabilnością daje dostrojony Random Forest. Dla tego modelu Balanced Accuracy wynosi 0.8830 na zbiorze treningowym oraz 0.8326 na zbiorze testowym. Różnica między wynikami jest zdecydowanie mniejsza niż w przypadku baggingu i domyślnego Random Forest, co oznacza lepszą zdolność generalizacji. Jednocześnie model ten osiąga najwyższą czułość na zbiorze testowym, równą 0.8770, co oznacza bardzo dobrą zdolność wykrywania osób zarabiających powyżej 50 tys. USD.

Ważność zmiennych w modelu Random Forest

Na końcu przeanalizowano ważność zmiennych w modelu Random Forest. W tym celu wykorzystano miarę Mean Decrease Gini, która pokazuje, jak bardzo dana zmienna przyczynia się do poprawy jakości podziałów w drzewach.

importance.salary.rf <- salary.rf$importance[, "MeanDecreaseGini"]

importance.salary.rf <- sort(importance.salary.rf, decreasing = TRUE)
importance.top <- head(importance.salary.rf, 20)

importance.df <- data.frame(
  Zmienna = names(importance.top),
  Waznosc = as.numeric(importance.top)
)

importance.df$Zmienna <- gsub("_", " ", importance.df$Zmienna)

ggplot(
  importance.df,
  aes(
    x = reorder(Zmienna, Waznosc),
    y = Waznosc,
    fill = Waznosc
  )
) +
  geom_col(width = 0.68, show.legend = FALSE) +
  coord_flip() +
  scale_fill_gradient(
    low = "#B7DDE8",
    high = "#2F6690"
  ) +
  labs(
    title = "Ważność zmiennych w modelu Random Forest",
    subtitle = "Najważniejsze predyktory dochodu",
    x = NULL,
    y = "Mean Decrease Gini"
  ) +
  theme_minimal(base_size = 9) +
  theme(
    plot.title = element_text(
      face = "bold",
      size = 12,
      hjust = 0
    ),
    plot.subtitle = element_text(
      size = 8.5,
      color = "grey35",
      margin = ggplot2::margin(b = 8)
    ),
    axis.text.y = element_text(
      size = 7.5,
      color = "grey20"
    ),
    axis.text.x = element_text(
      size = 7.5,
      color = "grey30"
    ),
    axis.title.x = element_text(
      size = 8.5,
      margin = ggplot2::margin(t = 6)
    ),
    panel.grid.major.y = element_blank(),
    panel.grid.minor = element_blank(),
    plot.margin = ggplot2::margin(6, 10, 6, 6)
  )

Wykres ważności zmiennych pokazuje, które predyktory miały największe znaczenie w modelu Random Forest. Najwyższe wartości Mean Decrease Gini oznaczają, że dana zmienna była często wykorzystywana do tworzenia skutecznych podziałów w drzewach i silnie wspierała proces klasyfikacji. W tym przyapdku kluczowymi zmiennymi była relacja w rodzinie oraz wiek.

Boosting drzew: XGBoost

W kolejnym etapie analizy zastosowano metodę XGBoost, czyli jedną z najczęściej wykorzystywanych metod boostingu drzew decyzyjnych. Nazwa XGBoost oznacza Extreme Gradient Boosting.

W przeciwieństwie do Random Forest, który buduje wiele drzew równolegle i następnie uśrednia ich predykcje, XGBoost buduje drzewa sekwencyjnie. Oznacza to, że każde kolejne drzewo próbuje poprawić błędy popełnione przez wcześniejsze drzewa. Model stopniowo uczy się więc na swoich pomyłkach.

Intuicyjnie można to wyjaśnić następująco:

  • pojedyncze drzewo decyzyjne może być zbyt proste lub niestabilne,
  • Random Forest tworzy wiele niezależnych drzew i ogranicza wariancję modelu,
  • XGBoost buduje drzewa jedno po drugim, a każde następne skupia się na trudniejszych przypadkach,
  • dzięki temu XGBoost często osiąga bardzo wysoką skuteczność predykcyjną.

W przypadku naszego problemu klasyfikacyjnego XGBoost będzie przewidywał prawdopodobieństwo przynależności obserwacji do klasy tak, czyli do grupy osób zarabiających powyżej 50 tys. USD rocznie.

W analizie zastosowano podejście podobne jak przy wcześniejszych modelach: najpierw zbudowano model startowy, a następnie przeprowadzono etapowe strojenie hiperparametrów.

Najważniejsze hiperparametry XGBoost

Przed budową modelu warto krótko wyjaśnić najważniejsze parametry:

  • nrounds — liczba kolejnych drzew budowanych przez model,
  • max_depth — maksymalna głębokość pojedynczego drzewa,
  • eta — tempo uczenia; im mniejsza wartość, tym wolniej, ale często stabilniej uczy się model,
  • gamma — minimalna poprawa jakości wymagana do wykonania kolejnego podziału w drzewie,
  • subsample — część obserwacji losowana do budowy każdego drzewa,
  • colsample_bytree — część zmiennych losowana do budowy każdego drzewa,
  • min_child_weight — minimalna liczba obserwacji / waga wymagana w liściu drzewa.

Niższe wartości eta, ograniczenia w subsample oraz colsample_bytree mogą zmniejszać ryzyko przeuczenia modelu. Z kolei zbyt duża liczba drzew lub zbyt głębokie drzewa mogą prowadzić do bardzo dobrego dopasowania na zbiorze treningowym, ale słabszej generalizacji na zbiorze testowym.

Model startowy XGBoost

Na początku zbudowano model startowy z jedną, bazową kombinacją parametrów. Model ten stanowi punkt odniesienia dla kolejnych etapów strojenia.

# Siatka parametrów dla modelu startowego XGBoost.
# Na tym etapie używamy jednej kombinacji parametrów,
# aby sprawdzić podstawową jakość modelu.

grid.salary.xgb.start <- expand.grid(
  nrounds = 100,             # liczba drzew
  max_depth = 3,             # maksymalna głębokość drzewa
  eta = 0.05,                # tempo uczenia
  gamma = 0,                 # brak dodatkowej kary za podział
  colsample_bytree = 0.8,    # 80% zmiennych dostępnych dla drzewa
  min_child_weight = 100,    # minimalna waga w liściu
  subsample = 0.8            # 80% obserwacji dla każdego drzewa
)

grid.salary.xgb.start
##   nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 1     100         3 0.05     0              0.8              100       0.8
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.
# Jeśli uruchamiasz projekt pierwszy raz i plik RDS jeszcze nie istnieje,
# trzeba tymczasowo odkomentować trening i saveRDS().

# set.seed(123)
# salary.xgb.start <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "xgbTree",
#   trControl = ctrl,
#   tuneGrid = grid.salary.xgb.start,
#   metric = "ROC"
# )
# 
# saveRDS(salary.xgb.start, "trenowane_modele/salary.xgb.start.rds")

# Wczytanie wcześniej zapisanego modelu:
salary.xgb.start <- readRDS("trenowane_modele/salary.xgb.start.rds")

salary.xgb.start
## eXtreme Gradient Boosting 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891 
## Resampling results:
## 
##   ROC        Sens      Spec     
##   0.9057785  0.936247  0.5780075
## 
## Tuning parameter 'nrounds' was held constant at a value of 100
## Tuning
##  held constant at a value of 100
## Tuning parameter 'subsample' was held
##  constant at a value of 0.8

Predykcja dla modelu startowego XGBoost

Po wczytaniu modelu wyznaczono przewidywane prawdopodobieństwa dla zbioru treningowego i testowego. Interesuje nas prawdopodobieństwo klasy tak.

pred.salary.xgb.start.train.prob <- predict(
  salary.xgb.start,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.salary.xgb.start.test.prob <- predict(
  salary.xgb.start,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

summary(pred.salary.xgb.start.train.prob)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## 0.009907 0.044268 0.144113 0.250579 0.386196 0.976476
summary(pred.salary.xgb.start.test.prob)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## 0.01003 0.04343 0.14305 0.24841 0.38494 0.97648

Ocena jakości modelu startowego XGBoost

Do oceny modelu wykorzystano tę samą funkcję klasyfikacja.metryki, która była stosowana wcześniej. Dzięki temu wyniki XGBoost można bezpośrednio porównać z wynikami logitu, drzew, baggingu oraz Random Forest.

Jako próg odcięcia zastosowano cutoff.bag, czyli próg wyznaczony wcześniej dla modeli zespołowych. Pozwala to zachować spójność porównania.

metryki.salary.xgb.start.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.xgb.start.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.salary.xgb.start.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.xgb.start.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  train = metryki.salary.xgb.start.train,
  test = metryki.salary.xgb.start.test
)
##       Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## train   0.7986      0.8727      0.7740           0.8234 0.5614 0.9483 0.6833
## test    0.7982      0.8601      0.7777           0.8189 0.5618 0.9438 0.6796
##        Kappa
## train 0.5456
## test  0.5416

Na podstawie wyników modelu startowego można ocenić, jak XGBoost radzi sobie bez dokładniejszego strojenia hiperparametrów. Kolejnym krokiem będzie poprawa modelu przez stopniowe dobieranie parametrów.


XGBoost: krok 1 — strojenie max_depth i min_child_weight

W pierwszym etapie strojenia sprawdzono różne wartości max_depth oraz min_child_weight.

Parametr max_depth kontroluje złożoność pojedynczych drzew. Głębsze drzewa mogą uchwycić bardziej skomplikowane zależności, ale mogą też zwiększać ryzyko przeuczenia.

Parametr min_child_weight ogranicza tworzenie zbyt małych liści. Większa wartość tego parametru sprawia, że model staje się bardziej ostrożny.

grid.salary.xgb.depth_child <- expand.grid(
  nrounds = 100,
  max_depth = c(2, 3, 4, 5),
  eta = 0.05,
  gamma = 0,
  colsample_bytree = 0.8,
  min_child_weight = c(50, 100, 200, 400),
  subsample = 0.8
)

grid.salary.xgb.depth_child
##    nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 1      100         2 0.05     0              0.8               50       0.8
## 2      100         3 0.05     0              0.8               50       0.8
## 3      100         4 0.05     0              0.8               50       0.8
## 4      100         5 0.05     0              0.8               50       0.8
## 5      100         2 0.05     0              0.8              100       0.8
## 6      100         3 0.05     0              0.8              100       0.8
## 7      100         4 0.05     0              0.8              100       0.8
## 8      100         5 0.05     0              0.8              100       0.8
## 9      100         2 0.05     0              0.8              200       0.8
## 10     100         3 0.05     0              0.8              200       0.8
## 11     100         4 0.05     0              0.8              200       0.8
## 12     100         5 0.05     0              0.8              200       0.8
## 13     100         2 0.05     0              0.8              400       0.8
## 14     100         3 0.05     0              0.8              400       0.8
## 15     100         4 0.05     0              0.8              400       0.8
## 16     100         5 0.05     0              0.8              400       0.8
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.1.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.

# set.seed(123)
# salary.xgb.1 <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "xgbTree",
#   trControl = ctrl,
#   tuneGrid = grid.salary.xgb.depth_child,
#   metric = "ROC"
# )
# 
# saveRDS(salary.xgb.1, "trenowane_modele/salary.xgb.1.rds")

salary.xgb.1 <- readRDS("trenowane_modele/salary.xgb.1.rds")

salary.xgb.1
## eXtreme Gradient Boosting 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16892, 16891, 16891, 16891, 16891 
## Resampling results across tuning parameters:
## 
##   max_depth  min_child_weight  ROC        Sens       Spec     
##   2           50               0.9020865  0.9387101  0.5543292
##   2          100               0.8996022  0.9384578  0.5436727
##   2          200               0.8799907  0.9274231  0.5259753
##   2          400               0.8476504  0.9345483  0.4289248
##   3           50               0.9066564  0.9404756  0.5691722
##   3          100               0.9035268  0.9351160  0.5646051
##   3          200               0.8840750  0.9260990  0.5495718
##   3          400               0.8479697  0.9317737  0.4441484
##   4           50               0.9088985  0.9382057  0.5828735
##   4          100               0.9050568  0.9341701  0.5824929
##   4          200               0.8856024  0.9278015  0.5507136
##   4          400               0.8474548  0.9295664  0.4513796
##   5           50               0.9098359  0.9368815  0.5923882
##   5          100               0.9059827  0.9344852  0.5794481
##   5          200               0.8860554  0.9274860  0.5512845
##   5          400               0.8482110  0.9312064  0.4496670
## 
## Tuning parameter 'nrounds' was held constant at a value of 100
## Tuning
##  'colsample_bytree' was held constant at a value of 0.8
## Tuning
##  parameter 'subsample' was held constant at a value of 0.8
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 100, max_depth = 5, eta
##  = 0.05, gamma = 0, colsample_bytree = 0.8, min_child_weight = 50 and
##  subsample = 0.8.
salary.xgb.1$bestTune
##    nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 13     100         5 0.05     0              0.8               50       0.8

Na podstawie wyników pierwszego etapu strojenia przyjęto wartości max_depth = 5 oraz min_child_weight = 50.


XGBoost: krok 2 — strojenie nrounds i eta

W drugim etapie strojenia sprawdzono liczbę drzew oraz tempo uczenia.

Parametr nrounds określa liczbę kolejnych drzew budowanych przez model. Parametr eta odpowiada za tempo uczenia. Niższe wartości eta powodują wolniejsze, ale często stabilniejsze uczenie modelu.

grid.salary.xgb.nrounds_eta <- expand.grid(
  nrounds = c(200, 300, 500, 800),
  max_depth = 5,
  eta = c(0.01, 0.025, 0.05, 0.1),
  gamma = 0,
  colsample_bytree = 0.8,
  min_child_weight = 50,
  subsample = 0.8
)

grid.salary.xgb.nrounds_eta
##    nrounds max_depth   eta gamma colsample_bytree min_child_weight subsample
## 1      200         5 0.010     0              0.8               50       0.8
## 2      300         5 0.010     0              0.8               50       0.8
## 3      500         5 0.010     0              0.8               50       0.8
## 4      800         5 0.010     0              0.8               50       0.8
## 5      200         5 0.025     0              0.8               50       0.8
## 6      300         5 0.025     0              0.8               50       0.8
## 7      500         5 0.025     0              0.8               50       0.8
## 8      800         5 0.025     0              0.8               50       0.8
## 9      200         5 0.050     0              0.8               50       0.8
## 10     300         5 0.050     0              0.8               50       0.8
## 11     500         5 0.050     0              0.8               50       0.8
## 12     800         5 0.050     0              0.8               50       0.8
## 13     200         5 0.100     0              0.8               50       0.8
## 14     300         5 0.100     0              0.8               50       0.8
## 15     500         5 0.100     0              0.8               50       0.8
## 16     800         5 0.100     0              0.8               50       0.8
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.2.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.

# set.seed(123)
# salary.xgb.2 <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "xgbTree",
#   trControl = ctrl,
#   tuneGrid = grid.salary.xgb.nrounds_eta,
#   metric = "ROC"
# )
# 
# saveRDS(salary.xgb.2, "trenowane_modele/salary.xgb.2.rds")

salary.xgb.2 <- readRDS("trenowane_modele/salary.xgb.2.rds")

salary.xgb.2
## eXtreme Gradient Boosting 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891 
## Resampling results across tuning parameters:
## 
##   eta    nrounds  ROC        Sens       Spec     
##   0.010  200      0.9051048  0.9444446  0.5643112
##   0.010  300      0.9082290  0.9444445  0.5705879
##   0.010  500      0.9113829  0.9398414  0.5920890
##   0.010  800      0.9130497  0.9361209  0.6092110
##   0.025  200      0.9114442  0.9395892  0.5938009
##   0.025  300      0.9127783  0.9364362  0.6052154
##   0.025  500      0.9142923  0.9342292  0.6213877
##   0.025  800      0.9151878  0.9318960  0.6309003
##   0.050  200      0.9135127  0.9345446  0.6183439
##   0.050  300      0.9145132  0.9322114  0.6259532
##   0.050  500      0.9153463  0.9319591  0.6326113
##   0.050  800      0.9159019  0.9317698  0.6360356
##   0.100  200      0.9146305  0.9329681  0.6299476
##   0.100  300      0.9153241  0.9325897  0.6339440
##   0.100  500      0.9156742  0.9312022  0.6364171
##   0.100  800      0.9153414  0.9291844  0.6407920
## 
## Tuning parameter 'max_depth' was held constant at a value of 5
## Tuning
## 
## Tuning parameter 'min_child_weight' was held constant at a value of 50
## 
## Tuning parameter 'subsample' was held constant at a value of 0.8
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
##  = 0.05, gamma = 0, colsample_bytree = 0.8, min_child_weight = 50 and
##  subsample = 0.8.
salary.xgb.2$bestTune
##    nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 12     800         5 0.05     0              0.8               50       0.8

Na podstawie wyników drugiego etapu strojenia przyjęto wartości nrounds = 800 oraz eta = 0.05.


XGBoost: krok 3 — strojenie subsample i colsample_bytree

W trzecim etapie sprawdzono parametry odpowiedzialne za losowanie obserwacji i zmiennych.

Parametr subsample określa, jaka część obserwacji jest wykorzystywana do budowy każdego drzewa. Parametr colsample_bytree określa, jaka część zmiennych jest wykorzystywana w każdym drzewie.

Takie losowanie może działać regularyzująco, czyli ograniczać przeuczenie modelu.

grid.salary.xgb.sample <- expand.grid(
  nrounds = 800,
  max_depth = 5,
  eta = 0.05,
  gamma = 0,
  colsample_bytree = c(0.2, 0.4, 0.6, 0.8, 0.9),
  min_child_weight = 50,
  subsample = c(0.2, 0.4, 0.6, 0.8, 0.9)
)

grid.salary.xgb.sample
##    nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 1      800         5 0.05     0              0.2               50       0.2
## 2      800         5 0.05     0              0.4               50       0.2
## 3      800         5 0.05     0              0.6               50       0.2
## 4      800         5 0.05     0              0.8               50       0.2
## 5      800         5 0.05     0              0.9               50       0.2
## 6      800         5 0.05     0              0.2               50       0.4
## 7      800         5 0.05     0              0.4               50       0.4
## 8      800         5 0.05     0              0.6               50       0.4
## 9      800         5 0.05     0              0.8               50       0.4
## 10     800         5 0.05     0              0.9               50       0.4
## 11     800         5 0.05     0              0.2               50       0.6
## 12     800         5 0.05     0              0.4               50       0.6
## 13     800         5 0.05     0              0.6               50       0.6
## 14     800         5 0.05     0              0.8               50       0.6
## 15     800         5 0.05     0              0.9               50       0.6
## 16     800         5 0.05     0              0.2               50       0.8
## 17     800         5 0.05     0              0.4               50       0.8
## 18     800         5 0.05     0              0.6               50       0.8
## 19     800         5 0.05     0              0.8               50       0.8
## 20     800         5 0.05     0              0.9               50       0.8
## 21     800         5 0.05     0              0.2               50       0.9
## 22     800         5 0.05     0              0.4               50       0.9
## 23     800         5 0.05     0              0.6               50       0.9
## 24     800         5 0.05     0              0.8               50       0.9
## 25     800         5 0.05     0              0.9               50       0.9
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.3.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.

# set.seed(123)
# salary.xgb.3 <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "xgbTree",
#   trControl = ctrl,
#   tuneGrid = grid.salary.xgb.sample,
#   metric = "ROC"
# )
# 
# saveRDS(salary.xgb.3, "trenowane_modele/salary.xgb.3.rds")

salary.xgb.3 <- readRDS("trenowane_modele/salary.xgb.3.rds")

salary.xgb.3
## eXtreme Gradient Boosting 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891 
## Resampling results across tuning parameters:
## 
##   colsample_bytree  subsample  ROC        Sens       Spec     
##   0.2               0.2        0.8870639  0.9310128  0.5487088
##   0.2               0.4        0.9091399  0.9320219  0.6059771
##   0.2               0.6        0.9126803  0.9335987  0.6200544
##   0.2               0.8        0.9155636  0.9342292  0.6299479
##   0.2               0.9        0.9162185  0.9338508  0.6331821
##   0.4               0.2        0.8878455  0.9300670  0.5570803
##   0.4               0.4        0.9096997  0.9299410  0.6141556
##   0.4               0.6        0.9129322  0.9301936  0.6270941
##   0.4               0.8        0.9159209  0.9334094  0.6316602
##   0.4               0.9        0.9169222  0.9326526  0.6358469
##   0.6               0.2        0.8878807  0.9274186  0.5631678
##   0.6               0.4        0.9092376  0.9300041  0.6185342
##   0.6               0.6        0.9128037  0.9303827  0.6303276
##   0.6               0.8        0.9161457  0.9326526  0.6388891
##   0.6               0.9        0.9171412  0.9325896  0.6375579
##   0.8               0.2        0.8882967  0.9279859  0.5610746
##   0.8               0.4        0.9089835  0.9302562  0.6139668
##   0.8               0.6        0.9128501  0.9309502  0.6307082
##   0.8               0.8        0.9157144  0.9328420  0.6366061
##   0.8               0.9        0.9169748  0.9317068  0.6428846
##   0.9               0.2        0.8881539  0.9287427  0.5637385
##   0.9               0.4        0.9087760  0.9295627  0.6166305
##   0.9               0.6        0.9123321  0.9309502  0.6289963
##   0.9               0.8        0.9157602  0.9313915  0.6388895
##   0.9               0.9        0.9169218  0.9315177  0.6436463
## 
## Tuning parameter 'nrounds' was held constant at a value of 800
## Tuning
##  held constant at a value of 0
## Tuning parameter 'min_child_weight' was
##  held constant at a value of 50
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
##  = 0.05, gamma = 0, colsample_bytree = 0.6, min_child_weight = 50 and
##  subsample = 0.9.
salary.xgb.3$bestTune
##    nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 15     800         5 0.05     0              0.6               50       0.9

Na podstawie wyników trzeciego etapu strojenia przyjęto wartości subsample = 0.9 oraz colsample_bytree = 0.6.


XGBoost: krok 4 — strojenie gamma

W ostatnim etapie strojenia sprawdzono parametr gamma. Parametr ten określa minimalną poprawę jakości modelu wymaganą do wykonania kolejnego podziału w drzewie.

Im większa wartość gamma, tym bardziej konserwatywny jest model, ponieważ trudniej mu tworzyć kolejne podziały.

grid.salary.xgb.gamma <- expand.grid(
  nrounds = 800,
  max_depth = 5,
  eta = 0.05,
  gamma = c(0, 0.5, 1, 2),
  colsample_bytree = 0.6,
  min_child_weight = 50,
  subsample = 0.9
)

grid.salary.xgb.gamma
##   nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 1     800         5 0.05   0.0              0.6               50       0.9
## 2     800         5 0.05   0.5              0.6               50       0.9
## 3     800         5 0.05   1.0              0.6               50       0.9
## 4     800         5 0.05   2.0              0.6               50       0.9
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.4.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.

# set.seed(123)
# salary.xgb.4 <- train(
#   klasa_dochodu_01 ~ .,
#   data = salary.train.caret,
#   method = "xgbTree",
#   trControl = ctrl,
#   tuneGrid = grid.salary.xgb.gamma,
#   metric = "ROC"
# )
# 
# saveRDS(salary.xgb.4, "trenowane_modele/salary.xgb.4.rds")

salary.xgb.4 <- readRDS("trenowane_modele/salary.xgb.4.rds")

salary.xgb.4
## eXtreme Gradient Boosting 
## 
## 21114 samples
##    17 predictor
##     2 classes: 'nie', 'tak' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891 
## Resampling results across tuning parameters:
## 
##   gamma  ROC        Sens       Spec     
##   0.0    0.9171491  0.9320852  0.6385101
##   0.5    0.9172577  0.9324004  0.6413624
##   1.0    0.9172496  0.9324004  0.6396503
##   2.0    0.9168120  0.9333463  0.6331822
## 
## Tuning parameter 'nrounds' was held constant at a value of 800
## Tuning
##  parameter 'min_child_weight' was held constant at a value of 50
## 
## Tuning parameter 'subsample' was held constant at a value of 0.9
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
##  = 0.05, gamma = 0.5, colsample_bytree = 0.6, min_child_weight = 50
##  and subsample = 0.9.
salary.xgb.4$bestTune
##   nrounds max_depth  eta gamma colsample_bytree min_child_weight subsample
## 2     800         5 0.05   0.5              0.6               50       0.9

Na podstawie wyników czwartego etapu strojenia jako model finalny przyjęto model salary.xgb.4.


XGBoost: model finalny

Po przeprowadzeniu strojenia wybrano finalny model XGBoost. Następnie wyznaczono prawdopodobieństwa dla zbioru treningowego i testowego.

salary.xgb.tuned <- salary.xgb.4

pred.salary.xgb.tuned.train.prob <- predict(
  salary.xgb.tuned,
  newdata = salary.train.caret,
  type = "prob"
)[, "tak"]

pred.salary.xgb.tuned.test.prob <- predict(
  salary.xgb.tuned,
  newdata = salary.test.caret,
  type = "prob"
)[, "tak"]

summary(pred.salary.xgb.tuned.train.prob)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## 0.0002398 0.0119718 0.0927959 0.2490277 0.4206244 0.9948304
summary(pred.salary.xgb.tuned.test.prob)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## 0.0003001 0.0114394 0.0919788 0.2452645 0.4131365 0.9946068

Ocena finalnego modelu XGBoost

metryki.salary.xgb.tuned.train <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.xgb.tuned.train.prob,
  real = salary.train.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

metryki.salary.xgb.tuned.test <- klasyfikacja.metryki(
  predicted_probabilities = pred.salary.xgb.tuned.test.prob,
  real = salary.test.caret$klasa_dochodu_01,
  cutoff = cutoff.bag,
  level_positive = "tak",
  level_negative = "nie"
)

rbind(
  xgb.start.train = metryki.salary.xgb.start.train,
  xgb.tuned.train = metryki.salary.xgb.tuned.train,
  xgb.start.test = metryki.salary.xgb.start.test,
  xgb.tuned.test = metryki.salary.xgb.tuned.test
)
##                 Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV
## xgb.start.train   0.7986      0.8727      0.7740           0.8234 0.5614 0.9483
## xgb.tuned.train   0.8286      0.8786      0.8120           0.8453 0.6077 0.9528
## xgb.start.test    0.7982      0.8601      0.7777           0.8189 0.5618 0.9438
## xgb.tuned.test    0.8239      0.8583      0.8125           0.8354 0.6027 0.9454
##                     F1  Kappa
## xgb.start.train 0.6833 0.5456
## xgb.tuned.train 0.7185 0.6011
## xgb.start.test  0.6796 0.5416
## xgb.tuned.test  0.7082 0.5876

Porównując model startowy oraz model dostrojony, można ocenić, czy strojenie hiperparametrów poprawiło jakość klasyfikacji. Szczególnie ważna jest wartość Balanced Accuracy, ponieważ w analizowanym zbiorze klasy są niezbalansowane. Balanced Accuracy dla danych testwoych w finalnym modelu wyniosła niemalże 84%, a dodatkowo co niezwykle ważne nie widać uznak jakiegolwiek przeuczenia modelu.

Krzywe ROC dla modeli XGBoost

Na końcu porównano model startowy i model dostrojony za pomocą krzywych ROC oraz indeksu Giniego.

ROC.xgb.start.train <- pROC::roc(
  as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
  pred.salary.xgb.start.train.prob
)

ROC.xgb.start.test <- pROC::roc(
  as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
  pred.salary.xgb.start.test.prob
)

ROC.xgb.tuned.train <- pROC::roc(
  as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
  pred.salary.xgb.tuned.train.prob
)

ROC.xgb.tuned.test <- pROC::roc(
  as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
  pred.salary.xgb.tuned.test.prob
)

list(
  xgb.start.train = ROC.xgb.start.train,
  xgb.start.test = ROC.xgb.start.test,
  xgb.tuned.train = ROC.xgb.tuned.train,
  xgb.tuned.test = ROC.xgb.tuned.test
) %>%
  ggroc(alpha = 0.6, linewidth = 1) +
  geom_segment(
    aes(x = 1, xend = 0, y = 0, yend = 1),
    color = "grey",
    linetype = "dashed"
  ) +
  labs(
    title = paste0(
      "Gini TEST: ",
      "XGB start = ", round(100 * (2 * pROC::auc(ROC.xgb.start.test) - 1), 1), "%, ",
      "XGB tuned = ", round(100 * (2 * pROC::auc(ROC.xgb.tuned.test) - 1), 1), "%"
    ),
    subtitle = paste0(
      "Gini TRAIN: ",
      "XGB start = ", round(100 * (2 * pROC::auc(ROC.xgb.start.train) - 1), 1), "%, ",
      "XGB tuned = ", round(100 * (2 * pROC::auc(ROC.xgb.tuned.train) - 1), 1), "%"
    ),
    color = "Model",
    x = "Specyficzność",
    y = "Wrażliwość"
  ) +
  theme_bw() +
  coord_fixed() +
  scale_color_brewer(palette = "Paired")

Na podstawie krzywych ROC oraz indeksu Giniego można ocenić, czy model XGBoost poprawił zdolność rozróżniania klas względem wcześniejszych modeli. Jeżeli indeks Giniego dla modelu dostrojonego jest wyższy niż dla modelu startowego, oznacza to, że strojenie hiperparametrów poprawiło zdolność separacji klas.

####Końcowe porównanie dla drzew i logitu

# Krzywe ROC - modele bazowe

ROC.train.logit <- pROC::roc(as.numeric(salary.train$klasa_dochodu_01) - 1, 
                            pred.glm.train.prob)

ROC.train.tree.tuned <- pROC::roc(as.numeric(salary.train.ML$klasa_dochodu_01) -1, 
                                 pred.tree.tuned.train.prob)


gini <- function(roc) {
  paste0(round(100 * (2 * as.numeric(pROC::auc(roc)) - 1), 1), "%")
}

pROC::ggroc(
  list(
    logit.train = ROC.train.logit,
    logit.test = ROC.test.logit,
    tree.tuned.train = ROC.train.tree.tuned,
    tree.tuned.test = ROC.test.tree.tuned,
    bagging.train = ROC.bagging.train,
    bagging.test = ROC.bagging.test,
    rf.train = ROC.rf.train,
    rf.test = ROC.rf.test,
    rf.tuned.train = ROC.rf.tuned.train,
    rf.tuned.test = ROC.rf.tuned.test,
    xgb.train = ROC.xgb.tuned.train,
    xgb.test = ROC.xgb.tuned.test
  ),
  alpha = 0.65,
  linewidth = 1
) +
  geom_segment(
    aes(x = 1, xend = 0, y = 0, yend = 1),
    color = "grey",
    linetype = "dashed"
  ) +
  labs(
    title = paste0(
      "Gini TEST:  ",
      "logit = ", gini(ROC.test.logit), "   |   ",
      "decision tree = ", gini(ROC.test.tree.tuned), "   |   ",
      "rf = ", gini(ROC.rf.test), "\n",
      "rf_tuned = ", gini(ROC.rf.tuned.test), "   |   ",
      "bagging = ", gini(ROC.bagging.test), "   |   ",
      "xgb = ", gini(ROC.xgb.tuned.test)
    ),
    subtitle = paste0(
      "Gini TRAIN:  ",
      "logit = ", gini(ROC.train.logit), "   |   ",
      "decision tree = ", gini(ROC.train.tree.tuned), "   |   ",
      "rf = ", gini(ROC.rf.train), "\n",
      "rf_tuned = ", gini(ROC.rf.tuned.train), "   |   ",
      "bagging = ", gini(ROC.bagging.train), "   |   ",
      "xgb = ", gini(ROC.xgb.tuned.train)
    ),
    color = "Model",
    x = "Specyficzność",
    y = "Wrażliwość"
  ) +
  theme_bw() +
  coord_fixed() +
  scale_color_brewer(palette = "Paired") +
  theme(
    plot.title = element_text(
      size = 10,
      face = "bold",
      lineheight = 0.95,
      margin = ggplot2::margin(b = 5)
    ),
    plot.subtitle = element_text(
      size = 9,
      lineheight = 0.95,
      margin = ggplot2::margin(b = 10)
    ),
    legend.position = "bottom",
    legend.title = element_blank(),
    legend.text = element_text(size = 8),
    axis.title = element_text(size = 11),
    axis.text = element_text(size = 9)
  ) +
  guides(
    color = guide_legend(nrow = 2, byrow = TRUE)
  )

Można zauważyć, iż dla danych testowych największe wartości Gini przyjęły kolejno modele: XGB, rf_tuned oraz rf.

Sieci neuronowe

Ostatnią grupą modeli wykorzystaną w analizie są sieci neuronowe. W przeciwieństwie do modeli liniowych, takich jak regresja logistyczna, sieci neuronowe mogą uchwycić bardziej złożone, nieliniowe zależności pomiędzy zmiennymi objaśniającymi a zmienną celu.

W naszym przypadku sieć neuronowa będzie wykorzystywana do klasyfikacji binarnej. Celem modelu jest przewidywanie, czy dana osoba należy do klasy tak, czyli czy osiąga dochód powyżej 50 tys. USD rocznie.

Sieci neuronowe wymagają jednak odpowiedniego przygotowania danych. Wszystkie zmienne wejściowe muszą być numeryczne, dlatego zmienne kategoryczne należy przekształcić za pomocą kodowania zero-jedynkowego, czyli one-hot encoding. Dodatkowo dane wejściowe zostały przeskalowane do zakresu 0–1, ponieważ sieci neuronowe są wrażliwe na skalę zmiennych.

W analizie wykorzystano pakiet neuralnet. Jest to pakiet dobry do celów dydaktycznych i pokazania podstawowej logiki działania sieci neuronowych, ale przy większych zbiorach danych może trenować się stosunkowo długo. Z tego powodu modele sieci neuronowych zostały wytrenowane na zmniejszonym podzbiorze danych treningowych.

library(neuralnet)             # proste sieci neuronowe
library(NeuralNetTools)        # wizualizacja sieci neuronowych

Przygotowanie danych do sieci neuronowych

Przed rozpoczęciem modelowania sprawdzono typy zmiennych w zbiorze treningowym. Jest to ważne, ponieważ sieci neuronowe wymagają danych numerycznych. Zmienne typu factor muszą więc zostać przekształcone do postaci zmiennych zero-jedynkowych.

# Liczba zmiennych typu factor
sum(sapply(salary.train.ML, is.factor))
## [1] 12
# Liczba zmiennych typu ordered factor
sum(sapply(salary.train.ML, is.ordered))
## [1] 0
# Nazwy zmiennych typu factor
names(salary.train.ML)[sapply(salary.train.ML, is.factor)]
##  [1] "stan_cywilny"              "zawod"                    
##  [3] "relacja_w_rodzinie"        "rasa"                     
##  [5] "plec"                      "typ_zatrudnienia"         
##  [7] "kraj_pochodzenia"          "czy_kapital_netto_dodatni"
##  [9] "wiek_sredni"               "dlugie_godziny_pracy"     
## [11] "klasa_dochodu_01"          "poziom_edukacji_factor"
# Nazwy zmiennych typu ordered factor
names(salary.train.ML)[sapply(salary.train.ML, is.ordered)]
## character(0)

Następnie dokonano podziału danych na zmienne objaśniające X oraz zmienną celu y. Zmienną objaśnianą jest klasa_dochodu_01, która informuje, czy dana osoba zarabia powyżej 50 tys. USD.

salary.x.train <- salary.train.ML %>% 
  dplyr::select(-klasa_dochodu_01)

salary.y.train <- salary.train.ML$klasa_dochodu_01

salary.x.test <- salary.test.ML %>% 
  dplyr::select(-klasa_dochodu_01)

salary.y.test <- salary.test.ML$klasa_dochodu_01

# Kontrola rozkładu zmiennej celu
table(salary.y.train)
## salary.y.train
##   nie   tak 
## 15858  5256
table(salary.y.test)
## salary.y.test
##  nie  tak 
## 6796 2252
prop.table(table(salary.y.train))
## salary.y.train
##       nie       tak 
## 0.7510656 0.2489344
prop.table(table(salary.y.test))
## salary.y.test
##       nie       tak 
## 0.7511052 0.2488948

One-hot encoding

Ponieważ sieci neuronowe nie przyjmują bezpośrednio zmiennych typu factor, zmienne kategoryczne zostały przekształcone za pomocą funkcji model.matrix(). W efekcie każda kategoria zostaje zapisana jako osobna zmienna zero-jedynkowa.

Ważne jest również wyrównanie kolumn w zbiorze testowym do kolumn ze zbioru treningowego. Jeżeli jakaś kategoria występuje w treningu, ale nie pojawia się w teście, dodajemy odpowiadającą jej kolumnę z zerami. Jeżeli natomiast w teście pojawią się dodatkowe kolumny, których nie było w treningu, zostają one usunięte.

salary.x.train.mm <- model.matrix(~ ., data = salary.x.train)[, -1]
salary.x.test.mm  <- model.matrix(~ ., data = salary.x.test)[, -1]

# Zapamiętujemy kolumny ze zbioru treningowego
train_cols <- colnames(salary.x.train.mm)

# Dodajemy brakujące kolumny w teście
missing_cols <- setdiff(train_cols, colnames(salary.x.test.mm))

if (length(missing_cols) > 0) {
  for (col in missing_cols) {
    salary.x.test.mm <- cbind(salary.x.test.mm, 0)
    colnames(salary.x.test.mm)[ncol(salary.x.test.mm)] <- col
  }
}

# Usuwamy ewentualne dodatkowe kolumny z testu i ustawiamy kolejność jak w train
salary.x.test.mm <- salary.x.test.mm[, train_cols]

# Kontrola wymiarów przed i po kodowaniu
dim(salary.x.train)
## [1] 21114    17
dim(salary.x.train.mm)
## [1] 21114    99
dim(salary.x.test)
## [1] 9048   17
dim(salary.x.test.mm)
## [1] 9048   99

Skalowanie danych

Kolejnym krokiem było przeskalowanie danych wejściowych do zakresu 0–1. Parametry skalowania, czyli minimum i maksimum, wyznaczono wyłącznie na zbiorze treningowym. Jest to bardzo ważne, ponieważ zbiór testowy powinien symulować nowe dane, których model nie znał podczas uczenia.

# Minimum i maksimum liczymy wyłącznie na zbiorze treningowym
salary.mins <- apply(salary.x.train.mm, 2, min)
salary.maxs <- apply(salary.x.train.mm, 2, max)

# Zabezpieczenie przed dzieleniem przez zero
salary.ranges <- salary.maxs - salary.mins
salary.ranges[salary.ranges == 0] <- 1

# Skalowanie zbioru treningowego
salary.x.train.scaled <- scale(
  salary.x.train.mm,
  center = salary.mins,
  scale = salary.ranges
)

# Skalowanie zbioru testowego tymi samymi parametrami
salary.x.test.scaled <- scale(
  salary.x.test.mm,
  center = salary.mins,
  scale = salary.ranges
)

summary(salary.x.train.scaled)
##       wiek        stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych
##  Min.   :0.0000   Min.   :0.0000000                             
##  1st Qu.:0.1507   1st Qu.:0.0000000                             
##  Median :0.2740   Median :0.0000000                             
##  Mean   :0.2948   Mean   :0.0006631                             
##  3rd Qu.:0.4110   3rd Qu.:0.0000000                             
##  Max.   :1.0000   Max.   :1.0000000                             
##  stan_cywilnymałżonek_nieobecny stan_cywilnynigdy_niezamężny_niezonaty
##  Min.   :0.00000                Min.   :0.000                         
##  1st Qu.:0.00000                1st Qu.:0.000                         
##  Median :0.00000                Median :0.000                         
##  Mean   :0.01246                Mean   :0.323                         
##  3rd Qu.:0.00000                3rd Qu.:1.000                         
##  Max.   :1.00000                Max.   :1.000                         
##  stan_cywilnyrozwiedziony_rozwiedziona stan_cywilnyw_separacji
##  Min.   :0.0000                        Min.   :0.0000         
##  1st Qu.:0.0000                        1st Qu.:0.0000         
##  Median :0.0000                        Median :0.0000         
##  Mean   :0.1377                        Mean   :0.0306         
##  3rd Qu.:0.0000                        3rd Qu.:0.0000         
##  Max.   :1.0000                        Max.   :1.0000         
##  stan_cywilnywdowiec_wdowa zawodinne_usługi zawodkadra_kierownicza
##  Min.   :0.00000           Min.   :0.0000   Min.   :0.0000        
##  1st Qu.:0.00000           1st Qu.:0.0000   1st Qu.:0.0000        
##  Median :0.00000           Median :0.0000   Median :0.0000        
##  Mean   :0.02804           Mean   :0.1048   Mean   :0.1327        
##  3rd Qu.:0.00000           3rd Qu.:0.0000   3rd Qu.:0.0000        
##  Max.   :1.00000           Max.   :1.0000   Max.   :1.0000        
##  zawodoperatorzy_maszyn_i_inspektorzy zawodpracownicy_fizyczni_i_sprzątający
##  Min.   :0.00000                      Min.   :0.00000                       
##  1st Qu.:0.00000                      1st Qu.:0.00000                       
##  Median :0.00000                      Median :0.00000                       
##  Mean   :0.06635                      Mean   :0.04476                       
##  3rd Qu.:0.00000                      3rd Qu.:0.00000                       
##  Max.   :1.00000                      Max.   :1.00000                       
##  zawodprywatne_usługi_domowe zawodrolnictwo_i_rybołówstwo
##  Min.   :0.000000            Min.   :0.00000             
##  1st Qu.:0.000000            1st Qu.:0.00000             
##  Median :0.000000            Median :0.00000             
##  Mean   :0.004547            Mean   :0.03254             
##  3rd Qu.:0.000000            3rd Qu.:0.00000             
##  Max.   :1.000000            Max.   :1.00000             
##  zawodrzemiosło_i_naprawy zawodsiły_zbrojne   zawodspecjaliści zawodsprzedaż   
##  Min.   :0.0000           Min.   :0.0000000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.0000           1st Qu.:0.0000000   1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :0.0000           Median :0.0000000   Median :0.0000   Median :0.0000  
##  Mean   :0.1353           Mean   :0.0003315   Mean   :0.1332   Mean   :0.1171  
##  3rd Qu.:0.0000           3rd Qu.:0.0000000   3rd Qu.:0.0000   3rd Qu.:0.0000  
##  Max.   :1.0000           Max.   :1.0000000   Max.   :1.0000   Max.   :1.0000  
##  zawodtransport_i_przemieszczanie zawodusługi_ochrony zawodwsparcie_techniczne
##  Min.   :0.00000                  Min.   :0.00000     Min.   :0.00000         
##  1st Qu.:0.00000                  1st Qu.:0.00000     1st Qu.:0.00000         
##  Median :0.00000                  Median :0.00000     Median :0.00000         
##  Mean   :0.05172                  Mean   :0.02169     Mean   :0.02951         
##  3rd Qu.:0.00000                  3rd Qu.:0.00000     3rd Qu.:0.00000         
##  Max.   :1.00000                  Max.   :1.00000     Max.   :1.00000         
##  relacja_w_rodziniemąż relacja_w_rodzinieniezamężny_niezonaty
##  Min.   :0.0000        Min.   :0.0000                        
##  1st Qu.:0.0000        1st Qu.:0.0000                        
##  Median :0.0000        Median :0.0000                        
##  Mean   :0.4142        Mean   :0.1063                        
##  3rd Qu.:1.0000        3rd Qu.:0.0000                        
##  Max.   :1.0000        Max.   :1.0000                        
##  relacja_w_rodziniepoza_rodziną relacja_w_rodziniewłasne_dziecko
##  Min.   :0.0000                 Min.   :0.0000                  
##  1st Qu.:0.0000                 1st Qu.:0.0000                  
##  Median :0.0000                 Median :0.0000                  
##  Mean   :0.2556                 Mean   :0.1461                  
##  3rd Qu.:1.0000                 3rd Qu.:0.0000                  
##  Max.   :1.0000                 Max.   :1.0000                  
##  relacja_w_rodzinieżona   rasabiała        rasaczarna         rasainna       
##  Min.   :0.00000        Min.   :0.0000   Min.   :0.00000   Min.   :0.000000  
##  1st Qu.:0.00000        1st Qu.:1.0000   1st Qu.:0.00000   1st Qu.:0.000000  
##  Median :0.00000        Median :1.0000   Median :0.00000   Median :0.000000  
##  Mean   :0.04741        Mean   :0.8622   Mean   :0.09297   Mean   :0.006962  
##  3rd Qu.:0.00000        3rd Qu.:1.0000   3rd Qu.:0.00000   3rd Qu.:0.000000  
##  Max.   :1.00000        Max.   :1.0000   Max.   :1.00000   Max.   :1.000000  
##  rasardzenni_amerykanie_lub_eskimosi plecmężczyzna    godziny_pracy_tygodniowo
##  Min.   :0.000000                    Min.   :0.0000   Min.   :0.0000          
##  1st Qu.:0.000000                    1st Qu.:0.0000   1st Qu.:0.3980          
##  Median :0.000000                    Median :1.0000   Median :0.3980          
##  Mean   :0.009472                    Mean   :0.6773   Mean   :0.4082          
##  3rd Qu.:0.000000                    3rd Qu.:1.0000   3rd Qu.:0.4490          
##  Max.   :1.000000                    Max.   :1.0000   Max.   :1.0000          
##  zysk_kapitalowy   strata_kapitalowa typ_zatrudnieniaadministracja_lokalna
##  Min.   :0.00000   Min.   :0.0000    Min.   :0.00000                      
##  1st Qu.:0.00000   1st Qu.:0.0000    1st Qu.:0.00000                      
##  Median :0.00000   Median :0.0000    Median :0.00000                      
##  Mean   :0.01069   Mean   :0.0206    Mean   :0.06801                      
##  3rd Qu.:0.00000   3rd Qu.:0.0000    3rd Qu.:0.00000                      
##  Max.   :1.00000   Max.   :1.0000    Max.   :1.00000                      
##  typ_zatrudnieniaadministracja_stanowa typ_zatrudnieniabez_wynagrodzenia
##  Min.   :0.00000                       Min.   :0.0000000                
##  1st Qu.:0.00000                       1st Qu.:0.0000000                
##  Median :0.00000                       Median :0.0000000                
##  Mean   :0.04244                       Mean   :0.0004263                
##  3rd Qu.:0.00000                       3rd Qu.:0.0000000                
##  Max.   :1.00000                       Max.   :1.0000000                
##  typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej
##  Min.   :0.0000                                        
##  1st Qu.:0.0000                                        
##  Median :0.0000                                        
##  Mean   :0.0853                                        
##  3rd Qu.:0.0000                                        
##  Max.   :1.0000                                        
##  typ_zatrudnieniasamozatrudniony_z_osobowością_prawną
##  Min.   :0.00000                                     
##  1st Qu.:0.00000                                     
##  Median :0.00000                                     
##  Mean   :0.03538                                     
##  3rd Qu.:0.00000                                     
##  Max.   :1.00000                                     
##  typ_zatrudnieniasektor_prywatny kraj_pochodzeniaChiny
##  Min.   :0.0000                  Min.   :0.000000     
##  1st Qu.:0.0000                  1st Qu.:0.000000     
##  Median :1.0000                  Median :0.000000     
##  Mean   :0.7377                  Mean   :0.002415     
##  3rd Qu.:1.0000                  3rd Qu.:0.000000     
##  Max.   :1.0000                  Max.   :1.000000     
##  kraj_pochodzeniaDominikana kraj_pochodzeniaEkwador kraj_pochodzeniaFilipiny
##  Min.   :0.0000             Min.   :0.0000000       Min.   :0.000000        
##  1st Qu.:0.0000             1st Qu.:0.0000000       1st Qu.:0.000000        
##  Median :0.0000             Median :0.0000000       Median :0.000000        
##  Mean   :0.0018             Mean   :0.0009472       Mean   :0.005257        
##  3rd Qu.:0.0000             3rd Qu.:0.0000000       3rd Qu.:0.000000        
##  Max.   :1.0000             Max.   :1.0000000       Max.   :1.000000        
##  kraj_pochodzeniaFrancja kraj_pochodzeniaGrecja kraj_pochodzeniaGwatemala
##  Min.   :0.0000000       Min.   :0.000000       Min.   :0.000000         
##  1st Qu.:0.0000000       1st Qu.:0.000000       1st Qu.:0.000000         
##  Median :0.0000000       Median :0.000000       Median :0.000000         
##  Mean   :0.0009472       Mean   :0.001042       Mean   :0.002273         
##  3rd Qu.:0.0000000       3rd Qu.:0.000000       3rd Qu.:0.000000         
##  Max.   :1.0000000       Max.   :1.000000       Max.   :1.000000         
##  kraj_pochodzeniaHaiti kraj_pochodzeniaHolandia kraj_pochodzeniaHonduras
##  Min.   :0.000000      Min.   :0.00000000       Min.   :0.0000000       
##  1st Qu.:0.000000      1st Qu.:0.00000000       1st Qu.:0.0000000       
##  Median :0.000000      Median :0.00000000       Median :0.0000000       
##  Mean   :0.001373      Mean   :0.00004736       Mean   :0.0003789       
##  3rd Qu.:0.000000      3rd Qu.:0.00000000       3rd Qu.:0.0000000       
##  Max.   :1.000000      Max.   :1.00000000       Max.   :1.0000000       
##  kraj_pochodzeniaHongkong kraj_pochodzeniaIndie kraj_pochodzeniaIran
##  Min.   :0.0000000        Min.   :0.000000      Min.   :0.000000    
##  1st Qu.:0.0000000        1st Qu.:0.000000      1st Qu.:0.000000    
##  Median :0.0000000        Median :0.000000      Median :0.000000    
##  Mean   :0.0006631        Mean   :0.003505      Mean   :0.001373    
##  3rd Qu.:0.0000000        3rd Qu.:0.000000      3rd Qu.:0.000000    
##  Max.   :1.0000000        Max.   :1.000000      Max.   :1.000000    
##  kraj_pochodzeniaIrlandia kraj_pochodzeniaJamajka kraj_pochodzeniaJaponia
##  Min.   :0.0000000        Min.   :0.000000        Min.   :0.000000       
##  1st Qu.:0.0000000        1st Qu.:0.000000        1st Qu.:0.000000       
##  Median :0.0000000        Median :0.000000        Median :0.000000       
##  Mean   :0.0009472        Mean   :0.002889        Mean   :0.002037       
##  3rd Qu.:0.0000000        3rd Qu.:0.000000        3rd Qu.:0.000000       
##  Max.   :1.0000000        Max.   :1.000000        Max.   :1.000000       
##  kraj_pochodzeniaJugosławia kraj_pochodzeniaKambodża kraj_pochodzeniaKanada
##  Min.   :0.000000           Min.   :0.000000         Min.   :0.00000       
##  1st Qu.:0.000000           1st Qu.:0.000000         1st Qu.:0.00000       
##  Median :0.000000           Median :0.000000         Median :0.00000       
##  Mean   :0.000521           Mean   :0.000521         Mean   :0.00341       
##  3rd Qu.:0.000000           3rd Qu.:0.000000         3rd Qu.:0.00000       
##  Max.   :1.000000           Max.   :1.000000         Max.   :1.00000       
##  kraj_pochodzeniaKolumbia kraj_pochodzeniaKorea_Południowa kraj_pochodzeniaKuba
##  Min.   :0.000000         Min.   :0.000000                 Min.   :0.000000    
##  1st Qu.:0.000000         1st Qu.:0.000000                 1st Qu.:0.000000    
##  Median :0.000000         Median :0.000000                 Median :0.000000    
##  Mean   :0.001942         Mean   :0.002605                 Mean   :0.003505    
##  3rd Qu.:0.000000         3rd Qu.:0.000000                 3rd Qu.:0.000000    
##  Max.   :1.000000         Max.   :1.000000                 Max.   :1.000000    
##  kraj_pochodzeniaLaos kraj_pochodzeniaMeksyk kraj_pochodzeniaNicaragua
##  Min.   :0.0000000    Min.   :0.00000        Min.   :0.000000         
##  1st Qu.:0.0000000    1st Qu.:0.00000        1st Qu.:0.000000         
##  Median :0.0000000    Median :0.00000        Median :0.000000         
##  Mean   :0.0004736    Mean   :0.02122        Mean   :0.001089         
##  3rd Qu.:0.0000000    3rd Qu.:0.00000        3rd Qu.:0.000000         
##  Max.   :1.0000000    Max.   :1.00000        Max.   :1.000000         
##  kraj_pochodzeniaNiemcy kraj_pochodzeniaPeru kraj_pochodzeniaPolska
##  Min.   :0.000000       Min.   :0.000000     Min.   :0.000000      
##  1st Qu.:0.000000       1st Qu.:0.000000     1st Qu.:0.000000      
##  Median :0.000000       Median :0.000000     Median :0.000000      
##  Mean   :0.004784       Mean   :0.001042     Mean   :0.001658      
##  3rd Qu.:0.000000       3rd Qu.:0.000000     3rd Qu.:0.000000      
##  Max.   :1.000000       Max.   :1.000000     Max.   :1.000000      
##  kraj_pochodzeniaPortoryko kraj_pochodzeniaPortugalia kraj_pochodzeniaSalwador
##  Min.   :0.000000          Min.   :0.000000           Min.   :0.000000        
##  1st Qu.:0.000000          1st Qu.:0.000000           1st Qu.:0.000000        
##  Median :0.000000          Median :0.000000           Median :0.000000        
##  Mean   :0.003505          Mean   :0.001184           Mean   :0.003363        
##  3rd Qu.:0.000000          3rd Qu.:0.000000           3rd Qu.:0.000000        
##  Max.   :1.000000          Max.   :1.000000           Max.   :1.000000        
##  kraj_pochodzeniaStany_Zjednoczone kraj_pochodzeniaSzkocja
##  Min.   :0.0000                    Min.   :0.0000000      
##  1st Qu.:1.0000                    1st Qu.:0.0000000      
##  Median :1.0000                    Median :0.0000000      
##  Mean   :0.9098                    Mean   :0.0004263      
##  3rd Qu.:1.0000                    3rd Qu.:0.0000000      
##  Max.   :1.0000                    Max.   :1.0000000      
##  kraj_pochodzeniaTajlandia kraj_pochodzeniaTajwan
##  Min.   :0.000000          Min.   :0.000000      
##  1st Qu.:0.000000          1st Qu.:0.000000      
##  Median :0.000000          Median :0.000000      
##  Mean   :0.000521          Mean   :0.001373      
##  3rd Qu.:0.000000          3rd Qu.:0.000000      
##  Max.   :1.000000          Max.   :1.000000      
##  kraj_pochodzeniaTerytoria_Zależne_USA kraj_pochodzeniaTrynidad_i_Tobago
##  Min.   :0.0000000                     Min.   :0.0000000                
##  1st Qu.:0.0000000                     1st Qu.:0.0000000                
##  Median :0.0000000                     Median :0.0000000                
##  Mean   :0.0004263                     Mean   :0.0004736                
##  3rd Qu.:0.0000000                     3rd Qu.:0.0000000                
##  Max.   :1.0000000                     Max.   :1.0000000                
##  kraj_pochodzeniaWęgry kraj_pochodzeniaWietnam kraj_pochodzeniaWłochy
##  Min.   :0.000000      Min.   :0.000000        Min.   :0.000000      
##  1st Qu.:0.000000      1st Qu.:0.000000        1st Qu.:0.000000      
##  Median :0.000000      Median :0.000000        Median :0.000000      
##  Mean   :0.000521      Mean   :0.001989        Mean   :0.002415      
##  3rd Qu.:0.000000      3rd Qu.:0.000000        3rd Qu.:0.000000      
##  Max.   :1.000000      Max.   :1.000000        Max.   :1.000000      
##  kapital_netto     czy_kapital_netto_dodatnikapital_netto_niedodatni
##  Min.   :0.00000   Min.   :0.0000                                   
##  1st Qu.:0.04174   1st Qu.:1.0000                                   
##  Median :0.04174   Median :1.0000                                   
##  Mean   :0.05112   Mean   :0.9156                                   
##  3rd Qu.:0.04174   3rd Qu.:1.0000                                   
##  Max.   :1.00000   Max.   :1.0000                                   
##  godziny_razy_edukacja wiek_sredniw_grupie_wieku_średniego
##  Min.   :0.0000        Min.   :0.000                      
##  1st Qu.:0.2180        1st Qu.:0.000                      
##  Median :0.2497        Median :1.000                      
##  Mean   :0.2620        Mean   :0.594                      
##  3rd Qu.:0.3257        3rd Qu.:1.000                      
##  Max.   :1.0000        Max.   :1.000                      
##  dlugie_godziny_pracytak poziom_edukacji_factorklasy_1_4
##  Min.   :0.0000          Min.   :0.000000               
##  1st Qu.:0.0000          1st Qu.:0.000000               
##  Median :0.0000          Median :0.000000               
##  Mean   :0.3054          Mean   :0.004973               
##  3rd Qu.:1.0000          3rd Qu.:0.000000               
##  Max.   :1.0000          Max.   :1.000000               
##  poziom_edukacji_factorklasy_5_6 poziom_edukacji_factorklasy_7_8
##  Min.   :0.000000                Min.   :0.00000                
##  1st Qu.:0.000000                1st Qu.:0.00000                
##  Median :0.000000                Median :0.00000                
##  Mean   :0.009804                Mean   :0.01894                
##  3rd Qu.:0.000000                3rd Qu.:0.00000                
##  Max.   :1.000000                Max.   :1.00000                
##  poziom_edukacji_factorklasa_9 poziom_edukacji_factorklasa_10
##  Min.   :0.00000               Min.   :0.00000               
##  1st Qu.:0.00000               1st Qu.:0.00000               
##  Median :0.00000               Median :0.00000               
##  Mean   :0.01501               Mean   :0.02633               
##  3rd Qu.:0.00000               3rd Qu.:0.00000               
##  Max.   :1.00000               Max.   :1.00000               
##  poziom_edukacji_factorklasa_11 poziom_edukacji_factorklasa_12
##  Min.   :0.00000                Min.   :0.00000               
##  1st Qu.:0.00000                1st Qu.:0.00000               
##  Median :0.00000                Median :0.00000               
##  Mean   :0.03429                Mean   :0.01302               
##  3rd Qu.:0.00000                3rd Qu.:0.00000               
##  Max.   :1.00000                Max.   :1.00000               
##  poziom_edukacji_factorszkoła_średnia poziom_edukacji_factorczęść_studiów
##  Min.   :0.0000                       Min.   :0.0000                     
##  1st Qu.:0.0000                       1st Qu.:0.0000                     
##  Median :0.0000                       Median :0.0000                     
##  Mean   :0.3283                       Mean   :0.2196                     
##  3rd Qu.:1.0000                       3rd Qu.:0.0000                     
##  Max.   :1.0000                       Max.   :1.0000                     
##  poziom_edukacji_factorassociate_zawodowe
##  Min.   :0.00000                         
##  1st Qu.:0.00000                         
##  Median :0.00000                         
##  Mean   :0.04424                         
##  3rd Qu.:0.00000                         
##  Max.   :1.00000                         
##  poziom_edukacji_factorassociate_akademickie poziom_edukacji_factorlicencjat
##  Min.   :0.00000                             Min.   :0.0000                 
##  1st Qu.:0.00000                             1st Qu.:0.0000                 
##  Median :0.00000                             Median :0.0000                 
##  Mean   :0.03396                             Mean   :0.1653                 
##  3rd Qu.:0.00000                             3rd Qu.:0.0000                 
##  Max.   :1.00000                             Max.   :1.0000                 
##  poziom_edukacji_factormagister poziom_edukacji_factorszkoła_profesjonalna
##  Min.   :0.00000                Min.   :0.00000                           
##  1st Qu.:0.00000                1st Qu.:0.00000                           
##  Median :0.00000                Median :0.00000                           
##  Mean   :0.05314                Mean   :0.01894                           
##  3rd Qu.:0.00000                3rd Qu.:0.00000                           
##  Max.   :1.00000                Max.   :1.00000                           
##  poziom_edukacji_factordoktorat
##  Min.   :0.00000               
##  1st Qu.:0.00000               
##  Median :0.00000               
##  Mean   :0.01265               
##  3rd Qu.:0.00000               
##  Max.   :1.00000

Przygotowanie zmiennej celu

Pakiet neuralnet wymaga, aby zmienna celu w klasyfikacji binarnej była zapisana numerycznie jako 0/1. W tym przypadku klasa tak, czyli osoby zarabiające powyżej 50 tys. USD, została oznaczona jako 1, a klasa nie jako 0.

salary.y.train.nn <- ifelse(salary.y.train == "tak", 1, 0)
salary.y.test.nn  <- ifelse(salary.y.test == "tak", 1, 0)

table(salary.y.train.nn)
## salary.y.train.nn
##     0     1 
## 15858  5256
table(salary.y.test.nn)
## salary.y.test.nn
##    0    1 
## 6796 2252

Następnie przygotowano końcowe zbiory danych dla sieci neuronowych. Nazwy kolumn zostały dodatkowo oczyszczone funkcją make.names(), ponieważ pakiet neuralnet może mieć problem z niektórymi znakami specjalnymi w nazwach zmiennych.

salary.train.nn <- as.data.frame(salary.x.train.scaled)
salary.train.nn$y <- salary.y.train.nn

salary.test.nn <- as.data.frame(salary.x.test.scaled)
salary.test.nn$y <- salary.y.test.nn

names(salary.train.nn) <- make.names(names(salary.train.nn), unique = TRUE)
names(salary.test.nn)  <- make.names(names(salary.test.nn), unique = TRUE)

glimpse(salary.train.nn)
## Rows: 21,114
## Columns: 100
## $ wiek                                                   <dbl> 0.30136986, 0.4…
## $ stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych         <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnymałżonek_nieobecny                         <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnynigdy_niezamężny_niezonaty                 <dbl> 1, 0, 0, 0, 0, …
## $ stan_cywilnyrozwiedziony_rozwiedziona                  <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnyw_separacji                                <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnywdowiec_wdowa                              <dbl> 0, 0, 0, 0, 0, …
## $ zawodinne_usługi                                       <dbl> 0, 0, 0, 0, 0, …
## $ zawodkadra_kierownicza                                 <dbl> 0, 1, 0, 0, 1, …
## $ zawodoperatorzy_maszyn_i_inspektorzy                   <dbl> 0, 0, 0, 0, 0, …
## $ zawodpracownicy_fizyczni_i_sprzątający                 <dbl> 0, 0, 1, 0, 0, …
## $ zawodprywatne_usługi_domowe                            <dbl> 0, 0, 0, 0, 0, …
## $ zawodrolnictwo_i_rybołówstwo                           <dbl> 0, 0, 0, 0, 0, …
## $ zawodrzemiosło_i_naprawy                               <dbl> 0, 0, 0, 0, 0, …
## $ zawodsiły_zbrojne                                      <dbl> 0, 0, 0, 0, 0, …
## $ zawodspecjaliści                                       <dbl> 0, 0, 0, 1, 0, …
## $ zawodsprzedaż                                          <dbl> 0, 0, 0, 0, 0, …
## $ zawodtransport_i_przemieszczanie                       <dbl> 0, 0, 0, 0, 0, …
## $ zawodusługi_ochrony                                    <dbl> 0, 0, 0, 0, 0, …
## $ zawodwsparcie_techniczne                               <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniemąż                                  <dbl> 0, 1, 1, 0, 0, …
## $ relacja_w_rodzinieniezamężny_niezonaty                 <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniepoza_rodziną                         <dbl> 1, 0, 0, 0, 0, …
## $ relacja_w_rodziniewłasne_dziecko                       <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodzinieżona                                 <dbl> 0, 0, 0, 1, 1, …
## $ rasabiała                                              <dbl> 1, 1, 0, 0, 1, …
## $ rasaczarna                                             <dbl> 0, 0, 1, 1, 0, …
## $ rasainna                                               <dbl> 0, 0, 0, 0, 0, …
## $ rasardzenni_amerykanie_lub_eskimosi                    <dbl> 0, 0, 0, 0, 0, …
## $ plecmężczyzna                                          <dbl> 1, 1, 1, 0, 0, …
## $ godziny_pracy_tygodniowo                               <dbl> 0.3979592, 0.12…
## $ zysk_kapitalowy                                        <dbl> 0.02174022, 0.0…
## $ strata_kapitalowa                                      <dbl> 0.0000000, 0.00…
## $ typ_zatrudnieniaadministracja_lokalna                  <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_stanowa                  <dbl> 1, 0, 0, 0, 0, …
## $ typ_zatrudnieniabez_wynagrodzenia                      <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej <dbl> 0, 1, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_z_osobowością_prawną   <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasektor_prywatny                        <dbl> 0, 0, 1, 1, 1, …
## $ kraj_pochodzeniaChiny                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaDominikana                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaEkwador                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFilipiny                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFrancja                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGrecja                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGwatemala                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHaiti                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHolandia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHonduras                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHongkong                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIndie                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIran                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIrlandia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJamajka                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJaponia                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJugosławia                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKambodża                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKanada                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKolumbia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKorea_Południowa                       <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKuba                                   <dbl> 0, 0, 0, 1, 0, …
## $ kraj_pochodzeniaLaos                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaMeksyk                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNicaragua                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNiemcy                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPeru                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPolska                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortoryko                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortugalia                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaSalwador                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaStany_Zjednoczone                      <dbl> 1, 1, 1, 0, 1, …
## $ kraj_pochodzeniaSzkocja                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajlandia                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajwan                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTerytoria_Zależne_USA                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTrynidad_i_Tobago                      <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWęgry                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWietnam                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWłochy                                 <dbl> 0, 0, 0, 0, 0, …
## $ kapital_netto                                          <dbl> 0.06257486, 0.0…
## $ czy_kapital_netto_dodatnikapital_netto_niedodatni      <dbl> 0, 1, 1, 1, 1, …
## $ godziny_razy_edukacja                                  <dbl> 0.32572877, 0.1…
## $ wiek_sredniw_grupie_wieku_średniego                    <dbl> 1, 1, 1, 0, 1, …
## $ dlugie_godziny_pracytak                                <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_1_4                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_5_6                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_7_8                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_9                          <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_10                         <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_11                         <dbl> 0, 0, 1, 0, 0, …
## $ poziom_edukacji_factorklasa_12                         <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorszkoła_średnia                   <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorczęść_studiów                    <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_zawodowe               <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_akademickie            <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorlicencjat                        <dbl> 1, 1, 0, 1, 0, …
## $ poziom_edukacji_factormagister                         <dbl> 0, 0, 0, 0, 1, …
## $ poziom_edukacji_factorszkoła_profesjonalna             <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factordoktorat                         <dbl> 0, 0, 0, 0, 0, …
## $ y                                                      <dbl> 0, 0, 0, 0, 0, …
glimpse(salary.test.nn)
## Rows: 9,048
## Columns: 100
## $ wiek                                                   <dbl> 0.28767123, 0.4…
## $ stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych         <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnymałżonek_nieobecny                         <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnynigdy_niezamężny_niezonaty                 <dbl> 0, 0, 1, 1, 0, …
## $ stan_cywilnyrozwiedziony_rozwiedziona                  <dbl> 1, 0, 0, 0, 1, …
## $ stan_cywilnyw_separacji                                <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnywdowiec_wdowa                              <dbl> 0, 0, 0, 0, 0, …
## $ zawodinne_usługi                                       <dbl> 0, 0, 0, 0, 0, …
## $ zawodkadra_kierownicza                                 <dbl> 0, 1, 0, 0, 1, …
## $ zawodoperatorzy_maszyn_i_inspektorzy                   <dbl> 0, 0, 0, 1, 0, …
## $ zawodpracownicy_fizyczni_i_sprzątający                 <dbl> 1, 0, 0, 0, 0, …
## $ zawodprywatne_usługi_domowe                            <dbl> 0, 0, 0, 0, 0, …
## $ zawodrolnictwo_i_rybołówstwo                           <dbl> 0, 0, 0, 0, 0, …
## $ zawodrzemiosło_i_naprawy                               <dbl> 0, 0, 0, 0, 0, …
## $ zawodsiły_zbrojne                                      <dbl> 0, 0, 0, 0, 0, …
## $ zawodspecjaliści                                       <dbl> 0, 0, 0, 0, 0, …
## $ zawodsprzedaż                                          <dbl> 0, 0, 0, 0, 0, …
## $ zawodtransport_i_przemieszczanie                       <dbl> 0, 0, 0, 0, 0, …
## $ zawodusługi_ochrony                                    <dbl> 0, 0, 0, 0, 0, …
## $ zawodwsparcie_techniczne                               <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniemąż                                  <dbl> 0, 1, 0, 0, 0, …
## $ relacja_w_rodzinieniezamężny_niezonaty                 <dbl> 0, 0, 0, 1, 1, …
## $ relacja_w_rodziniepoza_rodziną                         <dbl> 1, 0, 0, 0, 0, …
## $ relacja_w_rodziniewłasne_dziecko                       <dbl> 0, 0, 1, 0, 0, …
## $ relacja_w_rodzinieżona                                 <dbl> 0, 0, 0, 0, 0, …
## $ rasabiała                                              <dbl> 1, 1, 1, 1, 1, …
## $ rasaczarna                                             <dbl> 0, 0, 0, 0, 0, …
## $ rasainna                                               <dbl> 0, 0, 0, 0, 0, …
## $ rasardzenni_amerykanie_lub_eskimosi                    <dbl> 0, 0, 0, 0, 0, …
## $ plecmężczyzna                                          <dbl> 1, 1, 0, 1, 0, …
## $ godziny_pracy_tygodniowo                               <dbl> 0.3979592, 0.44…
## $ zysk_kapitalowy                                        <dbl> 0.0000000, 0.00…
## $ strata_kapitalowa                                      <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_lokalna                  <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_stanowa                  <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniabez_wynagrodzenia                      <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej <dbl> 0, 1, 0, 0, 1, …
## $ typ_zatrudnieniasamozatrudniony_z_osobowością_prawną   <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasektor_prywatny                        <dbl> 1, 0, 1, 1, 0, …
## $ kraj_pochodzeniaChiny                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaDominikana                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaEkwador                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFilipiny                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFrancja                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGrecja                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGwatemala                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHaiti                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHolandia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHonduras                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHongkong                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIndie                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIran                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIrlandia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJamajka                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJaponia                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJugosławia                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKambodża                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKanada                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKolumbia                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKorea_Południowa                       <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKuba                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaLaos                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaMeksyk                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNicaragua                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNiemcy                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPeru                                   <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPolska                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortoryko                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortugalia                             <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaSalwador                               <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaStany_Zjednoczone                      <dbl> 1, 1, 1, 1, 1, …
## $ kraj_pochodzeniaSzkocja                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajlandia                              <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajwan                                 <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTerytoria_Zależne_USA                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTrynidad_i_Tobago                      <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWęgry                                  <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWietnam                                <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWłochy                                 <dbl> 0, 0, 0, 0, 0, …
## $ kapital_netto                                          <dbl> 0.04174213, 0.0…
## $ czy_kapital_netto_dodatnikapital_netto_niedodatni      <dbl> 1, 1, 1, 1, 1, …
## $ godziny_razy_edukacja                                  <dbl> 0.2243346, 0.25…
## $ wiek_sredniw_grupie_wieku_średniego                    <dbl> 1, 1, 0, 1, 1, …
## $ dlugie_godziny_pracytak                                <dbl> 0, 1, 0, 0, 1, …
## $ poziom_edukacji_factorklasy_1_4                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_5_6                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_7_8                        <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_9                          <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_10                         <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_11                         <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_12                         <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorszkoła_średnia                   <dbl> 1, 1, 0, 1, 0, …
## $ poziom_edukacji_factorczęść_studiów                    <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_zawodowe               <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_akademickie            <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorlicencjat                        <dbl> 0, 0, 1, 0, 0, …
## $ poziom_edukacji_factormagister                         <dbl> 0, 0, 0, 0, 1, …
## $ poziom_edukacji_factorszkoła_profesjonalna             <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factordoktorat                         <dbl> 0, 0, 0, 0, 0, …
## $ y                                                      <dbl> 0, 1, 0, 0, 1, …
# Opcjonalny zapis przygotowanych danych
save(
  salary.train.nn,
  salary.test.nn,
  file = "02_dane_w_SML-20260311/salary_train_test.RData"
)

Zmniejszenie zbioru treningowego

Ze względu na czasochłonność uczenia sieci neuronowych w pakiecie neuralnet, modele zostały wytrenowane na losowym, stratyfikowanym podzbiorze zbioru treningowego liczącym maksymalnie 5000 obserwacji. Należy jednak podkreślić, że końcowa ocena modeli została przeprowadzona na pełnym zbiorze testowym, dzięki czemu możliwa jest ocena zdolności generalizacji na danych niewykorzystanych podczas uczenia.

set.seed(123)

idx_salary_nn <- createDataPartition(
  salary.train.nn$y,
  p = min(5000 / nrow(salary.train.nn), 1),
  list = FALSE
)

salary.train.small.nn <- salary.train.nn[idx_salary_nn, ]

dim(salary.train.small.nn)
## [1] 5000  100
table(salary.train.small.nn$y)
## 
##    0    1 
## 3749 1251
prop.table(table(salary.train.small.nn$y))
## 
##      0      1 
## 0.7498 0.2502

Dobór architektury sieci

Architektura sieci neuronowej, czyli liczba warstw ukrytych i liczba neuronów w każdej z nich, jest hiperparametrem. Większa liczba neuronów zwiększa elastyczność modelu, ale jednocześnie zwiększa liczbę wag do oszacowania, czas uczenia oraz ryzyko przeuczenia.

W celu porównania zastosowano trzy proste architektury:

  • NN_3 — jedna warstwa ukryta z 3 neuronami,
  • NN_5 — jedna warstwa ukryta z 5 neuronami,
  • NN_3_2 — dwie warstwy ukryte: pierwsza z 3 neuronami, druga z 2 neuronami.

Ponadto co ważne, w zastosowano logistyczną funkcję aktywacji, czyli funkcję sigmoidalną. Jest ona odpowiednia dla problemu klasyfikacji binarnej, ponieważ przekształca wynik modelu do przedziału od 0 do 1. Dzięki temu predykcję sieci można interpretować jako prawdopodobieństwo przynależności obserwacji do klasy pozytywnej, czyli osoby zarabiającej powyżej 50 tys. USD.

count_weights <- function(n_inputs, hidden, n_outputs = 1) {
  layers <- c(n_inputs, hidden, n_outputs)
  total <- 0
  
  for (i in 1:(length(layers) - 1)) {
    total <- total + (layers[i] + 1) * layers[i + 1]
  }
  
  return(total)
}

salary.predictors <- setdiff(names(salary.train.small.nn), "y")

n_inputs_salary <- length(salary.predictors)
n_inputs_salary
## [1] 99
salary_architectures <- list(
  NN_3 = c(3),
  NN_5 = c(5),
  NN_3_2 = c(3, 2)
)

for (name in names(salary_architectures)) {
  arch <- salary_architectures[[name]]
  cat(
    "Architektura:", name,
    "| hidden =", paste(arch, collapse = "-"),
    "| liczba wag:", count_weights(n_inputs_salary, arch, n_outputs = 1),
    "\n"
  )
}
## Architektura: NN_3 | hidden = 3 | liczba wag: 304 
## Architektura: NN_5 | hidden = 5 | liczba wag: 506 
## Architektura: NN_3_2 | hidden = 3-2 | liczba wag: 311

Na potrzeby funkcji neuralnet() zbudowano formułę modelu automatycznie na podstawie nazw predyktorów.

salary.formula <- as.formula(
  paste("y ~", paste(salary.predictors, collapse = " + "))
)

salary.formula
## y ~ wiek + stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych + 
##     stan_cywilnymałżonek_nieobecny + stan_cywilnynigdy_niezamężny_niezonaty + 
##     stan_cywilnyrozwiedziony_rozwiedziona + stan_cywilnyw_separacji + 
##     stan_cywilnywdowiec_wdowa + zawodinne_usługi + zawodkadra_kierownicza + 
##     zawodoperatorzy_maszyn_i_inspektorzy + zawodpracownicy_fizyczni_i_sprzątający + 
##     zawodprywatne_usługi_domowe + zawodrolnictwo_i_rybołówstwo + 
##     zawodrzemiosło_i_naprawy + zawodsiły_zbrojne + zawodspecjaliści + 
##     zawodsprzedaż + zawodtransport_i_przemieszczanie + zawodusługi_ochrony + 
##     zawodwsparcie_techniczne + relacja_w_rodziniemąż + relacja_w_rodzinieniezamężny_niezonaty + 
##     relacja_w_rodziniepoza_rodziną + relacja_w_rodziniewłasne_dziecko + 
##     relacja_w_rodzinieżona + rasabiała + rasaczarna + rasainna + 
##     rasardzenni_amerykanie_lub_eskimosi + plecmężczyzna + godziny_pracy_tygodniowo + 
##     zysk_kapitalowy + strata_kapitalowa + typ_zatrudnieniaadministracja_lokalna + 
##     typ_zatrudnieniaadministracja_stanowa + typ_zatrudnieniabez_wynagrodzenia + 
##     typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej + 
##     typ_zatrudnieniasamozatrudniony_z_osobowością_prawną + 
##     typ_zatrudnieniasektor_prywatny + kraj_pochodzeniaChiny + 
##     kraj_pochodzeniaDominikana + kraj_pochodzeniaEkwador + kraj_pochodzeniaFilipiny + 
##     kraj_pochodzeniaFrancja + kraj_pochodzeniaGrecja + kraj_pochodzeniaGwatemala + 
##     kraj_pochodzeniaHaiti + kraj_pochodzeniaHolandia + kraj_pochodzeniaHonduras + 
##     kraj_pochodzeniaHongkong + kraj_pochodzeniaIndie + kraj_pochodzeniaIran + 
##     kraj_pochodzeniaIrlandia + kraj_pochodzeniaJamajka + kraj_pochodzeniaJaponia + 
##     kraj_pochodzeniaJugosławia + kraj_pochodzeniaKambodża + 
##     kraj_pochodzeniaKanada + kraj_pochodzeniaKolumbia + kraj_pochodzeniaKorea_Południowa + 
##     kraj_pochodzeniaKuba + kraj_pochodzeniaLaos + kraj_pochodzeniaMeksyk + 
##     kraj_pochodzeniaNicaragua + kraj_pochodzeniaNiemcy + kraj_pochodzeniaPeru + 
##     kraj_pochodzeniaPolska + kraj_pochodzeniaPortoryko + kraj_pochodzeniaPortugalia + 
##     kraj_pochodzeniaSalwador + kraj_pochodzeniaStany_Zjednoczone + 
##     kraj_pochodzeniaSzkocja + kraj_pochodzeniaTajlandia + kraj_pochodzeniaTajwan + 
##     kraj_pochodzeniaTerytoria_Zależne_USA + kraj_pochodzeniaTrynidad_i_Tobago + 
##     kraj_pochodzeniaWęgry + kraj_pochodzeniaWietnam + kraj_pochodzeniaWłochy + 
##     kapital_netto + czy_kapital_netto_dodatnikapital_netto_niedodatni + 
##     godziny_razy_edukacja + wiek_sredniw_grupie_wieku_średniego + 
##     dlugie_godziny_pracytak + poziom_edukacji_factorklasy_1_4 + 
##     poziom_edukacji_factorklasy_5_6 + poziom_edukacji_factorklasy_7_8 + 
##     poziom_edukacji_factorklasa_9 + poziom_edukacji_factorklasa_10 + 
##     poziom_edukacji_factorklasa_11 + poziom_edukacji_factorklasa_12 + 
##     poziom_edukacji_factorszkoła_średnia + poziom_edukacji_factorczęść_studiów + 
##     poziom_edukacji_factorassociate_zawodowe + poziom_edukacji_factorassociate_akademickie + 
##     poziom_edukacji_factorlicencjat + poziom_edukacji_factormagister + 
##     poziom_edukacji_factorszkoła_profesjonalna + poziom_edukacji_factordoktorat

Model NN1: jedna warstwa ukryta z 3 neuronami

Pierwszy model ma jedną warstwę ukrytą z 3 neuronami. Dla klasyfikacji binarnej ustawiono linear.output = FALSE, dzięki czemu wynik sieci można interpretować jako wartość zbliżoną do prawdopodobieństwa przynależności do klasy pozytywnej.

# set.seed(123)
# 
# nn_salary_1 <- neuralnet(
#   formula = salary.formula,
#   data = salary.train.small.nn,
#   hidden = c(3),
#   linear.output = FALSE,
#   act.fct = "logistic",
#   threshold = 0.05,
#   stepmax = 1e6,
#   rep = 1,
#   lifesign = "minimal"
# )
# 
# saveRDS(nn_salary_1, "trenowane_modele/salary.nn1.rds")

# Wczytanie wcześniej zapisanego modelu
nn_salary_1 <- readRDS("trenowane_modele/salary.nn1.rds")

Dla pierwszego modelu przygotowano również wizualizację sieci. Przy dużej liczbie zmiennych wejściowych wykres może być mniej czytelny, ale pozwala zobaczyć ogólną strukturę połączeń między warstwami.

#Zablokowano ze względu na bardzo długi okres kompilacji

#plotnet(
#  nn_salary_1,
#  circle_col = list("lightyellow", "lightblue"),
#  bord_col = "grey40",
#  pos_col = "darkgreen",
#  neg_col = "red3",
#  circle_cex = 2,
#  cex_val = 0.5,
#  alpha_val = 0.8,
#  max_sp = TRUE,
#  pad_x = 0.9
#)

Predykcje i ocena modelu NN1

Predykcje zostały wyznaczone zarówno dla zmniejszonego zbioru treningowego, na którym model był uczony, jak i dla pełnego zbioru testowego.

pred_salary_nn1_train <- compute(
  nn_salary_1,
  salary.train.small.nn[, salary.predictors]
)$net.result

pred_salary_nn1_test <- compute(
  nn_salary_1,
  salary.test.nn[, salary.predictors]
)$net.result

pred_salary_nn1_train <- as.numeric(pred_salary_nn1_train)
pred_salary_nn1_test  <- as.numeric(pred_salary_nn1_test)

summary(pred_salary_nn1_train)
##       Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
## 0.00004445 0.00214147 0.04628933 0.24439997 0.42383922 0.96674992
summary(pred_salary_nn1_test)
##       Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
## 0.00004445 0.00209347 0.04403567 0.24352639 0.42018547 0.96674992

Ze względu na niezbalansowanie klas nie zastosowano domyślnego progu 0.5. Jako prosty próg odcięcia przyjęto udział klasy pozytywnej w zmniejszonym zbiorze treningowym.

cutoff.nn <- mean(salary.train.small.nn$y)
cutoff.nn
## [1] 0.2502
metryki.salary.nn1.train <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn1_train,
  real = as.character(salary.train.small.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)

metryki.salary.nn1.test <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn1_test,
  real = as.character(salary.test.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)

rbind(
  nn1.train = metryki.salary.nn1.train,
  nn1.test = metryki.salary.nn1.test
)
##           Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV
## nn1.train   0.8490      0.8753      0.8402           0.8578 0.6464 0.9528
## nn1.test    0.8031      0.7909      0.8071           0.7990 0.5760 0.9209
##               F1  Kappa
## nn1.train 0.7436 0.6400
## nn1.test  0.6665 0.5316

Model NN2: jedna warstwa ukryta z 5 neuronami

Drugi model również posiada jedną warstwę ukrytą, ale z większą liczbą neuronów. Celem jest sprawdzenie, czy nieco bardziej elastyczna sieć poprawi jakość klasyfikacji.

# Model NN2: jedna warstwa ukryta z 5 neuronami
#
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował sieci od nowa.
# Model został wcześniej zapisany do pliku RDS.

# set.seed(123)
# 
# nn_salary_2 <- neuralnet(
#   formula = salary.formula,
#   data = salary.train.small.nn,
#   hidden = c(5),
#   linear.output = FALSE,
#   act.fct = "logistic",
#   threshold = 0.05,
#   stepmax = 1e6,
#   rep = 1,
#   lifesign = "minimal"
# )
# 
# saveRDS(nn_salary_2, "trenowane_modele/salary.nn2.rds")

# Wczytanie wcześniej zapisanego modelu
nn_salary_2 <- readRDS("trenowane_modele/salary.nn2.rds")

#nn_salary_2

Model NN3: dwie warstwy ukryte

Trzeci model ma dwie warstwy ukryte. Pierwsza warstwa zawiera 3 neurony, a druga 2 neurony. Taka architektura pozwala sprawdzić, czy dodanie drugiej warstwy poprawi zdolność modelu do uchwycenia złożonych zależności.

# Model NN3: dwie warstwy ukryte, 3 i 2 neurony
#
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował sieci od nowa.
# Model został wcześniej zapisany do pliku RDS.

# set.seed(123)
# 
# nn_salary_3 <- neuralnet(
#   formula = salary.formula,
#   data = salary.train.small.nn,
#   hidden = c(3, 2),
#   linear.output = FALSE,
#   act.fct = "logistic",
#   threshold = 0.05,
#   stepmax = 1e6,
#   rep = 1,
#   lifesign = "minimal"
# )
# 
# saveRDS(nn_salary_3, "trenowane_modele/salary.nn3.rds")

# Wczytanie wcześniej zapisanego modelu
nn_salary_3 <- readRDS("trenowane_modele/salary.nn3.rds")

#nn_salary_3

Predykcje dla modeli NN2 i NN3

pred_salary_nn2_train <- compute(
  nn_salary_2,
  salary.train.small.nn[, salary.predictors]
)$net.result

pred_salary_nn2_test <- compute(
  nn_salary_2,
  salary.test.nn[, salary.predictors]
)$net.result

pred_salary_nn3_train <- compute(
  nn_salary_3,
  salary.train.small.nn[, salary.predictors]
)$net.result

pred_salary_nn3_test <- compute(
  nn_salary_3,
  salary.test.nn[, salary.predictors]
)$net.result

pred_salary_nn2_train <- as.numeric(pred_salary_nn2_train)
pred_salary_nn2_test  <- as.numeric(pred_salary_nn2_test)

pred_salary_nn3_train <- as.numeric(pred_salary_nn3_train)
pred_salary_nn3_test  <- as.numeric(pred_salary_nn3_test)

Porównanie modeli sieci neuronowych

Poniżej zestawiono wyniki trzech modeli sieci neuronowych. Porównanie obejmuje osobno wyniki na zbiorze treningowym i testowym.

metryki.salary.nn2.train <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn2_train,
  real = as.character(salary.train.small.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)

metryki.salary.nn2.test <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn2_test,
  real = as.character(salary.test.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)
metryki.salary.nn3.train <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn3_train,
  real = as.character(salary.train.small.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)

metryki.salary.nn3.test <- klasyfikacja.metryki(
  predicted_probabilities = pred_salary_nn3_test,
  real = as.character(salary.test.nn$y),
  cutoff = cutoff.nn,
  level_positive = "1",
  level_negative = "0"
)

rbind(
  nn.3 = metryki.salary.nn1.train,
  nn.5 = metryki.salary.nn2.train,
  nn.3.2 = metryki.salary.nn3.train
)
##        Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## nn.3     0.8490      0.8753      0.8402           0.8578 0.6464 0.9528 0.7436
## nn.5     0.8892      0.8729      0.8946           0.8838 0.7344 0.9547 0.7977
## nn.3.2   0.8498      0.8649      0.8448           0.8548 0.6502 0.9493 0.7424
##         Kappa
## nn.3   0.6400
## nn.5   0.7222
## nn.3.2 0.6393
rbind(
  nn.3 = metryki.salary.nn1.test,
  nn.5 = metryki.salary.nn2.test,
  nn.3.2 = metryki.salary.nn3.test
)
##        Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV     F1
## nn.3     0.8031      0.7909      0.8071           0.7990 0.5760 0.9209 0.6665
## nn.5     0.7993      0.6932      0.8345           0.7638 0.5812 0.8914 0.6322
## nn.3.2   0.8031      0.7815      0.8102           0.7959 0.5770 0.9180 0.6639
##         Kappa
## nn.3   0.5316
## nn.5   0.4957
## nn.3.2 0.5290

Na podstawie uzyskanych wyników można porównać wpływ architektury sieci na jakość klasyfikacji. Warto zwrócić uwagę nie tylko na Accuracy, ale również na Balanced Accuracy, Sensitivity, Specificity oraz F1.

W przypadku danych niezbalansowanych zwykła trafność klasyfikacji może być myląca. Model może osiągać wysoką wartość Accuracy, ponieważ dobrze rozpoznaje klasę większościową, ale jednocześnie słabo wykrywać klasę pozytywną. Dlatego większe znaczenie ma Balanced Accuracy, która uwzględnia zarówno czułość, jak i specyficzność. Jednak widać na modelach zjawisko przeuczenia dla modelu np 1 - 87% do *80** jeśli chodzi o Sensivity, jednak ze względu na i tak już długi czas kompilowania dla 5000 danych zdecydowano się na pozostawienie takiej ilości.

Krzywe ROC dla sieci neuronowych

Dodatkowo dla każdej sieci neuronowej wyznaczono krzywą ROC oraz indeks Giniego. Krzywa ROC pozwala ocenić zdolność modelu do rozróżniania klas niezależnie od konkretnego progu odcięcia.

ROC.salary.nn1.train <- pROC::roc(
  as.numeric(salary.train.small.nn$y == 1),
  pred_salary_nn1_train
)

ROC.salary.nn1.test <- pROC::roc(
  as.numeric(salary.test.nn$y == 1),
  pred_salary_nn1_test
)

ROC.salary.nn2.train <- pROC::roc(
  as.numeric(salary.train.small.nn$y == 1),
  pred_salary_nn2_train
)

ROC.salary.nn2.test <- pROC::roc(
  as.numeric(salary.test.nn$y == 1),
  pred_salary_nn2_test
)

ROC.salary.nn3.train <- pROC::roc(
  as.numeric(salary.train.small.nn$y == 1),
  pred_salary_nn3_train
)

ROC.salary.nn3.test <- pROC::roc(
  as.numeric(salary.test.nn$y == 1),
  pred_salary_nn3_test
)

ROC.salary.nn1.test$auc
## Area under the curve: 0.8799
ROC.salary.nn2.test$auc
## Area under the curve: 0.8497
ROC.salary.nn3.test$auc
## Area under the curve: 0.8647
gini <- function(roc) {
  paste0(round(100 * (2 * as.numeric(pROC::auc(roc)) - 1), 1), "%")
}

list(
  nn1.train = ROC.salary.nn1.train,
  nn1.test = ROC.salary.nn1.test,
  nn2.train = ROC.salary.nn2.train,
  nn2.test = ROC.salary.nn2.test,
  nn3.train = ROC.salary.nn3.train,
  nn3.test = ROC.salary.nn3.test
) %>%
  ggroc(alpha = 0.7, linewidth = 1.1) +
  geom_segment(
    aes(x = 1, xend = 0, y = 0, yend = 1),
    color = "grey60",
    linetype = "dashed"
  ) +
  labs(
    title = paste0(
      "Gini TEST: ",
      "NN1 (3) = ", gini(ROC.salary.nn1.test), ", ",
      "NN2 (5) = ", gini(ROC.salary.nn2.test), ", ",
      "NN3 (3,2) = ", gini(ROC.salary.nn3.test)
    ),
    subtitle = paste0(
      "Gini TRAIN: ",
      "NN1 (3) = ", gini(ROC.salary.nn1.train), ", ",
      "NN2 (5) = ", gini(ROC.salary.nn2.train), ", ",
      "NN3 (3,2) = ", gini(ROC.salary.nn3.train)
    ),
    color = "Model",
    x = "Specyficzność",
    y = "Wrażliwość"
  ) +
  theme_bw() +
  coord_fixed() +
  scale_color_brewer(palette = "Paired")

Wnioski z modeli sieci neuronowych

Sieci neuronowe zostały wytrenowane na zmniejszonym zbiorze treningowym, dlatego ich wyniki należy interpretować ostrożnie. Z jednej strony ograniczenie liczby obserwacji znacząco skraca czas uczenia, ale z drugiej strony może pogarszać zdolność generalizacji modelu.

W porównaniu trzech architektur można zauważyć, że większa liczba neuronów lub dodatkowa warstwa ukryta nie musi automatycznie prowadzić do lepszych wyników. Bardziej złożony model ma więcej wag do oszacowania, co może zwiększać ryzyko przeuczenia.

Jeżeli wyniki na zbiorze treningowym są wyraźnie lepsze niż na zbiorze testowym, może to oznaczać overfitting. Jeżeli natomiast wyniki treningowe i testowe są podobne, ale relatywnie słabe, model może być zbyt prosty albo trenowany na zbyt małej próbie.

W tej analizie sieci neuronowe pełnią rolę dodatkowego punktu odniesienia względem modeli takich jak regresja logistyczna, drzewa decyzyjne, bagging, Random Forest oraz XGBoost. Ich zaletą jest możliwość modelowania nieliniowych zależności, jednak przy użyciu pakietu neuralnet oraz dużej liczbie zmiennych po one-hot encodingu trening może być czasochłonny.

Porównanie wszystkich modeli

W ostatnim etapie analizy porównano wszystkie zbudowane modele klasyfikacyjne. Zestawienie obejmuje modele bazowe, drzewa decyzyjne, modele zespołowe, boosting oraz sieci neuronowe.

Porównanie wykonano osobno dla zbioru treningowego oraz testowego. Wyniki na zbiorze treningowym pokazują, jak dobrze model dopasował się do danych uczących, natomiast wyniki na zbiorze testowym są ważniejsze z punktu widzenia oceny jakości predykcyjnej, ponieważ pokazują skuteczność modelu na nowych obserwacjach.

W analizie wykorzystano kilka miar jakości klasyfikacji. Accuracy oznacza ogólną trafność klasyfikacji. Sensitivity pokazuje, jak dobrze model wykrywa klasę pozytywną, czyli osoby zarabiające powyżej 50 tys. USD. Specificity informuje, jak dobrze model rozpoznaje klasę negatywną. BalancedAccuracy jest średnią z czułości i specyficzności, dlatego jest szczególnie ważna przy niezbalansowanych danych. F1 łączy precyzję oraz czułość i jest użyteczna przy ocenie jakości klasyfikacji klasy pozytywnej.

Ponieważ w analizowanym zbiorze klasy nie są równoliczne, sama wartość Accuracy może być myląca. Model może osiągać wysoką trafność głównie dlatego, że dobrze rozpoznaje klasę większościową. Z tego powodu w końcowej ocenie szczególną uwagę zwrócono na BalancedAccuracy, Sensitivity oraz F1.

################################################################################
# 13. Porównanie wszystkich modeli ----
################################################################################

# W jednej tabeli zbieramy metryki dla wszystkich modeli.
# Tworzymy osobne zestawienie dla zbioru treningowego i testowego.
# Dzięki temu możemy sprawdzić, który model najlepiej działa na danych uczących,
# a który najlepiej generalizuje na dane testowe.

porownanie.wszystkie.train <- rbind(
  logit = metryki.glm.train,
  decision.tree.default = metryki.tree.default.train,
  decision.tree.tuned = metryki.tree.tuned.train,
  decision.tree.downsample = metryki.tree.down.train,
  bagging = metryki.salary.bag.train,
  random.forest = metryki.salary.rf.train,
  random.forest.tuned = metryki.salary.rf.tuned.train,
  xgb.start = metryki.salary.xgb.start.train,
  xgb.tuned = metryki.salary.xgb.tuned.train,
  neural.net.3 = metryki.salary.nn1.train,
  neural.net.5 = metryki.salary.nn2.train,
  neural.net.3.2 = metryki.salary.nn3.train
)

porownanie.wszystkie.test <- rbind(
  logit = metryki.glm.test,
  decision.tree.default = metryki.tree.default.test,
  decision.tree.tuned = metryki.tree.tuned.test,
  decision.tree.downsample = metryki.tree.down.test,
  bagging = metryki.salary.bag.test,
  random.forest = metryki.salary.rf.test,
  random.forest.tuned = metryki.salary.rf.tuned.test,
  xgb.start = metryki.salary.xgb.start.test,
  xgb.tuned = metryki.salary.xgb.tuned.test,
  neural.net.3 = metryki.salary.nn1.test,
  neural.net.5 = metryki.salary.nn2.test,
  neural.net.3.2 = metryki.salary.nn3.test
)

porownanie.wszystkie.train
##                          Accuracy Sensitivity Specificity BalancedAccuracy
## logit                      0.7936      0.8615      0.7711           0.8163
## decision.tree.default      0.8416      0.5099      0.9516           0.7307
## decision.tree.tuned        0.8616      0.6056      0.9465           0.7760
## decision.tree.downsample   0.8394      0.8824      0.7964           0.8394
## bagging                    0.9517      0.9812      0.9420           0.9616
## random.forest              0.9299      0.9553      0.9215           0.9384
## random.forest.tuned        0.8500      0.9488      0.8173           0.8830
## xgb.start                  0.7986      0.8727      0.7740           0.8234
## xgb.tuned                  0.8286      0.8786      0.8120           0.8453
## neural.net.3               0.8490      0.8753      0.8402           0.8578
## neural.net.5               0.8892      0.8729      0.8946           0.8838
## neural.net.3.2             0.8498      0.8649      0.8448           0.8548
##                             PPV    NPV     F1  Kappa
## logit                    0.5550 0.9438 0.6751 0.5340
## decision.tree.default    0.7773 0.8542 0.6158 0.5214
## decision.tree.tuned      0.7894 0.8786 0.6854 0.5987
## decision.tree.downsample 0.8125 0.8714 0.8460 0.6788
## bagging                  0.8486 0.9934 0.9101 0.8773
## random.forest            0.8013 0.9842 0.8716 0.8239
## random.forest.tuned      0.6325 0.9797 0.7590 0.6563
## xgb.start                0.5614 0.9483 0.6833 0.5456
## xgb.tuned                0.6077 0.9528 0.7185 0.6011
## neural.net.3             0.6464 0.9528 0.7436 0.6400
## neural.net.5             0.7344 0.9547 0.7977 0.7222
## neural.net.3.2           0.6502 0.9493 0.7424 0.6393
porownanie.wszystkie.test
##                          Accuracy Sensitivity Specificity BalancedAccuracy
## logit                      0.7940      0.8686      0.7692           0.8189
## decision.tree.default      0.8397      0.5115      0.9485           0.7300
## decision.tree.tuned        0.8457      0.5799      0.9338           0.7569
## decision.tree.downsample   0.7928      0.8535      0.7727           0.8131
## bagging                    0.8015      0.8215      0.7949           0.8082
## random.forest              0.8196      0.8139      0.8215           0.8177
## random.forest.tuned        0.8112      0.8752      0.7900           0.8326
## xgb.start                  0.7982      0.8601      0.7777           0.8189
## xgb.tuned                  0.8239      0.8583      0.8125           0.8354
## neural.net.3               0.8031      0.7909      0.8071           0.7990
## neural.net.5               0.7993      0.6932      0.8345           0.7638
## neural.net.3.2             0.8031      0.7815      0.8102           0.7959
##                             PPV    NPV     F1  Kappa
## logit                    0.5552 0.9464 0.6774 0.5366
## decision.tree.default    0.7670 0.8542 0.6137 0.5177
## decision.tree.tuned      0.7437 0.8703 0.6517 0.5545
## decision.tree.downsample 0.5544 0.9409 0.6721 0.5304
## bagging                  0.5703 0.9307 0.6732 0.5373
## random.forest            0.6018 0.9302 0.6920 0.5685
## random.forest.tuned      0.5800 0.9503 0.6977 0.5685
## xgb.start                0.5618 0.9438 0.6796 0.5416
## xgb.tuned                0.6027 0.9454 0.7082 0.5876
## neural.net.3             0.5760 0.9209 0.6665 0.5316
## neural.net.5             0.5812 0.8914 0.6322 0.4957
## neural.net.3.2           0.5770 0.9180 0.6639 0.5290

Wyniki modeli na zbiorze treningowym

Poniższa tabela przedstawia wyniki wszystkich modeli na zbiorze treningowym.

Na zbiorze treningowym najwyższe wartości większości metryk osiągnął model bagging. Uzyskał on bardzo wysoką wartość Accuracy = 0.9517, BalancedAccuracy = 0.9616, F1 = 0.9101 oraz Kappa = 0.8773. Oznacza to, że model bardzo dobrze dopasował się do danych uczących.

Jeżeli natomiast skupiamy się na predykcji klasy negatywnej, czyli osób niezarabiających powyżej 50 tys. USD, ważniejsza jest Specificity. Najwyższą specyficzność na zbiorze treningowym osiągnął model decision tree default (Specificity = 0.9516), bardzo zbliżony wynik uzyskał również decision tree tuned (Specificity = 0.9465) oraz bagging (Specificity = 0.9420). Oznacza to, że modele drzewiaste bardzo dobrze rozpoznają klasę większościową, czyli osoby z dochodem nieprzekraczającym 50 tys. USD.

Wysokie wartości metryk na tym zbiorze oznaczają dobre dopasowanie modelu do danych uczących. Należy jednak pamiętać, że bardzo dobre wyniki na treningu nie zawsze oznaczają dobrą jakość modelu, ponieważ model może być przeuczony.

knitr::kable(
  porownanie.wszystkie.train,
  digits = 4,
  caption = "Porównanie wszystkich modeli na zbiorze treningowym"
)
Porównanie wszystkich modeli na zbiorze treningowym
Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1 Kappa
logit 0.7936 0.8615 0.7711 0.8163 0.5550 0.9438 0.6751 0.5340
decision.tree.default 0.8416 0.5099 0.9516 0.7307 0.7773 0.8542 0.6158 0.5214
decision.tree.tuned 0.8616 0.6056 0.9465 0.7760 0.7894 0.8786 0.6854 0.5987
decision.tree.downsample 0.8394 0.8824 0.7964 0.8394 0.8125 0.8714 0.8460 0.6788
bagging 0.9517 0.9812 0.9420 0.9616 0.8486 0.9934 0.9101 0.8773
random.forest 0.9299 0.9553 0.9215 0.9384 0.8013 0.9842 0.8716 0.8239
random.forest.tuned 0.8500 0.9488 0.8173 0.8830 0.6325 0.9797 0.7590 0.6563
xgb.start 0.7986 0.8727 0.7740 0.8234 0.5614 0.9483 0.6833 0.5456
xgb.tuned 0.8286 0.8786 0.8120 0.8453 0.6077 0.9528 0.7185 0.6011
neural.net.3 0.8490 0.8753 0.8402 0.8578 0.6464 0.9528 0.7436 0.6400
neural.net.5 0.8892 0.8729 0.8946 0.8838 0.7344 0.9547 0.7977 0.7222
neural.net.3.2 0.8498 0.8649 0.8448 0.8548 0.6502 0.9493 0.7424 0.6393

Wyniki modeli na zbiorze testowym

Na zbiorze testowym wyniki są bardziej wyrównane. Najwyższą wartość BalancedAccuracy osiągnął model xgb.tuned (BalancedAccuracy = 0.8354). Bardzo dobry wynik uzyskał również random.forest.tuned (BalancedAccuracy = 0.8326). Oznacza to, że po dostrojeniu modele bardziej zaawansowane, szczególnie XGBoost i Random Forest, lepiej generalizują na nowe dane niż modele bardzo mocno dopasowane do zbioru treningowego.

Warto zauważyć, że bagging, który był zdecydowanie najlepszy na zbiorze treningowym, na zbiorze testowym osiągnął BalancedAccuracy = 0.8082. Jest to wynik dobry, ale wyraźnie niższy niż na treningu. Różnica między wynikami treningowymi i testowymi sugeruje, że bagging mógł częściowo przeuczyć się na danych uczących.

Jeżeli koncentrujemy się na wykrywaniu klasy pozytywnej, czyli osób zarabiających powyżej 50 tys. USD, najważniejsza jest Sensitivity. Na zbiorze testowym najwyższą czułość osiągnął random.forest.tuned (Sensitivity = 0.8752). Bardzo blisko są również logit (Sensitivity = 0.8686), xgb.start (Sensitivity = 0.8601) oraz xgb.tuned (Sensitivity = 0.8583). Oznacza to, że jeśli celem analizy byłoby wykrycie jak największej liczby osób z wysokim dochodem, najlepszym wyborem byłby random.forest.tuned.

Jeżeli skupiamy się na predykcji klasy negatywnej, ważniejsza jest Specificity. Na zbiorze testowym najwyższą specyficzność osiągnął decision tree default (Specificity = 0.9485), a następnie decision tree tuned (Specificity = 0.9338). Oznacza to, że drzewa decyzyjne najlepiej rozpoznają osoby, które nie zarabiają powyżej 50 tys. USD. Trzeba jednak zauważyć, że robią to kosztem dużo niższej czułości, czyli słabiej wykrywają klasę pozytywną.

Jeżeli patrzymy na F1, czyli miarę łączącą precyzję i czułość dla klasy pozytywnej, najlepszy wynik na zbiorze testowym osiągnął xgb.tuned (F1 = 0.7082). Jest to ważna informacja, ponieważ oznacza, że model XGBoost po strojeniu dobrze równoważy wykrywanie klasy pozytywnej oraz jakość tych wskazań.

knitr::kable(
  porownanie.wszystkie.test,
  digits = 4,
  caption = "Porównanie wszystkich modeli na zbiorze testowym"
)
Porównanie wszystkich modeli na zbiorze testowym
Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1 Kappa
logit 0.7940 0.8686 0.7692 0.8189 0.5552 0.9464 0.6774 0.5366
decision.tree.default 0.8397 0.5115 0.9485 0.7300 0.7670 0.8542 0.6137 0.5177
decision.tree.tuned 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703 0.6517 0.5545
decision.tree.downsample 0.7928 0.8535 0.7727 0.8131 0.5544 0.9409 0.6721 0.5304
bagging 0.8015 0.8215 0.7949 0.8082 0.5703 0.9307 0.6732 0.5373
random.forest 0.8196 0.8139 0.8215 0.8177 0.6018 0.9302 0.6920 0.5685
random.forest.tuned 0.8112 0.8752 0.7900 0.8326 0.5800 0.9503 0.6977 0.5685
xgb.start 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438 0.6796 0.5416
xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454 0.7082 0.5876
neural.net.3 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209 0.6665 0.5316
neural.net.5 0.7993 0.6932 0.8345 0.7638 0.5812 0.8914 0.6322 0.4957
neural.net.3.2 0.8031 0.7815 0.8102 0.7959 0.5770 0.9180 0.6639 0.5290

Wyniki na zbiorze testowym są podstawą końcowej oceny jakości modeli. Model, który osiąga dobre wyniki na teście, ma większą zdolność generalizacji, czyli lepiej radzi sobie z nowymi danymi.

Ranking modeli według Balanced Accuracy

W pierwszej kolejności modele uporządkowano według BalancedAccuracy na zbiorze testowym. Ta miara jest szczególnie ważna w przypadku niezbalansowanych klas, ponieważ uwzględnia zarówno poprawne rozpoznawanie klasy pozytywnej, jak i negatywnej.

ranking.test.balanced <- porownanie.wszystkie.test %>%
  as.data.frame() %>%
  rownames_to_column("model") %>%
  arrange(desc(BalancedAccuracy))

knitr::kable(
  ranking.test.balanced,
  digits = 4,
  caption = "Ranking modeli według Balanced Accuracy na zbiorze testowym"
)
Ranking modeli według Balanced Accuracy na zbiorze testowym
model Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1 Kappa
xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454 0.7082 0.5876
random.forest.tuned 0.8112 0.8752 0.7900 0.8326 0.5800 0.9503 0.6977 0.5685
logit 0.7940 0.8686 0.7692 0.8189 0.5552 0.9464 0.6774 0.5366
xgb.start 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438 0.6796 0.5416
random.forest 0.8196 0.8139 0.8215 0.8177 0.6018 0.9302 0.6920 0.5685
decision.tree.downsample 0.7928 0.8535 0.7727 0.8131 0.5544 0.9409 0.6721 0.5304
bagging 0.8015 0.8215 0.7949 0.8082 0.5703 0.9307 0.6732 0.5373
neural.net.3 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209 0.6665 0.5316
neural.net.3.2 0.8031 0.7815 0.8102 0.7959 0.5770 0.9180 0.6639 0.5290
neural.net.5 0.7993 0.6932 0.8345 0.7638 0.5812 0.8914 0.6322 0.4957
decision.tree.tuned 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703 0.6517 0.5545
decision.tree.default 0.8397 0.5115 0.9485 0.7300 0.7670 0.8542 0.6137 0.5177

Model znajdujący się najwyżej w tym rankingu osiąga najlepszy kompromis pomiędzy czułością i specyficznością. Oznacza to, że dobrze radzi sobie zarówno z wykrywaniem osób zarabiających powyżej 50 tys. USD, jak i osób z niższym dochodem.

Ranking modeli według F1

Drugim rankingiem jest zestawienie modeli według miary F1. Jest ona szczególnie użyteczna wtedy, gdy zależy nam na jakości klasyfikacji klasy pozytywnej.

ranking.test.f1 <- porownanie.wszystkie.test %>%
  as.data.frame() %>%
  rownames_to_column("model") %>%
  arrange(desc(F1))

knitr::kable(
  ranking.test.f1,
  digits = 4,
  caption = "Ranking modeli według F1 na zbiorze testowym"
)
Ranking modeli według F1 na zbiorze testowym
model Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1 Kappa
xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454 0.7082 0.5876
random.forest.tuned 0.8112 0.8752 0.7900 0.8326 0.5800 0.9503 0.6977 0.5685
random.forest 0.8196 0.8139 0.8215 0.8177 0.6018 0.9302 0.6920 0.5685
xgb.start 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438 0.6796 0.5416
logit 0.7940 0.8686 0.7692 0.8189 0.5552 0.9464 0.6774 0.5366
bagging 0.8015 0.8215 0.7949 0.8082 0.5703 0.9307 0.6732 0.5373
decision.tree.downsample 0.7928 0.8535 0.7727 0.8131 0.5544 0.9409 0.6721 0.5304
neural.net.3 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209 0.6665 0.5316
neural.net.3.2 0.8031 0.7815 0.8102 0.7959 0.5770 0.9180 0.6639 0.5290
decision.tree.tuned 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703 0.6517 0.5545
neural.net.5 0.7993 0.6932 0.8345 0.7638 0.5812 0.8914 0.6322 0.4957
decision.tree.default 0.8397 0.5115 0.9485 0.7300 0.7670 0.8542 0.6137 0.5177

Wysoka wartość F1 oznacza, że model osiąga dobry kompromis między wykrywaniem klasy pozytywnej a precyzją tych wskazań. W tym projekcie jest to istotne, ponieważ klasa osób zarabiających powyżej 50 tys. USD jest mniej liczna.

Porównanie wyników treningowych i testowych

Aby ocenić, czy modele nie są przeuczone, przygotowano również porównanie wyników na zbiorze treningowym i testowym. Duża różnica między wynikiem treningowym a testowym może sugerować overfitting, czyli zbyt silne dopasowanie modelu do danych uczących.

porownanie.train.test <- porownanie.wszystkie.test %>%
  as.data.frame() %>%
  rownames_to_column("model") %>%
  rename_with(~ paste0(.x, "_test"), -model) %>%
  left_join(
    porownanie.wszystkie.train %>%
      as.data.frame() %>%
      rownames_to_column("model") %>%
      rename_with(~ paste0(.x, "_train"), -model),
    by = "model"
  ) %>%
  mutate(
    roznica_Accuracy = Accuracy_train - Accuracy_test,
    roznica_BalancedAccuracy = BalancedAccuracy_train - BalancedAccuracy_test,
    roznica_F1 = F1_train - F1_test
  ) %>%
  arrange(desc(BalancedAccuracy_test))

knitr::kable(
  porownanie.train.test,
  digits = 4,
  caption = "Porównanie wyników treningowych i testowych"
)
Porównanie wyników treningowych i testowych
model Accuracy_test Sensitivity_test Specificity_test BalancedAccuracy_test PPV_test NPV_test F1_test Kappa_test Accuracy_train Sensitivity_train Specificity_train BalancedAccuracy_train PPV_train NPV_train F1_train Kappa_train roznica_Accuracy roznica_BalancedAccuracy roznica_F1
xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454 0.7082 0.5876 0.8286 0.8786 0.8120 0.8453 0.6077 0.9528 0.7185 0.6011 0.0047 0.0099 0.0103
random.forest.tuned 0.8112 0.8752 0.7900 0.8326 0.5800 0.9503 0.6977 0.5685 0.8500 0.9488 0.8173 0.8830 0.6325 0.9797 0.7590 0.6563 0.0388 0.0504 0.0613
logit 0.7940 0.8686 0.7692 0.8189 0.5552 0.9464 0.6774 0.5366 0.7936 0.8615 0.7711 0.8163 0.5550 0.9438 0.6751 0.5340 -0.0004 -0.0026 -0.0023
xgb.start 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438 0.6796 0.5416 0.7986 0.8727 0.7740 0.8234 0.5614 0.9483 0.6833 0.5456 0.0004 0.0045 0.0037
random.forest 0.8196 0.8139 0.8215 0.8177 0.6018 0.9302 0.6920 0.5685 0.9299 0.9553 0.9215 0.9384 0.8013 0.9842 0.8716 0.8239 0.1103 0.1207 0.1796
decision.tree.downsample 0.7928 0.8535 0.7727 0.8131 0.5544 0.9409 0.6721 0.5304 0.8394 0.8824 0.7964 0.8394 0.8125 0.8714 0.8460 0.6788 0.0466 0.0263 0.1739
bagging 0.8015 0.8215 0.7949 0.8082 0.5703 0.9307 0.6732 0.5373 0.9517 0.9812 0.9420 0.9616 0.8486 0.9934 0.9101 0.8773 0.1502 0.1534 0.2369
neural.net.3 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209 0.6665 0.5316 0.8490 0.8753 0.8402 0.8578 0.6464 0.9528 0.7436 0.6400 0.0459 0.0588 0.0771
neural.net.3.2 0.8031 0.7815 0.8102 0.7959 0.5770 0.9180 0.6639 0.5290 0.8498 0.8649 0.8448 0.8548 0.6502 0.9493 0.7424 0.6393 0.0467 0.0589 0.0785
neural.net.5 0.7993 0.6932 0.8345 0.7638 0.5812 0.8914 0.6322 0.4957 0.8892 0.8729 0.8946 0.8838 0.7344 0.9547 0.7977 0.7222 0.0899 0.1200 0.1655
decision.tree.tuned 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703 0.6517 0.5545 0.8616 0.6056 0.9465 0.7760 0.7894 0.8786 0.6854 0.5987 0.0159 0.0191 0.0337
decision.tree.default 0.8397 0.5115 0.9485 0.7300 0.7670 0.8542 0.6137 0.5177 0.8416 0.5099 0.9516 0.7307 0.7773 0.8542 0.6158 0.5214 0.0019 0.0007 0.0021

Im większa dodatnia różnica między wynikiem treningowym i testowym, tym większe podejrzenie przeuczenia. Szczególnie warto obserwować różnice dla BalancedAccuracy i F1, ponieważ są one bardziej informacyjne niż sama ogólna trafność.

Najlepsze modele według wybranych metryk

Poniżej automatycznie wskazano najlepszy model według BalancedAccuracy oraz najlepszy model według F1 na zbiorze testowym.

najlepszy.model.balanced <- ranking.test.balanced[1, ]
najlepszy.model.f1 <- ranking.test.f1[1, ]

najlepszy.model.balanced
##       model Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV
## 1 xgb.tuned   0.8239      0.8583      0.8125           0.8354 0.6027 0.9454
##       F1  Kappa
## 1 0.7082 0.5876
najlepszy.model.f1
##       model Accuracy Sensitivity Specificity BalancedAccuracy    PPV    NPV
## 1 xgb.tuned   0.8239      0.8583      0.8125           0.8354 0.6027 0.9454
##       F1  Kappa
## 1 0.7082 0.5876
cat(
  "Najlepszy model według Balanced Accuracy na zbiorze testowym to:",
  najlepszy.model.balanced$model,
  "z wynikiem",
  najlepszy.model.balanced$BalancedAccuracy,
  "\n"
)
## Najlepszy model według Balanced Accuracy na zbiorze testowym to: xgb.tuned z wynikiem 0.8354
cat(
  "Najlepszy model według F1 na zbiorze testowym to:",
  najlepszy.model.f1$model,
  "z wynikiem",
  najlepszy.model.f1$F1,
  "\n"
)
## Najlepszy model według F1 na zbiorze testowym to: xgb.tuned z wynikiem 0.7082

Interpretacja końcowa

Na podstawie zestawienia wyników można stwierdzić, że wybór najlepszego modelu zależy od celu analizy. Jeżeli najważniejsza byłaby ogólna trafność klasyfikacji, można analizować Accuracy, jednak przy niezbalansowanych danych nie powinna być ona jedynym kryterium wyboru modelu.

W tym projekcie większe znaczenie ma BalancedAccuracy, ponieważ uwzględnia zarówno zdolność wykrywania klasy pozytywnej, jak i negatywnej. Model o wysokiej wartości tej miary dobrze radzi sobie z obiema klasami, a nie tylko z klasą większościową.

Warto również analizować Sensitivity, ponieważ pokazuje, jaka część osób rzeczywiście zarabiających powyżej 50 tys. USD została poprawnie wykryta przez model. Jeżeli celem byłoby jak najlepsze identyfikowanie osób z wysokim dochodem, wysoka czułość byłaby szczególnie pożądana.

Z kolei Specificity pokazuje, jak dobrze model rozpoznaje osoby, które nie zarabiają powyżej 50 tys. USD. Wysoka specyficzność oznacza, że model rzadziej błędnie przypisuje osoby do klasy wysokiego dochodu.

Miara F1 jest dobrym uzupełnieniem analizy, ponieważ łączy precyzję i czułość. Jest przydatna wtedy, gdy chcemy ocenić jakość klasyfikacji klasy pozytywnej przy nierównych proporcjach klas.

Porównanie wyników treningowych i testowych pozwala dodatkowo ocenić stabilność modeli. Modele, które osiągają bardzo wysokie wyniki na zbiorze treningowym, ale wyraźnie słabsze na zbiorze testowym, mogą być przeuczone. Chociażby tak jak w przypadku pojedynczych drzew. Najbardziej pożądany jest model, który osiąga wysokie wyniki na zbiorze testowym i jednocześnie nie wykazuje dużej różnicy między treningiem a testem.

Ostatecznie za najlepszy model należy uznać ten, który osiąga wysoką jakość na zbiorze testowym, dobrą wartość BalancedAccuracy, odpowiednio wysoką czułość oraz akceptowalną różnicę między wynikami treningowymi i testowymi: W opini autorek projektów byłby to model XG BOOST.