library(tidyverse) # manipulacja i wizualizacja danych
library(skimr) # statystyki opisowe / overview danych
library(DataExplorer) # eksploracja danych
library(corrplot) # macierz korelacji
library(PerformanceAnalytics) # dodatkowe wykresy korelacji
library(janitor) # tabele częstości i czyszczenie nazw
library(modeldata) # zbiory danych
library(sf) # dane przestrzenne, jeśli używane
library(caret) # train/test, confusionMatrix, metryki
library(olsrr) # selekcja zmiennych
library(verification) # narzędzia do oceny klasyfikacji
library(glmnet) # modele regularyzowane
library(pROC) # krzywe ROC i AUC
library(rpart) # drzewa decyzyjne
library(rpart.plot) # wizualizacja drzew
library(cutpointr) # dobór punktu odcięcia
library(randomForest) # las losowy
library(ranger) # szybszy random forest
library(xgboost) #pamiętajmy o starszej wersji pakietu!
library(neuralnet) # proste sieci neuronowe
library(NeuralNetTools) # wizualizacja sieci neuronowych
Poniższe badanie stanowi jeden z dwóch projektów zaliczeniowych przygotowanych w ramach przedmiotu - Uczenie maszynowe nadzorowane (2400-M1ABUMN) - prowadzonego przez mgr Monikę Kot w semestrze letnim 2026. Projekt powstał w ramach współpracy z Małgorzatą Hodurek.
Niniejsza praca skupiona będzie na analizie klasyfikacyjnej zarobków osób zarabiających powyżej 50k USD od cech ekonomiczno-społecznych. Jednak równie istotny aspekt analizy stanowić będzie porównanie modeli nadzorowanego uczenia maszynowego dla zmiennej binarnej
Struktura pracy jest następująca: najpierw przygotowujemy dane do dalszej analizy, następnie szacujemy modele podstawowe - benchmarkowe, względem, których będziemy porównywać modele nadzorowanego uczenia maszynowego. W dalszej sekcji trenowane oraz dostrajane są, kolejno: drzewa klasyfikacyjne, modele bagging, random forest oraz boosting a na końcu sieci neuronowe. Następnie wszystkie modele są ze sobą porównane i wyłonione zostają najlepsze podejścia do analizy i predykcji klasyfikacji zaróbków.
W badaniu wykorzystujemy próbkę danych na temat zarobków danych osób ze strony [archive.ics.uci.edu] [https://archive.ics.uci.edu/dataset/2/adult]. Zbiór ten opiera się na danych pobranych zebranych w 1994 roku.
# ustawienie ścieżki
getwd()
## [1] "C:/Users/aleks/OneDrive - SGH/Pulpit/R/ML"
# ostrzeżenia w języku polskim
Sys.setenv(LANG = "pl")
# wyłączenie notacji naukowej
options(scipen = 999)
# wczytanie danych
salary <- read.csv("02_dane_w_SML-20260311/dane/k1_adult.csv")
glimpse(salary)
## Rows: 30,162
## Columns: 18
## $ class <chr> "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", "<=50K", …
## $ age <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 30, 23, 32…
## $ education_num <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 13, 12, 4,…
## $ marital_status <chr> "Never-married", "Married-civ-spouse", "Divorced", "Ma…
## $ occupation <chr> "Adm-clerical", "Exec-managerial", "Handlers-cleaners"…
## $ relationship <chr> "Not-in-family", "Husband", "Not-in-family", "Husband"…
## $ race <chr> "White", "White", "White", "Black", "Black", "White", …
## $ sex <chr> "Male", "Male", "Male", "Male", "Female", "Female", "F…
## $ hours_per_week <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 40, 30, 50…
## $ capital_gain <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_loss <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ workclass <chr> "State-gov", "Self-emp-not-inc", "Private", "Private",…
## $ native_country <chr> "United-States", "United-States", "United-States", "Un…
## $ net_capital <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0, 0, 0, 0,…
## $ capital_flag <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ hours_x_edu <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 520, 800, …
## $ mid_age <int> 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, …
## $ high_work_hours <int> 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, …
W kolejnym etapie przeprowadzono przygotowanie danych do modelowania. Na początku zmieniono nazwy zmiennych na bardziej czytelne oraz przekodowano wartości kategorii na polskie, opisowe etykiety. Dzięki temu dalsza analiza i interpretacja wyników są bardziej intuicyjne.
Następnie zmienne tekstowe zostały przekształcone do typu factor, ponieważ reprezentują cechy kategoryczne, takie jak płeć, stan cywilny, zawód, rasa, typ zatrudnienia czy kraj pochodzenia. Przed dalszym modelowaniem sprawdzono również występowanie braków danych.
W ramach feature engineeringu utworzono kilka nowych zmiennych. Pierwszą z nich była zmienna poziom_edukacji_grupa, która grupuje szczegółowe poziomy edukacji w szersze kategorie: wykształcenie podstawowe, średnie lub niepełne wyższe, associate/licencjackie oraz magisterskie lub wyższe. Takie przekształcenie pozwala uprościć strukturę danych i ułatwia interpretację wpływu edukacji na poziom dochodu. W kodzie widać, że grupowanie zostało wykonane na podstawie zmiennej liczbowej poziom_edukacji za pomocą instrukcji case_when() .
Drugą utworzoną zmienną była kraj_pochodzenia_grupa, która upraszcza zmienną kraju pochodzenia do dwóch kategorii: USA oraz Poza_USA. Zastosowano takie rozwiązanie, ponieważ zdecydowana większość obserwacji pochodzi ze Stanów Zjednoczonych, a pozostałe kraje występują znacznie rzadziej. Połączenie ich w jedną kategorię ogranicza problem małych liczebności w poszczególnych grupach oraz poprawia czytelność modelu .
Dodatkowo zmienna objaśniana klasa_dochodu została przekształcona do postaci binarnej klasa_dochodu_01, gdzie wartość tak oznacza dochód powyżej 50 tys., a wartość nie oznacza dochód nieprzekraczający tego progu. Na potrzeby regresji logistycznej zmienna ta została następnie zakodowana jako 0/1, co jest wymagane przy modelowaniu binarnym za pomocą funkcji glm(…, family = binomial) .
Na końcu przygotowano finalny zbiór danych salary_prep, z którego dla modelu logitowego usunięto zmienne zbędne lub zastąpione przez nowe konstrukcje, takie jak oryginalny kraj_pochodzenia, poziom_edukacji, klasa_dochodu, a także składowe kapitału, które zostały zastąpione zmienną syntetyczną dotyczącą dodatniego kapitału netto.
# Wczytanie słowników z nazwami zmiennych i wartościami kategorii
source("02_dane_w_SML-20260311/slowniki_funkcje/slowniki_wartosci_adult.R")
# Zmiana nazw zmiennych na polskie
salary_pl <- salary %>%
rename(any_of(nazwy_zmiennych_pl))
# Przekodowanie wartości kategorii na bardziej opisowe etykiety
salary_pl_2 <- salary_pl %>%
mutate(
klasa_dochodu = recode(trimws(as.character(klasa_dochodu)), !!!dict_klasa_dochodu),
stan_cywilny = recode(trimws(as.character(stan_cywilny)), !!!dict_stan_cywilny),
zawod = recode(trimws(as.character(zawod)), !!!dict_zawod),
relacja_w_rodzinie = recode(trimws(as.character(relacja_w_rodzinie)), !!!dict_relacja_w_rodzinie),
rasa = recode(trimws(as.character(rasa)), !!!dict_rasa),
plec = recode(trimws(as.character(plec)), !!!dict_plec),
typ_zatrudnienia = recode(trimws(as.character(typ_zatrudnienia)), !!!dict_typ_zatrudnienia),
kraj_pochodzenia = recode(trimws(as.character(kraj_pochodzenia)), !!!dict_kraj_pochodzenia),
czy_kapital_netto_dodatni = recode(trimws(as.character(czy_kapital_netto_dodatni)), !!!dict_czy_kapital_netto_dodatni),
wiek_sredni = recode(trimws(as.character(wiek_sredni)), !!!dict_wiek_sredni),
dlugie_godziny_pracy = recode(trimws(as.character(dlugie_godziny_pracy)), !!!dict_dlugie_godziny_pracy)
)
# Zamiana zmiennych tekstowych na zmienne kategoryczne
salary_pl_3 <- salary_pl_2 %>%
mutate(across(where(is.character), as.factor))
glimpse(salary_pl_3)
## Rows: 30,162
## Columns: 18
## $ klasa_dochodu <fct> dochód_do_50k, dochód_do_50k, dochód_do_50k,…
## $ wiek <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, …
## $ poziom_edukacji <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, …
## $ stan_cywilny <fct> nigdy_niezamężny_niezonaty, małżeństwo_cywil…
## $ zawod <fct> administracja_biurowa, kadra_kierownicza, pr…
## $ relacja_w_rodzinie <fct> poza_rodziną, mąż, poza_rodziną, mąż, żona, …
## $ rasa <fct> biała, biała, biała, czarna, czarna, biała, …
## $ plec <fct> mężczyzna, mężczyzna, mężczyzna, mężczyzna, …
## $ godziny_pracy_tygodniowo <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, …
## $ zysk_kapitalowy <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0…
## $ strata_kapitalowa <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ typ_zatrudnienia <fct> administracja_stanowa, samozatrudniony_bez_o…
## $ kraj_pochodzenia <fct> Stany_Zjednoczone, Stany_Zjednoczone, Stany_…
## $ kapital_netto <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0…
## $ czy_kapital_netto_dodatni <fct> kapital_netto_dodatni, kapital_netto_niedoda…
## $ godziny_razy_edukacja <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, …
## $ wiek_sredni <fct> w_grupie_wieku_średniego, w_grupie_wieku_śre…
## $ dlugie_godziny_pracy <fct> nie, nie, nie, nie, nie, nie, nie, tak, tak,…
W kolejnym kroku przeanalizowano zależności pomiędzy zmiennymi numerycznymi. W tym celu utworzono macierz korelacji, która pozwala sprawdzić, czy między wybranymi cechami występują silne zależności liniowe. Jest to istotne, ponieważ bardzo silnie skorelowane zmienne mogą powielać podobną informację w modelu.
# Wybór zmiennych numerycznych
num_data_salary <- salary_pl_3 %>%
dplyr::select(where(is.numeric))
glimpse(num_data_salary)
## Rows: 30,162
## Columns: 7
## $ wiek <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, 3…
## $ poziom_edukacji <int> 13, 13, 9, 7, 13, 14, 5, 9, 14, 13, 10, 13, 1…
## $ godziny_pracy_tygodniowo <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, 4…
## $ zysk_kapitalowy <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0,…
## $ strata_kapitalowa <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ kapital_netto <int> 2174, 0, 0, 0, 0, 0, 0, 0, 14084, 5178, 0, 0,…
## $ godziny_razy_edukacja <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, 5…
# Macierz korelacji
cor_mat_salary <- cor(num_data_salary, use = "pairwise.complete.obs")
# Wizualizacja macierzy korelacji
corrplot(
cor_mat_salary,
type = "lower",
tl.cex = 0.7
)
Warto również przeanalizować częstości występowania poszczególnych poziomów zmiennych kategorycznych. W przypadku zmiennej kraju pochodzenia widoczne jest bardzo silne skupienie obserwacji w jednej kategorii, czyli w Stanach Zjednoczonych. Pozostałe kraje występują znacznie rzadziej. Dlatego dla logitu zdecydowano się na połączenie tych kategorii
salary_pl_3 %>%
count(kraj_pochodzenia, sort = TRUE) %>%
mutate(kraj_pochodzenia = fct_reorder(kraj_pochodzenia, n)) %>%
ggplot(aes(x = kraj_pochodzenia, y = n)) +
geom_col() +
coord_flip() +
labs(
title = "Liczba obserwacji według kraju pochodzenia",
x = "Kraj pochodzenia",
y = "Liczba obserwacji"
) +
theme_minimal()
# Uproszczenie kraju pochodzenia do dwóch kategorii
salary_pl_3 <- salary_pl_3 %>%
mutate(
kraj_pochodzenia_grupa = if_else(
kraj_pochodzenia == "Stany_Zjednoczone",
"USA",
"Poza_USA",
missing = NA_character_
),
kraj_pochodzenia_grupa = as.factor(kraj_pochodzenia_grupa)
)
# Grupowanie poziomu edukacji do szerszych kategorii
salary_pl_3 <- salary_pl_3 %>%
mutate(
poziom_edukacji_grupa = case_when(
poziom_edukacji <= 8 ~ "podstawowe",
poziom_edukacji %in% c(9, 10) ~ "średnie_lub_niepełne_wyższe",
poziom_edukacji %in% c(11, 12, 13) ~ "associate_lub_licencjackie",
poziom_edukacji >= 14 ~ "magisterskie_lub_wyższe",
TRUE ~ NA_character_
),
poziom_edukacji_grupa = factor(
poziom_edukacji_grupa,
levels = c(
"podstawowe",
"średnie_lub_niepełne_wyższe",
"associate_lub_licencjackie",
"magisterskie_lub_wyższe"
),
ordered = FALSE
)
)
# Utworzenie binarnej zmiennej celu
salary_pl_3 <- salary_pl_3 %>%
mutate(
klasa_dochodu_01 = if_else(
klasa_dochodu == "dochód_powyżej_50k",
"tak",
"nie"
),
klasa_dochodu_01 = as.factor(klasa_dochodu_01)
)
# Finalny zbiór do modelu logitowego
salary_prep <- salary_pl_3 %>%
dplyr::select(-c(
kraj_pochodzenia,
poziom_edukacji,
klasa_dochodu,
kapital_netto,
zysk_kapitalowy,
strata_kapitalowa
))
# Kodowanie zmiennej celu jako 0/1 do regresji logistycznej
salary_prep <- salary_prep %>%
mutate(
klasa_dochodu_01 = if_else(klasa_dochodu_01 == "tak", 1, 0)
)
glimpse(salary_prep)
## Rows: 30,162
## Columns: 15
## $ wiek <int> 39, 50, 38, 53, 28, 37, 49, 52, 31, 42, 37, …
## $ stan_cywilny <fct> nigdy_niezamężny_niezonaty, małżeństwo_cywil…
## $ zawod <fct> administracja_biurowa, kadra_kierownicza, pr…
## $ relacja_w_rodzinie <fct> poza_rodziną, mąż, poza_rodziną, mąż, żona, …
## $ rasa <fct> biała, biała, biała, czarna, czarna, biała, …
## $ plec <fct> mężczyzna, mężczyzna, mężczyzna, mężczyzna, …
## $ godziny_pracy_tygodniowo <int> 40, 13, 40, 40, 40, 40, 16, 45, 50, 40, 80, …
## $ typ_zatrudnienia <fct> administracja_stanowa, samozatrudniony_bez_o…
## $ czy_kapital_netto_dodatni <fct> kapital_netto_dodatni, kapital_netto_niedoda…
## $ godziny_razy_edukacja <int> 520, 169, 360, 280, 520, 560, 80, 405, 700, …
## $ wiek_sredni <fct> w_grupie_wieku_średniego, w_grupie_wieku_śre…
## $ dlugie_godziny_pracy <fct> nie, nie, nie, nie, nie, nie, nie, tak, tak,…
## $ kraj_pochodzenia_grupa <fct> USA, USA, USA, USA, Poza_USA, USA, Poza_USA,…
## $ poziom_edukacji_grupa <fct> associate_lub_licencjackie, associate_lub_li…
## $ klasa_dochodu_01 <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,…
Sprawdźmy jeszcze czy w naszym zbiorze nie ma pustych wartości. Z poniższego outputu wynika, że nie mamy do czynienia z żadnymi brakami w danych. Zbiór do dalszej analizy jest gotowy, jak widać usunięto w nim również zmienne które powielały informacje jak kapitał_netto, zysk czy strata netto a pozostawiono zmienną czy kapitał netto dodatni.
colSums(is.na(salary_prep)) %>% sort() # nie ma braków
## wiek stan_cywilny zawod
## 0 0 0
## relacja_w_rodzinie rasa plec
## 0 0 0
## godziny_pracy_tygodniowo typ_zatrudnienia czy_kapital_netto_dodatni
## 0 0 0
## godziny_razy_edukacja wiek_sredni dlugie_godziny_pracy
## 0 0 0
## kraj_pochodzenia_grupa poziom_edukacji_grupa klasa_dochodu_01
## 0 0 0
Kluczowym etapem przygotowania zbioru do analizy przy wykorzystaniu technik uczenia maszynowego jest podział całego zbioru na zbiór treningowy (służące do estymacji modelu, wyboru zmiennych oraz strojenia hiperparametrów) oraz zbiór testowy (służące jedynie końcowej ocenie jakości modelu).
W przypadku naszego zbioru, który nie kieruje się żadną chronologią, możemy posłużyć się podziałem losowym. W taki przyapdku zakładamy, że obserwacje są niezależne i pochodzą z tego samego rozkładu (i.i.d.). Jest to prosty sposób podziału, który jest równocześnie zgodny z klasycznym podejściem Uczenia Maszynowego, estymacja błędu w tym przypadku jest stabilna a dane są wykorzystane efektywnie. Natomiast jeśli warunki rynkowe silnie zmieniały się w czasie model uwzględnia te zmiany zarówno na etapie treningu, jak i testowania.
W naszje analizie posłużymy się podziałem 70/30.
set.seed(123)
trening <- createDataPartition(salary_prep$klasa_dochodu_01,
p = 0.7,
list = FALSE)
salary.train <- salary_prep[c(trening),]
salary.test <- salary_prep[-c(trening),]
Jako pierwszy model benchmarkowy zastosowano logit, czyli klasyczny model wykorzystywany wtedy, gdy zmienna objaśniana ma charakter binarny. W naszym przypadku zmienna klasa_dochodu_01 przyjmuje wartość 1, gdy dana osoba osiąga dochód powyżej 50 tys. USD rocznie, oraz wartość 0 w przeciwnym przypadku.
Model logitowy nie przewiduje bezpośrednio klasy, lecz prawdopodobieństwo przynależności do klasy pozytywnej. Oznacza to, że dla każdej obserwacji model zwraca wartość z przedziału od 0 do 1, którą można interpretować jako prawdopodobieństwo osiągania dochodu powyżej 50 tys. USD.
W pierwszym kroku oszacowano pełny model regresji logistycznej, wykorzystując wszystkie dostępne zmienne objaśniające znajdujące się w przygotowanym zbiorze danych.
glm.full.salary <- glm(
klasa_dochodu_01 ~ .,
data = salary.train,
family = binomial
)
summary(glm.full.salary)
##
## Call:
## glm(formula = klasa_dochodu_01 ~ ., family = binomial, data = salary.train)
##
## Coefficients:
## Estimate Std. Error
## (Intercept) -2.6754588 0.3807457
## wiek 0.0294486 0.0020325
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych 1.2242191 0.5656714
## stan_cywilnymałżonek_nieobecny -2.1825524 0.4047527
## stan_cywilnynigdy_niezamężny_niezonaty -2.4892457 0.3120541
## stan_cywilnyrozwiedziony_rozwiedziona -2.2613339 0.3189205
## stan_cywilnyw_separacji -2.2664291 0.3536429
## stan_cywilnywdowiec_wdowa -1.9068868 0.3539218
## zawodinne_usługi -0.9122487 0.1384000
## zawodkadra_kierownicza 0.8144842 0.0895568
## zawodoperatorzy_maszyn_i_inspektorzy -0.3997932 0.1179369
## zawodpracownicy_fizyczni_i_sprzątający -0.8191188 0.1708117
## zawodprywatne_usługi_domowe -12.6204938 127.4793512
## zawodrolnictwo_i_rybołówstwo -0.8856543 0.1583252
## zawodrzemiosło_i_naprawy -0.0157938 0.0933926
## zawodsiły_zbrojne -0.3556816 1.4053650
## zawodspecjaliści 0.5813813 0.0946097
## zawodsprzedaż 0.3530584 0.0958303
## zawodtransport_i_przemieszczanie -0.0892285 0.1154147
## zawodusługi_ochrony 0.5981511 0.1467605
## zawodwsparcie_techniczne 0.5923707 0.1301816
## relacja_w_rodziniemąż 0.2829615 0.2793491
## relacja_w_rodzinieniezamężny_niezonaty 0.4970554 0.3075370
## relacja_w_rodziniepoza_rodziną 0.8819446 0.2914316
## relacja_w_rodziniewłasne_dziecko -0.1253493 0.3071854
## relacja_w_rodzinieżona 1.5351340 0.2913368
## rasabiała -0.0635356 0.1394250
## rasaczarna -0.1422131 0.1611755
## rasainna -0.6782135 0.3413477
## rasardzenni_amerykanie_lub_eskimosi -0.5841652 0.2987833
## plecmężczyzna 0.7672716 0.0903320
## godziny_pracy_tygodniowo -0.0256287 0.0061979
## typ_zatrudnieniaadministracja_lokalna -0.6376198 0.1326488
## typ_zatrudnieniaadministracja_stanowa -0.7654349 0.1455334
## typ_zatrudnieniabez_wynagrodzenia -13.7862623 401.4844552
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej -0.8419172 0.1279671
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną -0.2362106 0.1426511
## typ_zatrudnieniasektor_prywatny -0.4448359 0.1102385
## czy_kapital_netto_dodatnikapital_netto_niedodatni -1.5661187 0.0672901
## godziny_razy_edukacja 0.0039352 0.0005241
## wiek_sredniw_grupie_wieku_średniego 0.7070402 0.0505085
## dlugie_godziny_pracytak 0.3014832 0.0609898
## kraj_pochodzenia_grupaUSA 0.2247643 0.0921824
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe 0.3165938 0.1228674
## poziom_edukacji_grupaassociate_lub_licencjackie 0.5127905 0.1777039
## poziom_edukacji_grupamagisterskie_lub_wyższe 0.8351007 0.2232604
## z value
## (Intercept) -7.027
## wiek 14.489
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych 2.164
## stan_cywilnymałżonek_nieobecny -5.392
## stan_cywilnynigdy_niezamężny_niezonaty -7.977
## stan_cywilnyrozwiedziony_rozwiedziona -7.091
## stan_cywilnyw_separacji -6.409
## stan_cywilnywdowiec_wdowa -5.388
## zawodinne_usługi -6.591
## zawodkadra_kierownicza 9.095
## zawodoperatorzy_maszyn_i_inspektorzy -3.390
## zawodpracownicy_fizyczni_i_sprzątający -4.795
## zawodprywatne_usługi_domowe -0.099
## zawodrolnictwo_i_rybołówstwo -5.594
## zawodrzemiosło_i_naprawy -0.169
## zawodsiły_zbrojne -0.253
## zawodspecjaliści 6.145
## zawodsprzedaż 3.684
## zawodtransport_i_przemieszczanie -0.773
## zawodusługi_ochrony 4.076
## zawodwsparcie_techniczne 4.550
## relacja_w_rodziniemąż 1.013
## relacja_w_rodzinieniezamężny_niezonaty 1.616
## relacja_w_rodziniepoza_rodziną 3.026
## relacja_w_rodziniewłasne_dziecko -0.408
## relacja_w_rodzinieżona 5.269
## rasabiała -0.456
## rasaczarna -0.882
## rasainna -1.987
## rasardzenni_amerykanie_lub_eskimosi -1.955
## plecmężczyzna 8.494
## godziny_pracy_tygodniowo -4.135
## typ_zatrudnieniaadministracja_lokalna -4.807
## typ_zatrudnieniaadministracja_stanowa -5.260
## typ_zatrudnieniabez_wynagrodzenia -0.034
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej -6.579
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną -1.656
## typ_zatrudnieniasektor_prywatny -4.035
## czy_kapital_netto_dodatnikapital_netto_niedodatni -23.274
## godziny_razy_edukacja 7.509
## wiek_sredniw_grupie_wieku_średniego 13.998
## dlugie_godziny_pracytak 4.943
## kraj_pochodzenia_grupaUSA 2.438
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe 2.577
## poziom_edukacji_grupaassociate_lub_licencjackie 2.886
## poziom_edukacji_grupamagisterskie_lub_wyższe 3.740
## Pr(>|z|)
## (Intercept) 0.0000000000021118 ***
## wiek < 0.0000000000000002 ***
## stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych 0.030450 *
## stan_cywilnymałżonek_nieobecny 0.0000000695572516 ***
## stan_cywilnynigdy_niezamężny_niezonaty 0.0000000000000015 ***
## stan_cywilnyrozwiedziony_rozwiedziona 0.0000000000013354 ***
## stan_cywilnyw_separacji 0.0000000001466635 ***
## stan_cywilnywdowiec_wdowa 0.0000000712951907 ***
## zawodinne_usługi 0.0000000000435724 ***
## zawodkadra_kierownicza < 0.0000000000000002 ***
## zawodoperatorzy_maszyn_i_inspektorzy 0.000699 ***
## zawodpracownicy_fizyczni_i_sprzątający 0.0000016231096794 ***
## zawodprywatne_usługi_domowe 0.921138
## zawodrolnictwo_i_rybołówstwo 0.0000000222033528 ***
## zawodrzemiosło_i_naprawy 0.865709
## zawodsiły_zbrojne 0.800200
## zawodspecjaliści 0.0000000007993977 ***
## zawodsprzedaż 0.000229 ***
## zawodtransport_i_przemieszczanie 0.439456
## zawodusługi_ochrony 0.0000458768836760 ***
## zawodwsparcie_techniczne 0.0000053559060923 ***
## relacja_w_rodziniemąż 0.311093
## relacja_w_rodzinieniezamężny_niezonaty 0.106041
## relacja_w_rodziniepoza_rodziną 0.002476 **
## relacja_w_rodziniewłasne_dziecko 0.683232
## relacja_w_rodzinieżona 0.0000001369627385 ***
## rasabiała 0.648607
## rasaczarna 0.377588
## rasainna 0.046937 *
## rasardzenni_amerykanie_lub_eskimosi 0.050566 .
## plecmężczyzna < 0.0000000000000002 ***
## godziny_pracy_tygodniowo 0.0000354877337519 ***
## typ_zatrudnieniaadministracja_lokalna 0.0000015334482761 ***
## typ_zatrudnieniaadministracja_stanowa 0.0000001444358345 ***
## typ_zatrudnieniabez_wynagrodzenia 0.972607
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej 0.0000000000473082 ***
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną 0.097750 .
## typ_zatrudnieniasektor_prywatny 0.0000545525206521 ***
## czy_kapital_netto_dodatnikapital_netto_niedodatni < 0.0000000000000002 ***
## godziny_razy_edukacja 0.0000000000000596 ***
## wiek_sredniw_grupie_wieku_średniego < 0.0000000000000002 ***
## dlugie_godziny_pracytak 0.0000007685947367 ***
## kraj_pochodzenia_grupaUSA 0.014758 *
## poziom_edukacji_grupaśrednie_lub_niepełne_wyższe 0.009974 **
## poziom_edukacji_grupaassociate_lub_licencjackie 0.003906 **
## poziom_edukacji_grupamagisterskie_lub_wyższe 0.000184 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 23695 on 21113 degrees of freedom
## Residual deviance: 14413 on 21068 degrees of freedom
## AIC: 14505
##
## Number of Fisher Scoring iterations: 14
Następnie wyznaczono przewidywane prawdopodobieństwa dla zbioru treningowego. W funkcji predict() zastosowano argument type = “response”, ponieważ dla modelu logistycznego pozwala on otrzymać prawdopodobieństwa przynależności do klasy pozytywnej.
pred.glm.train.prob <- predict( glm.full.salary, newdata = salary.train, type = "response" )
summary(pred.glm.train.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.00000 0.01913 0.12163 0.24889 0.42026 0.99238
head(pred.glm.train.prob)
## 2 3 4 5 7 9
## 0.49164947 0.03160871 0.11627241 0.41251300 0.00678185 0.56088961
W klasyfikacji binarnej standardowo przyjmuje się próg odcięcia równy 0.5. W przypadku danych niezbalansowanych taki próg może jednak prowadzić do zbyt słabego wykrywania klasy mniejszościowej. Ponieważ osoby zarabiające powyżej 50 tys. USD stanowią mniejszą część obserwacji, wyznaczono optymalny próg odcięcia na podstawie krzywej ROC.
krzywa_roc <- pROC::roc(
response = salary.train$klasa_dochodu_01,
predictor = pred.glm.train.prob
)
plot(
krzywa_roc,
main = "Krzywa ROC - zbiór treningowy",
col = "blue",
lwd = 2,
print.auc = TRUE
)
optymalny_prog <- pROC::coords(
krzywa_roc,
x = "best",
best.method = "youden",
ret = c("threshold", "specificity", "sensitivity")
)
optymalny_prog
## threshold specificity sensitivity
## 1 0.2346138 0.7711079 0.8614653
optymalny_prog_glm <- optymalny_prog$threshold[1]
Optymalnym tresholdem= został punkt 0,2346
Po wyznaczeniu optymalnego progu odcięcia przewidywane
prawdopodobieństwa dla zbioru treningowego zostały przekształcone na
klasy. Jeżeli przewidywane prawdopodobieństwo było większe lub równe
progowi optymalny_prog_glm, obserwacja została przypisana
do klasy 1, czyli do grupy osób zarabiających powyżej 50 tys. USD. W
przeciwnym przypadku obserwacja została przypisana do klasy 0.
pred.glm.train.class <- ifelse(
pred.glm.train.prob >= optymalny_prog_glm,
1,
0
)
pred.glm.train.class <- factor(pred.glm.train.class, levels = c(0, 1))
salary.train$klasa_dochodu_01 <- factor(salary.train$klasa_dochodu_01, levels = c(0, 1))
cm.train <- confusionMatrix(
pred.glm.train.class,
salary.train$klasa_dochodu_01,
positive = "1"
)
cm.train
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 12229 728
## 1 3630 4527
##
## Accuracy : 0.7936
## 95% CI : (0.7881, 0.799)
## No Information Rate : 0.7511
## P-Value [Acc > NIR] : < 0.00000000000000022
##
## Kappa : 0.534
##
## Mcnemar's Test P-Value : < 0.00000000000000022
##
## Sensitivity : 0.8615
## Specificity : 0.7711
## Pos Pred Value : 0.5550
## Neg Pred Value : 0.9438
## Prevalence : 0.2489
## Detection Rate : 0.2144
## Detection Prevalence : 0.3863
## Balanced Accuracy : 0.8163
##
## 'Positive' Class : 1
##
cm.train$table
## Reference
## Prediction 0 1
## 0 12229 728
## 1 3630 4527
Macierz pomyłek pozwala sprawdzić, ile obserwacji zostało sklasyfikowanych poprawnie oraz jakie rodzaje błędów popełnia model. W przypadku tego problemu szczególnie ważna jest czułość, czyli zdolność modelu do wykrywania osób rzeczywiście zarabiających powyżej 50 tys. USD.
ggplot(
data.frame(
p = pred.glm.train.prob,
klasa = salary.train$klasa_dochodu_01
),
aes(x = p, fill = klasa)
) +
geom_histogram(alpha = 0.5, bins = 30, position = "identity") +
labs(
title = "Rozkład przewidywanych prawdopodobieństw - zbiór treningowy",
x = "P(dochód powyżej 50 tys. USD)",
y = "Liczba obserwacji",
fill = "Klasa"
) +
theme_minimal()
Powyższy histogram przedstawia rozkład przewidywanych prawdopodobieństw w podziale na rzeczywiste klasy. Im bardziej rozkłady dla klas 0 i 1 są od siebie oddzielone, tym lepsza jest zdolność modelu do rozróżniania osób o niższych i wyższych dochodach.
Następnie model zastosowano do zbioru testowego, który nie był wykorzystywany podczas estymacji. Dzięki temu można ocenić, jak model radzi sobie na nowych danych.
pred.glm.test.prob <- predict(
glm.full.salary,
newdata = salary.test,
type = "response"
)
head(pred.glm.test.prob)
## 1 6 8 12 24 25
## 0.32515206 0.83682007 0.55955963 0.43759321 0.03727325 0.52690441
summary(pred.glm.test.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.00000 0.01922 0.11915 0.24933 0.42175 0.99545
Po otrzymaniu prawdopodobieństw dla zbioru testowego zastosowano ten sam próg odcięcia, który został wyznaczony na podstawie zbioru treningowego. Takie podejście jest poprawne, ponieważ próg nie powinien być dobierany na podstawie danych testowych.
pred.glm.test.class <- ifelse(
pred.glm.test.prob >= optymalny_prog_glm,
1,
0
)
pred.glm.test.class <- factor(pred.glm.test.class, levels = c(0, 1))
salary.test$klasa_dochodu_01 <- factor(salary.test$klasa_dochodu_01, levels = c(0, 1))
cm.test <- confusionMatrix(
pred.glm.test.class,
salary.test$klasa_dochodu_01,
positive = "1"
)
cm.test
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 5227 296
## 1 1568 1957
##
## Accuracy : 0.794
## 95% CI : (0.7855, 0.8023)
## No Information Rate : 0.751
## P-Value [Acc > NIR] : < 0.00000000000000022
##
## Kappa : 0.5366
##
## Mcnemar's Test P-Value : < 0.00000000000000022
##
## Sensitivity : 0.8686
## Specificity : 0.7692
## Pos Pred Value : 0.5552
## Neg Pred Value : 0.9464
## Prevalence : 0.2490
## Detection Rate : 0.2163
## Detection Prevalence : 0.3896
## Balanced Accuracy : 0.8189
##
## 'Positive' Class : 1
##
cm.test$table
## Reference
## Prediction 0 1
## 0 5227 296
## 1 1568 1957
Ocena na zbiorze testowym jest kluczowa, ponieważ pozwala sprawdzić zdolność generalizacji modelu. Jeżeli wyniki na zbiorze treningowym i testowym są do siebie zbliżone, można uznać, że model nie wykazuje silnych oznak przeuczenia.
ggplot(
data.frame(
p = pred.glm.test.prob,
klasa = salary.test$klasa_dochodu_01
),
aes(x = p, fill = klasa)
) +
geom_histogram(alpha = 0.5, bins = 30, position = "identity") +
labs(
title = "Rozkład przewidywanych prawdopodobieństw - zbiór testowy",
x = "P(dochód powyżej 50 tys. USD)",
y = "Liczba obserwacji",
fill = "Klasa"
) +
theme_minimal()
W celu uporządkowania oceny jakości modeli przygotowano funkcję, która dla zadanych prawdopodobieństw, wartości rzeczywistych oraz progu odcięcia oblicza najważniejsze metryki klasyfikacji binarnej. Funkcja zwraca między innymi Accuracy, Sensitivity, Specificity, Balanced Accuracy, PPV, NPV, F1 oraz Kappa.
klasyfikacja.metryki <- function(predicted_probabilities,
real,
cutoff = 0.5,
level_positive,
level_negative) {
predicted_class <- ifelse(
predicted_probabilities >= cutoff,
level_positive,
level_negative
)
predicted_class <- factor(
predicted_class,
levels = c(level_negative, level_positive)
)
real <- factor(
real,
levels = c(level_negative, level_positive)
)
ctable <- confusionMatrix(
data = predicted_class,
reference = real,
positive = as.character(level_positive)
)
statystyki <- c(
Accuracy = unname(ctable$overall["Accuracy"]),
Sensitivity = unname(ctable$byClass["Sensitivity"]),
Specificity = unname(ctable$byClass["Specificity"]),
BalancedAccuracy = unname(ctable$byClass["Balanced Accuracy"]),
PPV = unname(ctable$byClass["Pos Pred Value"]),
NPV = unname(ctable$byClass["Neg Pred Value"]),
F1 = unname(ctable$byClass["F1"]),
Kappa = unname(ctable$overall["Kappa"])
)
wynik <- round(statystyki, 4)
return(wynik)
}
Następnie obliczono metryki jakości dla zbioru treningowego i testowego przy wykorzystaniu optymalnego progu odcięcia.
metryki.glm.train <- klasyfikacja.metryki(
predicted_probabilities = pred.glm.train.prob,
real = salary.train$klasa_dochodu_01,
cutoff = optymalny_prog_glm,
level_positive = "1",
level_negative = "0"
)
metryki.glm.test <- klasyfikacja.metryki(
predicted_probabilities = pred.glm.test.prob,
real = salary.test$klasa_dochodu_01,
cutoff = optymalny_prog_glm,
level_positive = "1",
level_negative = "0"
)
rbind(
train = metryki.glm.train,
test = metryki.glm.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.7936 0.8615 0.7711 0.8163 0.5550 0.9438 0.6751
## test 0.7940 0.8686 0.7692 0.8189 0.5552 0.9464 0.6774
## Kappa
## train 0.5340
## test 0.5366
Porównanie wyników dla zbioru treningowego i testowego pozwala ocenić stabilność modelu. Co ważne niewidoczny jest overfiting, model przyjmuje również całkiem zadowaloające wartości predyckyjne. A co zaskakujące minimalnie lepsze Sensivity przyjmuje nawet dla danych testowych w porównaniu z treningowymi.
Na końcu porównano wyniki modelu dla standardowego progu 0.5 oraz progu optymalnego wyznaczonego na podstawie krzywej ROC.
metryki.glm.test.standard <- klasyfikacja.metryki(
predicted_probabilities = pred.glm.test.prob,
real = salary.test$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "1",
level_negative = "0"
)
rbind(
logit.cutoff.0.5 = metryki.glm.test.standard,
logit.cutoff.optymalny = metryki.glm.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV
## logit.cutoff.0.5 0.8443 0.5992 0.9255 0.7624 0.7274
## logit.cutoff.optymalny 0.7940 0.8686 0.7692 0.8189 0.5552
## NPV F1 Kappa
## logit.cutoff.0.5 0.8744 0.6571 0.5576
## logit.cutoff.optymalny 0.9464 0.6774 0.5366
Zauważmy, jak zmiana progu z domyślnego 0.5 na wyliczone 0.2246 zmieniła wyniki dla modelu Ogólne Accuracy minimalnie spadło (z około 84,4% na 79,40%), ale za to Sensitivity (czułość) drastycznie wzrosła! Model z nowym progiem poprawnie wykrywa ponad 86% osób zarabiających powyżej 50k (TP: 4527), podczas gdy na progu 0.5 wykrywał ich tylko około 59% (wtedy TP wynosiło 3107). To świetny dowód na to, dlaczego szukanie optymalnego progu odcięcia ma tak ogromne znaczenie na niezbalansowanych zbiorach.
Drugim modelem benchmarkowym wykorzystanym w analizie jest drzewo klasyfikacyjne CART. W przeciwieństwie do modelu logitowego, drzewo decyzyjne nie zakłada liniowej zależności między zmiennymi objaśniającymi a zmienną celu. Model ten dzieli obserwacje na coraz bardziej jednorodne grupy, wykorzystując kolejne reguły decyzyjne. Jego zaletą jest wysoka interpretowalność, ponieważ wyniki można przedstawić w postaci graficznego drzewa.
Na potrzeby modeli uczenia maszynowego przygotowano osobny zbiór
danych. W przypadku drzew decyzyjnych zmienna celu pozostaje faktorem z
poziomami nie oraz tak, ponieważ modele
klasyfikacyjne w pakiecie rpart oraz caret
wymagają, aby zmienna objaśniana była zmienną kategoryczną.
salary_ML <- salary_pl_3 %>%
dplyr::select(-c(kraj_pochodzenia_grupa, poziom_edukacji_grupa, klasa_dochodu))
Ponadto zmienne dla edukacji postanowiono odgrupować, aby modele ML miały cały zakres zmiennych do dyspozycji
salary_ML <- salary_ML %>%
mutate(
poziom_edukacji_factor = case_when(
poziom_edukacji == 1 ~ "przedszkole",
poziom_edukacji == 2 ~ "klasy_1_4",
poziom_edukacji == 3 ~ "klasy_5_6",
poziom_edukacji == 4 ~ "klasy_7_8",
poziom_edukacji == 5 ~ "klasa_9",
poziom_edukacji == 6 ~ "klasa_10",
poziom_edukacji == 7 ~ "klasa_11",
poziom_edukacji == 8 ~ "klasa_12",
poziom_edukacji == 9 ~ "szkoła_średnia",
poziom_edukacji == 10 ~ "część_studiów",
poziom_edukacji == 11 ~ "associate_zawodowe",
poziom_edukacji == 12 ~ "associate_akademickie",
poziom_edukacji == 13 ~ "licencjat",
poziom_edukacji == 14 ~ "magister",
poziom_edukacji == 15 ~ "szkoła_profesjonalna",
poziom_edukacji == 16 ~ "doktorat",
TRUE ~ NA_character_
),
poziom_edukacji_factor = factor(
poziom_edukacji_factor,
levels = c(
"przedszkole",
"klasy_1_4",
"klasy_5_6",
"klasy_7_8",
"klasa_9",
"klasa_10",
"klasa_11",
"klasa_12",
"szkoła_średnia",
"część_studiów",
"associate_zawodowe",
"associate_akademickie",
"licencjat",
"magister",
"szkoła_profesjonalna",
"doktorat"
),
ordered = FALSE
)
)
salary_ML <- salary_ML %>%
dplyr::select(-poziom_edukacji)
Nie potraktowano poziomu edukacji jako zmiennej porządkowej, ponieważ
niektóre poziomy nie muszą tworzyć jednoznacznej hierarchii jakościowej.
Przykładowo poziom szkoła_profesjonalna nie musi być wprost
porównywalny z poziomem licencjat, dlatego zmienną
potraktowano jako zwykły factor.
salary_ML$klasa_dochodu_01 <- factor(
salary_ML$klasa_dochodu_01,
levels = c("nie", "tak")
)
set.seed(123)
trening_ML <- createDataPartition(
salary_ML$klasa_dochodu_01,
p = 0.7,
list = FALSE
)
salary.train.ML <- salary_ML[c(trening_ML), ]
salary.test.ML <- salary_ML[-c(trening_ML), ]
W pierwszym kroku oszacowano podstawowe drzewo klasyfikacyjne z
domyślnymi ustawieniami funkcji rpart(). Model ten stanowi
punkt odniesienia dla późniejszego strojenia hiperparametrów.
tree.salary.default <- rpart(
klasa_dochodu_01 ~ .,
data = salary.train.ML,
method = "class"
)
tree.salary.default
## n= 21114
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 21114 5256 nie (0.75106564 0.24893436)
## 2) relacja_w_rodzinie=inny_krewny,niezamężny_niezonaty,poza_rodziną,własne_dziecko 11367 766 nie (0.93261195 0.06738805)
## 4) zysk_kapitalowy< 7073.5 11166 570 nie (0.94895218 0.05104782) *
## 5) zysk_kapitalowy>=7073.5 201 5 tak (0.02487562 0.97512438) *
## 3) relacja_w_rodzinie=mąż,żona 9747 4490 nie (0.53934544 0.46065456)
## 6) poziom_edukacji_factor=przedszkole,klasy_1_4,klasy_5_6,klasy_7_8,klasa_9,klasa_10,klasa_11,klasa_12,szkoła_średnia,część_studiów,associate_zawodowe,associate_akademickie 6856 2355 nie (0.65650525 0.34349475)
## 12) zysk_kapitalowy< 5095.5 6500 2006 nie (0.69138462 0.30861538) *
## 13) zysk_kapitalowy>=5095.5 356 7 tak (0.01966292 0.98033708) *
## 7) poziom_edukacji_factor=licencjat,magister,szkoła_profesjonalna,doktorat 2891 756 tak (0.26150121 0.73849879) *
summary(tree.salary.default)
## Call:
## rpart(formula = klasa_dochodu_01 ~ ., data = salary.train.ML,
## method = "class")
## n= 21114
##
## CP nsplit rel error xerror xstd
## 1 0.13118341 0 1.0000000 1.0000000 0.01195395
## 2 0.06506849 2 0.7376332 0.7376332 0.01070380
## 3 0.03633942 3 0.6725647 0.6725647 0.01032170
## 4 0.01000000 4 0.6362253 0.6362253 0.01009337
##
## Variable importance
## relacja_w_rodzinie stan_cywilny kapital_netto
## 22 22 9
## zysk_kapitalowy poziom_edukacji_factor plec
## 9 9 7
## zawod godziny_razy_edukacja wiek
## 6 5 5
## wiek_sredni czy_kapital_netto_dodatni
## 4 1
##
## Node number 1: 21114 observations, complexity param=0.1311834
## predicted class=nie expected loss=0.2489344 P(node) =1
## class counts: 15858 5256
## probabilities: 0.751 0.249
## left son=2 (11367 obs) right son=3 (9747 obs)
## Primary splits:
## relacja_w_rodzinie splits as LRLLLR, improve=1623.1180, (0 missing)
## stan_cywilny splits as RRLLLLL, improve=1597.2620, (0 missing)
## zysk_kapitalowy < 5119 to the left, improve=1081.9280, (0 missing)
## kapital_netto < 5119 to the left, improve=1081.9280, (0 missing)
## godziny_razy_edukacja < 488 to the left, improve= 933.1938, (0 missing)
## Surrogate splits:
## stan_cywilny splits as RRLLLLL, agree=0.993, adj=0.986, (0 split)
## plec splits as LR, agree=0.690, adj=0.327, (0 split)
## wiek < 33.5 to the left, agree=0.645, adj=0.231, (0 split)
## zawod splits as LLRLLLRRLRLRRL, agree=0.619, adj=0.174, (0 split)
## wiek_sredni splits as LR, agree=0.615, adj=0.166, (0 split)
##
## Node number 2: 11367 observations, complexity param=0.03633942
## predicted class=nie expected loss=0.06738805 P(node) =0.5383632
## class counts: 10601 766
## probabilities: 0.933 0.067
## left son=4 (11166 obs) right son=5 (201 obs)
## Primary splits:
## zysk_kapitalowy < 7073.5 to the left, improve=337.20480, (0 missing)
## kapital_netto < 7073.5 to the left, improve=337.20480, (0 missing)
## godziny_razy_edukacja < 559.5 to the left, improve=128.56960, (0 missing)
## czy_kapital_netto_dodatni splits as RL, improve=106.51450, (0 missing)
## poziom_edukacji_factor splits as LLLLLLLLLLLLLRRR, improve= 95.72629, (0 missing)
## Surrogate splits:
## kapital_netto < 7073.5 to the left, agree=1, adj=1, (0 split)
##
## Node number 3: 9747 observations, complexity param=0.1311834
## predicted class=nie expected loss=0.4606546 P(node) =0.4616368
## class counts: 5257 4490
## probabilities: 0.539 0.461
## left son=6 (6856 obs) right son=7 (2891 obs)
## Primary splits:
## poziom_edukacji_factor splits as LLLLLLLLLLLLRRRR, improve=634.5721, (0 missing)
## zawod splits as RLRLLLLLRRRLRR, improve=600.1650, (0 missing)
## godziny_razy_edukacja < 482.5 to the left, improve=559.9230, (0 missing)
## zysk_kapitalowy < 5095.5 to the left, improve=492.1049, (0 missing)
## kapital_netto < 5095.5 to the left, improve=492.1049, (0 missing)
## Surrogate splits:
## godziny_razy_edukacja < 517 to the left, agree=0.880, adj=0.597, (0 split)
## zawod splits as LLRLLLLLRRLLLL, agree=0.792, adj=0.300, (0 split)
## zysk_kapitalowy < 7493 to the left, agree=0.720, adj=0.055, (0 split)
## kapital_netto < 7493 to the left, agree=0.720, adj=0.055, (0 split)
## kraj_pochodzenia splits as RRLLRLLLL-RLRRLLRLLLLRLLLLLLLLLLLLLRLLLLL, agree=0.713, adj=0.031, (0 split)
##
## Node number 4: 11166 observations
## predicted class=nie expected loss=0.05104782 P(node) =0.5288434
## class counts: 10596 570
## probabilities: 0.949 0.051
##
## Node number 5: 201 observations
## predicted class=tak expected loss=0.02487562 P(node) =0.00951975
## class counts: 5 196
## probabilities: 0.025 0.975
##
## Node number 6: 6856 observations, complexity param=0.06506849
## predicted class=nie expected loss=0.3434947 P(node) =0.3247135
## class counts: 4501 2355
## probabilities: 0.657 0.343
## left son=12 (6500 obs) right son=13 (356 obs)
## Primary splits:
## zysk_kapitalowy < 5095.5 to the left, improve=304.5799, (0 missing)
## kapital_netto < 5095.5 to the left, improve=304.5799, (0 missing)
## zawod splits as RLRLLLLL-RRLRR, improve=174.9527, (0 missing)
## godziny_razy_edukacja < 347 to the left, improve=143.5213, (0 missing)
## poziom_edukacji_factor splits as LLLLLLLRRRRR----, improve=121.2879, (0 missing)
## Surrogate splits:
## kapital_netto < 5095.5 to the left, agree=1.000, adj=1.000, (0 split)
## czy_kapital_netto_dodatni splits as RL, agree=0.956, adj=0.143, (0 split)
##
## Node number 7: 2891 observations
## predicted class=tak expected loss=0.2615012 P(node) =0.1369234
## class counts: 756 2135
## probabilities: 0.262 0.738
##
## Node number 12: 6500 observations
## predicted class=nie expected loss=0.3086154 P(node) =0.3078526
## class counts: 4494 2006
## probabilities: 0.691 0.309
##
## Node number 13: 356 observations
## predicted class=tak expected loss=0.01966292 P(node) =0.01686085
## class counts: 7 349
## probabilities: 0.020 0.980
Poniżej przedstawiono podstawową wizualizację drzewa decyzyjnego.
rpart.plot(tree.salary.default)
Dla większej czytelności przygotowano również bardziej rozbudowaną wizualizację drzewa, pokazującą dodatkowe informacje w węzłach końcowych.
rpart.plot(
tree.salary.default,
type = 2,
extra = 104,
fallen.leaves = TRUE,
nn = TRUE,
box.palette = "Blues",
branch.lty = 2,
shadow.col = "gray"
)
Następnie wyznaczono przewidywane prawdopodobieństwa przynależności
do klasy tak, osobno dla zbioru treningowego i
testowego.
pred.tree.default.train.prob <- predict(
tree.salary.default,
newdata = salary.train.ML,
type = "prob"
)[, "tak"]
pred.tree.default.test.prob <- predict(
tree.salary.default,
newdata = salary.test.ML,
type = "prob"
)[, "tak"]
Ocena drzewa domyślnego została przeprowadzona przy domyślnym progu odcięcia równym 0.5.
metryki.tree.default.train <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.default.train.prob,
real = salary.train.ML$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
metryki.tree.default.test <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.default.test.prob,
real = salary.test.ML$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.tree.default.train,
test = metryki.tree.default.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.8416 0.5099 0.9516 0.7307 0.7773 0.8542 0.6158
## test 0.8397 0.5115 0.9485 0.7300 0.7670 0.8542 0.6137
## Kappa
## train 0.5214
## test 0.5177
Wyniki drzewa domyślnego są bardzo podobne do wyników modelu logitowego przy standardowym progu odcięcia 0.5. Oznacza to, że proste drzewo jest w stanie uchwycić podstawowe zależności w danych, jednak podobnie jak logit przy progu 0.5 może mieć problem z odpowiednio dobrym wykrywaniem klasy mniejszościowej.
Podobnie jak w przypadku modelu logitowego, dla drzewa decyzyjnego wyznaczono optymalny próg odcięcia na podstawie krzywej ROC i metody Youdena.
krzywa_roc_tree <- pROC::roc(
response = salary.train.ML$klasa_dochodu_01,
predictor = pred.tree.default.train.prob
)
optymalny_prog_tree <- pROC::coords(
krzywa_roc_tree,
x = "best",
best.method = "youden",
ret = c("threshold", "specificity", "sensitivity")
)
prog_tree <- optymalny_prog_tree$threshold[1]
print(paste("Optymalny próg dla drzewa to:", round(prog_tree, 4)))
## [1] "Optymalny próg dla drzewa to: 0.1798"
Optymalny próg dla drzewa domyślnego wyniósł 0,1798.
Jest to wartość niższa od standardowego progu 0.5, co wynika z
niezbalansowania klas. Obniżenie progu zwiększa skłonność modelu do
klasyfikowania obserwacji jako tak, a więc poprawia
wykrywanie osób zarabiających powyżej 50 tys. USD.
metryki.tree.optimal.train <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.default.train.prob,
real = salary.train.ML$klasa_dochodu_01,
cutoff = prog_tree,
level_positive = "tak",
level_negative = "nie"
)
metryki.tree.optimal.test <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.default.test.prob,
real = salary.test.ML$klasa_dochodu_01,
cutoff = prog_tree,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.tree.optimal.train,
test = metryki.tree.optimal.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.7238 0.8916 0.6682 0.7799 0.4710 0.949 0.6164
## test 0.7216 0.8779 0.6698 0.7738 0.4684 0.943 0.6108
## Kappa
## train 0.4311
## test 0.4238
Przeprowadzona analiza klasyfikacyjna na zbiorze danych Adult Census Income pozwoliła na porównanie dwóch odmiennych podejść algorytmicznych: tradycyjnego modelu regresji logistycznej oraz nieparametrycznego drzewa decyzyjnego CART. Ze względu na silne niezbalansowanie klas, w proporcji około 75% do 25%, kluczowym elementem podnoszącym jakość obu modeli była rezygnacja z domyślnego progu klasyfikacji 0.5 na rzecz progów optymalizowanych metodą Youdena. Progi te wyniosły około 0,2346 dla logitu oraz około 0,18 dla drzewa.
W ostatecznym porównaniu na zbiorze testowym model regresji logistycznej okazał się silniejszy ze względu na zbalansowaną dokładność. Model logitowy osiągnął Balanced Accuracy około 81,9%, podczas gdy proste drzewo uzyskało około 72,16%.
Jednocześnie drzewo decyzyjne osiągnęło bardzo wysoką czułość. Czułość modelu logitowego wyniosła około 86,8%, natomiast dla drzewa było to około 87,79%. Oznacza to, że jeżeli najważniejszym celem byłoby wykrywanie klasy pozytywnej, czyli osób zarabiających powyżej 50 tys. USD, drzewo może być atrakcyjnym rozwiązaniem.
Warto jednak podkreślić brak wyraźnego zjawiska przeuczenia w obu modelach, ponieważ metryki na zbiorach treningowych i testowych są do siebie zbliżone. Oznacza to, że modele zachowują zdolność generalizacji na nowych danych. Logit osiąga lepszą skuteczność predykcyjną w sensie zbalansowanej dokładności, natomiast drzewo decyzyjne oferuje większą interpretowalność i możliwość prześledzenia ścieżki decyzyjnej.
Domyślne drzewo okazało się dość proste, dlatego w kolejnym kroku zbudowano model z kontrolą parametrów wzrostu drzewa. Celem było sprawdzenie, czy odpowiednie strojenie hiperparametrów pozwoli poprawić jakość klasyfikacji.
W modelach drzewiastych szczególne znaczenie mają następujące parametry:
cp, czyli complexity parameter, określający minimalną
poprawę jakości modelu wymaganą do dodania kolejnego podziału,minsplit, czyli minimalna liczba obserwacji w węźle
potrzebna do rozważenia dalszego podziału,minbucket, czyli minimalna liczba obserwacji w liściu
końcowym,maxdepth, czyli maksymalna głębokość drzewa,xval, czyli liczba części wykorzystywanych w
wewnętrznej walidacji krzyżowej.W niniejszej pracy zastosowano kontrolę złożoności drzewa. Parametr
cp dobrano automatycznie przy użyciu walidacji
krzyżowej.
salary.train.caret <- salary.train.ML
salary.test.caret <- salary.test.ML
salary.train.caret$klasa_dochodu_01 <- factor(
salary.train.caret$klasa_dochodu_01,
labels = c("nie", "tak")
)
salary.test.caret$klasa_dochodu_01 <- factor(
salary.test.caret$klasa_dochodu_01,
labels = c("nie", "tak")
)
ctrl <- trainControl(
method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = "final"
)
grid.cp <- expand.grid(
cp = seq(0.0000, 0.001, length = 50)
)
Następnie przeprowadzono strojenie drzewa decyzyjnego. Jako główną metrykę optymalizacji przyjęto ROC, ponieważ w przypadku niezbalansowanych danych jest ona bardziej informacyjna niż sama ogólna trafność klasyfikacji.
set.seed(123)
tree.salary.tuned <- train(
klasa_dochodu_01 ~ .,
data = salary.train.caret,
method = "rpart",
trControl = ctrl,
tuneGrid = grid.cp,
metric = "ROC",
control = rpart.control(
minsplit = 30,
minbucket = 10,
maxdepth = 10
)
)
tree.salary.tuned$bestTune
## cp
## 3 0.00004081633
plot(tree.salary.tuned)
W celu maksymalizacji skuteczności drzewa decyzyjnego przeprowadzono
proces strojenia hiperparametrów metodą walidacji krzyżowej. Szukano
optymalnej wartości parametru złożoności cp, który stanowi
karę za nadmierny rozrost drzewa. Im wyższy parametr cp,
tym silniej drzewo jest przycinane.
Z analizy wynika, że najwyższą skuteczność rozróżniania klas, mierzoną wartością ROC na poziomie około 0,89, osiągnięto dla bardzo niskiej wartości parametru: cp = 0.00004081633. Zwiększanie tej wartości, czyli silniejsze obcinanie gałęzi, prowadziło do widocznego spadku wydajności modelu. Oznacza to, że struktura danych jest dość złożona i wymaga zbudowania bardziej rozbudowanego drzewa, aby precyzyjnie uchwycić wzorce decydujące o wysokich zarobkach.
Jednak wizualizacja w tym przypadku nie przynosi tak naprawdę żadnych rezultatów. Obraz jest po prostu niewidoczny.
tree.salary.tuned$finalModel
## n= 21114
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 21114 5256 nie (0.751065644 0.248934356)
## 2) relacja_w_rodziniemąż< 0.5 12368 1269 nie (0.897396507 0.102603493)
## 4) zysk_kapitalowy< 7073.5 12087 994 nie (0.917762886 0.082237114)
## 8) relacja_w_rodzinieżona< 0.5 11166 570 nie (0.948952176 0.051047824)
## 16) godziny_razy_edukacja< 558.5 9823 285 nie (0.970986460 0.029013540)
## 32) strata_kapitalowa< 2218.5 9779 263 nie (0.973105635 0.026894365) *
## 33) strata_kapitalowa>=2218.5 44 22 nie (0.500000000 0.500000000)
## 66) godziny_razy_edukacja< 370 21 6 nie (0.714285714 0.285714286) *
## 67) godziny_razy_edukacja>=370 23 7 tak (0.304347826 0.695652174) *
## 17) godziny_razy_edukacja>=558.5 1343 285 nie (0.787788533 0.212211467)
## 34) wiek< 27.5 262 10 nie (0.961832061 0.038167939) *
## 35) wiek>=27.5 1081 275 nie (0.745605920 0.254394080)
## 70) strata_kapitalowa< 2351 1068 262 nie (0.754681648 0.245318352)
## 140) zawodkadra_kierownicza< 0.5 835 170 nie (0.796407186 0.203592814)
## 280) poziom_edukacji_factorszkoła_profesjonalna< 0.5 766 140 nie (0.817232376 0.182767624) *
## 281) poziom_edukacji_factorszkoła_profesjonalna>=0.5 69 30 nie (0.565217391 0.434782609)
## 562) wiek< 32.5 20 3 nie (0.850000000 0.150000000) *
## 563) wiek>=32.5 49 22 tak (0.448979592 0.551020408)
## 1126) godziny_pracy_tygodniowo< 44 23 8 nie (0.652173913 0.347826087) *
## 1127) godziny_pracy_tygodniowo>=44 26 7 tak (0.269230769 0.730769231) *
## 141) zawodkadra_kierownicza>=0.5 233 92 nie (0.605150215 0.394849785)
## 282) poziom_edukacji_factormagister< 0.5 167 57 nie (0.658682635 0.341317365)
## 564) wiek< 39.5 87 22 nie (0.747126437 0.252873563) *
## 565) wiek>=39.5 80 35 nie (0.562500000 0.437500000)
## 1130) godziny_pracy_tygodniowo>=59 31 6 nie (0.806451613 0.193548387) *
## 1131) godziny_pracy_tygodniowo< 59 49 20 tak (0.408163265 0.591836735) *
## 283) poziom_edukacji_factormagister>=0.5 66 31 tak (0.469696970 0.530303030)
## 566) typ_zatrudnieniasektor_prywatny< 0.5 26 9 nie (0.653846154 0.346153846) *
## 567) typ_zatrudnieniasektor_prywatny>=0.5 40 14 tak (0.350000000 0.650000000)
## 1134) wiek< 36 12 4 nie (0.666666667 0.333333333) *
## 1135) wiek>=36 28 6 tak (0.214285714 0.785714286) *
## 71) strata_kapitalowa>=2351 13 0 tak (0.000000000 1.000000000) *
## 9) relacja_w_rodzinieżona>=0.5 921 424 nie (0.539630836 0.460369164)
## 18) godziny_razy_edukacja< 374 455 151 nie (0.668131868 0.331868132)
## 36) strata_kapitalowa< 1794 438 136 nie (0.689497717 0.310502283)
## 72) wiek< 34.5 139 23 nie (0.834532374 0.165467626) *
## 73) wiek>=34.5 299 113 nie (0.622073579 0.377926421)
## 146) zawodinne_usługi>=0.5 61 9 nie (0.852459016 0.147540984) *
## 147) zawodinne_usługi< 0.5 238 104 nie (0.563025210 0.436974790)
## 294) wiek>=45.5 126 41 nie (0.674603175 0.325396825)
## 588) zawodkadra_kierownicza< 0.5 110 30 nie (0.727272727 0.272727273) *
## 589) zawodkadra_kierownicza>=0.5 16 5 tak (0.312500000 0.687500000) *
## 295) wiek< 45.5 112 49 tak (0.437500000 0.562500000)
## 590) zawodoperatorzy_maszyn_i_inspektorzy>=0.5 13 4 nie (0.692307692 0.307692308) *
## 591) zawodoperatorzy_maszyn_i_inspektorzy< 0.5 99 40 tak (0.404040404 0.595959596)
## 1182) zawodspecjaliści>=0.5 21 9 nie (0.571428571 0.428571429) *
## 1183) zawodspecjaliści< 0.5 78 28 tak (0.358974359 0.641025641) *
## 37) strata_kapitalowa>=1794 17 2 tak (0.117647059 0.882352941) *
## 19) godziny_razy_edukacja>=374 466 193 tak (0.414163090 0.585836910)
## 38) zawodspecjaliści< 0.5 331 161 tak (0.486404834 0.513595166)
## 76) zawodkadra_kierownicza< 0.5 222 94 nie (0.576576577 0.423423423)
## 152) godziny_pracy_tygodniowo>=43 64 14 nie (0.781250000 0.218750000) *
## 153) godziny_pracy_tygodniowo< 43 158 78 tak (0.493670886 0.506329114)
## 306) wiek< 33.5 61 22 nie (0.639344262 0.360655738)
## 612) typ_zatrudnieniasektor_prywatny>=0.5 51 16 nie (0.686274510 0.313725490)
## 1224) godziny_razy_edukacja< 500 40 10 nie (0.750000000 0.250000000) *
## 1225) godziny_razy_edukacja>=500 11 5 tak (0.454545455 0.545454545) *
## 613) typ_zatrudnieniasektor_prywatny< 0.5 10 4 tak (0.400000000 0.600000000) *
## 307) wiek>=33.5 97 39 tak (0.402061856 0.597938144)
## 614) rasabiała< 0.5 14 6 nie (0.571428571 0.428571429) *
## 615) rasabiała>=0.5 83 31 tak (0.373493976 0.626506024) *
## 77) zawodkadra_kierownicza>=0.5 109 33 tak (0.302752294 0.697247706)
## 154) wiek_sredniw_grupie_wieku_średniego< 0.5 30 13 nie (0.566666667 0.433333333)
## 308) poziom_edukacji_factorlicencjat< 0.5 14 4 nie (0.714285714 0.285714286) *
## 309) poziom_edukacji_factorlicencjat>=0.5 16 7 tak (0.437500000 0.562500000) *
## 155) wiek_sredniw_grupie_wieku_średniego>=0.5 79 16 tak (0.202531646 0.797468354) *
## 39) zawodspecjaliści>=0.5 135 32 tak (0.237037037 0.762962963) *
## 5) zysk_kapitalowy>=7073.5 281 6 tak (0.021352313 0.978647687) *
## 3) relacja_w_rodziniemąż>=0.5 8746 3987 nie (0.544134461 0.455865539)
## 6) godziny_razy_edukacja< 488 5152 1587 nie (0.691964286 0.308035714)
## 12) zysk_kapitalowy< 5095.5 4916 1356 nie (0.724165989 0.275834011)
## 24) godziny_razy_edukacja< 307 1134 122 nie (0.892416226 0.107583774)
## 48) strata_kapitalowa< 1780 1104 107 nie (0.903079710 0.096920290) *
## 49) strata_kapitalowa>=1780 30 15 nie (0.500000000 0.500000000)
## 98) kapital_netto< -1989.5 17 4 nie (0.764705882 0.235294118) *
## 99) kapital_netto>=-1989.5 13 2 tak (0.153846154 0.846153846) *
## 25) godziny_razy_edukacja>=307 3782 1234 nie (0.673717610 0.326282390)
## 50) wiek< 35.5 1234 237 nie (0.807941653 0.192058347)
## 100) wiek< 26.5 272 21 nie (0.922794118 0.077205882) *
## 101) wiek>=26.5 962 216 nie (0.775467775 0.224532225)
## 202) strata_kapitalowa< 1794 935 199 nie (0.787165775 0.212834225)
## 404) poziom_edukacji_factorszkoła_średnia>=0.5 559 93 nie (0.833631485 0.166368515) *
## 405) poziom_edukacji_factorszkoła_średnia< 0.5 376 106 nie (0.718085106 0.281914894)
## 810) zawodspecjaliści< 0.5 353 93 nie (0.736543909 0.263456091) *
## 811) zawodspecjaliści>=0.5 23 10 tak (0.434782609 0.565217391) *
## 203) strata_kapitalowa>=1794 27 10 tak (0.370370370 0.629629630) *
## 51) wiek>=35.5 2548 997 nie (0.608712716 0.391287284)
## 102) strata_kapitalowa< 1794 2435 911 nie (0.625872690 0.374127310)
## 204) godziny_razy_edukacja< 376.5 1117 345 nie (0.691136974 0.308863026)
## 408) zawodkadra_kierownicza< 0.5 1017 297 nie (0.707964602 0.292035398)
## 816) wiek< 45.5 415 100 nie (0.759036145 0.240963855) *
## 817) wiek>=45.5 602 197 nie (0.672757475 0.327242525)
## 1634) zawodspecjaliści< 0.5 579 183 nie (0.683937824 0.316062176) *
## 1635) zawodspecjaliści>=0.5 23 9 tak (0.391304348 0.608695652) *
## 409) zawodkadra_kierownicza>=0.5 100 48 nie (0.520000000 0.480000000)
## 818) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej>=0.5 19 3 nie (0.842105263 0.157894737) *
## 819) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej< 0.5 81 36 tak (0.444444444 0.555555556)
## 1638) wiek< 59.5 65 32 nie (0.507692308 0.492307692) *
## 1639) wiek>=59.5 16 3 tak (0.187500000 0.812500000) *
## 205) godziny_razy_edukacja>=376.5 1318 566 nie (0.570561457 0.429438543) *
## 103) strata_kapitalowa>=1794 113 27 tak (0.238938053 0.761061947)
## 206) strata_kapitalowa>=1989.5 34 10 nie (0.705882353 0.294117647) *
## 207) strata_kapitalowa< 1989.5 79 3 tak (0.037974684 0.962025316) *
## 13) zysk_kapitalowy>=5095.5 236 5 tak (0.021186441 0.978813559) *
## 7) godziny_razy_edukacja>=488 3594 1194 tak (0.332220367 0.667779633)
## 14) zysk_kapitalowy< 5095.5 3115 1193 tak (0.382985554 0.617014446)
## 28) kapital_netto>=-1846 2817 1177 tak (0.417820376 0.582179624)
## 56) poziom_edukacji_factorszkoła_średnia>=0.5 357 127 nie (0.644257703 0.355742297)
## 112) wiek< 31.5 66 11 nie (0.833333333 0.166666667) *
## 113) wiek>=31.5 291 116 nie (0.601374570 0.398625430)
## 226) zysk_kapitalowy>=3120 11 0 nie (1.000000000 0.000000000) *
## 227) zysk_kapitalowy< 3120 280 116 nie (0.585714286 0.414285714)
## 454) zawodrolnictwo_i_rybołówstwo>=0.5 40 9 nie (0.775000000 0.225000000) *
## 455) zawodrolnictwo_i_rybołówstwo< 0.5 240 107 nie (0.554166667 0.445833333)
## 910) zawodsprzedaż< 0.5 190 78 nie (0.589473684 0.410526316) *
## 911) zawodsprzedaż>=0.5 50 21 tak (0.420000000 0.580000000)
## 1822) wiek< 45.5 24 11 nie (0.541666667 0.458333333) *
## 1823) wiek>=45.5 26 8 tak (0.307692308 0.692307692) *
## 57) poziom_edukacji_factorszkoła_średnia< 0.5 2460 947 tak (0.384959350 0.615040650)
## 114) wiek< 33.5 507 229 nie (0.548323471 0.451676529)
## 228) wiek< 29.5 209 74 nie (0.645933014 0.354066986) *
## 229) wiek>=29.5 298 143 tak (0.479865772 0.520134228)
## 458) godziny_razy_edukacja< 862.5 278 137 nie (0.507194245 0.492805755)
## 916) zawodkadra_kierownicza< 0.5 218 94 nie (0.568807339 0.431192661)
## 1832) godziny_pracy_tygodniowo>=45.5 113 38 nie (0.663716814 0.336283186) *
## 1833) godziny_pracy_tygodniowo< 45.5 105 49 tak (0.466666667 0.533333333) *
## 917) zawodkadra_kierownicza>=0.5 60 17 tak (0.283333333 0.716666667) *
## 459) godziny_razy_edukacja>=862.5 20 2 tak (0.100000000 0.900000000) *
## 115) wiek>=33.5 1953 669 tak (0.342549923 0.657450077)
## 230) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej>=0.5 263 117 nie (0.555133080 0.444866920)
## 460) poziom_edukacji_factorszkoła_profesjonalna< 0.5 228 88 nie (0.614035088 0.385964912)
## 920) zawodrolnictwo_i_rybołówstwo>=0.5 45 7 nie (0.844444444 0.155555556) *
## 921) zawodrolnictwo_i_rybołówstwo< 0.5 183 81 nie (0.557377049 0.442622951)
## 1842) godziny_pracy_tygodniowo>=49 126 48 nie (0.619047619 0.380952381) *
## 1843) godziny_pracy_tygodniowo< 49 57 24 tak (0.421052632 0.578947368) *
## 461) poziom_edukacji_factorszkoła_profesjonalna>=0.5 35 6 tak (0.171428571 0.828571429) *
## 231) typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej< 0.5 1690 523 tak (0.309467456 0.690532544) *
## 29) kapital_netto< -1846 298 16 tak (0.053691275 0.946308725)
## 58) strata_kapitalowa>=1989.5 51 15 tak (0.294117647 0.705882353)
## 116) strata_kapitalowa< 2384.5 20 5 nie (0.750000000 0.250000000) *
## 117) strata_kapitalowa>=2384.5 31 0 tak (0.000000000 1.000000000) *
## 59) strata_kapitalowa< 1989.5 247 1 tak (0.004048583 0.995951417) *
## 15) zysk_kapitalowy>=5095.5 479 1 tak (0.002087683 0.997912317) *
rpart.plot(
tree.salary.tuned$finalModel,
type = 2,
extra = 104,
fallen.leaves = TRUE,
nn = FALSE,
box.palette = "Blues",
shadow.col = 0,
cex = 0.8
)
pred.tree.tuned.train.prob <- predict(
tree.salary.tuned,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.tree.tuned.test.prob <- predict(
tree.salary.tuned,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
metryki.tree.tuned.train <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.tuned.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
metryki.tree.tuned.test <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.tuned.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.tree.tuned.train,
test = metryki.tree.tuned.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.8616 0.6056 0.9465 0.7760 0.7894 0.8786 0.6854
## test 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703 0.6517
## Kappa
## train 0.5987
## test 0.5545
Dla dostrojonego drzewa ponownie wyznaczono optymalny próg
klasyfikacji. Tym razem zastosowano funkcję cutpointr,
wykorzystując kryterium roc01.
cut.tree.tuned <- cutpointr(
data = data.frame(
pred = pred.tree.tuned.train.prob,
real = salary.train.caret$klasa_dochodu_01
),
x = pred,
class = real,
pos_class = "tak",
neg_class = "nie",
direction = ">=",
method = minimize_metric,
metric = roc01
)
cutoff.tree.tuned <- cut.tree.tuned$optimal_cutpoint
cutoff.tree.tuned
## [1] 0.2634561
Optymalny cutoff dla dostrojonego drzewa wyniósł około 0,26.
metryki.tree.tuned.train.opt <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.tuned.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.tree.tuned,
level_positive = "tak",
level_negative = "nie"
)
metryki.tree.tuned.test.opt <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.tuned.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.tree.tuned,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.tree.tuned.train.opt,
test = metryki.tree.tuned.test.opt
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.8219 0.8364 0.8171 0.8267 0.6024 0.9378 0.7004
## test 0.8147 0.8099 0.8162 0.8131 0.5936 0.9284 0.6851
## Kappa
## train 0.5784
## test 0.5581
rbind(
train.0.5 = metryki.tree.tuned.train,
test.0.5 = metryki.tree.tuned.test,
train.opt = metryki.tree.tuned.train.opt,
test.opt = metryki.tree.tuned.test.opt
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV
## train.0.5 0.8616 0.6056 0.9465 0.7760 0.7894 0.8786
## test.0.5 0.8457 0.5799 0.9338 0.7569 0.7437 0.8703
## train.opt 0.8219 0.8364 0.8171 0.8267 0.6024 0.9378
## test.opt 0.8147 0.8099 0.8162 0.8131 0.5936 0.9284
## F1 Kappa
## train.0.5 0.6854 0.5987
## test.0.5 0.6517 0.5545
## train.opt 0.7004 0.5784
## test.opt 0.6851 0.5581
Zastosowanie optymalnego progu odcięcia poprawiło wartość Balanced Accuracy w porównaniu z progiem 0.5. Jednak największa zmiana widoczna jest dla specifity.
porownanie.control.test <- rbind(
logit.0.5 = metryki.glm.test.standard,
logit.opt = metryki.glm.test,
tree.default.0.5 = metryki.tree.default.test,
tree.default.opt = metryki.tree.optimal.test,
tree.tuned.0.5 = metryki.tree.tuned.test,
tree.tuned.opt = metryki.tree.tuned.test.opt
)
porownanie.control.test
## Accuracy Sensitivity Specificity BalancedAccuracy PPV
## logit.0.5 0.8443 0.5992 0.9255 0.7624 0.7274
## logit.opt 0.7940 0.8686 0.7692 0.8189 0.5552
## tree.default.0.5 0.8397 0.5115 0.9485 0.7300 0.7670
## tree.default.opt 0.7216 0.8779 0.6698 0.7738 0.4684
## tree.tuned.0.5 0.8457 0.5799 0.9338 0.7569 0.7437
## tree.tuned.opt 0.8147 0.8099 0.8162 0.8131 0.5936
## NPV F1 Kappa
## logit.0.5 0.8744 0.6571 0.5576
## logit.opt 0.9464 0.6774 0.5366
## tree.default.0.5 0.8542 0.6137 0.5177
## tree.default.opt 0.9430 0.6108 0.4238
## tree.tuned.0.5 0.8703 0.6517 0.5545
## tree.tuned.opt 0.9284 0.6851 0.5581
Porównanie modeli pokazuje, że domyślny próg 0.5 może tworzyć złudzenie wysokiej skuteczności. Taki model osiąga wysoką ogólną trafność, ale głównie dlatego, że dobrze rozpoznaje klasę większościową. Jednocześnie może ignorować znaczną część osób zarabiających powyżej 50 tys. USD, czyli klasę mniejszościową.
Przejście na próg optymalny, mieszczący się w przybliżeniu w przedziale 0,18–0,26, pozwala znacząco zwiększyć czułość modelu. Oznacza to lepszą wykrywalność osób z wyższym dochodem, przy akceptowalnym spadku specyficzności. W efekcie rośnie również Balanced Accuracy.
Strojenie drzewa poprzez dobór parametru cp sprawiło, że
model stał się stabilniejszy. Po tuningu czułość i specyficzność
wyrównały się na poziomie około 80%, co czyni model
bardziej rzetelnym.
Ostatecznie zawsze powinniśmy wybrac model w zależności od decyzji biznesowej czy najbardziej zależy nam na wykrywaniu grupy pozytynwej czy negatywnej a może ogólnej Accuracy.
W kolejnym kroku porównano modele za pomocą krzywych ROC oraz indeksu Giniego. Dla każdego modelu wyznaczono przewidywane prawdopodobieństwa na zbiorze testowym, a następnie obliczono odpowiadające im krzywe ROC.
pred.test.logit <- predict(
glm.full.salary,
newdata = salary.test,
type = "response"
)
pred.test.tree.default <- predict(
tree.salary.default,
newdata = salary.test.ML,
type = "prob"
)[, "tak"]
pred.test.tree.tuned <- predict(
tree.salary.tuned,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
ROC.test.logit <- pROC::roc(
as.numeric(salary.test$klasa_dochodu_01) - 1,
pred.test.logit
)
ROC.test.tree.default <- pROC::roc(
salary.test.ML$klasa_dochodu_01,
pred.test.tree.default
)
ROC.test.tree.tuned <- pROC::roc(
salary.test.caret$klasa_dochodu_01,
pred.test.tree.tuned
)
list(
logit = ROC.test.logit,
tree.default = ROC.test.tree.default,
tree.tuned = ROC.test.tree.tuned
) %>%
ggroc(alpha = 0.6, linewidth = 1) +
geom_segment(
aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed"
) +
labs(
title = "ROC – zbiór testowy",
subtitle = paste0(
"Gini TEST: ",
"logit = ", round(100 * (2 * pROC::auc(ROC.test.logit) - 1), 1), "%, ",
"default = ", round(100 * (2 * pROC::auc(ROC.test.tree.default) - 1), 1), "%, ",
"tuned = ", round(100 * (2 * pROC::auc(ROC.test.tree.tuned) - 1), 1), "%"
)
) +
theme_bw() +
coord_fixed()
Na podstawie wartości indeksu Giniego można zauważyć, że model logitowy osiągnął bardzo wysoką zdolność rozróżniania klas, równą około 80,1%. Drzewo domyślne uzyskało wynik około 67,6%, natomiast drzewo dostrojone około 77,8%. Oznacza to, że w tym porównaniu logit charakteryzuje się najlepszą zdolnością separacji klas na zbiorze testowym.
Ostatnim krokiem w tej części analizy było sprawdzenie, czy zrównoważenie klas w zbiorze treningowym poprawi jakość drzewa klasyfikacyjnego. W tym celu zastosowano technikę downsamplingu, która polega na zmniejszeniu liczebności klasy większościowej tak, aby odpowiadała liczebności klasy mniejszościowej.
set.seed(123)
salary.train.down <- downSample(
x = salary.train.caret %>% dplyr::select(-klasa_dochodu_01),
y = salary.train.caret$klasa_dochodu_01,
yname = "klasa_dochodu_01"
)
table(salary.train.down$klasa_dochodu_01)
##
## nie tak
## 5256 5256
prop.table(table(salary.train.down$klasa_dochodu_01))
##
## nie tak
## 0.5 0.5
Po zastosowaniu downsamplingu klasy w zbiorze treningowym są zbalansowane. Następnie na tak przygotowanych danych ponownie oszacowano drzewo decyzyjne.
tree.downsampled <- train(
klasa_dochodu_01 ~ .,
data = salary.train.down,
method = "rpart",
trControl = ctrl,
tuneGrid = grid.cp,
metric = "ROC",
control = rpart.control(
minsplit = 30,
minbucket = 10,
maxdepth = 10
)
)
rpart.plot(
tree.downsampled$finalModel,
type = 2,
extra = 104,
fallen.leaves = TRUE,
nn = FALSE,
box.palette = "Blues",
shadow.col = 0,
cex = 0.8
)
pred.tree.down.train.prob <- predict(
tree.downsampled,
newdata = salary.train.down,
type = "prob"
)[, "tak"]
pred.tree.down.test.prob <- predict(
tree.downsampled,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
metryki.tree.down.train <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.down.train.prob,
real = salary.train.down$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
metryki.tree.down.test <- klasyfikacja.metryki(
predicted_probabilities = pred.tree.down.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = 0.5,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.tree.down.train,
test = metryki.tree.down.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.8394 0.8824 0.7964 0.8394 0.8125 0.8714 0.8460
## test 0.7928 0.8535 0.7727 0.8131 0.5544 0.9409 0.6721
## Kappa
## train 0.6788
## test 0.5304
Zastosowanie techniki downsamplingu poprawiło czułość modelu, czyli jego zdolność do wykrywania klasy pozytywnej. Jednocześnie spowodowało spadek zdolności generalizacyjnych, co objawia się rozbieżnością wyników między zbiorem treningowym a testowym. Można więc mówić o lekkim przeuczeniu modelu.
Dlatego też następnym krokiem analizy było skupienie się na Baggingu oraz Random Forescie.
W kolejnej części analizy wykorzystano modele zespołowe oparte na drzewach decyzyjnych: bagging oraz Random Forest. Pojedyncze drzewo decyzyjne jest łatwe do interpretacji, ale może być niestabilne — niewielka zmiana danych może prowadzić do powstania zupełnie innej struktury drzewa.
Bagging, czyli Bootstrap Aggregating, polega na losowaniu wielu prób bootstrapowych ze zbioru treningowego, budowaniu osobnego drzewa na każdej z nich, a następnie łączeniu predykcji wielu modeli. W przypadku klasyfikacji końcowa predykcja może być wyznaczana przez głosowanie większościowe albo przez uśrednienie przewidywanych prawdopodobieństw.
Random Forest jest rozwinięciem baggingu. Również wykorzystuje wiele drzew budowanych na próbach bootstrapowych, ale dodatkowo przy każdym podziale drzewa bierze pod uwagę tylko losowy podzbiór zmiennych. Dzięki temu drzewa są bardziej zróżnicowane, co zwykle poprawia stabilność i jakość predykcji.
Kluczowa różnica między tymi metodami polega więc na tym, że w baggingu każde drzewo przy każdym podziale może korzystać ze wszystkich zmiennych, natomiast w Random Forest dostępny jest tylko losowy podzbiór predyktorów.
W pierwszym kroku zbudowano model baggingowy. W tym celu w funkcji
randomForest() ustawiono parametr mtry równy
liczbie wszystkich zmiennych objaśniających. Oznacza to, że każde drzewo
przy każdym podziale mogło korzystać ze wszystkich dostępnych
predyktorów.
p.salary <- ncol(salary.train.caret) - 1
p.salary
## [1] 17
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.
# set.seed(123)
# salary.bag <- randomForest(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# mtry = p.salary,
# importance = TRUE,
# ntree = 500
# )
#salary.bag %>% saveRDS("trenowane_modele/salary.bagging.rds")
salary.bag <- readRDS("trenowane_modele/salary.bagging.rds")
Następnie przeanalizowano błąd OOB, czyli Out-of-Bag. Jest to wewnętrzna miara jakości modelu, liczona na obserwacjach, które nie zostały wykorzystane do budowy danego drzewa.
print(salary.bag)
##
## Call:
## randomForest(formula = klasa_dochodu_01 ~ ., data = salary.train.caret, mtry = p.salary, importance = TRUE, ntree = 500)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 17
##
## OOB estimate of error rate: 16.26%
## Confusion matrix:
## nie tak class.error
## nie 14339 1519 0.09578762
## tak 1915 3341 0.36434551
plot(salary.bag)
Wykres przedstawia przebieg błędu OOB w funkcji liczby drzew.
Widoczna jest szybka stabilizacja błędu już po osiągnięciu około 100
drzew, co potwierdza zasadność przyjęcia parametru
ntree = 500. Jednocześnie wyraźna różnica między błędem dla
klas wskazuje, że model nadal ma większą trudność w klasyfikowaniu klasy
tak, czyli osób zarabiających powyżej 50 tys. USD.
W kolejnym kroku wyznaczono przewidywane prawdopodobieństwa dla
zbioru treningowego i testowego. Ponieważ model został zbudowany przy
użyciu funkcji randomForest(), do uzyskania
prawdopodobieństw zastosowano argument type = "prob".
pred.salary.bag.train.prob <- predict(
salary.bag,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.salary.bag.test.prob <- predict(
salary.bag,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
cutoff.bag <- prop.table(table(salary.train.caret$klasa_dochodu_01))["tak"]
cutoff.bag
## tak
## 0.2489344
Jako próg odcięcia przyjęto udział klasy tak w zbiorze
treningowym. Jest to rozwiązanie dopasowane do niezbalansowanych danych,
ponieważ próg ten znajduje się bliżej rzeczywistego udziału klasy
pozytywnej niż standardowe 0.5.
metryki.salary.bag.train <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.bag.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
metryki.salary.bag.test <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.bag.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.salary.bag.train,
test = metryki.salary.bag.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.9517 0.9812 0.9420 0.9616 0.8486 0.9934 0.9101
## test 0.8015 0.8215 0.7949 0.8082 0.5703 0.9307 0.6732
## Kappa
## train 0.8773
## test 0.5373
Wyniki wskazują na widoczne przeuczenie modelu baggingowego. Model bardzo dobrze przewiduje wyniki dla danych treningowych, ale osiąga wyraźnie słabsze rezultaty na zbiorze testowym. Oznacza to, że konieczne jest zastosowanie bardziej kontrolowanego modelu, dlatego w kolejnym kroku przeanalizowano Random Forest.
Random Forest jest rozwinięciem baggingu. W klasyfikacji domyślna
wartość parametru mtry jest zwykle zbliżona do pierwiastka
z liczby zmiennych objaśniających. Dzięki temu każde drzewo korzysta z
innego losowego podzbioru zmiennych, co zwiększa różnorodność drzew i
może ograniczać przeuczenie.
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.
# set.seed(123)
# salary.rf <- randomForest(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# importance = TRUE
# )
# saveRDS(
# salary.rf,
# "trenowane_modele/salary.rf.rds"
# )
salary.rf <- readRDS("trenowane_modele/salary.rf.rds")
salary.rf
##
## Call:
## randomForest(formula = klasa_dochodu_01 ~ ., data = salary.train.caret, importance = TRUE)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 4
##
## OOB estimate of error rate: 14.12%
## Confusion matrix:
## nie tak class.error
## nie 14694 1164 0.07340144
## tak 1817 3439 0.34570015
plot(salary.rf)
salary.rf$err.rate[10, ]
## OOB nie tak
## 0.16334147 0.09610853 0.36634615
salary.rf$err.rate[50, ]
## OOB nie tak
## 0.14630103 0.08084248 0.34379756
salary.rf$err.rate[100, ]
## OOB nie tak
## 0.14298570 0.07554547 0.34646119
plot(salary.rf)
abline(h = salary.rf$err.rate[50, "OOB"], col = "red")
Na wykresie błędu OOB widać, że po przekroczeniu około 50 drzew dalsze zwiększanie liczby drzew daje już niewielką poprawę. Oznacza to, że model dość szybko stabilizuje swoje działanie.
pred.salary.rf.train.prob <- predict(
salary.rf,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.salary.rf.test.prob <- predict(
salary.rf,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
metryki.salary.rf.train <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.rf.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
metryki.salary.rf.test <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.rf.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.salary.rf.train,
test = metryki.salary.rf.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.9299 0.9553 0.9215 0.9384 0.8013 0.9842 0.8716
## test 0.8196 0.8139 0.8215 0.8177 0.6018 0.9302 0.6920
## Kappa
## train 0.8239
## test 0.5685
Wyniki wskazują, że w modelu Random Forest bez strojenia nadal widoczne jest przeuczenie. Accuracy wynosi około 0,93 dla zbioru treningowego oraz około 0,82 dla zbioru testowego. Podobną różnicę widać dla czułości, która wynosi około 0,95 na zbiorze treningowym oraz około 0,81 na zbiorze testowym. Z tego powodu w kolejnym kroku zastosowano strojenie hiperparametrów.
W przypadku klasyfikacji jako główną metrykę strojenia wykorzystano ROC. Jest to szczególnie ważne przy niezbalansowanych danych, ponieważ sama Accuracy może dawać mylący obraz jakości modelu.
W procesie strojenia sprawdzano różne wartości parametrów
mtry oraz min.node.size. Parametr
mtry określa liczbę zmiennych losowanych przy każdym
podziale drzewa, natomiast min.node.size określa minimalną
liczebność węzła końcowego.
ctrl.salary.rf <- trainControl(
method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = "final"
)
grid.salary.rf <- expand.grid(
mtry = seq(from = 2, to = 10, by = 2),
min.node.size = c(5, 10, 25, 50),
splitrule = "gini"
)
grid.salary.rf
## mtry min.node.size splitrule
## 1 2 5 gini
## 2 4 5 gini
## 3 6 5 gini
## 4 8 5 gini
## 5 10 5 gini
## 6 2 10 gini
## 7 4 10 gini
## 8 6 10 gini
## 9 8 10 gini
## 10 10 10 gini
## 11 2 25 gini
## 12 4 25 gini
## 13 6 25 gini
## 14 8 25 gini
## 15 10 25 gini
## 16 2 50 gini
## 17 4 50 gini
## 18 6 50 gini
## 19 8 50 gini
## 20 10 50 gini
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.
# set.seed(123)
# salary.rf.tuned <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "ranger",
# trControl = ctrl.salary.rf,
# tuneGrid = grid.salary.rf,
# metric = "ROC",
# num.trees = 50
# )
# saveRDS(
# salary.rf.tuned,
# "trenowane_modele/salary.rf.tuned.rds"
# )
salary.rf.tuned <- readRDS("trenowane_modele/salary.rf.tuned.rds")
salary.rf.tuned
## Random Forest
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891
## Resampling results across tuning parameters:
##
## mtry min.node.size ROC Sens Spec
## 2 5 0.9079080 0.9872620 0.3544547
## 2 10 0.9074837 0.9883341 0.3464620
## 2 25 0.9076500 0.9871358 0.3529315
## 2 50 0.9074293 0.9877664 0.3485534
## 4 5 0.9140599 0.9610922 0.5340574
## 4 10 0.9140022 0.9610921 0.5340585
## 4 25 0.9136749 0.9609659 0.5336768
## 4 50 0.9136811 0.9610290 0.5321535
## 6 5 0.9157692 0.9510658 0.5810506
## 6 10 0.9160098 0.9511919 0.5774345
## 6 25 0.9155840 0.9511919 0.5766737
## 6 50 0.9151235 0.9516333 0.5745826
## 8 5 0.9160909 0.9463996 0.6031182
## 8 10 0.9161528 0.9464624 0.6027378
## 8 25 0.9158625 0.9474715 0.5983628
## 8 50 0.9156369 0.9480391 0.5939858
## 10 5 0.9157612 0.9429312 0.6139632
## 10 10 0.9154328 0.9432464 0.6114901
## 10 25 0.9159407 0.9453275 0.6073041
## 10 50 0.9155349 0.9458319 0.6025478
##
## Tuning parameter 'splitrule' was held constant at a value of gini
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 8, splitrule = gini
## and min.node.size = 10.
plot(salary.rf.tuned)
salary.rf.tuned$bestTune
## mtry splitrule min.node.size
## 14 8 gini 10
W wyniku strojenia jako najlepszy wybrano model o parametrach:
mtry = 8, splitrule = gini oraz
min.node.size = 10. Do wyboru modelu wykorzystano
największą wartość ROC.
pred.salary.rf.tuned.train.prob <- predict(
salary.rf.tuned,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.salary.rf.tuned.test.prob <- predict(
salary.rf.tuned,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
metryki.salary.rf.tuned.train <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.rf.tuned.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
metryki.salary.rf.tuned.test <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.rf.tuned.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.salary.rf.tuned.train,
test = metryki.salary.rf.tuned.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.8500 0.9488 0.8173 0.8830 0.6325 0.9797 0.7590
## test 0.8112 0.8752 0.7900 0.8326 0.5800 0.9503 0.6977
## Kappa
## train 0.6563
## test 0.5685
Dostrojony Random Forest osiąga większą czułość, wysoką wartość Balanced Accuracy oraz nie wykazuje tak wyraźnego przeuczenia jak wcześniejsze modele zespołowe. Oznacza to, że odpowiednie dobranie hiperparametrów znacząco poprawia stabilność modelu.
ROC.bagging.train <- pROC::roc(
as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
pred.salary.bag.train.prob
)
ROC.bagging.test <- pROC::roc(
as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
pred.salary.bag.test.prob
)
ROC.rf.train <- pROC::roc(
as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
pred.salary.rf.train.prob
)
ROC.rf.test <- pROC::roc(
as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
pred.salary.rf.test.prob
)
ROC.rf.tuned.train <- pROC::roc(
as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
pred.salary.rf.tuned.train.prob
)
ROC.rf.tuned.test <- pROC::roc(
as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
pred.salary.rf.tuned.test.prob
)
list(
bagging.train = ROC.bagging.train,
bagging.test = ROC.bagging.test,
rf.train = ROC.rf.train,
rf.test = ROC.rf.test,
rf.tuned.train = ROC.rf.tuned.train,
rf.tuned.test = ROC.rf.tuned.test
) %>%
ggroc(alpha = 0.5, linewidth = 1) +
geom_segment(
aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed"
) +
labs(
title = paste0(
"Gini TEST: ",
"bagging = ", round(100 * (2 * pROC::auc(ROC.bagging.test) - 1), 1), "%, ",
"RF = ", round(100 * (2 * pROC::auc(ROC.rf.test) - 1), 1), "%, ",
"RF tuned = ", round(100 * (2 * pROC::auc(ROC.rf.tuned.test) - 1), 1), "%"
),
subtitle = paste0(
"Gini TRAIN: ",
"bagging = ", round(100 * (2 * pROC::auc(ROC.bagging.train) - 1), 1), "%, ",
"RF = ", round(100 * (2 * pROC::auc(ROC.rf.train) - 1), 1), "%, ",
"RF tuned = ", round(100 * (2 * pROC::auc(ROC.rf.tuned.train) - 1), 1), "%"
)
) +
theme_bw() +
coord_fixed() +
scale_color_brewer(palette = "Paired")
Na wykresie dobrze widoczne jest przeuczenie modelu baggingowego. Dla danych treningowych bagging osiąga indeks Giniego na poziomie około 99%, natomiast dla danych testowych tylko około 79. Najlepszy wynik dla danych testowych uzyskuje dostrojony Random Forest, dla którego indeks Giniego wynosi około 83%.
Na końcu porównano najważniejsze metryki dla modeli oszacowanych do tej pory. W przypadku danych niezbalansowanych progi odcięcia znajdują się zwykle w okolicach udziału klasy pozytywnej, czyli około 25%. Dla modeli trenowanych na danych zbalansowanych naturalnym progiem jest 0.5.
porownanie.salary.train <- rbind(
logit = metryki.glm.train,
decision.tree = metryki.tree.default.train,
decision.tree.downsample = metryki.tree.down.train,
bagging = metryki.salary.bag.train,
rf = metryki.salary.rf.train,
rf.tuned = metryki.salary.rf.tuned.train
)
porownanie.salary.test <- rbind(
logit = metryki.glm.test,
decision.tree = metryki.tree.default.test,
decision.tree.downsample = metryki.tree.down.test,
bagging = metryki.salary.bag.test,
rf = metryki.salary.rf.test,
rf.tuned = metryki.salary.rf.tuned.test
)
porownanie.salary.train
## Accuracy Sensitivity Specificity BalancedAccuracy
## logit 0.7936 0.8615 0.7711 0.8163
## decision.tree 0.8416 0.5099 0.9516 0.7307
## decision.tree.downsample 0.8394 0.8824 0.7964 0.8394
## bagging 0.9517 0.9812 0.9420 0.9616
## rf 0.9299 0.9553 0.9215 0.9384
## rf.tuned 0.8500 0.9488 0.8173 0.8830
## PPV NPV F1 Kappa
## logit 0.5550 0.9438 0.6751 0.5340
## decision.tree 0.7773 0.8542 0.6158 0.5214
## decision.tree.downsample 0.8125 0.8714 0.8460 0.6788
## bagging 0.8486 0.9934 0.9101 0.8773
## rf 0.8013 0.9842 0.8716 0.8239
## rf.tuned 0.6325 0.9797 0.7590 0.6563
porownanie.salary.test
## Accuracy Sensitivity Specificity BalancedAccuracy
## logit 0.7940 0.8686 0.7692 0.8189
## decision.tree 0.8397 0.5115 0.9485 0.7300
## decision.tree.downsample 0.7928 0.8535 0.7727 0.8131
## bagging 0.8015 0.8215 0.7949 0.8082
## rf 0.8196 0.8139 0.8215 0.8177
## rf.tuned 0.8112 0.8752 0.7900 0.8326
## PPV NPV F1 Kappa
## logit 0.5552 0.9464 0.6774 0.5366
## decision.tree 0.7670 0.8542 0.6137 0.5177
## decision.tree.downsample 0.5544 0.9409 0.6721 0.5304
## bagging 0.5703 0.9307 0.6732 0.5373
## rf 0.6018 0.9302 0.6920 0.5685
## rf.tuned 0.5800 0.9503 0.6977 0.5685
Przeprowadzona analiza porównawcza wskazuje, że wybór optymalnego modelu predykcyjnego dla dochodów w zbiorze Adult zależy od priorytetów analizy. Model logitowy oferuje wysoką stabilność i dobrą zdolność generalizacji, podczas gdy dostrojony Random Forest wykazuje szczególnie dobrą czułość w identyfikacji klasy mniejszościowej. Mimo zaobserwowanego przeuczenia w bazowych wariantach modeli zespołowych, takich jak bagging, proces strojenia hiperparametrów pozwolił uzyskać wyniki o wysokiej wartości predykcyjnej przy zachowaniu akceptowalnego poziomu dopasowania.
Warto podkreślić, iż dla baggingu metryki na zbiorze treningowym są bardzo wysokie: Accuracy wynosi 0.9517, a Balanced Accuracy 0.9616. Na zbiorze testowym wartości te spadają odpowiednio do 0.8015 oraz 0.8082. Oznacza to, że model bardzo dobrze dopasował się do danych treningowych, ale jego zdolność generalizacji na nowych danych jest wyraźnie słabsza.
Podobne, choć mniej silne zjawisko widać dla domyślnego Random Forest. Model osiąga Accuracy 0.9299 i Balanced Accuracy 0.9384 na zbiorze treningowym, natomiast na zbiorze testowym odpowiednio 0.8196 oraz 0.8177. Różnica ta sugeruje, że model również częściowo przeuczył się na danych treningowych.
Najlepszy kompromis między skutecznością a stabilnością daje dostrojony Random Forest. Dla tego modelu Balanced Accuracy wynosi 0.8830 na zbiorze treningowym oraz 0.8326 na zbiorze testowym. Różnica między wynikami jest zdecydowanie mniejsza niż w przypadku baggingu i domyślnego Random Forest, co oznacza lepszą zdolność generalizacji. Jednocześnie model ten osiąga najwyższą czułość na zbiorze testowym, równą 0.8770, co oznacza bardzo dobrą zdolność wykrywania osób zarabiających powyżej 50 tys. USD.
Na końcu przeanalizowano ważność zmiennych w modelu Random Forest. W tym celu wykorzystano miarę Mean Decrease Gini, która pokazuje, jak bardzo dana zmienna przyczynia się do poprawy jakości podziałów w drzewach.
importance.salary.rf <- salary.rf$importance[, "MeanDecreaseGini"]
importance.salary.rf <- sort(importance.salary.rf, decreasing = TRUE)
importance.top <- head(importance.salary.rf, 20)
importance.df <- data.frame(
Zmienna = names(importance.top),
Waznosc = as.numeric(importance.top)
)
importance.df$Zmienna <- gsub("_", " ", importance.df$Zmienna)
ggplot(
importance.df,
aes(
x = reorder(Zmienna, Waznosc),
y = Waznosc,
fill = Waznosc
)
) +
geom_col(width = 0.68, show.legend = FALSE) +
coord_flip() +
scale_fill_gradient(
low = "#B7DDE8",
high = "#2F6690"
) +
labs(
title = "Ważność zmiennych w modelu Random Forest",
subtitle = "Najważniejsze predyktory dochodu",
x = NULL,
y = "Mean Decrease Gini"
) +
theme_minimal(base_size = 9) +
theme(
plot.title = element_text(
face = "bold",
size = 12,
hjust = 0
),
plot.subtitle = element_text(
size = 8.5,
color = "grey35",
margin = ggplot2::margin(b = 8)
),
axis.text.y = element_text(
size = 7.5,
color = "grey20"
),
axis.text.x = element_text(
size = 7.5,
color = "grey30"
),
axis.title.x = element_text(
size = 8.5,
margin = ggplot2::margin(t = 6)
),
panel.grid.major.y = element_blank(),
panel.grid.minor = element_blank(),
plot.margin = ggplot2::margin(6, 10, 6, 6)
)
Wykres ważności zmiennych pokazuje, które predyktory miały największe znaczenie w modelu Random Forest. Najwyższe wartości Mean Decrease Gini oznaczają, że dana zmienna była często wykorzystywana do tworzenia skutecznych podziałów w drzewach i silnie wspierała proces klasyfikacji. W tym przyapdku kluczowymi zmiennymi była relacja w rodzinie oraz wiek.
W kolejnym etapie analizy zastosowano metodę XGBoost, czyli jedną z najczęściej wykorzystywanych metod boostingu drzew decyzyjnych. Nazwa XGBoost oznacza Extreme Gradient Boosting.
W przeciwieństwie do Random Forest, który buduje wiele drzew równolegle i następnie uśrednia ich predykcje, XGBoost buduje drzewa sekwencyjnie. Oznacza to, że każde kolejne drzewo próbuje poprawić błędy popełnione przez wcześniejsze drzewa. Model stopniowo uczy się więc na swoich pomyłkach.
Intuicyjnie można to wyjaśnić następująco:
W przypadku naszego problemu klasyfikacyjnego XGBoost będzie
przewidywał prawdopodobieństwo przynależności obserwacji do klasy
tak, czyli do grupy osób zarabiających powyżej 50 tys. USD
rocznie.
W analizie zastosowano podejście podobne jak przy wcześniejszych modelach: najpierw zbudowano model startowy, a następnie przeprowadzono etapowe strojenie hiperparametrów.
Przed budową modelu warto krótko wyjaśnić najważniejsze parametry:
nrounds — liczba kolejnych drzew budowanych przez
model,max_depth — maksymalna głębokość pojedynczego
drzewa,eta — tempo uczenia; im mniejsza wartość, tym wolniej,
ale często stabilniej uczy się model,gamma — minimalna poprawa jakości wymagana do wykonania
kolejnego podziału w drzewie,subsample — część obserwacji losowana do budowy każdego
drzewa,colsample_bytree — część zmiennych losowana do budowy
każdego drzewa,min_child_weight — minimalna liczba obserwacji / waga
wymagana w liściu drzewa.Niższe wartości eta, ograniczenia w
subsample oraz colsample_bytree mogą
zmniejszać ryzyko przeuczenia modelu. Z kolei zbyt duża liczba drzew lub
zbyt głębokie drzewa mogą prowadzić do bardzo dobrego dopasowania na
zbiorze treningowym, ale słabszej generalizacji na zbiorze testowym.
Na początku zbudowano model startowy z jedną, bazową kombinacją parametrów. Model ten stanowi punkt odniesienia dla kolejnych etapów strojenia.
# Siatka parametrów dla modelu startowego XGBoost.
# Na tym etapie używamy jednej kombinacji parametrów,
# aby sprawdzić podstawową jakość modelu.
grid.salary.xgb.start <- expand.grid(
nrounds = 100, # liczba drzew
max_depth = 3, # maksymalna głębokość drzewa
eta = 0.05, # tempo uczenia
gamma = 0, # brak dodatkowej kary za podział
colsample_bytree = 0.8, # 80% zmiennych dostępnych dla drzewa
min_child_weight = 100, # minimalna waga w liściu
subsample = 0.8 # 80% obserwacji dla każdego drzewa
)
grid.salary.xgb.start
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 1 100 3 0.05 0 0.8 100 0.8
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował modelu od nowa.
# Model został wcześniej zapisany do pliku RDS.
# Jeśli uruchamiasz projekt pierwszy raz i plik RDS jeszcze nie istnieje,
# trzeba tymczasowo odkomentować trening i saveRDS().
# set.seed(123)
# salary.xgb.start <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "xgbTree",
# trControl = ctrl,
# tuneGrid = grid.salary.xgb.start,
# metric = "ROC"
# )
#
# saveRDS(salary.xgb.start, "trenowane_modele/salary.xgb.start.rds")
# Wczytanie wcześniej zapisanego modelu:
salary.xgb.start <- readRDS("trenowane_modele/salary.xgb.start.rds")
salary.xgb.start
## eXtreme Gradient Boosting
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891
## Resampling results:
##
## ROC Sens Spec
## 0.9057785 0.936247 0.5780075
##
## Tuning parameter 'nrounds' was held constant at a value of 100
## Tuning
## held constant at a value of 100
## Tuning parameter 'subsample' was held
## constant at a value of 0.8
Po wczytaniu modelu wyznaczono przewidywane prawdopodobieństwa dla
zbioru treningowego i testowego. Interesuje nas prawdopodobieństwo klasy
tak.
pred.salary.xgb.start.train.prob <- predict(
salary.xgb.start,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.salary.xgb.start.test.prob <- predict(
salary.xgb.start,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
summary(pred.salary.xgb.start.train.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.009907 0.044268 0.144113 0.250579 0.386196 0.976476
summary(pred.salary.xgb.start.test.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.01003 0.04343 0.14305 0.24841 0.38494 0.97648
Do oceny modelu wykorzystano tę samą funkcję
klasyfikacja.metryki, która była stosowana wcześniej.
Dzięki temu wyniki XGBoost można bezpośrednio porównać z wynikami
logitu, drzew, baggingu oraz Random Forest.
Jako próg odcięcia zastosowano cutoff.bag, czyli próg
wyznaczony wcześniej dla modeli zespołowych. Pozwala to zachować
spójność porównania.
metryki.salary.xgb.start.train <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.xgb.start.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
metryki.salary.xgb.start.test <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.xgb.start.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
rbind(
train = metryki.salary.xgb.start.train,
test = metryki.salary.xgb.start.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## train 0.7986 0.8727 0.7740 0.8234 0.5614 0.9483 0.6833
## test 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438 0.6796
## Kappa
## train 0.5456
## test 0.5416
Na podstawie wyników modelu startowego można ocenić, jak XGBoost radzi sobie bez dokładniejszego strojenia hiperparametrów. Kolejnym krokiem będzie poprawa modelu przez stopniowe dobieranie parametrów.
max_depth i
min_child_weightW pierwszym etapie strojenia sprawdzono różne wartości
max_depth oraz min_child_weight.
Parametr max_depth kontroluje złożoność pojedynczych
drzew. Głębsze drzewa mogą uchwycić bardziej skomplikowane zależności,
ale mogą też zwiększać ryzyko przeuczenia.
Parametr min_child_weight ogranicza tworzenie zbyt
małych liści. Większa wartość tego parametru sprawia, że model staje się
bardziej ostrożny.
grid.salary.xgb.depth_child <- expand.grid(
nrounds = 100,
max_depth = c(2, 3, 4, 5),
eta = 0.05,
gamma = 0,
colsample_bytree = 0.8,
min_child_weight = c(50, 100, 200, 400),
subsample = 0.8
)
grid.salary.xgb.depth_child
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 1 100 2 0.05 0 0.8 50 0.8
## 2 100 3 0.05 0 0.8 50 0.8
## 3 100 4 0.05 0 0.8 50 0.8
## 4 100 5 0.05 0 0.8 50 0.8
## 5 100 2 0.05 0 0.8 100 0.8
## 6 100 3 0.05 0 0.8 100 0.8
## 7 100 4 0.05 0 0.8 100 0.8
## 8 100 5 0.05 0 0.8 100 0.8
## 9 100 2 0.05 0 0.8 200 0.8
## 10 100 3 0.05 0 0.8 200 0.8
## 11 100 4 0.05 0 0.8 200 0.8
## 12 100 5 0.05 0 0.8 200 0.8
## 13 100 2 0.05 0 0.8 400 0.8
## 14 100 3 0.05 0 0.8 400 0.8
## 15 100 4 0.05 0 0.8 400 0.8
## 16 100 5 0.05 0 0.8 400 0.8
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.1.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.
# set.seed(123)
# salary.xgb.1 <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "xgbTree",
# trControl = ctrl,
# tuneGrid = grid.salary.xgb.depth_child,
# metric = "ROC"
# )
#
# saveRDS(salary.xgb.1, "trenowane_modele/salary.xgb.1.rds")
salary.xgb.1 <- readRDS("trenowane_modele/salary.xgb.1.rds")
salary.xgb.1
## eXtreme Gradient Boosting
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16892, 16891, 16891, 16891, 16891
## Resampling results across tuning parameters:
##
## max_depth min_child_weight ROC Sens Spec
## 2 50 0.9020865 0.9387101 0.5543292
## 2 100 0.8996022 0.9384578 0.5436727
## 2 200 0.8799907 0.9274231 0.5259753
## 2 400 0.8476504 0.9345483 0.4289248
## 3 50 0.9066564 0.9404756 0.5691722
## 3 100 0.9035268 0.9351160 0.5646051
## 3 200 0.8840750 0.9260990 0.5495718
## 3 400 0.8479697 0.9317737 0.4441484
## 4 50 0.9088985 0.9382057 0.5828735
## 4 100 0.9050568 0.9341701 0.5824929
## 4 200 0.8856024 0.9278015 0.5507136
## 4 400 0.8474548 0.9295664 0.4513796
## 5 50 0.9098359 0.9368815 0.5923882
## 5 100 0.9059827 0.9344852 0.5794481
## 5 200 0.8860554 0.9274860 0.5512845
## 5 400 0.8482110 0.9312064 0.4496670
##
## Tuning parameter 'nrounds' was held constant at a value of 100
## Tuning
## 'colsample_bytree' was held constant at a value of 0.8
## Tuning
## parameter 'subsample' was held constant at a value of 0.8
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 100, max_depth = 5, eta
## = 0.05, gamma = 0, colsample_bytree = 0.8, min_child_weight = 50 and
## subsample = 0.8.
salary.xgb.1$bestTune
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 13 100 5 0.05 0 0.8 50 0.8
Na podstawie wyników pierwszego etapu strojenia przyjęto wartości
max_depth = 5 oraz min_child_weight = 50.
nrounds i
etaW drugim etapie strojenia sprawdzono liczbę drzew oraz tempo uczenia.
Parametr nrounds określa liczbę kolejnych drzew
budowanych przez model. Parametr eta odpowiada za tempo
uczenia. Niższe wartości eta powodują wolniejsze, ale
często stabilniejsze uczenie modelu.
grid.salary.xgb.nrounds_eta <- expand.grid(
nrounds = c(200, 300, 500, 800),
max_depth = 5,
eta = c(0.01, 0.025, 0.05, 0.1),
gamma = 0,
colsample_bytree = 0.8,
min_child_weight = 50,
subsample = 0.8
)
grid.salary.xgb.nrounds_eta
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 1 200 5 0.010 0 0.8 50 0.8
## 2 300 5 0.010 0 0.8 50 0.8
## 3 500 5 0.010 0 0.8 50 0.8
## 4 800 5 0.010 0 0.8 50 0.8
## 5 200 5 0.025 0 0.8 50 0.8
## 6 300 5 0.025 0 0.8 50 0.8
## 7 500 5 0.025 0 0.8 50 0.8
## 8 800 5 0.025 0 0.8 50 0.8
## 9 200 5 0.050 0 0.8 50 0.8
## 10 300 5 0.050 0 0.8 50 0.8
## 11 500 5 0.050 0 0.8 50 0.8
## 12 800 5 0.050 0 0.8 50 0.8
## 13 200 5 0.100 0 0.8 50 0.8
## 14 300 5 0.100 0 0.8 50 0.8
## 15 500 5 0.100 0 0.8 50 0.8
## 16 800 5 0.100 0 0.8 50 0.8
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.2.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.
# set.seed(123)
# salary.xgb.2 <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "xgbTree",
# trControl = ctrl,
# tuneGrid = grid.salary.xgb.nrounds_eta,
# metric = "ROC"
# )
#
# saveRDS(salary.xgb.2, "trenowane_modele/salary.xgb.2.rds")
salary.xgb.2 <- readRDS("trenowane_modele/salary.xgb.2.rds")
salary.xgb.2
## eXtreme Gradient Boosting
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891
## Resampling results across tuning parameters:
##
## eta nrounds ROC Sens Spec
## 0.010 200 0.9051048 0.9444446 0.5643112
## 0.010 300 0.9082290 0.9444445 0.5705879
## 0.010 500 0.9113829 0.9398414 0.5920890
## 0.010 800 0.9130497 0.9361209 0.6092110
## 0.025 200 0.9114442 0.9395892 0.5938009
## 0.025 300 0.9127783 0.9364362 0.6052154
## 0.025 500 0.9142923 0.9342292 0.6213877
## 0.025 800 0.9151878 0.9318960 0.6309003
## 0.050 200 0.9135127 0.9345446 0.6183439
## 0.050 300 0.9145132 0.9322114 0.6259532
## 0.050 500 0.9153463 0.9319591 0.6326113
## 0.050 800 0.9159019 0.9317698 0.6360356
## 0.100 200 0.9146305 0.9329681 0.6299476
## 0.100 300 0.9153241 0.9325897 0.6339440
## 0.100 500 0.9156742 0.9312022 0.6364171
## 0.100 800 0.9153414 0.9291844 0.6407920
##
## Tuning parameter 'max_depth' was held constant at a value of 5
## Tuning
##
## Tuning parameter 'min_child_weight' was held constant at a value of 50
##
## Tuning parameter 'subsample' was held constant at a value of 0.8
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
## = 0.05, gamma = 0, colsample_bytree = 0.8, min_child_weight = 50 and
## subsample = 0.8.
salary.xgb.2$bestTune
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 12 800 5 0.05 0 0.8 50 0.8
Na podstawie wyników drugiego etapu strojenia przyjęto wartości
nrounds = 800 oraz eta = 0.05.
subsample i
colsample_bytreeW trzecim etapie sprawdzono parametry odpowiedzialne za losowanie obserwacji i zmiennych.
Parametr subsample określa, jaka część obserwacji jest
wykorzystywana do budowy każdego drzewa. Parametr
colsample_bytree określa, jaka część zmiennych jest
wykorzystywana w każdym drzewie.
Takie losowanie może działać regularyzująco, czyli ograniczać przeuczenie modelu.
grid.salary.xgb.sample <- expand.grid(
nrounds = 800,
max_depth = 5,
eta = 0.05,
gamma = 0,
colsample_bytree = c(0.2, 0.4, 0.6, 0.8, 0.9),
min_child_weight = 50,
subsample = c(0.2, 0.4, 0.6, 0.8, 0.9)
)
grid.salary.xgb.sample
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 1 800 5 0.05 0 0.2 50 0.2
## 2 800 5 0.05 0 0.4 50 0.2
## 3 800 5 0.05 0 0.6 50 0.2
## 4 800 5 0.05 0 0.8 50 0.2
## 5 800 5 0.05 0 0.9 50 0.2
## 6 800 5 0.05 0 0.2 50 0.4
## 7 800 5 0.05 0 0.4 50 0.4
## 8 800 5 0.05 0 0.6 50 0.4
## 9 800 5 0.05 0 0.8 50 0.4
## 10 800 5 0.05 0 0.9 50 0.4
## 11 800 5 0.05 0 0.2 50 0.6
## 12 800 5 0.05 0 0.4 50 0.6
## 13 800 5 0.05 0 0.6 50 0.6
## 14 800 5 0.05 0 0.8 50 0.6
## 15 800 5 0.05 0 0.9 50 0.6
## 16 800 5 0.05 0 0.2 50 0.8
## 17 800 5 0.05 0 0.4 50 0.8
## 18 800 5 0.05 0 0.6 50 0.8
## 19 800 5 0.05 0 0.8 50 0.8
## 20 800 5 0.05 0 0.9 50 0.8
## 21 800 5 0.05 0 0.2 50 0.9
## 22 800 5 0.05 0 0.4 50 0.9
## 23 800 5 0.05 0 0.6 50 0.9
## 24 800 5 0.05 0 0.8 50 0.9
## 25 800 5 0.05 0 0.9 50 0.9
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.3.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.
# set.seed(123)
# salary.xgb.3 <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "xgbTree",
# trControl = ctrl,
# tuneGrid = grid.salary.xgb.sample,
# metric = "ROC"
# )
#
# saveRDS(salary.xgb.3, "trenowane_modele/salary.xgb.3.rds")
salary.xgb.3 <- readRDS("trenowane_modele/salary.xgb.3.rds")
salary.xgb.3
## eXtreme Gradient Boosting
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891
## Resampling results across tuning parameters:
##
## colsample_bytree subsample ROC Sens Spec
## 0.2 0.2 0.8870639 0.9310128 0.5487088
## 0.2 0.4 0.9091399 0.9320219 0.6059771
## 0.2 0.6 0.9126803 0.9335987 0.6200544
## 0.2 0.8 0.9155636 0.9342292 0.6299479
## 0.2 0.9 0.9162185 0.9338508 0.6331821
## 0.4 0.2 0.8878455 0.9300670 0.5570803
## 0.4 0.4 0.9096997 0.9299410 0.6141556
## 0.4 0.6 0.9129322 0.9301936 0.6270941
## 0.4 0.8 0.9159209 0.9334094 0.6316602
## 0.4 0.9 0.9169222 0.9326526 0.6358469
## 0.6 0.2 0.8878807 0.9274186 0.5631678
## 0.6 0.4 0.9092376 0.9300041 0.6185342
## 0.6 0.6 0.9128037 0.9303827 0.6303276
## 0.6 0.8 0.9161457 0.9326526 0.6388891
## 0.6 0.9 0.9171412 0.9325896 0.6375579
## 0.8 0.2 0.8882967 0.9279859 0.5610746
## 0.8 0.4 0.9089835 0.9302562 0.6139668
## 0.8 0.6 0.9128501 0.9309502 0.6307082
## 0.8 0.8 0.9157144 0.9328420 0.6366061
## 0.8 0.9 0.9169748 0.9317068 0.6428846
## 0.9 0.2 0.8881539 0.9287427 0.5637385
## 0.9 0.4 0.9087760 0.9295627 0.6166305
## 0.9 0.6 0.9123321 0.9309502 0.6289963
## 0.9 0.8 0.9157602 0.9313915 0.6388895
## 0.9 0.9 0.9169218 0.9315177 0.6436463
##
## Tuning parameter 'nrounds' was held constant at a value of 800
## Tuning
## held constant at a value of 0
## Tuning parameter 'min_child_weight' was
## held constant at a value of 50
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
## = 0.05, gamma = 0, colsample_bytree = 0.6, min_child_weight = 50 and
## subsample = 0.9.
salary.xgb.3$bestTune
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 15 800 5 0.05 0 0.6 50 0.9
Na podstawie wyników trzeciego etapu strojenia przyjęto wartości
subsample = 0.9 oraz
colsample_bytree = 0.6.
gammaW ostatnim etapie strojenia sprawdzono parametr gamma.
Parametr ten określa minimalną poprawę jakości modelu wymaganą do
wykonania kolejnego podziału w drzewie.
Im większa wartość gamma, tym bardziej konserwatywny
jest model, ponieważ trudniej mu tworzyć kolejne podziały.
grid.salary.xgb.gamma <- expand.grid(
nrounds = 800,
max_depth = 5,
eta = 0.05,
gamma = c(0, 0.5, 1, 2),
colsample_bytree = 0.6,
min_child_weight = 50,
subsample = 0.9
)
grid.salary.xgb.gamma
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 1 800 5 0.05 0.0 0.6 50 0.9
## 2 800 5 0.05 0.5 0.6 50 0.9
## 3 800 5 0.05 1.0 0.6 50 0.9
## 4 800 5 0.05 2.0 0.6 50 0.9
# Trenowanie modelu zostało wyhasztagowane.
# Jeśli plik "salary.xgb.4.rds" jeszcze nie istnieje,
# należy tymczasowo odkomentować poniższy fragment.
# set.seed(123)
# salary.xgb.4 <- train(
# klasa_dochodu_01 ~ .,
# data = salary.train.caret,
# method = "xgbTree",
# trControl = ctrl,
# tuneGrid = grid.salary.xgb.gamma,
# metric = "ROC"
# )
#
# saveRDS(salary.xgb.4, "trenowane_modele/salary.xgb.4.rds")
salary.xgb.4 <- readRDS("trenowane_modele/salary.xgb.4.rds")
salary.xgb.4
## eXtreme Gradient Boosting
##
## 21114 samples
## 17 predictor
## 2 classes: 'nie', 'tak'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 16891, 16892, 16891, 16891, 16891
## Resampling results across tuning parameters:
##
## gamma ROC Sens Spec
## 0.0 0.9171491 0.9320852 0.6385101
## 0.5 0.9172577 0.9324004 0.6413624
## 1.0 0.9172496 0.9324004 0.6396503
## 2.0 0.9168120 0.9333463 0.6331822
##
## Tuning parameter 'nrounds' was held constant at a value of 800
## Tuning
## parameter 'min_child_weight' was held constant at a value of 50
##
## Tuning parameter 'subsample' was held constant at a value of 0.9
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 800, max_depth = 5, eta
## = 0.05, gamma = 0.5, colsample_bytree = 0.6, min_child_weight = 50
## and subsample = 0.9.
salary.xgb.4$bestTune
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 2 800 5 0.05 0.5 0.6 50 0.9
Na podstawie wyników czwartego etapu strojenia jako model finalny
przyjęto model salary.xgb.4.
Po przeprowadzeniu strojenia wybrano finalny model XGBoost. Następnie wyznaczono prawdopodobieństwa dla zbioru treningowego i testowego.
salary.xgb.tuned <- salary.xgb.4
pred.salary.xgb.tuned.train.prob <- predict(
salary.xgb.tuned,
newdata = salary.train.caret,
type = "prob"
)[, "tak"]
pred.salary.xgb.tuned.test.prob <- predict(
salary.xgb.tuned,
newdata = salary.test.caret,
type = "prob"
)[, "tak"]
summary(pred.salary.xgb.tuned.train.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0002398 0.0119718 0.0927959 0.2490277 0.4206244 0.9948304
summary(pred.salary.xgb.tuned.test.prob)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0003001 0.0114394 0.0919788 0.2452645 0.4131365 0.9946068
metryki.salary.xgb.tuned.train <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.xgb.tuned.train.prob,
real = salary.train.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
metryki.salary.xgb.tuned.test <- klasyfikacja.metryki(
predicted_probabilities = pred.salary.xgb.tuned.test.prob,
real = salary.test.caret$klasa_dochodu_01,
cutoff = cutoff.bag,
level_positive = "tak",
level_negative = "nie"
)
rbind(
xgb.start.train = metryki.salary.xgb.start.train,
xgb.tuned.train = metryki.salary.xgb.tuned.train,
xgb.start.test = metryki.salary.xgb.start.test,
xgb.tuned.test = metryki.salary.xgb.tuned.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV
## xgb.start.train 0.7986 0.8727 0.7740 0.8234 0.5614 0.9483
## xgb.tuned.train 0.8286 0.8786 0.8120 0.8453 0.6077 0.9528
## xgb.start.test 0.7982 0.8601 0.7777 0.8189 0.5618 0.9438
## xgb.tuned.test 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454
## F1 Kappa
## xgb.start.train 0.6833 0.5456
## xgb.tuned.train 0.7185 0.6011
## xgb.start.test 0.6796 0.5416
## xgb.tuned.test 0.7082 0.5876
Porównując model startowy oraz model dostrojony, można ocenić, czy strojenie hiperparametrów poprawiło jakość klasyfikacji. Szczególnie ważna jest wartość Balanced Accuracy, ponieważ w analizowanym zbiorze klasy są niezbalansowane. Balanced Accuracy dla danych testwoych w finalnym modelu wyniosła niemalże 84%, a dodatkowo co niezwykle ważne nie widać uznak jakiegolwiek przeuczenia modelu.
Na końcu porównano model startowy i model dostrojony za pomocą krzywych ROC oraz indeksu Giniego.
ROC.xgb.start.train <- pROC::roc(
as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
pred.salary.xgb.start.train.prob
)
ROC.xgb.start.test <- pROC::roc(
as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
pred.salary.xgb.start.test.prob
)
ROC.xgb.tuned.train <- pROC::roc(
as.numeric(salary.train.caret$klasa_dochodu_01 == "tak"),
pred.salary.xgb.tuned.train.prob
)
ROC.xgb.tuned.test <- pROC::roc(
as.numeric(salary.test.caret$klasa_dochodu_01 == "tak"),
pred.salary.xgb.tuned.test.prob
)
list(
xgb.start.train = ROC.xgb.start.train,
xgb.start.test = ROC.xgb.start.test,
xgb.tuned.train = ROC.xgb.tuned.train,
xgb.tuned.test = ROC.xgb.tuned.test
) %>%
ggroc(alpha = 0.6, linewidth = 1) +
geom_segment(
aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed"
) +
labs(
title = paste0(
"Gini TEST: ",
"XGB start = ", round(100 * (2 * pROC::auc(ROC.xgb.start.test) - 1), 1), "%, ",
"XGB tuned = ", round(100 * (2 * pROC::auc(ROC.xgb.tuned.test) - 1), 1), "%"
),
subtitle = paste0(
"Gini TRAIN: ",
"XGB start = ", round(100 * (2 * pROC::auc(ROC.xgb.start.train) - 1), 1), "%, ",
"XGB tuned = ", round(100 * (2 * pROC::auc(ROC.xgb.tuned.train) - 1), 1), "%"
),
color = "Model",
x = "Specyficzność",
y = "Wrażliwość"
) +
theme_bw() +
coord_fixed() +
scale_color_brewer(palette = "Paired")
Na podstawie krzywych ROC oraz indeksu Giniego można ocenić, czy model XGBoost poprawił zdolność rozróżniania klas względem wcześniejszych modeli. Jeżeli indeks Giniego dla modelu dostrojonego jest wyższy niż dla modelu startowego, oznacza to, że strojenie hiperparametrów poprawiło zdolność separacji klas.
####Końcowe porównanie dla drzew i logitu
# Krzywe ROC - modele bazowe
ROC.train.logit <- pROC::roc(as.numeric(salary.train$klasa_dochodu_01) - 1,
pred.glm.train.prob)
ROC.train.tree.tuned <- pROC::roc(as.numeric(salary.train.ML$klasa_dochodu_01) -1,
pred.tree.tuned.train.prob)
gini <- function(roc) {
paste0(round(100 * (2 * as.numeric(pROC::auc(roc)) - 1), 1), "%")
}
pROC::ggroc(
list(
logit.train = ROC.train.logit,
logit.test = ROC.test.logit,
tree.tuned.train = ROC.train.tree.tuned,
tree.tuned.test = ROC.test.tree.tuned,
bagging.train = ROC.bagging.train,
bagging.test = ROC.bagging.test,
rf.train = ROC.rf.train,
rf.test = ROC.rf.test,
rf.tuned.train = ROC.rf.tuned.train,
rf.tuned.test = ROC.rf.tuned.test,
xgb.train = ROC.xgb.tuned.train,
xgb.test = ROC.xgb.tuned.test
),
alpha = 0.65,
linewidth = 1
) +
geom_segment(
aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey",
linetype = "dashed"
) +
labs(
title = paste0(
"Gini TEST: ",
"logit = ", gini(ROC.test.logit), " | ",
"decision tree = ", gini(ROC.test.tree.tuned), " | ",
"rf = ", gini(ROC.rf.test), "\n",
"rf_tuned = ", gini(ROC.rf.tuned.test), " | ",
"bagging = ", gini(ROC.bagging.test), " | ",
"xgb = ", gini(ROC.xgb.tuned.test)
),
subtitle = paste0(
"Gini TRAIN: ",
"logit = ", gini(ROC.train.logit), " | ",
"decision tree = ", gini(ROC.train.tree.tuned), " | ",
"rf = ", gini(ROC.rf.train), "\n",
"rf_tuned = ", gini(ROC.rf.tuned.train), " | ",
"bagging = ", gini(ROC.bagging.train), " | ",
"xgb = ", gini(ROC.xgb.tuned.train)
),
color = "Model",
x = "Specyficzność",
y = "Wrażliwość"
) +
theme_bw() +
coord_fixed() +
scale_color_brewer(palette = "Paired") +
theme(
plot.title = element_text(
size = 10,
face = "bold",
lineheight = 0.95,
margin = ggplot2::margin(b = 5)
),
plot.subtitle = element_text(
size = 9,
lineheight = 0.95,
margin = ggplot2::margin(b = 10)
),
legend.position = "bottom",
legend.title = element_blank(),
legend.text = element_text(size = 8),
axis.title = element_text(size = 11),
axis.text = element_text(size = 9)
) +
guides(
color = guide_legend(nrow = 2, byrow = TRUE)
)
Można zauważyć, iż dla danych testowych największe wartości Gini przyjęły kolejno modele: XGB, rf_tuned oraz rf.
Ostatnią grupą modeli wykorzystaną w analizie są sieci neuronowe. W przeciwieństwie do modeli liniowych, takich jak regresja logistyczna, sieci neuronowe mogą uchwycić bardziej złożone, nieliniowe zależności pomiędzy zmiennymi objaśniającymi a zmienną celu.
W naszym przypadku sieć neuronowa będzie wykorzystywana do
klasyfikacji binarnej. Celem modelu jest przewidywanie, czy dana osoba
należy do klasy tak, czyli czy osiąga dochód powyżej 50
tys. USD rocznie.
Sieci neuronowe wymagają jednak odpowiedniego przygotowania danych. Wszystkie zmienne wejściowe muszą być numeryczne, dlatego zmienne kategoryczne należy przekształcić za pomocą kodowania zero-jedynkowego, czyli one-hot encoding. Dodatkowo dane wejściowe zostały przeskalowane do zakresu 0–1, ponieważ sieci neuronowe są wrażliwe na skalę zmiennych.
W analizie wykorzystano pakiet neuralnet. Jest to pakiet
dobry do celów dydaktycznych i pokazania podstawowej logiki działania
sieci neuronowych, ale przy większych zbiorach danych może trenować się
stosunkowo długo. Z tego powodu modele sieci neuronowych zostały
wytrenowane na zmniejszonym podzbiorze danych treningowych.
library(neuralnet) # proste sieci neuronowe
library(NeuralNetTools) # wizualizacja sieci neuronowych
Przed rozpoczęciem modelowania sprawdzono typy zmiennych w zbiorze
treningowym. Jest to ważne, ponieważ sieci neuronowe wymagają danych
numerycznych. Zmienne typu factor muszą więc zostać
przekształcone do postaci zmiennych zero-jedynkowych.
# Liczba zmiennych typu factor
sum(sapply(salary.train.ML, is.factor))
## [1] 12
# Liczba zmiennych typu ordered factor
sum(sapply(salary.train.ML, is.ordered))
## [1] 0
# Nazwy zmiennych typu factor
names(salary.train.ML)[sapply(salary.train.ML, is.factor)]
## [1] "stan_cywilny" "zawod"
## [3] "relacja_w_rodzinie" "rasa"
## [5] "plec" "typ_zatrudnienia"
## [7] "kraj_pochodzenia" "czy_kapital_netto_dodatni"
## [9] "wiek_sredni" "dlugie_godziny_pracy"
## [11] "klasa_dochodu_01" "poziom_edukacji_factor"
# Nazwy zmiennych typu ordered factor
names(salary.train.ML)[sapply(salary.train.ML, is.ordered)]
## character(0)
Następnie dokonano podziału danych na zmienne objaśniające
X oraz zmienną celu y. Zmienną objaśnianą jest
klasa_dochodu_01, która informuje, czy dana osoba zarabia
powyżej 50 tys. USD.
salary.x.train <- salary.train.ML %>%
dplyr::select(-klasa_dochodu_01)
salary.y.train <- salary.train.ML$klasa_dochodu_01
salary.x.test <- salary.test.ML %>%
dplyr::select(-klasa_dochodu_01)
salary.y.test <- salary.test.ML$klasa_dochodu_01
# Kontrola rozkładu zmiennej celu
table(salary.y.train)
## salary.y.train
## nie tak
## 15858 5256
table(salary.y.test)
## salary.y.test
## nie tak
## 6796 2252
prop.table(table(salary.y.train))
## salary.y.train
## nie tak
## 0.7510656 0.2489344
prop.table(table(salary.y.test))
## salary.y.test
## nie tak
## 0.7511052 0.2488948
Ponieważ sieci neuronowe nie przyjmują bezpośrednio zmiennych typu
factor, zmienne kategoryczne zostały przekształcone za
pomocą funkcji model.matrix(). W efekcie każda kategoria
zostaje zapisana jako osobna zmienna zero-jedynkowa.
Ważne jest również wyrównanie kolumn w zbiorze testowym do kolumn ze zbioru treningowego. Jeżeli jakaś kategoria występuje w treningu, ale nie pojawia się w teście, dodajemy odpowiadającą jej kolumnę z zerami. Jeżeli natomiast w teście pojawią się dodatkowe kolumny, których nie było w treningu, zostają one usunięte.
salary.x.train.mm <- model.matrix(~ ., data = salary.x.train)[, -1]
salary.x.test.mm <- model.matrix(~ ., data = salary.x.test)[, -1]
# Zapamiętujemy kolumny ze zbioru treningowego
train_cols <- colnames(salary.x.train.mm)
# Dodajemy brakujące kolumny w teście
missing_cols <- setdiff(train_cols, colnames(salary.x.test.mm))
if (length(missing_cols) > 0) {
for (col in missing_cols) {
salary.x.test.mm <- cbind(salary.x.test.mm, 0)
colnames(salary.x.test.mm)[ncol(salary.x.test.mm)] <- col
}
}
# Usuwamy ewentualne dodatkowe kolumny z testu i ustawiamy kolejność jak w train
salary.x.test.mm <- salary.x.test.mm[, train_cols]
# Kontrola wymiarów przed i po kodowaniu
dim(salary.x.train)
## [1] 21114 17
dim(salary.x.train.mm)
## [1] 21114 99
dim(salary.x.test)
## [1] 9048 17
dim(salary.x.test.mm)
## [1] 9048 99
Kolejnym krokiem było przeskalowanie danych wejściowych do zakresu 0–1. Parametry skalowania, czyli minimum i maksimum, wyznaczono wyłącznie na zbiorze treningowym. Jest to bardzo ważne, ponieważ zbiór testowy powinien symulować nowe dane, których model nie znał podczas uczenia.
# Minimum i maksimum liczymy wyłącznie na zbiorze treningowym
salary.mins <- apply(salary.x.train.mm, 2, min)
salary.maxs <- apply(salary.x.train.mm, 2, max)
# Zabezpieczenie przed dzieleniem przez zero
salary.ranges <- salary.maxs - salary.mins
salary.ranges[salary.ranges == 0] <- 1
# Skalowanie zbioru treningowego
salary.x.train.scaled <- scale(
salary.x.train.mm,
center = salary.mins,
scale = salary.ranges
)
# Skalowanie zbioru testowego tymi samymi parametrami
salary.x.test.scaled <- scale(
salary.x.test.mm,
center = salary.mins,
scale = salary.ranges
)
summary(salary.x.train.scaled)
## wiek stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych
## Min. :0.0000 Min. :0.0000000
## 1st Qu.:0.1507 1st Qu.:0.0000000
## Median :0.2740 Median :0.0000000
## Mean :0.2948 Mean :0.0006631
## 3rd Qu.:0.4110 3rd Qu.:0.0000000
## Max. :1.0000 Max. :1.0000000
## stan_cywilnymałżonek_nieobecny stan_cywilnynigdy_niezamężny_niezonaty
## Min. :0.00000 Min. :0.000
## 1st Qu.:0.00000 1st Qu.:0.000
## Median :0.00000 Median :0.000
## Mean :0.01246 Mean :0.323
## 3rd Qu.:0.00000 3rd Qu.:1.000
## Max. :1.00000 Max. :1.000
## stan_cywilnyrozwiedziony_rozwiedziona stan_cywilnyw_separacji
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000
## Mean :0.1377 Mean :0.0306
## 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000
## stan_cywilnywdowiec_wdowa zawodinne_usługi zawodkadra_kierownicza
## Min. :0.00000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.00000 Median :0.0000 Median :0.0000
## Mean :0.02804 Mean :0.1048 Mean :0.1327
## 3rd Qu.:0.00000 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.00000 Max. :1.0000 Max. :1.0000
## zawodoperatorzy_maszyn_i_inspektorzy zawodpracownicy_fizyczni_i_sprzątający
## Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000
## Mean :0.06635 Mean :0.04476
## 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000
## zawodprywatne_usługi_domowe zawodrolnictwo_i_rybołówstwo
## Min. :0.000000 Min. :0.00000
## 1st Qu.:0.000000 1st Qu.:0.00000
## Median :0.000000 Median :0.00000
## Mean :0.004547 Mean :0.03254
## 3rd Qu.:0.000000 3rd Qu.:0.00000
## Max. :1.000000 Max. :1.00000
## zawodrzemiosło_i_naprawy zawodsiły_zbrojne zawodspecjaliści zawodsprzedaż
## Min. :0.0000 Min. :0.0000000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000000 Median :0.0000 Median :0.0000
## Mean :0.1353 Mean :0.0003315 Mean :0.1332 Mean :0.1171
## 3rd Qu.:0.0000 3rd Qu.:0.0000000 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000000 Max. :1.0000 Max. :1.0000
## zawodtransport_i_przemieszczanie zawodusługi_ochrony zawodwsparcie_techniczne
## Min. :0.00000 Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000 Median :0.00000
## Mean :0.05172 Mean :0.02169 Mean :0.02951
## 3rd Qu.:0.00000 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000 Max. :1.00000
## relacja_w_rodziniemąż relacja_w_rodzinieniezamężny_niezonaty
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000
## Mean :0.4142 Mean :0.1063
## 3rd Qu.:1.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000
## relacja_w_rodziniepoza_rodziną relacja_w_rodziniewłasne_dziecko
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000
## Mean :0.2556 Mean :0.1461
## 3rd Qu.:1.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000
## relacja_w_rodzinieżona rasabiała rasaczarna rasainna
## Min. :0.00000 Min. :0.0000 Min. :0.00000 Min. :0.000000
## 1st Qu.:0.00000 1st Qu.:1.0000 1st Qu.:0.00000 1st Qu.:0.000000
## Median :0.00000 Median :1.0000 Median :0.00000 Median :0.000000
## Mean :0.04741 Mean :0.8622 Mean :0.09297 Mean :0.006962
## 3rd Qu.:0.00000 3rd Qu.:1.0000 3rd Qu.:0.00000 3rd Qu.:0.000000
## Max. :1.00000 Max. :1.0000 Max. :1.00000 Max. :1.000000
## rasardzenni_amerykanie_lub_eskimosi plecmężczyzna godziny_pracy_tygodniowo
## Min. :0.000000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.000000 1st Qu.:0.0000 1st Qu.:0.3980
## Median :0.000000 Median :1.0000 Median :0.3980
## Mean :0.009472 Mean :0.6773 Mean :0.4082
## 3rd Qu.:0.000000 3rd Qu.:1.0000 3rd Qu.:0.4490
## Max. :1.000000 Max. :1.0000 Max. :1.0000
## zysk_kapitalowy strata_kapitalowa typ_zatrudnieniaadministracja_lokalna
## Min. :0.00000 Min. :0.0000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.00000
## Median :0.00000 Median :0.0000 Median :0.00000
## Mean :0.01069 Mean :0.0206 Mean :0.06801
## 3rd Qu.:0.00000 3rd Qu.:0.0000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.0000 Max. :1.00000
## typ_zatrudnieniaadministracja_stanowa typ_zatrudnieniabez_wynagrodzenia
## Min. :0.00000 Min. :0.0000000
## 1st Qu.:0.00000 1st Qu.:0.0000000
## Median :0.00000 Median :0.0000000
## Mean :0.04244 Mean :0.0004263
## 3rd Qu.:0.00000 3rd Qu.:0.0000000
## Max. :1.00000 Max. :1.0000000
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej
## Min. :0.0000
## 1st Qu.:0.0000
## Median :0.0000
## Mean :0.0853
## 3rd Qu.:0.0000
## Max. :1.0000
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną
## Min. :0.00000
## 1st Qu.:0.00000
## Median :0.00000
## Mean :0.03538
## 3rd Qu.:0.00000
## Max. :1.00000
## typ_zatrudnieniasektor_prywatny kraj_pochodzeniaChiny
## Min. :0.0000 Min. :0.000000
## 1st Qu.:0.0000 1st Qu.:0.000000
## Median :1.0000 Median :0.000000
## Mean :0.7377 Mean :0.002415
## 3rd Qu.:1.0000 3rd Qu.:0.000000
## Max. :1.0000 Max. :1.000000
## kraj_pochodzeniaDominikana kraj_pochodzeniaEkwador kraj_pochodzeniaFilipiny
## Min. :0.0000 Min. :0.0000000 Min. :0.000000
## 1st Qu.:0.0000 1st Qu.:0.0000000 1st Qu.:0.000000
## Median :0.0000 Median :0.0000000 Median :0.000000
## Mean :0.0018 Mean :0.0009472 Mean :0.005257
## 3rd Qu.:0.0000 3rd Qu.:0.0000000 3rd Qu.:0.000000
## Max. :1.0000 Max. :1.0000000 Max. :1.000000
## kraj_pochodzeniaFrancja kraj_pochodzeniaGrecja kraj_pochodzeniaGwatemala
## Min. :0.0000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.0000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.0000000 Median :0.000000 Median :0.000000
## Mean :0.0009472 Mean :0.001042 Mean :0.002273
## 3rd Qu.:0.0000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.0000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaHaiti kraj_pochodzeniaHolandia kraj_pochodzeniaHonduras
## Min. :0.000000 Min. :0.00000000 Min. :0.0000000
## 1st Qu.:0.000000 1st Qu.:0.00000000 1st Qu.:0.0000000
## Median :0.000000 Median :0.00000000 Median :0.0000000
## Mean :0.001373 Mean :0.00004736 Mean :0.0003789
## 3rd Qu.:0.000000 3rd Qu.:0.00000000 3rd Qu.:0.0000000
## Max. :1.000000 Max. :1.00000000 Max. :1.0000000
## kraj_pochodzeniaHongkong kraj_pochodzeniaIndie kraj_pochodzeniaIran
## Min. :0.0000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.0000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.0000000 Median :0.000000 Median :0.000000
## Mean :0.0006631 Mean :0.003505 Mean :0.001373
## 3rd Qu.:0.0000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.0000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaIrlandia kraj_pochodzeniaJamajka kraj_pochodzeniaJaponia
## Min. :0.0000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.0000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.0000000 Median :0.000000 Median :0.000000
## Mean :0.0009472 Mean :0.002889 Mean :0.002037
## 3rd Qu.:0.0000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.0000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaJugosławia kraj_pochodzeniaKambodża kraj_pochodzeniaKanada
## Min. :0.000000 Min. :0.000000 Min. :0.00000
## 1st Qu.:0.000000 1st Qu.:0.000000 1st Qu.:0.00000
## Median :0.000000 Median :0.000000 Median :0.00000
## Mean :0.000521 Mean :0.000521 Mean :0.00341
## 3rd Qu.:0.000000 3rd Qu.:0.000000 3rd Qu.:0.00000
## Max. :1.000000 Max. :1.000000 Max. :1.00000
## kraj_pochodzeniaKolumbia kraj_pochodzeniaKorea_Południowa kraj_pochodzeniaKuba
## Min. :0.000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.000000 Median :0.000000 Median :0.000000
## Mean :0.001942 Mean :0.002605 Mean :0.003505
## 3rd Qu.:0.000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaLaos kraj_pochodzeniaMeksyk kraj_pochodzeniaNicaragua
## Min. :0.0000000 Min. :0.00000 Min. :0.000000
## 1st Qu.:0.0000000 1st Qu.:0.00000 1st Qu.:0.000000
## Median :0.0000000 Median :0.00000 Median :0.000000
## Mean :0.0004736 Mean :0.02122 Mean :0.001089
## 3rd Qu.:0.0000000 3rd Qu.:0.00000 3rd Qu.:0.000000
## Max. :1.0000000 Max. :1.00000 Max. :1.000000
## kraj_pochodzeniaNiemcy kraj_pochodzeniaPeru kraj_pochodzeniaPolska
## Min. :0.000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.000000 Median :0.000000 Median :0.000000
## Mean :0.004784 Mean :0.001042 Mean :0.001658
## 3rd Qu.:0.000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaPortoryko kraj_pochodzeniaPortugalia kraj_pochodzeniaSalwador
## Min. :0.000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.000000 Median :0.000000 Median :0.000000
## Mean :0.003505 Mean :0.001184 Mean :0.003363
## 3rd Qu.:0.000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.000000 Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaStany_Zjednoczone kraj_pochodzeniaSzkocja
## Min. :0.0000 Min. :0.0000000
## 1st Qu.:1.0000 1st Qu.:0.0000000
## Median :1.0000 Median :0.0000000
## Mean :0.9098 Mean :0.0004263
## 3rd Qu.:1.0000 3rd Qu.:0.0000000
## Max. :1.0000 Max. :1.0000000
## kraj_pochodzeniaTajlandia kraj_pochodzeniaTajwan
## Min. :0.000000 Min. :0.000000
## 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.000000 Median :0.000000
## Mean :0.000521 Mean :0.001373
## 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.000000 Max. :1.000000
## kraj_pochodzeniaTerytoria_Zależne_USA kraj_pochodzeniaTrynidad_i_Tobago
## Min. :0.0000000 Min. :0.0000000
## 1st Qu.:0.0000000 1st Qu.:0.0000000
## Median :0.0000000 Median :0.0000000
## Mean :0.0004263 Mean :0.0004736
## 3rd Qu.:0.0000000 3rd Qu.:0.0000000
## Max. :1.0000000 Max. :1.0000000
## kraj_pochodzeniaWęgry kraj_pochodzeniaWietnam kraj_pochodzeniaWłochy
## Min. :0.000000 Min. :0.000000 Min. :0.000000
## 1st Qu.:0.000000 1st Qu.:0.000000 1st Qu.:0.000000
## Median :0.000000 Median :0.000000 Median :0.000000
## Mean :0.000521 Mean :0.001989 Mean :0.002415
## 3rd Qu.:0.000000 3rd Qu.:0.000000 3rd Qu.:0.000000
## Max. :1.000000 Max. :1.000000 Max. :1.000000
## kapital_netto czy_kapital_netto_dodatnikapital_netto_niedodatni
## Min. :0.00000 Min. :0.0000
## 1st Qu.:0.04174 1st Qu.:1.0000
## Median :0.04174 Median :1.0000
## Mean :0.05112 Mean :0.9156
## 3rd Qu.:0.04174 3rd Qu.:1.0000
## Max. :1.00000 Max. :1.0000
## godziny_razy_edukacja wiek_sredniw_grupie_wieku_średniego
## Min. :0.0000 Min. :0.000
## 1st Qu.:0.2180 1st Qu.:0.000
## Median :0.2497 Median :1.000
## Mean :0.2620 Mean :0.594
## 3rd Qu.:0.3257 3rd Qu.:1.000
## Max. :1.0000 Max. :1.000
## dlugie_godziny_pracytak poziom_edukacji_factorklasy_1_4
## Min. :0.0000 Min. :0.000000
## 1st Qu.:0.0000 1st Qu.:0.000000
## Median :0.0000 Median :0.000000
## Mean :0.3054 Mean :0.004973
## 3rd Qu.:1.0000 3rd Qu.:0.000000
## Max. :1.0000 Max. :1.000000
## poziom_edukacji_factorklasy_5_6 poziom_edukacji_factorklasy_7_8
## Min. :0.000000 Min. :0.00000
## 1st Qu.:0.000000 1st Qu.:0.00000
## Median :0.000000 Median :0.00000
## Mean :0.009804 Mean :0.01894
## 3rd Qu.:0.000000 3rd Qu.:0.00000
## Max. :1.000000 Max. :1.00000
## poziom_edukacji_factorklasa_9 poziom_edukacji_factorklasa_10
## Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000
## Mean :0.01501 Mean :0.02633
## 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000
## poziom_edukacji_factorklasa_11 poziom_edukacji_factorklasa_12
## Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000
## Mean :0.03429 Mean :0.01302
## 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000
## poziom_edukacji_factorszkoła_średnia poziom_edukacji_factorczęść_studiów
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.0000 Median :0.0000
## Mean :0.3283 Mean :0.2196
## 3rd Qu.:1.0000 3rd Qu.:0.0000
## Max. :1.0000 Max. :1.0000
## poziom_edukacji_factorassociate_zawodowe
## Min. :0.00000
## 1st Qu.:0.00000
## Median :0.00000
## Mean :0.04424
## 3rd Qu.:0.00000
## Max. :1.00000
## poziom_edukacji_factorassociate_akademickie poziom_edukacji_factorlicencjat
## Min. :0.00000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:0.0000
## Median :0.00000 Median :0.0000
## Mean :0.03396 Mean :0.1653
## 3rd Qu.:0.00000 3rd Qu.:0.0000
## Max. :1.00000 Max. :1.0000
## poziom_edukacji_factormagister poziom_edukacji_factorszkoła_profesjonalna
## Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000
## Mean :0.05314 Mean :0.01894
## 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000
## poziom_edukacji_factordoktorat
## Min. :0.00000
## 1st Qu.:0.00000
## Median :0.00000
## Mean :0.01265
## 3rd Qu.:0.00000
## Max. :1.00000
Pakiet neuralnet wymaga, aby zmienna celu w klasyfikacji
binarnej była zapisana numerycznie jako 0/1. W tym przypadku klasa
tak, czyli osoby zarabiające powyżej 50 tys. USD, została
oznaczona jako 1, a klasa nie jako 0.
salary.y.train.nn <- ifelse(salary.y.train == "tak", 1, 0)
salary.y.test.nn <- ifelse(salary.y.test == "tak", 1, 0)
table(salary.y.train.nn)
## salary.y.train.nn
## 0 1
## 15858 5256
table(salary.y.test.nn)
## salary.y.test.nn
## 0 1
## 6796 2252
Następnie przygotowano końcowe zbiory danych dla sieci neuronowych.
Nazwy kolumn zostały dodatkowo oczyszczone funkcją
make.names(), ponieważ pakiet neuralnet może
mieć problem z niektórymi znakami specjalnymi w nazwach zmiennych.
salary.train.nn <- as.data.frame(salary.x.train.scaled)
salary.train.nn$y <- salary.y.train.nn
salary.test.nn <- as.data.frame(salary.x.test.scaled)
salary.test.nn$y <- salary.y.test.nn
names(salary.train.nn) <- make.names(names(salary.train.nn), unique = TRUE)
names(salary.test.nn) <- make.names(names(salary.test.nn), unique = TRUE)
glimpse(salary.train.nn)
## Rows: 21,114
## Columns: 100
## $ wiek <dbl> 0.30136986, 0.4…
## $ stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnymałżonek_nieobecny <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnynigdy_niezamężny_niezonaty <dbl> 1, 0, 0, 0, 0, …
## $ stan_cywilnyrozwiedziony_rozwiedziona <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnyw_separacji <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnywdowiec_wdowa <dbl> 0, 0, 0, 0, 0, …
## $ zawodinne_usługi <dbl> 0, 0, 0, 0, 0, …
## $ zawodkadra_kierownicza <dbl> 0, 1, 0, 0, 1, …
## $ zawodoperatorzy_maszyn_i_inspektorzy <dbl> 0, 0, 0, 0, 0, …
## $ zawodpracownicy_fizyczni_i_sprzątający <dbl> 0, 0, 1, 0, 0, …
## $ zawodprywatne_usługi_domowe <dbl> 0, 0, 0, 0, 0, …
## $ zawodrolnictwo_i_rybołówstwo <dbl> 0, 0, 0, 0, 0, …
## $ zawodrzemiosło_i_naprawy <dbl> 0, 0, 0, 0, 0, …
## $ zawodsiły_zbrojne <dbl> 0, 0, 0, 0, 0, …
## $ zawodspecjaliści <dbl> 0, 0, 0, 1, 0, …
## $ zawodsprzedaż <dbl> 0, 0, 0, 0, 0, …
## $ zawodtransport_i_przemieszczanie <dbl> 0, 0, 0, 0, 0, …
## $ zawodusługi_ochrony <dbl> 0, 0, 0, 0, 0, …
## $ zawodwsparcie_techniczne <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniemąż <dbl> 0, 1, 1, 0, 0, …
## $ relacja_w_rodzinieniezamężny_niezonaty <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniepoza_rodziną <dbl> 1, 0, 0, 0, 0, …
## $ relacja_w_rodziniewłasne_dziecko <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodzinieżona <dbl> 0, 0, 0, 1, 1, …
## $ rasabiała <dbl> 1, 1, 0, 0, 1, …
## $ rasaczarna <dbl> 0, 0, 1, 1, 0, …
## $ rasainna <dbl> 0, 0, 0, 0, 0, …
## $ rasardzenni_amerykanie_lub_eskimosi <dbl> 0, 0, 0, 0, 0, …
## $ plecmężczyzna <dbl> 1, 1, 1, 0, 0, …
## $ godziny_pracy_tygodniowo <dbl> 0.3979592, 0.12…
## $ zysk_kapitalowy <dbl> 0.02174022, 0.0…
## $ strata_kapitalowa <dbl> 0.0000000, 0.00…
## $ typ_zatrudnieniaadministracja_lokalna <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_stanowa <dbl> 1, 0, 0, 0, 0, …
## $ typ_zatrudnieniabez_wynagrodzenia <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej <dbl> 0, 1, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_z_osobowością_prawną <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasektor_prywatny <dbl> 0, 0, 1, 1, 1, …
## $ kraj_pochodzeniaChiny <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaDominikana <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaEkwador <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFilipiny <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFrancja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGrecja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGwatemala <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHaiti <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHolandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHonduras <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHongkong <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIndie <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIran <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIrlandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJamajka <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJaponia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJugosławia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKambodża <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKanada <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKolumbia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKorea_Południowa <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKuba <dbl> 0, 0, 0, 1, 0, …
## $ kraj_pochodzeniaLaos <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaMeksyk <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNicaragua <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNiemcy <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPeru <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPolska <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortoryko <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortugalia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaSalwador <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaStany_Zjednoczone <dbl> 1, 1, 1, 0, 1, …
## $ kraj_pochodzeniaSzkocja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajlandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajwan <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTerytoria_Zależne_USA <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTrynidad_i_Tobago <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWęgry <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWietnam <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWłochy <dbl> 0, 0, 0, 0, 0, …
## $ kapital_netto <dbl> 0.06257486, 0.0…
## $ czy_kapital_netto_dodatnikapital_netto_niedodatni <dbl> 0, 1, 1, 1, 1, …
## $ godziny_razy_edukacja <dbl> 0.32572877, 0.1…
## $ wiek_sredniw_grupie_wieku_średniego <dbl> 1, 1, 1, 0, 1, …
## $ dlugie_godziny_pracytak <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_1_4 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_5_6 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_7_8 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_9 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_10 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_11 <dbl> 0, 0, 1, 0, 0, …
## $ poziom_edukacji_factorklasa_12 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorszkoła_średnia <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorczęść_studiów <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_zawodowe <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_akademickie <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorlicencjat <dbl> 1, 1, 0, 1, 0, …
## $ poziom_edukacji_factormagister <dbl> 0, 0, 0, 0, 1, …
## $ poziom_edukacji_factorszkoła_profesjonalna <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factordoktorat <dbl> 0, 0, 0, 0, 0, …
## $ y <dbl> 0, 0, 0, 0, 0, …
glimpse(salary.test.nn)
## Rows: 9,048
## Columns: 100
## $ wiek <dbl> 0.28767123, 0.4…
## $ stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnymałżonek_nieobecny <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnynigdy_niezamężny_niezonaty <dbl> 0, 0, 1, 1, 0, …
## $ stan_cywilnyrozwiedziony_rozwiedziona <dbl> 1, 0, 0, 0, 1, …
## $ stan_cywilnyw_separacji <dbl> 0, 0, 0, 0, 0, …
## $ stan_cywilnywdowiec_wdowa <dbl> 0, 0, 0, 0, 0, …
## $ zawodinne_usługi <dbl> 0, 0, 0, 0, 0, …
## $ zawodkadra_kierownicza <dbl> 0, 1, 0, 0, 1, …
## $ zawodoperatorzy_maszyn_i_inspektorzy <dbl> 0, 0, 0, 1, 0, …
## $ zawodpracownicy_fizyczni_i_sprzątający <dbl> 1, 0, 0, 0, 0, …
## $ zawodprywatne_usługi_domowe <dbl> 0, 0, 0, 0, 0, …
## $ zawodrolnictwo_i_rybołówstwo <dbl> 0, 0, 0, 0, 0, …
## $ zawodrzemiosło_i_naprawy <dbl> 0, 0, 0, 0, 0, …
## $ zawodsiły_zbrojne <dbl> 0, 0, 0, 0, 0, …
## $ zawodspecjaliści <dbl> 0, 0, 0, 0, 0, …
## $ zawodsprzedaż <dbl> 0, 0, 0, 0, 0, …
## $ zawodtransport_i_przemieszczanie <dbl> 0, 0, 0, 0, 0, …
## $ zawodusługi_ochrony <dbl> 0, 0, 0, 0, 0, …
## $ zawodwsparcie_techniczne <dbl> 0, 0, 0, 0, 0, …
## $ relacja_w_rodziniemąż <dbl> 0, 1, 0, 0, 0, …
## $ relacja_w_rodzinieniezamężny_niezonaty <dbl> 0, 0, 0, 1, 1, …
## $ relacja_w_rodziniepoza_rodziną <dbl> 1, 0, 0, 0, 0, …
## $ relacja_w_rodziniewłasne_dziecko <dbl> 0, 0, 1, 0, 0, …
## $ relacja_w_rodzinieżona <dbl> 0, 0, 0, 0, 0, …
## $ rasabiała <dbl> 1, 1, 1, 1, 1, …
## $ rasaczarna <dbl> 0, 0, 0, 0, 0, …
## $ rasainna <dbl> 0, 0, 0, 0, 0, …
## $ rasardzenni_amerykanie_lub_eskimosi <dbl> 0, 0, 0, 0, 0, …
## $ plecmężczyzna <dbl> 1, 1, 0, 1, 0, …
## $ godziny_pracy_tygodniowo <dbl> 0.3979592, 0.44…
## $ zysk_kapitalowy <dbl> 0.0000000, 0.00…
## $ strata_kapitalowa <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_lokalna <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniaadministracja_stanowa <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniabez_wynagrodzenia <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej <dbl> 0, 1, 0, 0, 1, …
## $ typ_zatrudnieniasamozatrudniony_z_osobowością_prawną <dbl> 0, 0, 0, 0, 0, …
## $ typ_zatrudnieniasektor_prywatny <dbl> 1, 0, 1, 1, 0, …
## $ kraj_pochodzeniaChiny <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaDominikana <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaEkwador <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFilipiny <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaFrancja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGrecja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaGwatemala <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHaiti <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHolandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHonduras <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaHongkong <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIndie <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIran <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaIrlandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJamajka <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJaponia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaJugosławia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKambodża <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKanada <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKolumbia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKorea_Południowa <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaKuba <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaLaos <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaMeksyk <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNicaragua <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaNiemcy <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPeru <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPolska <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortoryko <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaPortugalia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaSalwador <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaStany_Zjednoczone <dbl> 1, 1, 1, 1, 1, …
## $ kraj_pochodzeniaSzkocja <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajlandia <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTajwan <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTerytoria_Zależne_USA <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaTrynidad_i_Tobago <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWęgry <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWietnam <dbl> 0, 0, 0, 0, 0, …
## $ kraj_pochodzeniaWłochy <dbl> 0, 0, 0, 0, 0, …
## $ kapital_netto <dbl> 0.04174213, 0.0…
## $ czy_kapital_netto_dodatnikapital_netto_niedodatni <dbl> 1, 1, 1, 1, 1, …
## $ godziny_razy_edukacja <dbl> 0.2243346, 0.25…
## $ wiek_sredniw_grupie_wieku_średniego <dbl> 1, 1, 0, 1, 1, …
## $ dlugie_godziny_pracytak <dbl> 0, 1, 0, 0, 1, …
## $ poziom_edukacji_factorklasy_1_4 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_5_6 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasy_7_8 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_9 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_10 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_11 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorklasa_12 <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorszkoła_średnia <dbl> 1, 1, 0, 1, 0, …
## $ poziom_edukacji_factorczęść_studiów <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_zawodowe <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorassociate_akademickie <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factorlicencjat <dbl> 0, 0, 1, 0, 0, …
## $ poziom_edukacji_factormagister <dbl> 0, 0, 0, 0, 1, …
## $ poziom_edukacji_factorszkoła_profesjonalna <dbl> 0, 0, 0, 0, 0, …
## $ poziom_edukacji_factordoktorat <dbl> 0, 0, 0, 0, 0, …
## $ y <dbl> 0, 1, 0, 0, 1, …
# Opcjonalny zapis przygotowanych danych
save(
salary.train.nn,
salary.test.nn,
file = "02_dane_w_SML-20260311/salary_train_test.RData"
)
Ze względu na czasochłonność uczenia sieci neuronowych w pakiecie neuralnet, modele zostały wytrenowane na losowym, stratyfikowanym podzbiorze zbioru treningowego liczącym maksymalnie 5000 obserwacji. Należy jednak podkreślić, że końcowa ocena modeli została przeprowadzona na pełnym zbiorze testowym, dzięki czemu możliwa jest ocena zdolności generalizacji na danych niewykorzystanych podczas uczenia.
set.seed(123)
idx_salary_nn <- createDataPartition(
salary.train.nn$y,
p = min(5000 / nrow(salary.train.nn), 1),
list = FALSE
)
salary.train.small.nn <- salary.train.nn[idx_salary_nn, ]
dim(salary.train.small.nn)
## [1] 5000 100
table(salary.train.small.nn$y)
##
## 0 1
## 3749 1251
prop.table(table(salary.train.small.nn$y))
##
## 0 1
## 0.7498 0.2502
Architektura sieci neuronowej, czyli liczba warstw ukrytych i liczba neuronów w każdej z nich, jest hiperparametrem. Większa liczba neuronów zwiększa elastyczność modelu, ale jednocześnie zwiększa liczbę wag do oszacowania, czas uczenia oraz ryzyko przeuczenia.
W celu porównania zastosowano trzy proste architektury:
NN_3 — jedna warstwa ukryta z 3 neuronami,NN_5 — jedna warstwa ukryta z 5 neuronami,NN_3_2 — dwie warstwy ukryte: pierwsza z 3 neuronami,
druga z 2 neuronami.Ponadto co ważne, w zastosowano logistyczną funkcję aktywacji, czyli funkcję sigmoidalną. Jest ona odpowiednia dla problemu klasyfikacji binarnej, ponieważ przekształca wynik modelu do przedziału od 0 do 1. Dzięki temu predykcję sieci można interpretować jako prawdopodobieństwo przynależności obserwacji do klasy pozytywnej, czyli osoby zarabiającej powyżej 50 tys. USD.
count_weights <- function(n_inputs, hidden, n_outputs = 1) {
layers <- c(n_inputs, hidden, n_outputs)
total <- 0
for (i in 1:(length(layers) - 1)) {
total <- total + (layers[i] + 1) * layers[i + 1]
}
return(total)
}
salary.predictors <- setdiff(names(salary.train.small.nn), "y")
n_inputs_salary <- length(salary.predictors)
n_inputs_salary
## [1] 99
salary_architectures <- list(
NN_3 = c(3),
NN_5 = c(5),
NN_3_2 = c(3, 2)
)
for (name in names(salary_architectures)) {
arch <- salary_architectures[[name]]
cat(
"Architektura:", name,
"| hidden =", paste(arch, collapse = "-"),
"| liczba wag:", count_weights(n_inputs_salary, arch, n_outputs = 1),
"\n"
)
}
## Architektura: NN_3 | hidden = 3 | liczba wag: 304
## Architektura: NN_5 | hidden = 5 | liczba wag: 506
## Architektura: NN_3_2 | hidden = 3-2 | liczba wag: 311
Na potrzeby funkcji neuralnet() zbudowano formułę modelu
automatycznie na podstawie nazw predyktorów.
salary.formula <- as.formula(
paste("y ~", paste(salary.predictors, collapse = " + "))
)
salary.formula
## y ~ wiek + stan_cywilnymałżeństwo_z_osobą_z_sił_zbrojnych +
## stan_cywilnymałżonek_nieobecny + stan_cywilnynigdy_niezamężny_niezonaty +
## stan_cywilnyrozwiedziony_rozwiedziona + stan_cywilnyw_separacji +
## stan_cywilnywdowiec_wdowa + zawodinne_usługi + zawodkadra_kierownicza +
## zawodoperatorzy_maszyn_i_inspektorzy + zawodpracownicy_fizyczni_i_sprzątający +
## zawodprywatne_usługi_domowe + zawodrolnictwo_i_rybołówstwo +
## zawodrzemiosło_i_naprawy + zawodsiły_zbrojne + zawodspecjaliści +
## zawodsprzedaż + zawodtransport_i_przemieszczanie + zawodusługi_ochrony +
## zawodwsparcie_techniczne + relacja_w_rodziniemąż + relacja_w_rodzinieniezamężny_niezonaty +
## relacja_w_rodziniepoza_rodziną + relacja_w_rodziniewłasne_dziecko +
## relacja_w_rodzinieżona + rasabiała + rasaczarna + rasainna +
## rasardzenni_amerykanie_lub_eskimosi + plecmężczyzna + godziny_pracy_tygodniowo +
## zysk_kapitalowy + strata_kapitalowa + typ_zatrudnieniaadministracja_lokalna +
## typ_zatrudnieniaadministracja_stanowa + typ_zatrudnieniabez_wynagrodzenia +
## typ_zatrudnieniasamozatrudniony_bez_osobowości_prawnej +
## typ_zatrudnieniasamozatrudniony_z_osobowością_prawną +
## typ_zatrudnieniasektor_prywatny + kraj_pochodzeniaChiny +
## kraj_pochodzeniaDominikana + kraj_pochodzeniaEkwador + kraj_pochodzeniaFilipiny +
## kraj_pochodzeniaFrancja + kraj_pochodzeniaGrecja + kraj_pochodzeniaGwatemala +
## kraj_pochodzeniaHaiti + kraj_pochodzeniaHolandia + kraj_pochodzeniaHonduras +
## kraj_pochodzeniaHongkong + kraj_pochodzeniaIndie + kraj_pochodzeniaIran +
## kraj_pochodzeniaIrlandia + kraj_pochodzeniaJamajka + kraj_pochodzeniaJaponia +
## kraj_pochodzeniaJugosławia + kraj_pochodzeniaKambodża +
## kraj_pochodzeniaKanada + kraj_pochodzeniaKolumbia + kraj_pochodzeniaKorea_Południowa +
## kraj_pochodzeniaKuba + kraj_pochodzeniaLaos + kraj_pochodzeniaMeksyk +
## kraj_pochodzeniaNicaragua + kraj_pochodzeniaNiemcy + kraj_pochodzeniaPeru +
## kraj_pochodzeniaPolska + kraj_pochodzeniaPortoryko + kraj_pochodzeniaPortugalia +
## kraj_pochodzeniaSalwador + kraj_pochodzeniaStany_Zjednoczone +
## kraj_pochodzeniaSzkocja + kraj_pochodzeniaTajlandia + kraj_pochodzeniaTajwan +
## kraj_pochodzeniaTerytoria_Zależne_USA + kraj_pochodzeniaTrynidad_i_Tobago +
## kraj_pochodzeniaWęgry + kraj_pochodzeniaWietnam + kraj_pochodzeniaWłochy +
## kapital_netto + czy_kapital_netto_dodatnikapital_netto_niedodatni +
## godziny_razy_edukacja + wiek_sredniw_grupie_wieku_średniego +
## dlugie_godziny_pracytak + poziom_edukacji_factorklasy_1_4 +
## poziom_edukacji_factorklasy_5_6 + poziom_edukacji_factorklasy_7_8 +
## poziom_edukacji_factorklasa_9 + poziom_edukacji_factorklasa_10 +
## poziom_edukacji_factorklasa_11 + poziom_edukacji_factorklasa_12 +
## poziom_edukacji_factorszkoła_średnia + poziom_edukacji_factorczęść_studiów +
## poziom_edukacji_factorassociate_zawodowe + poziom_edukacji_factorassociate_akademickie +
## poziom_edukacji_factorlicencjat + poziom_edukacji_factormagister +
## poziom_edukacji_factorszkoła_profesjonalna + poziom_edukacji_factordoktorat
Pierwszy model ma jedną warstwę ukrytą z 3 neuronami. Dla
klasyfikacji binarnej ustawiono linear.output = FALSE,
dzięki czemu wynik sieci można interpretować jako wartość zbliżoną do
prawdopodobieństwa przynależności do klasy pozytywnej.
# set.seed(123)
#
# nn_salary_1 <- neuralnet(
# formula = salary.formula,
# data = salary.train.small.nn,
# hidden = c(3),
# linear.output = FALSE,
# act.fct = "logistic",
# threshold = 0.05,
# stepmax = 1e6,
# rep = 1,
# lifesign = "minimal"
# )
#
# saveRDS(nn_salary_1, "trenowane_modele/salary.nn1.rds")
# Wczytanie wcześniej zapisanego modelu
nn_salary_1 <- readRDS("trenowane_modele/salary.nn1.rds")
Dla pierwszego modelu przygotowano również wizualizację sieci. Przy dużej liczbie zmiennych wejściowych wykres może być mniej czytelny, ale pozwala zobaczyć ogólną strukturę połączeń między warstwami.
#Zablokowano ze względu na bardzo długi okres kompilacji
#plotnet(
# nn_salary_1,
# circle_col = list("lightyellow", "lightblue"),
# bord_col = "grey40",
# pos_col = "darkgreen",
# neg_col = "red3",
# circle_cex = 2,
# cex_val = 0.5,
# alpha_val = 0.8,
# max_sp = TRUE,
# pad_x = 0.9
#)
Predykcje zostały wyznaczone zarówno dla zmniejszonego zbioru treningowego, na którym model był uczony, jak i dla pełnego zbioru testowego.
pred_salary_nn1_train <- compute(
nn_salary_1,
salary.train.small.nn[, salary.predictors]
)$net.result
pred_salary_nn1_test <- compute(
nn_salary_1,
salary.test.nn[, salary.predictors]
)$net.result
pred_salary_nn1_train <- as.numeric(pred_salary_nn1_train)
pred_salary_nn1_test <- as.numeric(pred_salary_nn1_test)
summary(pred_salary_nn1_train)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.00004445 0.00214147 0.04628933 0.24439997 0.42383922 0.96674992
summary(pred_salary_nn1_test)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.00004445 0.00209347 0.04403567 0.24352639 0.42018547 0.96674992
Ze względu na niezbalansowanie klas nie zastosowano domyślnego progu 0.5. Jako prosty próg odcięcia przyjęto udział klasy pozytywnej w zmniejszonym zbiorze treningowym.
cutoff.nn <- mean(salary.train.small.nn$y)
cutoff.nn
## [1] 0.2502
metryki.salary.nn1.train <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn1_train,
real = as.character(salary.train.small.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
metryki.salary.nn1.test <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn1_test,
real = as.character(salary.test.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
rbind(
nn1.train = metryki.salary.nn1.train,
nn1.test = metryki.salary.nn1.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV
## nn1.train 0.8490 0.8753 0.8402 0.8578 0.6464 0.9528
## nn1.test 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209
## F1 Kappa
## nn1.train 0.7436 0.6400
## nn1.test 0.6665 0.5316
Drugi model również posiada jedną warstwę ukrytą, ale z większą liczbą neuronów. Celem jest sprawdzenie, czy nieco bardziej elastyczna sieć poprawi jakość klasyfikacji.
# Model NN2: jedna warstwa ukryta z 5 neuronami
#
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował sieci od nowa.
# Model został wcześniej zapisany do pliku RDS.
# set.seed(123)
#
# nn_salary_2 <- neuralnet(
# formula = salary.formula,
# data = salary.train.small.nn,
# hidden = c(5),
# linear.output = FALSE,
# act.fct = "logistic",
# threshold = 0.05,
# stepmax = 1e6,
# rep = 1,
# lifesign = "minimal"
# )
#
# saveRDS(nn_salary_2, "trenowane_modele/salary.nn2.rds")
# Wczytanie wcześniej zapisanego modelu
nn_salary_2 <- readRDS("trenowane_modele/salary.nn2.rds")
#nn_salary_2
Trzeci model ma dwie warstwy ukryte. Pierwsza warstwa zawiera 3 neurony, a druga 2 neurony. Taka architektura pozwala sprawdzić, czy dodanie drugiej warstwy poprawi zdolność modelu do uchwycenia złożonych zależności.
# Model NN3: dwie warstwy ukryte, 3 i 2 neurony
#
# Trenowanie modelu zostało wyhasztagowane, aby raport nie budował sieci od nowa.
# Model został wcześniej zapisany do pliku RDS.
# set.seed(123)
#
# nn_salary_3 <- neuralnet(
# formula = salary.formula,
# data = salary.train.small.nn,
# hidden = c(3, 2),
# linear.output = FALSE,
# act.fct = "logistic",
# threshold = 0.05,
# stepmax = 1e6,
# rep = 1,
# lifesign = "minimal"
# )
#
# saveRDS(nn_salary_3, "trenowane_modele/salary.nn3.rds")
# Wczytanie wcześniej zapisanego modelu
nn_salary_3 <- readRDS("trenowane_modele/salary.nn3.rds")
#nn_salary_3
pred_salary_nn2_train <- compute(
nn_salary_2,
salary.train.small.nn[, salary.predictors]
)$net.result
pred_salary_nn2_test <- compute(
nn_salary_2,
salary.test.nn[, salary.predictors]
)$net.result
pred_salary_nn3_train <- compute(
nn_salary_3,
salary.train.small.nn[, salary.predictors]
)$net.result
pred_salary_nn3_test <- compute(
nn_salary_3,
salary.test.nn[, salary.predictors]
)$net.result
pred_salary_nn2_train <- as.numeric(pred_salary_nn2_train)
pred_salary_nn2_test <- as.numeric(pred_salary_nn2_test)
pred_salary_nn3_train <- as.numeric(pred_salary_nn3_train)
pred_salary_nn3_test <- as.numeric(pred_salary_nn3_test)
Poniżej zestawiono wyniki trzech modeli sieci neuronowych. Porównanie obejmuje osobno wyniki na zbiorze treningowym i testowym.
metryki.salary.nn2.train <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn2_train,
real = as.character(salary.train.small.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
metryki.salary.nn2.test <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn2_test,
real = as.character(salary.test.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
metryki.salary.nn3.train <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn3_train,
real = as.character(salary.train.small.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
metryki.salary.nn3.test <- klasyfikacja.metryki(
predicted_probabilities = pred_salary_nn3_test,
real = as.character(salary.test.nn$y),
cutoff = cutoff.nn,
level_positive = "1",
level_negative = "0"
)
rbind(
nn.3 = metryki.salary.nn1.train,
nn.5 = metryki.salary.nn2.train,
nn.3.2 = metryki.salary.nn3.train
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## nn.3 0.8490 0.8753 0.8402 0.8578 0.6464 0.9528 0.7436
## nn.5 0.8892 0.8729 0.8946 0.8838 0.7344 0.9547 0.7977
## nn.3.2 0.8498 0.8649 0.8448 0.8548 0.6502 0.9493 0.7424
## Kappa
## nn.3 0.6400
## nn.5 0.7222
## nn.3.2 0.6393
rbind(
nn.3 = metryki.salary.nn1.test,
nn.5 = metryki.salary.nn2.test,
nn.3.2 = metryki.salary.nn3.test
)
## Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV F1
## nn.3 0.8031 0.7909 0.8071 0.7990 0.5760 0.9209 0.6665
## nn.5 0.7993 0.6932 0.8345 0.7638 0.5812 0.8914 0.6322
## nn.3.2 0.8031 0.7815 0.8102 0.7959 0.5770 0.9180 0.6639
## Kappa
## nn.3 0.5316
## nn.5 0.4957
## nn.3.2 0.5290
Na podstawie uzyskanych wyników można porównać wpływ architektury
sieci na jakość klasyfikacji. Warto zwrócić uwagę nie tylko na
Accuracy, ale również na Balanced Accuracy,
Sensitivity, Specificity oraz
F1.
W przypadku danych niezbalansowanych zwykła trafność klasyfikacji
może być myląca. Model może osiągać wysoką wartość
Accuracy, ponieważ dobrze rozpoznaje klasę większościową,
ale jednocześnie słabo wykrywać klasę pozytywną. Dlatego większe
znaczenie ma Balanced Accuracy, która uwzględnia zarówno
czułość, jak i specyficzność. Jednak widać na modelach zjawisko
przeuczenia dla modelu np 1 - 87% do *80** jeśli chodzi
o Sensivity, jednak ze względu na i tak już długi czas kompilowania dla
5000 danych zdecydowano się na pozostawienie takiej ilości.
Dodatkowo dla każdej sieci neuronowej wyznaczono krzywą ROC oraz indeks Giniego. Krzywa ROC pozwala ocenić zdolność modelu do rozróżniania klas niezależnie od konkretnego progu odcięcia.
ROC.salary.nn1.train <- pROC::roc(
as.numeric(salary.train.small.nn$y == 1),
pred_salary_nn1_train
)
ROC.salary.nn1.test <- pROC::roc(
as.numeric(salary.test.nn$y == 1),
pred_salary_nn1_test
)
ROC.salary.nn2.train <- pROC::roc(
as.numeric(salary.train.small.nn$y == 1),
pred_salary_nn2_train
)
ROC.salary.nn2.test <- pROC::roc(
as.numeric(salary.test.nn$y == 1),
pred_salary_nn2_test
)
ROC.salary.nn3.train <- pROC::roc(
as.numeric(salary.train.small.nn$y == 1),
pred_salary_nn3_train
)
ROC.salary.nn3.test <- pROC::roc(
as.numeric(salary.test.nn$y == 1),
pred_salary_nn3_test
)
ROC.salary.nn1.test$auc
## Area under the curve: 0.8799
ROC.salary.nn2.test$auc
## Area under the curve: 0.8497
ROC.salary.nn3.test$auc
## Area under the curve: 0.8647
gini <- function(roc) {
paste0(round(100 * (2 * as.numeric(pROC::auc(roc)) - 1), 1), "%")
}
list(
nn1.train = ROC.salary.nn1.train,
nn1.test = ROC.salary.nn1.test,
nn2.train = ROC.salary.nn2.train,
nn2.test = ROC.salary.nn2.test,
nn3.train = ROC.salary.nn3.train,
nn3.test = ROC.salary.nn3.test
) %>%
ggroc(alpha = 0.7, linewidth = 1.1) +
geom_segment(
aes(x = 1, xend = 0, y = 0, yend = 1),
color = "grey60",
linetype = "dashed"
) +
labs(
title = paste0(
"Gini TEST: ",
"NN1 (3) = ", gini(ROC.salary.nn1.test), ", ",
"NN2 (5) = ", gini(ROC.salary.nn2.test), ", ",
"NN3 (3,2) = ", gini(ROC.salary.nn3.test)
),
subtitle = paste0(
"Gini TRAIN: ",
"NN1 (3) = ", gini(ROC.salary.nn1.train), ", ",
"NN2 (5) = ", gini(ROC.salary.nn2.train), ", ",
"NN3 (3,2) = ", gini(ROC.salary.nn3.train)
),
color = "Model",
x = "Specyficzność",
y = "Wrażliwość"
) +
theme_bw() +
coord_fixed() +
scale_color_brewer(palette = "Paired")
Sieci neuronowe zostały wytrenowane na zmniejszonym zbiorze treningowym, dlatego ich wyniki należy interpretować ostrożnie. Z jednej strony ograniczenie liczby obserwacji znacząco skraca czas uczenia, ale z drugiej strony może pogarszać zdolność generalizacji modelu.
W porównaniu trzech architektur można zauważyć, że większa liczba neuronów lub dodatkowa warstwa ukryta nie musi automatycznie prowadzić do lepszych wyników. Bardziej złożony model ma więcej wag do oszacowania, co może zwiększać ryzyko przeuczenia.
Jeżeli wyniki na zbiorze treningowym są wyraźnie lepsze niż na zbiorze testowym, może to oznaczać overfitting. Jeżeli natomiast wyniki treningowe i testowe są podobne, ale relatywnie słabe, model może być zbyt prosty albo trenowany na zbyt małej próbie.
W tej analizie sieci neuronowe pełnią rolę dodatkowego punktu
odniesienia względem modeli takich jak regresja logistyczna, drzewa
decyzyjne, bagging, Random Forest oraz XGBoost. Ich zaletą jest
możliwość modelowania nieliniowych zależności, jednak przy użyciu
pakietu neuralnet oraz dużej liczbie zmiennych po one-hot
encodingu trening może być czasochłonny.
W ostatnim etapie analizy porównano wszystkie zbudowane modele klasyfikacyjne. Zestawienie obejmuje modele bazowe, drzewa decyzyjne, modele zespołowe, boosting oraz sieci neuronowe.
Porównanie wykonano osobno dla zbioru treningowego oraz testowego. Wyniki na zbiorze treningowym pokazują, jak dobrze model dopasował się do danych uczących, natomiast wyniki na zbiorze testowym są ważniejsze z punktu widzenia oceny jakości predykcyjnej, ponieważ pokazują skuteczność modelu na nowych obserwacjach.
W analizie wykorzystano kilka miar jakości klasyfikacji.
Accuracy oznacza ogólną trafność klasyfikacji.
Sensitivity pokazuje, jak dobrze model wykrywa klasę
pozytywną, czyli osoby zarabiające powyżej 50 tys. USD.
Specificity informuje, jak dobrze model rozpoznaje klasę
negatywną. BalancedAccuracy jest średnią z czułości i
specyficzności, dlatego jest szczególnie ważna przy niezbalansowanych
danych. F1 łączy precyzję oraz czułość i jest użyteczna
przy ocenie jakości klasyfikacji klasy pozytywnej.
Ponieważ w analizowanym zbiorze klasy nie są równoliczne, sama
wartość Accuracy może być myląca. Model może osiągać wysoką
trafność głównie dlatego, że dobrze rozpoznaje klasę większościową. Z
tego powodu w końcowej ocenie szczególną uwagę zwrócono na
BalancedAccuracy, Sensitivity oraz
F1.
################################################################################
# 13. Porównanie wszystkich modeli ----
################################################################################
# W jednej tabeli zbieramy metryki dla wszystkich modeli.
# Tworzymy osobne zestawienie dla zbioru treningowego i testowego.
# Dzięki temu możemy sprawdzić, który model najlepiej działa na danych uczących,
# a który najlepiej generalizuje na dane testowe.
porownanie.wszystkie.train <- rbind(
logit = metryki.glm.train,
decision.tree.default = metryki.tree.default.train,
decision.tree.tuned = metryki.tree.tuned.train,
decision.tree.downsample = metryki.tree.down.train,
bagging = metryki.salary.bag.train,
random.forest = metryki.salary.rf.train,
random.forest.tuned = metryki.salary.rf.tuned.train,
xgb.start = metryki.salary.xgb.start.train,
xgb.tuned = metryki.salary.xgb.tuned.train,
neural.net.3 = metryki.salary.nn1.train,
neural.net.5 = metryki.salary.nn2.train,
neural.net.3.2 = metryki.salary.nn3.train
)
porownanie.wszystkie.test <- rbind(
logit = metryki.glm.test,
decision.tree.default = metryki.tree.default.test,
decision.tree.tuned = metryki.tree.tuned.test,
decision.tree.downsample = metryki.tree.down.test,
bagging = metryki.salary.bag.test,
random.forest = metryki.salary.rf.test,
random.forest.tuned = metryki.salary.rf.tuned.test,
xgb.start = metryki.salary.xgb.start.test,
xgb.tuned = metryki.salary.xgb.tuned.test,
neural.net.3 = metryki.salary.nn1.test,
neural.net.5 = metryki.salary.nn2.test,
neural.net.3.2 = metryki.salary.nn3.test
)
porownanie.wszystkie.train
## Accuracy Sensitivity Specificity BalancedAccuracy
## logit 0.7936 0.8615 0.7711 0.8163
## decision.tree.default 0.8416 0.5099 0.9516 0.7307
## decision.tree.tuned 0.8616 0.6056 0.9465 0.7760
## decision.tree.downsample 0.8394 0.8824 0.7964 0.8394
## bagging 0.9517 0.9812 0.9420 0.9616
## random.forest 0.9299 0.9553 0.9215 0.9384
## random.forest.tuned 0.8500 0.9488 0.8173 0.8830
## xgb.start 0.7986 0.8727 0.7740 0.8234
## xgb.tuned 0.8286 0.8786 0.8120 0.8453
## neural.net.3 0.8490 0.8753 0.8402 0.8578
## neural.net.5 0.8892 0.8729 0.8946 0.8838
## neural.net.3.2 0.8498 0.8649 0.8448 0.8548
## PPV NPV F1 Kappa
## logit 0.5550 0.9438 0.6751 0.5340
## decision.tree.default 0.7773 0.8542 0.6158 0.5214
## decision.tree.tuned 0.7894 0.8786 0.6854 0.5987
## decision.tree.downsample 0.8125 0.8714 0.8460 0.6788
## bagging 0.8486 0.9934 0.9101 0.8773
## random.forest 0.8013 0.9842 0.8716 0.8239
## random.forest.tuned 0.6325 0.9797 0.7590 0.6563
## xgb.start 0.5614 0.9483 0.6833 0.5456
## xgb.tuned 0.6077 0.9528 0.7185 0.6011
## neural.net.3 0.6464 0.9528 0.7436 0.6400
## neural.net.5 0.7344 0.9547 0.7977 0.7222
## neural.net.3.2 0.6502 0.9493 0.7424 0.6393
porownanie.wszystkie.test
## Accuracy Sensitivity Specificity BalancedAccuracy
## logit 0.7940 0.8686 0.7692 0.8189
## decision.tree.default 0.8397 0.5115 0.9485 0.7300
## decision.tree.tuned 0.8457 0.5799 0.9338 0.7569
## decision.tree.downsample 0.7928 0.8535 0.7727 0.8131
## bagging 0.8015 0.8215 0.7949 0.8082
## random.forest 0.8196 0.8139 0.8215 0.8177
## random.forest.tuned 0.8112 0.8752 0.7900 0.8326
## xgb.start 0.7982 0.8601 0.7777 0.8189
## xgb.tuned 0.8239 0.8583 0.8125 0.8354
## neural.net.3 0.8031 0.7909 0.8071 0.7990
## neural.net.5 0.7993 0.6932 0.8345 0.7638
## neural.net.3.2 0.8031 0.7815 0.8102 0.7959
## PPV NPV F1 Kappa
## logit 0.5552 0.9464 0.6774 0.5366
## decision.tree.default 0.7670 0.8542 0.6137 0.5177
## decision.tree.tuned 0.7437 0.8703 0.6517 0.5545
## decision.tree.downsample 0.5544 0.9409 0.6721 0.5304
## bagging 0.5703 0.9307 0.6732 0.5373
## random.forest 0.6018 0.9302 0.6920 0.5685
## random.forest.tuned 0.5800 0.9503 0.6977 0.5685
## xgb.start 0.5618 0.9438 0.6796 0.5416
## xgb.tuned 0.6027 0.9454 0.7082 0.5876
## neural.net.3 0.5760 0.9209 0.6665 0.5316
## neural.net.5 0.5812 0.8914 0.6322 0.4957
## neural.net.3.2 0.5770 0.9180 0.6639 0.5290
Poniższa tabela przedstawia wyniki wszystkich modeli na zbiorze treningowym.
Na zbiorze treningowym najwyższe wartości większości metryk osiągnął model bagging. Uzyskał on bardzo wysoką wartość Accuracy = 0.9517, BalancedAccuracy = 0.9616, F1 = 0.9101 oraz Kappa = 0.8773. Oznacza to, że model bardzo dobrze dopasował się do danych uczących.
Jeżeli natomiast skupiamy się na predykcji klasy negatywnej, czyli osób niezarabiających powyżej 50 tys. USD, ważniejsza jest Specificity. Najwyższą specyficzność na zbiorze treningowym osiągnął model decision tree default (Specificity = 0.9516), bardzo zbliżony wynik uzyskał również decision tree tuned (Specificity = 0.9465) oraz bagging (Specificity = 0.9420). Oznacza to, że modele drzewiaste bardzo dobrze rozpoznają klasę większościową, czyli osoby z dochodem nieprzekraczającym 50 tys. USD.
Wysokie wartości metryk na tym zbiorze oznaczają dobre dopasowanie modelu do danych uczących. Należy jednak pamiętać, że bardzo dobre wyniki na treningu nie zawsze oznaczają dobrą jakość modelu, ponieważ model może być przeuczony.
knitr::kable(
porownanie.wszystkie.train,
digits = 4,
caption = "Porównanie wszystkich modeli na zbiorze treningowym"
)
| Accuracy | Sensitivity | Specificity | BalancedAccuracy | PPV | NPV | F1 | Kappa | |
|---|---|---|---|---|---|---|---|---|
| logit | 0.7936 | 0.8615 | 0.7711 | 0.8163 | 0.5550 | 0.9438 | 0.6751 | 0.5340 |
| decision.tree.default | 0.8416 | 0.5099 | 0.9516 | 0.7307 | 0.7773 | 0.8542 | 0.6158 | 0.5214 |
| decision.tree.tuned | 0.8616 | 0.6056 | 0.9465 | 0.7760 | 0.7894 | 0.8786 | 0.6854 | 0.5987 |
| decision.tree.downsample | 0.8394 | 0.8824 | 0.7964 | 0.8394 | 0.8125 | 0.8714 | 0.8460 | 0.6788 |
| bagging | 0.9517 | 0.9812 | 0.9420 | 0.9616 | 0.8486 | 0.9934 | 0.9101 | 0.8773 |
| random.forest | 0.9299 | 0.9553 | 0.9215 | 0.9384 | 0.8013 | 0.9842 | 0.8716 | 0.8239 |
| random.forest.tuned | 0.8500 | 0.9488 | 0.8173 | 0.8830 | 0.6325 | 0.9797 | 0.7590 | 0.6563 |
| xgb.start | 0.7986 | 0.8727 | 0.7740 | 0.8234 | 0.5614 | 0.9483 | 0.6833 | 0.5456 |
| xgb.tuned | 0.8286 | 0.8786 | 0.8120 | 0.8453 | 0.6077 | 0.9528 | 0.7185 | 0.6011 |
| neural.net.3 | 0.8490 | 0.8753 | 0.8402 | 0.8578 | 0.6464 | 0.9528 | 0.7436 | 0.6400 |
| neural.net.5 | 0.8892 | 0.8729 | 0.8946 | 0.8838 | 0.7344 | 0.9547 | 0.7977 | 0.7222 |
| neural.net.3.2 | 0.8498 | 0.8649 | 0.8448 | 0.8548 | 0.6502 | 0.9493 | 0.7424 | 0.6393 |
Na zbiorze testowym wyniki są bardziej wyrównane. Najwyższą wartość BalancedAccuracy osiągnął model xgb.tuned (BalancedAccuracy = 0.8354). Bardzo dobry wynik uzyskał również random.forest.tuned (BalancedAccuracy = 0.8326). Oznacza to, że po dostrojeniu modele bardziej zaawansowane, szczególnie XGBoost i Random Forest, lepiej generalizują na nowe dane niż modele bardzo mocno dopasowane do zbioru treningowego.
Warto zauważyć, że bagging, który był zdecydowanie najlepszy na zbiorze treningowym, na zbiorze testowym osiągnął BalancedAccuracy = 0.8082. Jest to wynik dobry, ale wyraźnie niższy niż na treningu. Różnica między wynikami treningowymi i testowymi sugeruje, że bagging mógł częściowo przeuczyć się na danych uczących.
Jeżeli koncentrujemy się na wykrywaniu klasy pozytywnej, czyli osób zarabiających powyżej 50 tys. USD, najważniejsza jest Sensitivity. Na zbiorze testowym najwyższą czułość osiągnął random.forest.tuned (Sensitivity = 0.8752). Bardzo blisko są również logit (Sensitivity = 0.8686), xgb.start (Sensitivity = 0.8601) oraz xgb.tuned (Sensitivity = 0.8583). Oznacza to, że jeśli celem analizy byłoby wykrycie jak największej liczby osób z wysokim dochodem, najlepszym wyborem byłby random.forest.tuned.
Jeżeli skupiamy się na predykcji klasy negatywnej, ważniejsza jest Specificity. Na zbiorze testowym najwyższą specyficzność osiągnął decision tree default (Specificity = 0.9485), a następnie decision tree tuned (Specificity = 0.9338). Oznacza to, że drzewa decyzyjne najlepiej rozpoznają osoby, które nie zarabiają powyżej 50 tys. USD. Trzeba jednak zauważyć, że robią to kosztem dużo niższej czułości, czyli słabiej wykrywają klasę pozytywną.
Jeżeli patrzymy na F1, czyli miarę łączącą precyzję i czułość dla klasy pozytywnej, najlepszy wynik na zbiorze testowym osiągnął xgb.tuned (F1 = 0.7082). Jest to ważna informacja, ponieważ oznacza, że model XGBoost po strojeniu dobrze równoważy wykrywanie klasy pozytywnej oraz jakość tych wskazań.
knitr::kable(
porownanie.wszystkie.test,
digits = 4,
caption = "Porównanie wszystkich modeli na zbiorze testowym"
)
| Accuracy | Sensitivity | Specificity | BalancedAccuracy | PPV | NPV | F1 | Kappa | |
|---|---|---|---|---|---|---|---|---|
| logit | 0.7940 | 0.8686 | 0.7692 | 0.8189 | 0.5552 | 0.9464 | 0.6774 | 0.5366 |
| decision.tree.default | 0.8397 | 0.5115 | 0.9485 | 0.7300 | 0.7670 | 0.8542 | 0.6137 | 0.5177 |
| decision.tree.tuned | 0.8457 | 0.5799 | 0.9338 | 0.7569 | 0.7437 | 0.8703 | 0.6517 | 0.5545 |
| decision.tree.downsample | 0.7928 | 0.8535 | 0.7727 | 0.8131 | 0.5544 | 0.9409 | 0.6721 | 0.5304 |
| bagging | 0.8015 | 0.8215 | 0.7949 | 0.8082 | 0.5703 | 0.9307 | 0.6732 | 0.5373 |
| random.forest | 0.8196 | 0.8139 | 0.8215 | 0.8177 | 0.6018 | 0.9302 | 0.6920 | 0.5685 |
| random.forest.tuned | 0.8112 | 0.8752 | 0.7900 | 0.8326 | 0.5800 | 0.9503 | 0.6977 | 0.5685 |
| xgb.start | 0.7982 | 0.8601 | 0.7777 | 0.8189 | 0.5618 | 0.9438 | 0.6796 | 0.5416 |
| xgb.tuned | 0.8239 | 0.8583 | 0.8125 | 0.8354 | 0.6027 | 0.9454 | 0.7082 | 0.5876 |
| neural.net.3 | 0.8031 | 0.7909 | 0.8071 | 0.7990 | 0.5760 | 0.9209 | 0.6665 | 0.5316 |
| neural.net.5 | 0.7993 | 0.6932 | 0.8345 | 0.7638 | 0.5812 | 0.8914 | 0.6322 | 0.4957 |
| neural.net.3.2 | 0.8031 | 0.7815 | 0.8102 | 0.7959 | 0.5770 | 0.9180 | 0.6639 | 0.5290 |
Wyniki na zbiorze testowym są podstawą końcowej oceny jakości modeli. Model, który osiąga dobre wyniki na teście, ma większą zdolność generalizacji, czyli lepiej radzi sobie z nowymi danymi.
W pierwszej kolejności modele uporządkowano według
BalancedAccuracy na zbiorze testowym. Ta miara jest
szczególnie ważna w przypadku niezbalansowanych klas, ponieważ
uwzględnia zarówno poprawne rozpoznawanie klasy pozytywnej, jak i
negatywnej.
ranking.test.balanced <- porownanie.wszystkie.test %>%
as.data.frame() %>%
rownames_to_column("model") %>%
arrange(desc(BalancedAccuracy))
knitr::kable(
ranking.test.balanced,
digits = 4,
caption = "Ranking modeli według Balanced Accuracy na zbiorze testowym"
)
| model | Accuracy | Sensitivity | Specificity | BalancedAccuracy | PPV | NPV | F1 | Kappa |
|---|---|---|---|---|---|---|---|---|
| xgb.tuned | 0.8239 | 0.8583 | 0.8125 | 0.8354 | 0.6027 | 0.9454 | 0.7082 | 0.5876 |
| random.forest.tuned | 0.8112 | 0.8752 | 0.7900 | 0.8326 | 0.5800 | 0.9503 | 0.6977 | 0.5685 |
| logit | 0.7940 | 0.8686 | 0.7692 | 0.8189 | 0.5552 | 0.9464 | 0.6774 | 0.5366 |
| xgb.start | 0.7982 | 0.8601 | 0.7777 | 0.8189 | 0.5618 | 0.9438 | 0.6796 | 0.5416 |
| random.forest | 0.8196 | 0.8139 | 0.8215 | 0.8177 | 0.6018 | 0.9302 | 0.6920 | 0.5685 |
| decision.tree.downsample | 0.7928 | 0.8535 | 0.7727 | 0.8131 | 0.5544 | 0.9409 | 0.6721 | 0.5304 |
| bagging | 0.8015 | 0.8215 | 0.7949 | 0.8082 | 0.5703 | 0.9307 | 0.6732 | 0.5373 |
| neural.net.3 | 0.8031 | 0.7909 | 0.8071 | 0.7990 | 0.5760 | 0.9209 | 0.6665 | 0.5316 |
| neural.net.3.2 | 0.8031 | 0.7815 | 0.8102 | 0.7959 | 0.5770 | 0.9180 | 0.6639 | 0.5290 |
| neural.net.5 | 0.7993 | 0.6932 | 0.8345 | 0.7638 | 0.5812 | 0.8914 | 0.6322 | 0.4957 |
| decision.tree.tuned | 0.8457 | 0.5799 | 0.9338 | 0.7569 | 0.7437 | 0.8703 | 0.6517 | 0.5545 |
| decision.tree.default | 0.8397 | 0.5115 | 0.9485 | 0.7300 | 0.7670 | 0.8542 | 0.6137 | 0.5177 |
Model znajdujący się najwyżej w tym rankingu osiąga najlepszy kompromis pomiędzy czułością i specyficznością. Oznacza to, że dobrze radzi sobie zarówno z wykrywaniem osób zarabiających powyżej 50 tys. USD, jak i osób z niższym dochodem.
Drugim rankingiem jest zestawienie modeli według miary
F1. Jest ona szczególnie użyteczna wtedy, gdy zależy nam na
jakości klasyfikacji klasy pozytywnej.
ranking.test.f1 <- porownanie.wszystkie.test %>%
as.data.frame() %>%
rownames_to_column("model") %>%
arrange(desc(F1))
knitr::kable(
ranking.test.f1,
digits = 4,
caption = "Ranking modeli według F1 na zbiorze testowym"
)
| model | Accuracy | Sensitivity | Specificity | BalancedAccuracy | PPV | NPV | F1 | Kappa |
|---|---|---|---|---|---|---|---|---|
| xgb.tuned | 0.8239 | 0.8583 | 0.8125 | 0.8354 | 0.6027 | 0.9454 | 0.7082 | 0.5876 |
| random.forest.tuned | 0.8112 | 0.8752 | 0.7900 | 0.8326 | 0.5800 | 0.9503 | 0.6977 | 0.5685 |
| random.forest | 0.8196 | 0.8139 | 0.8215 | 0.8177 | 0.6018 | 0.9302 | 0.6920 | 0.5685 |
| xgb.start | 0.7982 | 0.8601 | 0.7777 | 0.8189 | 0.5618 | 0.9438 | 0.6796 | 0.5416 |
| logit | 0.7940 | 0.8686 | 0.7692 | 0.8189 | 0.5552 | 0.9464 | 0.6774 | 0.5366 |
| bagging | 0.8015 | 0.8215 | 0.7949 | 0.8082 | 0.5703 | 0.9307 | 0.6732 | 0.5373 |
| decision.tree.downsample | 0.7928 | 0.8535 | 0.7727 | 0.8131 | 0.5544 | 0.9409 | 0.6721 | 0.5304 |
| neural.net.3 | 0.8031 | 0.7909 | 0.8071 | 0.7990 | 0.5760 | 0.9209 | 0.6665 | 0.5316 |
| neural.net.3.2 | 0.8031 | 0.7815 | 0.8102 | 0.7959 | 0.5770 | 0.9180 | 0.6639 | 0.5290 |
| decision.tree.tuned | 0.8457 | 0.5799 | 0.9338 | 0.7569 | 0.7437 | 0.8703 | 0.6517 | 0.5545 |
| neural.net.5 | 0.7993 | 0.6932 | 0.8345 | 0.7638 | 0.5812 | 0.8914 | 0.6322 | 0.4957 |
| decision.tree.default | 0.8397 | 0.5115 | 0.9485 | 0.7300 | 0.7670 | 0.8542 | 0.6137 | 0.5177 |
Wysoka wartość F1 oznacza, że model osiąga dobry
kompromis między wykrywaniem klasy pozytywnej a precyzją tych wskazań. W
tym projekcie jest to istotne, ponieważ klasa osób zarabiających powyżej
50 tys. USD jest mniej liczna.
Aby ocenić, czy modele nie są przeuczone, przygotowano również porównanie wyników na zbiorze treningowym i testowym. Duża różnica między wynikiem treningowym a testowym może sugerować overfitting, czyli zbyt silne dopasowanie modelu do danych uczących.
porownanie.train.test <- porownanie.wszystkie.test %>%
as.data.frame() %>%
rownames_to_column("model") %>%
rename_with(~ paste0(.x, "_test"), -model) %>%
left_join(
porownanie.wszystkie.train %>%
as.data.frame() %>%
rownames_to_column("model") %>%
rename_with(~ paste0(.x, "_train"), -model),
by = "model"
) %>%
mutate(
roznica_Accuracy = Accuracy_train - Accuracy_test,
roznica_BalancedAccuracy = BalancedAccuracy_train - BalancedAccuracy_test,
roznica_F1 = F1_train - F1_test
) %>%
arrange(desc(BalancedAccuracy_test))
knitr::kable(
porownanie.train.test,
digits = 4,
caption = "Porównanie wyników treningowych i testowych"
)
| model | Accuracy_test | Sensitivity_test | Specificity_test | BalancedAccuracy_test | PPV_test | NPV_test | F1_test | Kappa_test | Accuracy_train | Sensitivity_train | Specificity_train | BalancedAccuracy_train | PPV_train | NPV_train | F1_train | Kappa_train | roznica_Accuracy | roznica_BalancedAccuracy | roznica_F1 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| xgb.tuned | 0.8239 | 0.8583 | 0.8125 | 0.8354 | 0.6027 | 0.9454 | 0.7082 | 0.5876 | 0.8286 | 0.8786 | 0.8120 | 0.8453 | 0.6077 | 0.9528 | 0.7185 | 0.6011 | 0.0047 | 0.0099 | 0.0103 |
| random.forest.tuned | 0.8112 | 0.8752 | 0.7900 | 0.8326 | 0.5800 | 0.9503 | 0.6977 | 0.5685 | 0.8500 | 0.9488 | 0.8173 | 0.8830 | 0.6325 | 0.9797 | 0.7590 | 0.6563 | 0.0388 | 0.0504 | 0.0613 |
| logit | 0.7940 | 0.8686 | 0.7692 | 0.8189 | 0.5552 | 0.9464 | 0.6774 | 0.5366 | 0.7936 | 0.8615 | 0.7711 | 0.8163 | 0.5550 | 0.9438 | 0.6751 | 0.5340 | -0.0004 | -0.0026 | -0.0023 |
| xgb.start | 0.7982 | 0.8601 | 0.7777 | 0.8189 | 0.5618 | 0.9438 | 0.6796 | 0.5416 | 0.7986 | 0.8727 | 0.7740 | 0.8234 | 0.5614 | 0.9483 | 0.6833 | 0.5456 | 0.0004 | 0.0045 | 0.0037 |
| random.forest | 0.8196 | 0.8139 | 0.8215 | 0.8177 | 0.6018 | 0.9302 | 0.6920 | 0.5685 | 0.9299 | 0.9553 | 0.9215 | 0.9384 | 0.8013 | 0.9842 | 0.8716 | 0.8239 | 0.1103 | 0.1207 | 0.1796 |
| decision.tree.downsample | 0.7928 | 0.8535 | 0.7727 | 0.8131 | 0.5544 | 0.9409 | 0.6721 | 0.5304 | 0.8394 | 0.8824 | 0.7964 | 0.8394 | 0.8125 | 0.8714 | 0.8460 | 0.6788 | 0.0466 | 0.0263 | 0.1739 |
| bagging | 0.8015 | 0.8215 | 0.7949 | 0.8082 | 0.5703 | 0.9307 | 0.6732 | 0.5373 | 0.9517 | 0.9812 | 0.9420 | 0.9616 | 0.8486 | 0.9934 | 0.9101 | 0.8773 | 0.1502 | 0.1534 | 0.2369 |
| neural.net.3 | 0.8031 | 0.7909 | 0.8071 | 0.7990 | 0.5760 | 0.9209 | 0.6665 | 0.5316 | 0.8490 | 0.8753 | 0.8402 | 0.8578 | 0.6464 | 0.9528 | 0.7436 | 0.6400 | 0.0459 | 0.0588 | 0.0771 |
| neural.net.3.2 | 0.8031 | 0.7815 | 0.8102 | 0.7959 | 0.5770 | 0.9180 | 0.6639 | 0.5290 | 0.8498 | 0.8649 | 0.8448 | 0.8548 | 0.6502 | 0.9493 | 0.7424 | 0.6393 | 0.0467 | 0.0589 | 0.0785 |
| neural.net.5 | 0.7993 | 0.6932 | 0.8345 | 0.7638 | 0.5812 | 0.8914 | 0.6322 | 0.4957 | 0.8892 | 0.8729 | 0.8946 | 0.8838 | 0.7344 | 0.9547 | 0.7977 | 0.7222 | 0.0899 | 0.1200 | 0.1655 |
| decision.tree.tuned | 0.8457 | 0.5799 | 0.9338 | 0.7569 | 0.7437 | 0.8703 | 0.6517 | 0.5545 | 0.8616 | 0.6056 | 0.9465 | 0.7760 | 0.7894 | 0.8786 | 0.6854 | 0.5987 | 0.0159 | 0.0191 | 0.0337 |
| decision.tree.default | 0.8397 | 0.5115 | 0.9485 | 0.7300 | 0.7670 | 0.8542 | 0.6137 | 0.5177 | 0.8416 | 0.5099 | 0.9516 | 0.7307 | 0.7773 | 0.8542 | 0.6158 | 0.5214 | 0.0019 | 0.0007 | 0.0021 |
Im większa dodatnia różnica między wynikiem treningowym i testowym,
tym większe podejrzenie przeuczenia. Szczególnie warto obserwować
różnice dla BalancedAccuracy i F1, ponieważ są
one bardziej informacyjne niż sama ogólna trafność.
Poniżej automatycznie wskazano najlepszy model według
BalancedAccuracy oraz najlepszy model według
F1 na zbiorze testowym.
najlepszy.model.balanced <- ranking.test.balanced[1, ]
najlepszy.model.f1 <- ranking.test.f1[1, ]
najlepszy.model.balanced
## model Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV
## 1 xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454
## F1 Kappa
## 1 0.7082 0.5876
najlepszy.model.f1
## model Accuracy Sensitivity Specificity BalancedAccuracy PPV NPV
## 1 xgb.tuned 0.8239 0.8583 0.8125 0.8354 0.6027 0.9454
## F1 Kappa
## 1 0.7082 0.5876
cat(
"Najlepszy model według Balanced Accuracy na zbiorze testowym to:",
najlepszy.model.balanced$model,
"z wynikiem",
najlepszy.model.balanced$BalancedAccuracy,
"\n"
)
## Najlepszy model według Balanced Accuracy na zbiorze testowym to: xgb.tuned z wynikiem 0.8354
cat(
"Najlepszy model według F1 na zbiorze testowym to:",
najlepszy.model.f1$model,
"z wynikiem",
najlepszy.model.f1$F1,
"\n"
)
## Najlepszy model według F1 na zbiorze testowym to: xgb.tuned z wynikiem 0.7082
Na podstawie zestawienia wyników można stwierdzić, że wybór
najlepszego modelu zależy od celu analizy. Jeżeli najważniejsza byłaby
ogólna trafność klasyfikacji, można analizować Accuracy,
jednak przy niezbalansowanych danych nie powinna być ona jedynym
kryterium wyboru modelu.
W tym projekcie większe znaczenie ma BalancedAccuracy,
ponieważ uwzględnia zarówno zdolność wykrywania klasy pozytywnej, jak i
negatywnej. Model o wysokiej wartości tej miary dobrze radzi sobie z
obiema klasami, a nie tylko z klasą większościową.
Warto również analizować Sensitivity, ponieważ pokazuje,
jaka część osób rzeczywiście zarabiających powyżej 50 tys. USD została
poprawnie wykryta przez model. Jeżeli celem byłoby jak najlepsze
identyfikowanie osób z wysokim dochodem, wysoka czułość byłaby
szczególnie pożądana.
Z kolei Specificity pokazuje, jak dobrze model
rozpoznaje osoby, które nie zarabiają powyżej 50 tys. USD. Wysoka
specyficzność oznacza, że model rzadziej błędnie przypisuje osoby do
klasy wysokiego dochodu.
Miara F1 jest dobrym uzupełnieniem analizy, ponieważ
łączy precyzję i czułość. Jest przydatna wtedy, gdy chcemy ocenić jakość
klasyfikacji klasy pozytywnej przy nierównych proporcjach klas.
Porównanie wyników treningowych i testowych pozwala dodatkowo ocenić stabilność modeli. Modele, które osiągają bardzo wysokie wyniki na zbiorze treningowym, ale wyraźnie słabsze na zbiorze testowym, mogą być przeuczone. Chociażby tak jak w przypadku pojedynczych drzew. Najbardziej pożądany jest model, który osiąga wysokie wyniki na zbiorze testowym i jednocześnie nie wykazuje dużej różnicy między treningiem a testem.
Ostatecznie za najlepszy model należy uznać ten, który osiąga wysoką
jakość na zbiorze testowym, dobrą wartość BalancedAccuracy,
odpowiednio wysoką czułość oraz akceptowalną różnicę między wynikami
treningowymi i testowymi: W opini autorek projektów byłby to model XG
BOOST.