Uvod

Koristit ćemo tidyverse biblioteku koja sadrži R pakete za obradu i prikazivanje podataka. Detaljnije informacije o svim paketima možete pronaći na tidyverse.

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.1     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter()     masks stats::filter()
## ✖ dplyr::group_rows() masks kableExtra::group_rows()
## ✖ dplyr::lag()        masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors

Za učenje pod nadzorom koristit ćemo tidymodels biblioteku koja sadrži R pakete za statističku analizu i modeliranje. Detaljnije informacije o svim paketima možete pronaći na tidymodels.

library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom        1.0.6     ✔ rsample      1.2.1
## ✔ dials        1.3.0     ✔ tune         1.2.1
## ✔ infer        1.0.7     ✔ workflows    1.1.4
## ✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
## ✔ parsnip      1.2.1     ✔ yardstick    1.3.1
## ✔ recipes      1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard()   masks purrr::discard()
## ✖ dplyr::filter()     masks stats::filter()
## ✖ recipes::fixed()    masks stringr::fixed()
## ✖ dplyr::group_rows() masks kableExtra::group_rows()
## ✖ dplyr::lag()        masks stats::lag()
## ✖ yardstick::spec()   masks readr::spec()
## ✖ recipes::step()     masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/

Dodatne biblioteke koje ćemo trebati.

library(rpart.plot)
library(vip)
library(baguette)
library(ranger)
library(Boruta)
library(xgboost)
library(DiagrammeR)
update_geom_defaults(geom = "point", new = list(color = "#F8766D", alpha = 0.5))
update_geom_defaults(geom = "tile", new = list(color = "black"))

Ovdje ćemo koristiti modele koji su bazirani na stablima. Modeli bazirani na stablima nemaju parametara, ali imaju hiperparametre. Hiperparametre algoritam za učenje ne uči, nego se njihove vrijednosti moraju postaviti prije samog procesa učenja. Vrijednosti hiperparametara mogu bitno utjecati na efikasnost modela pa ih treba pažljivo odabrati. Možemo koristiti defaultne vrijednosti koje su postavljene u određenoj implementaciji, možemo sami odabrati njihove vrijednosti ili pak možemo npr. preko cross-validation testirati ponašanje algoritma za učenje na dovoljnom broju kombinacija vrijednosti određenih hiperparametara (tuning hyperparameters) i pritom mjeriti za svaku takvu kombinaciju efikasnost promatranog algoritma za učenje. Na kraju odabiremo onu kombinaciju hiperparametara za koju algoritam za učenje ima najbolju efikasnost s obzirom na zadanu (željenu) metriku.

Za kreiranje mreže vrijednosti hiperparametara tidymodels nudi gotove funkcije grid_regular, grid_random, grid_max_entropy i grid_latin_hypercube. Detalje o tim funkcijama možete vidjeti na dials reference. Ovdje ćemo koristiti funkciju grid_latin_hypercube koja radi uzorkovanje na latinskoj hiperkocki.

tidymodels biblioteka također nudi implementirane metrike za mjerenje efikasnosti algoritama za učenje u slučaju klasifikacije i u slučaju regresije. Popis dostupnih metrika možete vidjeti na yardstick metrike. Metrike koje ćemo ovdje koristiti dane su u donjoj tablici.

metrika tip info
mae regresija mean absolute error
mape regresija mean absolute percentage error
mpe regresija mean percentage error
rmse regresija root mean square error
rsq regresija R squared
sens klasifikacija sensitivity
spec klasifikacija specificity
precision klasifikacija precision
accuracy klasifikacija accuracy
f_meas klasifikacija F-score
roc_auc klasifikacija ROC AUC1, ROC AUC2, Hand-Till
kap klasifikacija Cohen’s kappa

U tidymodels biblioteki možemo definirati koje metrike želimo koristiti pomoću metric_set naredbe.

metrike_klas <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
metrike_reg <- metric_set(mae, mape, mpe, rmse, rsq)

Za određivanje optimalne kombinacije hiperparametara u slučaju regresije koristit ćemo rmse metriku koja daje prosječnu udaljenost između predviđenih i stvarnih vrijednosti.

Sve klasifikacijske metrike su implementirane za binarni slučaj i slučaj s više od dvije klase. Detalje o tome možete pogledati na multiclass averaging. Za određivanje optimalne kombinacije hiperparametara u slučaju klasifikacije koristit ćemo roc_auc metriku i Hand-Till metodu.

U binarnom slučaju ROC krivulja pokazuje odnos između dviju mjera, sensitivity i 1-specificity. Sensitivity mjeri u kojem postotku je algoritam za učenje točan na pozitivnoj klasi, a specificity mjeri u kojem postotku je algoritam za učenje točan na negativnoj klasi. Algoritam za učenje daje vjerojatnosti pripadanja određenoj klasi. Kako bi se odredilo kojoj klasi promatrani podatak pripada, mora se zadati određena granica (threshold) za pozitivnu klasu. Standardno se koristi granica od \(0.5\): ako je dobivena vjerojatnost za pozitivnu klasu veća ili jednaka od \(0.5\), algoritam će promatrani podatak klasificirati u pozitivnu klasu, a u protivnom u negativnu klasu. Međutim, može se zadati proizvoljna granica \(p\in[0,1]\) za pozitivnu klasu. U tom slučaju, ako je dobivena vjerojatnost za pozitivnu klasu veća ili jednaka od \(p\), algoritam će promatrani podatak klasificirati u pozitivnu klasu, a u protivnom u negativnu klasu. Upravo to je poanta ROC krivulje: za različite vrijednosti granice \(p\in[0,1]\) prikazati koliko je algoritam dobar na pozitivnoj klasi u odnosu na to koliko je loš na negativnoj klasi. Površina ispod ROC krivulje je zapravo roc_auc metrika.

U slučaju više klasa, roc_auc metriku možemo odrediti preko macro averaging. Zapravo se koristi one-vs-all pristup gdje se svaka klasa gleda u odnosu na ostale pa se na taj način svodi na više binarnih slučaja. Međutim, takva metrika je osjetljiva na raspodjelu klasa. Zbog toga je implementirana hand_till metoda koja čuva neosjetljivost na raspodjelu klasa, ali nema jednostavnu vizualnu interpretaciju. Za detalje pogledajte članak Hand-Till, a ovdje ćemo samo ukratko intuitivno objasniti glavnu ideju.

  • Neka je \(\hat{A}(i\,|\,j)\) vjerojatnost da slučajno odabrani član iz klase \(j\) ima manju procijenjenu vjerojatnost pripadnosti klasi \(i\) od slučajno odabranog člana iz klase \(i\).
  • U binarnom slučaju vrijedi \(\hat{A}(1\,|\,0)=\hat{A}(0\,|\,1)\). Međutim, u slučaju kada imamo više od dvije klase općenito je \(\hat{A}(i\,|\,j)=\hat{A}(j\,|\,i)\). Stoga se definira \[\hat{A}(i,j)=\frac{\hat{A}(i\,|\,j)+\hat{A}(j\,|\,i)}{2}.\]
  • Konačno, roc_auc s hand_till metodom računa se po formuli \[M=\frac{2}{c(c-1)}\sum_{i<j}{\hat{A}(i,j)}\] pri čemu je \(c\) ukupni broj klasa. Drugim riječima, \(M\) je zapravo aritmetička sredina svih \(\hat{A}(i,j)\). Uočite da je \[\frac{2}{c(c-1)}=\frac{1}{\binom{c}{2}},\] a \(\binom{c}{2}\) jednak je ukupnom broju svih neuređenih parova postojećih \(c\) klasa.

Iz navedenih razmatranja možemo zaključiti da je roc_auc dobiven hand_till metodom zapravo jednak prosječnoj vjerojatnosti da slučajno odabrani član iz neke klase ima manju procijenjenu vjerojatnost da pripada nekoj drugoj klasi od slučajno odabranog člana iz te druge klase.

Podaci o pingvinima

Koristit ćemo penguins podatke o pingvinima iz modeldata biblioteke.

data("penguins")
podaci <- penguins %>% na.omit()
podaci %>% glimpse()
## Rows: 333
## Columns: 7
## $ species           <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel…
## $ island            <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgerse…
## $ bill_length_mm    <dbl> 39.1, 39.5, 40.3, 36.7, 39.3, 38.9, 39.2, 41.1, 38.6…
## $ bill_depth_mm     <dbl> 18.7, 17.4, 18.0, 19.3, 20.6, 17.8, 19.6, 17.6, 21.2…
## $ flipper_length_mm <int> 181, 186, 195, 193, 190, 181, 195, 182, 191, 198, 18…
## $ body_mass_g       <int> 3750, 3800, 3250, 3450, 3650, 3625, 4675, 3200, 3800…
## $ sex               <fct> male, female, female, female, male, female, male, fe…

Podaci sadrže tri faktorske varijable i četiri numeričke varijable.

podaci %>% pivot_longer(c(species, island, sex), names_to = "varijabla", values_to = "vrijednost") %>%
  ggplot(aes(x = vrijednost)) + geom_bar(fill = "skyblue3") + 
  facet_wrap(vars(varijabla), scales = "free_x") + xlab("") + ylab("")
Stupčasti dijagrami faktorskih varijabli

Stupčasti dijagrami faktorskih varijabli

podaci %>% pivot_longer(where(is.numeric), names_to = "varijabla", values_to = "vrijednost") %>%
  ggplot(aes(x = vrijednost)) + geom_histogram(fill = "skyblue", color = "violet") + 
  facet_wrap(vars(varijabla), scales = "free_x") + xlab("") + ylab("")
Histogrami numeričkih varijabli

Histogrami numeričkih varijabli

podaci %>% pivot_longer(where(is.numeric), names_to = "varijabla", values_to = "vrijednost") %>%
  ggplot(aes(x = varijabla, y = vrijednost)) + geom_boxplot(color = "skyblue3") + 
  stat_summary(fun = mean, geom = "point", shape = 10, size=3, color="sienna2") +
  facet_wrap(vars(varijabla), scales = "free") + xlab("") + ylab("") +
  theme(axis.text.x = element_blank(), 
        axis.ticks.x = element_blank()) 
Brkate kutije numeričkih varijabli zajedno s aritmetičkom sredinom

Brkate kutije numeričkih varijabli zajedno s aritmetičkom sredinom

Stablo odlučivanja

Regresija. Neka su \(X_1,X_2,\dotsc,X_p\) prediktorske varijable, a \(Y\) numerička response varijabla.

  • Želimo pronaći particiju prostora prediktora \(X_1\times X_2\times\dotsb\times X_p\) koja ima \(J\) elemenata \(R_1,R_2,\dotsc,R_J\).
  • Svaki podatak koji se nalazi u skupu \(R_j\) ima istu predikciju koja je jednaka aritmetičkoj sredini vrijednosti response varijable \(Y\) za trening podatke koji se nalaze u \(R_j\).
  • Radi jednostavnosti i lakše interpretacije \(R_j\) je višedimenzionalna kutija čije su hiperstrane paralelne s koordinatnim hiperravninama. U slučaju dva prediktora \(R_j\) je pravokutnik čije su stranice paralelne s koordinatnim osima. U slučaju tri prediktora \(R_j\) je kvadar čije su strane paralelne s koordinatnim ravninama.
  • Cilj je pronaći kutije \(R_1,R_2,\dotsc,R_J\) koje minimiziraju sumu kvadrata reziduala \[RSS=\sum_{j=1}^J{\sum_{x_i\in R_j}{\big(y_i-\hat{y}_{R_j}\big)^2}}\] pri čemu je \(\hat{y}_{R_j}\) aritmetička sredina za trening podatke koji pripadaju kutiji \(R_j\).
  • Računski je prezahtjevno ispitati sve moguće particije prostora prediktora u \(J\) kutija. Zbog toga se koristi top-down pohlepni pristup preko rekurzivne binarne podjele. Počinje se u korijenu stabla i zatim se sukcesivno dijeli prostor prediktora. Svaka podjela je označena kroz dvije nove grane u stablu.
  • Najprije biramo prediktor \(X_j\) i točku rezanja \(s\) tako da podjela prostora prediktora na dva područja \[\{(X_1,X_2,\dotsc,X_p)\mid X_j<s\}\quad\text{i}\quad \{(X_1,X_2,\dotsc,X_p)\mid X_j\geqslant s\}\] daje najveće moguće smanjenje u sumi kvadrata reziduala \(RSS\).
  • Nakon toga ponavljamo postupak. Tražimo najbolji prediktor i najbolju točku rezanja za iduću podjelu prostora prediktora kako bi se minimizirala suma kvadrata reziduala unutar svake dobivene kutije. Međutim, ovaj put ne dijelimo čitav prostor prediktora, nego dijelimo jednu od prethodno dvije dobivene kutije tako da nakon toga imamo tri kutije.
  • Nakon toga nastavljamo dalje dijeliti neku od postojeće tri kutije. Postupak se nastavlja sve dok se ne zadovolji neki kriterij zaustavljanja. Na primjer, možemo zaustaviti dijeljenje postojeće kutije ukoliko je u njoj broj podataka manji od neke zadane vrijednosti.

Klasifikacija. Neka su \(X_1,X_2,\dotsc,X_p\) prediktorske varijable, a \(Y\) kvalitativna response varijabla. Klasifikacijsko stablo odlučivanja funkcionira na sličan način kao i regresijsko stablo odlučivanja.

  • Svaki podatak koji se nalazi u kutiji \(R_j\) ima istu predikciju koja je jednaka najčešćoj klasi koja se javlja na trening skupu podataka koji se nalaze u \(R_j\).
  • Kao i u slučaju regresije koristimo rekurzivnu binarnu podjelu za gradnju klasifikacijskog stabla. Jedino ne možemo koristiti \(RSS\) za rekurzivne binarne podjele. Prirodna alternativa je stopa pogreške klasifikacije koja je jednaka postotku trening podataka u promatranoj kutiji koji ne pripadaju najčešćoj klasi u toj kutiji, tj. \[E=1-\max_k{\hat{p}_{mk}}\] pri čemu je \(\hat{p}_{mk}\) postotak trening podataka u kutiji \(R_m\) koji pripadaju \(k\)-toj klasi.
  • Međutim, postoje i dvije bolje mjere koje se koriste u primjenama. Prva od tih mjera je Gini indeks koji je definiran s \[G=\sum_{k=1}^K{\hat{p}_{mk}\big(1-\hat{p}_{mk}\big)}\] pri čemu je \(K\) ukupni broj klasa. Gini indeks poprima male vrijednosti ako su svi \(\hat{p}_{mk}\) blizu \(0\) ili \(1\). Zbog toga se za Gini indeks kaže da je mjera čistoće vrha pri čemu male vrijednosti ukazuju da se u promatranom vrhu nalaze podaci koji uglavnom pripadaju istoj klasi.
  • Alternativa Gini indeksu je entropija koja se definira s \[D=-\sum_{k=1}^K{\hat{p}_{mk}\log{\hat{p}_{mk}}}.\] Pokazuje se da su Gini indeks i entropija numerički vrlo slični. Gini indeks je ipak jednostavnije izračunati jer kod entropije treba koristiti logaritamsku funkciju.

Obrezivanje stabla (tree pruning). Postupak izgradnje stabla pomoću rekurzivne binarne podjele može dovesti do prenaučenosti (overfitting), tj. stablo odlučivanja može davati dobre predikcije na skupu na kojem je trenirano, ali loše na novom testnom skupu. S druge strane, stablo odlučivanja s manjim brojem kutija (listova) može smanjiti prenaučenost i omogućiti jednostavniju interpretaciju uz malu žrtvu da se lošije ponaša na podacima na kojima je trenirano.

  • Za dobivanje optimalnog regresijskog stabla možemo graditi stablo tako dugo dok se u svakom vrhu suma kvadrata reziduala više bitno ne smanjuje s obzirom na neku unaprijed zadanu vrijednost.
  • Za dobivanje optimalnog klasifikacijskog stabla možemo graditi stablo tako dugo dok se u svakom vrhu Gini indeks više bitno ne smanjuje s obzirom na neku unaprijed zadanu vrijednost.

Opisane gornje strategije daju manja stabla, ali uz jedan veliki nedostatak. Preranim prekidanjem grananja u nekom vrhu možemo biti “oštećeni” za neko dobro grananje koje bi slijedilo kasnije da nismo ranije prekinuli. Drugim riječima, ukoliko u nekom vrhu nema značajnog smanjenja sume kvadrata reziduala ili Gini indeksa, to ne znači da se u kasnijim grananjima ne mogu javiti značajnija smanjenja.

Bolja strategija je prvo napraviti vrlo veliko stablo \(T_0\), a nakon toga ga obrezati kako bismo dobili optimalno podstablo. Jedna takva poznata strategija je cost complexity pruning. Promatramo niz stabala indeksiranih s nenegativnim parametrom podešavanja \(\alpha\).

  • U slučaju regresije, za svaku vrijednost \(\alpha\) postoji podstablo \(T\subset T_0\) takvo da je \[\sum_{m=1}^{|T|}{\sum_{x_i\in R_m}{\big(y_i-\hat{y}_{R_m}\big)^2}}+\alpha|T|\] minimalna.
  • U slučaju klasifikacije, za svaku vrijednost \(\alpha\) postoji podstablo \(T\subset T_0\) takvo da je \[\sum_{m=1}^{|T|}{N_m\sum_{k=1}^K{\hat{p}_{mk}\big(1-\hat{p}_{mk}\big)}}+\alpha|T|\] minimalna, pri čemu je \(N_m\) broj podataka u kutiji \(R_m\).

Pritom je \(|T|\) oznaka za ukupni broj listova u stablu \(T\). Pomoću parametra \(\alpha\) kontroliramo koliko će biti kompleksno stablo odlučivanja, odnosno koliko ćemo “kažnjavati” stablo s obzirom na broj listova koje ima. Ako je \(\alpha=0\), onda uopće nema kazne i u tom slučaju zapravo dobivamo najkompleksnije stablo. Ako je \(\alpha=\infty\), tada je kazna prevelika jer dobivamo prazno stablo, tj. stablo koje ima samo jedan vrh (korijen). Pomoću cross-validation možemo pronaći optimalnu vrijednost za parametar \(\alpha\). Za više detalja možete pogledati npr. linkove prune1 i prune2.

U tidymodels biblioteki postoje tri hiperparametra kod stabla odlučivanja:

  • tree_depth - prirodni broj za najveću dozvoljenu dubinu stabla
  • min_n - minimalni broj podataka potrebnih u vrhu kako bi se on dalje dijelio
  • cost_complexity - ranije opisani parametar \(\alpha\) za obrezivanje stabla

Svaki od tih hipeparametara ima već postavljene defaultne intervale unutar kojih se ti parametri biraju.

tree_depth()
## Tree Depth (quantitative)
## Range: [1, 15]
tree_depth() %>% range_get()
## $lower
## [1] 1
## 
## $upper
## [1] 15
min_n()
## Minimal Node Size (quantitative)
## Range: [2, 40]
min_n() %>% range_get()
## $lower
## [1] 2
## 
## $upper
## [1] 40
cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]
cost_complexity() %>% range_get()
## $lower
## [1] 1e-10
## 
## $upper
## [1] 0.1

Možemo po želji postaviti i druge intervale za pojedini hiperparametar.

cost_complexity(range(-20,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-20, 1]
cost_complexity(range(-20,1)) %>% range_get()
## $lower
## [1] 1e-20
## 
## $upper
## [1] 10

Primjer klasifikacijskog stabla

Pogledajmo sada klasifikacijsko stablo na penguins podacima koje klasificira vrstu pingvina pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidymodels)
library(modeldata)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = species)

# definiranje modela
stablo_model <- decision_tree() %>%
  set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
  set_engine("rpart") %>%
  set_mode("classification")

# recept za transformiranje podataka
stablo_recipe <- recipe(species ~ ., data = training(podaci_split))

# workflow
stablo_work <- workflow() %>% 
  add_model(stablo_model) %>% 
  add_recipe(stablo_recipe)

# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)

# metrike
stablo_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)

stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 5000)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>% 
  tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)

saveRDS(stablo_tuning, file = "modeli/stablo_tuning_klas.rds")

best_stablo_model <- stablo_tuning %>% select_best(metric = 'roc_auc')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)

trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/stablo_work_klas.rds")
saveRDS(stablo_fit, "modeli/stablo_fit_klas.rds")

Učitavanje spremljenih podataka iz rds datoteka.

stablo_fit_klas <- readRDS("modeli/stablo_fit_klas.rds")
stablo_tuning_klas <- readRDS("modeli/stablo_tuning_klas.rds")
stablo_work_klas <- readRDS("modeli/stablo_work_klas.rds")
stablo_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za roc_auc metriku

stablo_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)

Metrike na testnom i trening skupu

tab1 <- stablo_fit_klas %>% collect_predictions() %>% 
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- stablo_work_klas %>%
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
stablo_work_klas %>%  
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu

Confusion matrix na trening skupu

stablo_fit_klas %>% collect_predictions() %>%
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu

Confusion matrix na testnom skupu

Pogledajmo ROC i gain krivulje

stablo_fit_klas %>% collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje

ROC krivulje

stablo_fit_klas %>% collect_predictions() %>% 
  gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje

gain krivulje

Pogledajmo i dobiveno stablo odlučivanja.

stablo_fit_klas %>% 
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 3)
klasifikacijsko stablo odlučivanja

klasifikacijsko stablo odlučivanja

Na primjer, iz dobivenog stabla možemo zaključiti ukoliko pingvin ima peraje dugačke barem 208 milimetara i nalazi se na Biscoe otocima, tada se radi o Gentoo vrsti pingvina.

Možemo ispisati stablo u tekstualnom formatu.

stablo_fit_klas %>% 
  extract_fit_engine()
## n= 249 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 249 140 Adelie (0.437751004 0.204819277 0.357429719)  
##    2) flipper_length_mm< 207.5 156  48 Adelie (0.692307692 0.301282051 0.006410256)  
##      4) bill_length_mm< 43.35 108   2 Adelie (0.981481481 0.018518519 0.000000000) *
##      5) bill_length_mm>=43.35 48   3 Chinstrap (0.041666667 0.937500000 0.020833333)  
##       10) island=Biscoe,Torgersen 3   1 Adelie (0.666666667 0.000000000 0.333333333) *
##       11) island=Dream 45   0 Chinstrap (0.000000000 1.000000000 0.000000000) *
##    3) flipper_length_mm>=207.5 93   5 Gentoo (0.010752688 0.043010753 0.946236559)  
##      6) island=Dream,Torgersen 5   1 Chinstrap (0.200000000 0.800000000 0.000000000) *
##      7) island=Biscoe 88   0 Gentoo (0.000000000 0.000000000 1.000000000) *

Također, možemo dobiti detaljnije informacije o gradnji stabla kod svakog vrha.

model <- stablo_fit_klas %>% extract_fit_engine()
summary(model)
## Call:
## rpart::rpart(formula = ..y ~ ., data = data, cp = ~1.60719499929497e-10, 
##     maxdepth = ~3, minsplit = min_rows(6, data))
##   n= 249 
## 
##             CP nsplit  rel error     xerror       xstd
## 1 6.214286e-01      0 1.00000000 1.00000000 0.05591773
## 2 3.071429e-01      1 0.37857143 0.40000000 0.04705925
## 3 2.857143e-02      2 0.07142857 0.11428571 0.02763823
## 4 1.428571e-02      3 0.04285714 0.05714286 0.01987585
## 5 1.607195e-10      4 0.02857143 0.05000000 0.01863069
## 
## Variable importance
## flipper_length_mm    bill_length_mm     bill_depth_mm       body_mass_g 
##                24                23                20                17 
##            island 
##                16 
## 
## Node number 1: 249 observations,    complexity param=0.6214286
##   predicted class=Adelie     expected loss=0.562249  P(node) =1
##     class counts:   109    51    89
##    probabilities: 0.438 0.205 0.357 
##   left son=2 (156 obs) right son=3 (93 obs)
##   Primary splits:
##       flipper_length_mm < 207.5  to the left,  improve=82.41562, (0 missing)
##       bill_length_mm    < 43.25  to the left,  improve=77.82187, (0 missing)
##       bill_depth_mm     < 16.45  to the right, improve=75.79562, (0 missing)
##       body_mass_g       < 4525   to the left,  improve=60.35847, (0 missing)
##       island            splits as  RLL,        improve=47.71573, (0 missing)
##   Surrogate splits:
##       bill_depth_mm  < 16.35  to the right, agree=0.940, adj=0.839, (0 split)
##       body_mass_g    < 4687.5 to the left,  agree=0.908, adj=0.753, (0 split)
##       island         splits as  RLL,        agree=0.831, adj=0.548, (0 split)
##       bill_length_mm < 43.25  to the left,  agree=0.783, adj=0.419, (0 split)
## 
## Node number 2: 156 observations,    complexity param=0.3071429
##   predicted class=Adelie     expected loss=0.3076923  P(node) =0.626506
##     class counts:   108    47     1
##    probabilities: 0.692 0.301 0.006 
##   left son=4 (108 obs) right son=5 (48 obs)
##   Primary splits:
##       bill_length_mm    < 43.35  to the left,  improve=57.4298400, (0 missing)
##       island            splits as  LRL,        improve=21.8747000, (0 missing)
##       flipper_length_mm < 196.5  to the left,  improve= 9.5811760, (0 missing)
##       body_mass_g       < 3225   to the left,  improve= 2.3040520, (0 missing)
##       bill_depth_mm     < 19.55  to the left,  improve= 0.8794872, (0 missing)
##   Surrogate splits:
##       flipper_length_mm < 196.5  to the left,  agree=0.763, adj=0.229, (0 split)
##       island            splits as  LRL,        agree=0.712, adj=0.063, (0 split)
## 
## Node number 3: 93 observations,    complexity param=0.02857143
##   predicted class=Gentoo     expected loss=0.05376344  P(node) =0.373494
##     class counts:     1     4    88
##    probabilities: 0.011 0.043 0.946 
##   left son=6 (5 obs) right son=7 (88 obs)
##   Primary splits:
##       island            splits as  RLL,        improve=7.9483870, (0 missing)
##       bill_depth_mm     < 17.65  to the right, improve=7.9483870, (0 missing)
##       body_mass_g       < 4125   to the left,  improve=3.1382750, (0 missing)
##       flipper_length_mm < 212.5  to the left,  improve=1.1039430, (0 missing)
##       bill_length_mm    < 48.9   to the right, improve=0.5234619, (0 missing)
##   Surrogate splits:
##       bill_depth_mm < 17.65  to the right, agree=1.000, adj=1.0, (0 split)
##       body_mass_g   < 4125   to the left,  agree=0.968, adj=0.4, (0 split)
## 
## Node number 4: 108 observations
##   predicted class=Adelie     expected loss=0.01851852  P(node) =0.4337349
##     class counts:   106     2     0
##    probabilities: 0.981 0.019 0.000 
## 
## Node number 5: 48 observations,    complexity param=0.01428571
##   predicted class=Chinstrap  expected loss=0.0625  P(node) =0.1927711
##     class counts:     2    45     1
##    probabilities: 0.042 0.937 0.021 
##   left son=10 (3 obs) right son=11 (45 obs)
##   Primary splits:
##       island            splits as  LRL,        improve=4.3750000, (0 missing)
##       body_mass_g       < 4575   to the left,  improve=2.7518120, (0 missing)
##       bill_depth_mm     < 16.5   to the right, improve=0.8822464, (0 missing)
##       bill_length_mm    < 45.85  to the left,  improve=0.7583333, (0 missing)
##       flipper_length_mm < 202.5  to the left,  improve=0.1891696, (0 missing)
##   Surrogate splits:
##       body_mass_g < 4575   to the right, agree=0.979, adj=0.667, (0 split)
## 
## Node number 6: 5 observations
##   predicted class=Chinstrap  expected loss=0.2  P(node) =0.02008032
##     class counts:     1     4     0
##    probabilities: 0.200 0.800 0.000 
## 
## Node number 7: 88 observations
##   predicted class=Gentoo     expected loss=0  P(node) =0.3534137
##     class counts:     0     0    88
##    probabilities: 0.000 0.000 1.000 
## 
## Node number 10: 3 observations
##   predicted class=Adelie     expected loss=0.3333333  P(node) =0.01204819
##     class counts:     2     0     1
##    probabilities: 0.667 0.000 0.333 
## 
## Node number 11: 45 observations
##   predicted class=Chinstrap  expected loss=0  P(node) =0.1807229
##     class counts:     0    45     0
##    probabilities: 0.000 1.000 0.000
model$splits
##                   count ncat    improve   index       adj
## flipper_length_mm   249   -1 82.4156228  207.50 0.0000000
## bill_length_mm      249   -1 77.8218746   43.25 0.0000000
## bill_depth_mm       249    1 75.7956167   16.45 0.0000000
## body_mass_g         249   -1 60.3584658 4525.00 0.0000000
## island              249    3 47.7157254    1.00 0.0000000
## bill_depth_mm         0    1  0.9397590   16.35 0.8387097
## body_mass_g           0   -1  0.9076305 4687.50 0.7526882
## island                0    3  0.8313253    2.00 0.5483871
## bill_length_mm        0   -1  0.7831325   43.25 0.4193548
## bill_length_mm      156   -1 57.4298433   43.35 0.0000000
## island              156    3 21.8746973    3.00 0.0000000
## flipper_length_mm   156   -1  9.5811757  196.50 0.0000000
## body_mass_g         156   -1  2.3040518 3225.00 0.0000000
## bill_depth_mm       156   -1  0.8794872   19.55 0.0000000
## flipper_length_mm     0   -1  0.7628205  196.50 0.2291667
## island                0    3  0.7115385    4.00 0.0625000
## island               48    3  4.3750000    5.00 0.0000000
## body_mass_g          48   -1  2.7518116 4575.00 0.0000000
## bill_depth_mm        48    1  0.8822464   16.50 0.0000000
## bill_length_mm       48   -1  0.7583333   45.85 0.0000000
## flipper_length_mm    48   -1  0.1891696  202.50 0.0000000
## body_mass_g           0    1  0.9791667 4575.00 0.6666667
## island               93    3  7.9483871    6.00 0.0000000
## bill_depth_mm        93    1  7.9483871   17.65 0.0000000
## body_mass_g          93   -1  3.1382747 4125.00 0.0000000
## flipper_length_mm    93   -1  1.1039427  212.50 0.0000000
## bill_length_mm       93    1  0.5234619   48.90 0.0000000
## bill_depth_mm         0    1  1.0000000   17.65 1.0000000
## body_mass_g           0   -1  0.9677419 4125.00 0.4000000

Objasnimo ukratko kako se u gornjem ispisu računa improve. Neka je \(V\) promatrani vrh, a \(V_L\) i \(V_R\) redom njegovo lijevo i desno dijete nakon grananja po nekoj varijabli. Neka su \(G(V), G(V_L)\) i \(G(V_R)\) redom Gini indeksi vrhova \(V,V_L\) i \(V_R\). Nadalje, neka je s \(|V|, |V_L|\) i \(|V_R|\) označen ukupni broj podataka koji se nalaze redom u vrhovima \(V,V_L\) i \(V_R\). Tada je \(|V|=|V_L|+|V_R|\) i vrijedi \[\mathrm{improve}=|V_L|\cdot\big(G(V)-G(V_L)\big)+|V_R|\cdot\big(G(V)-G(V_R)\big).\]

U donjem grafičkom prikazu važnosti prediktora, zapravo je prikazana suma svih improve vrijednosti za pojedinu varijablu.

stablo_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli

Važnost varijabli

Također vidimo da u summary ispisu imamo skalirane cjelobrojne važnosti varijabli tako da je njihova suma jednaka 100. Ukoliko želimo, možemo sami prikazati grafički skalirane važnosti varijabli.

(varimpo <- model$variable.importance %>% as_tibble() %>%
  bind_cols(varijabla = names(model$variable.importance)) %>%
  mutate(rel_importance = round(100 * value / sum(value))) %>%
  relocate(varijabla) %>% rename(importance = value))
ggplot(varimpo, aes(x = fct_reorder(varijabla, rel_importance), y = rel_importance)) + 
  geom_col() + coord_flip() + xlab("") + ylab("importance")
skalirana važnost varijabli

skalirana važnost varijabli

Primjer regresijskog stabla

Pogledajmo sada regresijsko stablo na penguins podacima koje procjenjuje masu pingvina u gramima pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidymodels)
library(modeldata)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)

# definiranje modela
stablo_model <- decision_tree() %>%
  set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
  set_engine("rpart") %>%
  set_mode("regression")

# recept za transformiranje podataka
stablo_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))

# workflow
stablo_work <- workflow() %>% 
  add_model(stablo_model) %>% 
  add_recipe(stablo_recipe)

# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)

# metrike
stablo_metrike <- metric_set(mae, mape, mpe, rmse, rsq)

stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 5000)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>% 
  tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)

saveRDS(stablo_tuning, file = "modeli/stablo_tuning_reg.rds")

best_stablo_model <- stablo_tuning %>% select_best(metric = 'rmse')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)

trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/stablo_work_reg.rds")
saveRDS(stablo_fit, "modeli/stablo_fit_reg.rds")

Učitavanje spremljenih podataka iz rds datoteka.

stablo_fit_reg <- readRDS("modeli/stablo_fit_reg.rds")
stablo_tuning_reg <- readRDS("modeli/stablo_tuning_reg.rds")
stablo_work_reg <- readRDS("modeli/stablo_work_reg.rds")
stablo_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za rmse metriku

stablo_tuning_reg %>% show_best(metric = 'rmse', n = 20)

Metrike na testnom i trening skupu

tab1 <- stablo_fit_reg %>% collect_predictions() %>% 
  metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- stablo_work_reg %>%
  metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(stablo_work_reg, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
  ylab("predikcija")
Predikcije na trening skupu

Predikcije na trening skupu

test_podaci <- stablo_fit_reg %>% collect_predictions()

ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5, alpha = 0.7) +
  geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu

Predikcije na testnom skupu

Pogledajmo dobiveno regresijsko stablo odlučivanja.

stablo_fit_reg %>% 
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 3, tweak = 1.2)
regresijsko stablo odlučivanja

regresijsko stablo odlučivanja

Na primjer, iz regresijskog stabla možemo vidjeti da Gentoo muški pingvini s duljinom kljuna većom od 47.7 milimetara imaju u prosjeku težinu od 5542 grama.

Ispis stabla u tekstualnom formatu.

stablo_fit_reg %>% 
  extract_fit_engine()
## n= 248 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 248 151430000.0 4203.528  
##    2) species=Adelie,Chinstrap 160  29311870.0 3725.156  
##      4) sex=female 79   6079747.0 3438.608  
##        8) bill_depth_mm< 18.75 69   4804366.0 3399.638  
##         16) flipper_length_mm< 195.5 60   3937958.0 3364.167  
##           32) bill_depth_mm< 16.65 9    592222.2 3119.444 *
##           33) bill_depth_mm>=16.65 51   2711618.0 3407.353 *
##         17) flipper_length_mm>=195.5 9    287638.9 3636.111 *
##        9) bill_depth_mm>=18.75 10    447562.5 3707.500 *
##      5) sex=male 81  10418890.0 4004.630  
##       10) flipper_length_mm< 194.5 37   3661419.0 3880.405 *
##       11) flipper_length_mm>=194.5 44   5706364.0 4109.091  
##         22) bill_length_mm>=49.25 18   2527812.0 3970.833  
##           44) flipper_length_mm< 199.5 7    230535.7 3603.571 *
##           45) flipper_length_mm>=199.5 11    752272.7 4204.545 *
##         23) bill_length_mm< 49.25 26   2596274.0 4204.808 *
##    3) species=Gentoo 88  18932240.0 5073.295  
##      6) sex=female 45   3109250.0 4698.333  
##       12) flipper_length_mm< 211.5 18   1221944.0 4530.556 *
##       13) flipper_length_mm>=211.5 27   1042824.0 4810.185 *
##      7) sex=male 43   2875029.0 5465.698  
##       14) bill_length_mm< 47.7 10    170250.0 5215.000 *
##       15) bill_length_mm>=47.7 33   1885833.0 5541.667 *

Detaljnije informacije o gradnji stabla kod svakog vrha.

model <- stablo_fit_reg %>% extract_fit_engine()
summary(model)
## Call:
## rpart::rpart(formula = ..y ~ ., data = data, cp = ~0.0037551782899252, 
##     maxdepth = ~14, minsplit = min_rows(9, data))
##   n= 248 
## 
##            CP nsplit rel error    xerror       xstd
## 1 0.681409870      0 1.0000000 1.0038092 0.06866008
## 2 0.085504603      1 0.3185901 0.3230474 0.02492202
## 3 0.084614886      2 0.2330855 0.2693748 0.02386278
## 4 0.006996383      3 0.1484706 0.1529744 0.01321357
## 5 0.005576710      6 0.1274815 0.1705304 0.01498297
## 6 0.005466672      7 0.1219048 0.1721455 0.01544124
## 7 0.005408080      8 0.1164381 0.1711754 0.01540275
## 8 0.004004777      9 0.1110300 0.1649165 0.01544877
## 9 0.003755178     11 0.1030205 0.1647426 0.01542931
## 
## Variable importance
##     bill_depth_mm flipper_length_mm           species            island 
##                24                24                22                14 
##    bill_length_mm               sex 
##                11                 5 
## 
## Node number 1: 248 observations,    complexity param=0.6814099
##   mean=4203.528, MSE=610605 
##   left son=2 (160 obs) right son=3 (88 obs)
##   Primary splits:
##       species           splits as  LLR,       improve=0.6814099, (0 missing)
##       flipper_length_mm < 208.5 to the left,  improve=0.6649007, (0 missing)
##       bill_depth_mm     < 16.4  to the right, improve=0.4916466, (0 missing)
##       island            splits as  RLL,       improve=0.3908572, (0 missing)
##       bill_length_mm    < 42.55 to the left,  improve=0.3111583, (0 missing)
##   Surrogate splits:
##       flipper_length_mm < 206.5 to the left,  agree=0.972, adj=0.920, (0 split)
##       bill_depth_mm     < 16.4  to the right, agree=0.960, adj=0.886, (0 split)
##       island            splits as  RLL,       agree=0.863, adj=0.614, (0 split)
##       bill_length_mm    < 43.25 to the left,  agree=0.774, adj=0.364, (0 split)
## 
## Node number 2: 160 observations,    complexity param=0.08461489
##   mean=3725.156, MSE=183199.2 
##   left son=4 (79 obs) right son=5 (81 obs)
##   Primary splits:
##       sex               splits as  LR,        improve=0.437134700, (0 missing)
##       bill_depth_mm     < 18.05 to the left,  improve=0.236826100, (0 missing)
##       flipper_length_mm < 194.5 to the left,  improve=0.184746400, (0 missing)
##       bill_length_mm    < 39.05 to the left,  improve=0.118219000, (0 missing)
##       island            splits as  RRL,       improve=0.003803564, (0 missing)
##   Surrogate splits:
##       bill_depth_mm     < 17.95 to the left,  agree=0.806, adj=0.608, (0 split)
##       bill_length_mm    < 38.95 to the left,  agree=0.675, adj=0.342, (0 split)
##       flipper_length_mm < 195.5 to the left,  agree=0.669, adj=0.329, (0 split)
##       island            splits as  RRL,       agree=0.525, adj=0.038, (0 split)
##       species           splits as  RL-,       agree=0.512, adj=0.013, (0 split)
## 
## Node number 3: 88 observations,    complexity param=0.0855046
##   mean=5073.295, MSE=215139.1 
##   left son=6 (45 obs) right son=7 (43 obs)
##   Primary splits:
##       sex               splits as  LR,        improve=0.6839107, (0 missing)
##       bill_depth_mm     < 14.75 to the left,  improve=0.4953340, (0 missing)
##       bill_length_mm    < 48.3  to the left,  improve=0.4320779, (0 missing)
##       flipper_length_mm < 212.5 to the left,  improve=0.3651048, (0 missing)
##   Surrogate splits:
##       bill_depth_mm     < 14.85 to the left,  agree=0.909, adj=0.814, (0 split)
##       bill_length_mm    < 47.9  to the left,  agree=0.807, adj=0.605, (0 split)
##       flipper_length_mm < 217.5 to the left,  agree=0.807, adj=0.605, (0 split)
## 
## Node number 4: 79 observations,    complexity param=0.005466672
##   mean=3438.608, MSE=76958.82 
##   left son=8 (69 obs) right son=9 (10 obs)
##   Primary splits:
##       bill_depth_mm     < 18.75 to the left,  improve=0.13616000, (0 missing)
##       flipper_length_mm < 188.5 to the left,  improve=0.09458126, (0 missing)
##       bill_length_mm    < 44.35 to the left,  improve=0.07698016, (0 missing)
##       species           splits as  LR-,       improve=0.03663995, (0 missing)
##       island            splits as  LRL,       improve=0.00384916, (0 missing)
## 
## Node number 5: 81 observations,    complexity param=0.006996383
##   mean=4004.63, MSE=128628.3 
##   left son=10 (37 obs) right son=11 (44 obs)
##   Primary splits:
##       flipper_length_mm < 194.5 to the left,  improve=0.100884700, (0 missing)
##       bill_depth_mm     < 20.65 to the left,  improve=0.058103760, (0 missing)
##       bill_length_mm    < 49.45 to the right, improve=0.022226350, (0 missing)
##       species           splits as  RL-,       improve=0.017807420, (0 missing)
##       island            splits as  RLL,       improve=0.003979764, (0 missing)
##   Surrogate splits:
##       bill_length_mm < 41.2  to the left,  agree=0.753, adj=0.459, (0 split)
##       species        splits as  LR-,       agree=0.667, adj=0.270, (0 split)
##       bill_depth_mm  < 18.95 to the left,  agree=0.667, adj=0.270, (0 split)
##       island         splits as  LRR,       agree=0.630, adj=0.189, (0 split)
## 
## Node number 6: 45 observations,    complexity param=0.00557671
##   mean=4698.333, MSE=69094.44 
##   left son=12 (18 obs) right son=13 (27 obs)
##   Primary splits:
##       flipper_length_mm < 211.5 to the left,  improve=0.2716030, (0 missing)
##       bill_depth_mm     < 14.65 to the left,  improve=0.1697643, (0 missing)
##       bill_length_mm    < 46.45 to the left,  improve=0.1080713, (0 missing)
##   Surrogate splits:
##       bill_length_mm < 43.35 to the left,  agree=0.644, adj=0.111, (0 split)
##       bill_depth_mm  < 13.65 to the left,  agree=0.644, adj=0.111, (0 split)
## 
## Node number 7: 43 observations,    complexity param=0.00540808
##   mean=5465.698, MSE=66861.14 
##   left son=14 (10 obs) right son=15 (33 obs)
##   Primary splits:
##       bill_length_mm    < 47.7  to the left,  improve=0.2848478, (0 missing)
##       bill_depth_mm     < 15.85 to the left,  improve=0.1380727, (0 missing)
##       flipper_length_mm < 228.5 to the left,  improve=0.1374721, (0 missing)
##   Surrogate splits:
##       flipper_length_mm < 215.5 to the left,  agree=0.814, adj=0.2, (0 split)
## 
## Node number 8: 69 observations,    complexity param=0.004004777
##   mean=3399.638, MSE=69628.49 
##   left son=16 (60 obs) right son=17 (9 obs)
##   Primary splits:
##       flipper_length_mm < 195.5 to the left,  improve=0.12046720, (0 missing)
##       bill_depth_mm     < 16.75 to the left,  improve=0.11065450, (0 missing)
##       bill_length_mm    < 44.35 to the left,  improve=0.10740950, (0 missing)
##       species           splits as  LR-,       improve=0.05913125, (0 missing)
##       island            splits as  LRL,       improve=0.01884782, (0 missing)
##   Surrogate splits:
##       bill_length_mm < 48.7  to the left,  agree=0.884, adj=0.111, (0 split)
## 
## Node number 9: 10 observations
##   mean=3707.5, MSE=44756.25 
## 
## Node number 10: 37 observations
##   mean=3880.405, MSE=98957.27 
## 
## Node number 11: 44 observations,    complexity param=0.006996383
##   mean=4109.091, MSE=129690.1 
##   left son=22 (18 obs) right son=23 (26 obs)
##   Primary splits:
##       bill_length_mm    < 49.25 to the right, improve=0.10204000, (0 missing)
##       bill_depth_mm     < 20.4  to the left,  improve=0.08950952, (0 missing)
##       island            splits as  RLL,       improve=0.08655960, (0 missing)
##       species           splits as  RL-,       improve=0.07489047, (0 missing)
##       flipper_length_mm < 202.5 to the left,  improve=0.05269237, (0 missing)
##   Surrogate splits:
##       species           splits as  RL-,       agree=0.932, adj=0.833, (0 split)
##       island            splits as  RLR,       agree=0.750, adj=0.389, (0 split)
##       flipper_length_mm < 200.5 to the right, agree=0.705, adj=0.278, (0 split)
##       bill_depth_mm     < 19.75 to the right, agree=0.636, adj=0.111, (0 split)
## 
## Node number 12: 18 observations
##   mean=4530.556, MSE=67885.8 
## 
## Node number 13: 27 observations
##   mean=4810.185, MSE=38623.11 
## 
## Node number 14: 10 observations
##   mean=5215, MSE=17025 
## 
## Node number 15: 33 observations
##   mean=5541.667, MSE=57146.46 
## 
## Node number 16: 60 observations,    complexity param=0.004004777
##   mean=3364.167, MSE=65632.64 
##   left son=32 (9 obs) right son=33 (51 obs)
##   Primary splits:
##       bill_depth_mm     < 16.65 to the left,  improve=0.16102720, (0 missing)
##       bill_length_mm    < 45.3  to the left,  improve=0.10953280, (0 missing)
##       flipper_length_mm < 188.5 to the left,  improve=0.05126286, (0 missing)
##       species           splits as  LR-,       improve=0.04810778, (0 missing)
##       island            splits as  LRR,       improve=0.04608987, (0 missing)
## 
## Node number 17: 9 observations
##   mean=3636.111, MSE=31959.88 
## 
## Node number 22: 18 observations,    complexity param=0.006996383
##   mean=3970.833, MSE=140434 
##   left son=44 (7 obs) right son=45 (11 obs)
##   Primary splits:
##       flipper_length_mm < 199.5 to the left,  improve=0.61120200, (0 missing)
##       bill_length_mm    < 50.4  to the left,  improve=0.13577520, (0 missing)
##       bill_depth_mm     < 19.85 to the left,  improve=0.06731364, (0 missing)
##   Surrogate splits:
##       bill_length_mm < 50.4  to the left,  agree=0.722, adj=0.286, (0 split)
## 
## Node number 23: 26 observations
##   mean=4204.808, MSE=99856.69 
## 
## Node number 32: 9 observations
##   mean=3119.444, MSE=65802.47 
## 
## Node number 33: 51 observations
##   mean=3407.353, MSE=53168.97 
## 
## Node number 44: 7 observations
##   mean=3603.571, MSE=32933.67 
## 
## Node number 45: 11 observations
##   mean=4204.545, MSE=68388.43
model$splits
##                   count ncat     improve  index        adj
## species             248    3 0.681409870   1.00 0.00000000
## flipper_length_mm   248   -1 0.664900737 208.50 0.00000000
## bill_depth_mm       248    1 0.491646586  16.40 0.00000000
## island              248    3 0.390857150   2.00 0.00000000
## bill_length_mm      248   -1 0.311158277  42.55 0.00000000
## flipper_length_mm     0   -1 0.971774194 206.50 0.92045455
## bill_depth_mm         0    1 0.959677419  16.40 0.88636364
## island                0    3 0.862903226   3.00 0.61363636
## bill_length_mm        0   -1 0.774193548  43.25 0.36363636
## sex                 160    2 0.437134679   4.00 0.00000000
## bill_depth_mm       160   -1 0.236826057  18.05 0.00000000
## flipper_length_mm   160   -1 0.184746367 194.50 0.00000000
## bill_length_mm      160   -1 0.118218983  39.05 0.00000000
## island              160    3 0.003803564   5.00 0.00000000
## bill_depth_mm         0   -1 0.806250000  17.95 0.60759494
## bill_length_mm        0   -1 0.675000000  38.95 0.34177215
## flipper_length_mm     0   -1 0.668750000 195.50 0.32911392
## island                0    3 0.525000000   6.00 0.03797468
## species               0    3 0.512500000   7.00 0.01265823
## bill_depth_mm        79   -1 0.136160011  18.75 0.00000000
## flipper_length_mm    79   -1 0.094581258 188.50 0.00000000
## bill_length_mm       79   -1 0.076980156  44.35 0.00000000
## species              79    3 0.036639950   8.00 0.00000000
## island               79    3 0.003849160   9.00 0.00000000
## flipper_length_mm    69   -1 0.120467243 195.50 0.00000000
## bill_depth_mm        69   -1 0.110654501  16.75 0.00000000
## bill_length_mm       69   -1 0.107409451  44.35 0.00000000
## species              69    3 0.059131250  10.00 0.00000000
## island               69    3 0.018847823  11.00 0.00000000
## bill_length_mm        0   -1 0.884057971  48.70 0.11111111
## bill_depth_mm        60   -1 0.161027215  16.65 0.00000000
## bill_length_mm       60   -1 0.109532815  45.30 0.00000000
## flipper_length_mm    60   -1 0.051262862 188.50 0.00000000
## species              60    3 0.048107779  12.00 0.00000000
## island               60    3 0.046089873  13.00 0.00000000
## flipper_length_mm    81   -1 0.100884686 194.50 0.00000000
## bill_depth_mm        81   -1 0.058103756  20.65 0.00000000
## bill_length_mm       81    1 0.022226346  49.45 0.00000000
## species              81    3 0.017807420  14.00 0.00000000
## island               81    3 0.003979764  15.00 0.00000000
## bill_length_mm        0   -1 0.753086420  41.20 0.45945946
## species               0    3 0.666666667  16.00 0.27027027
## bill_depth_mm         0   -1 0.666666667  18.95 0.27027027
## island                0    3 0.629629630  17.00 0.18918919
## bill_length_mm       44    1 0.102039957  49.25 0.00000000
## bill_depth_mm        44   -1 0.089509524  20.40 0.00000000
## island               44    3 0.086559601  18.00 0.00000000
## species              44    3 0.074890469  19.00 0.00000000
## flipper_length_mm    44   -1 0.052692369 202.50 0.00000000
## species               0    3 0.931818182  20.00 0.83333333
## island                0    3 0.750000000  21.00 0.38888889
## flipper_length_mm     0    1 0.704545455 200.50 0.27777778
## bill_depth_mm         0    1 0.636363636  19.75 0.11111111
## flipper_length_mm    18   -1 0.611202001 199.50 0.00000000
## bill_length_mm       18   -1 0.135775215  50.40 0.00000000
## bill_depth_mm        18   -1 0.067313636  19.85 0.00000000
## bill_length_mm        0   -1 0.722222222  50.40 0.28571429
## sex                  88    2 0.683910741  22.00 0.00000000
## bill_depth_mm        88   -1 0.495333985  14.75 0.00000000
## bill_length_mm       88   -1 0.432077925  48.30 0.00000000
## flipper_length_mm    88   -1 0.365104843 212.50 0.00000000
## bill_depth_mm         0   -1 0.909090909  14.85 0.81395349
## bill_length_mm        0   -1 0.806818182  47.90 0.60465116
## flipper_length_mm     0   -1 0.806818182 217.50 0.60465116
## flipper_length_mm    45   -1 0.271602953 211.50 0.00000000
## bill_depth_mm        45   -1 0.169764261  14.65 0.00000000
## bill_length_mm       45   -1 0.108071346  46.45 0.00000000
## bill_length_mm        0   -1 0.644444444  43.35 0.11111111
## bill_depth_mm         0   -1 0.644444444  13.65 0.11111111
## bill_length_mm       43   -1 0.284847811  47.70 0.00000000
## bill_depth_mm        43   -1 0.138072693  15.85 0.00000000
## flipper_length_mm    43   -1 0.137472108 228.50 0.00000000
## flipper_length_mm     0   -1 0.813953488 215.50 0.20000000

Objasnimo ukratko kako se u gornjem ispisu računa improve. Neka je \(V\) promatrani vrh, a \(V_L\) i \(V_R\) redom njegovo lijevo i desno dijete nakon grananja po nekoj varijabli. Neka je \[ RSS(V)=\sum_{x_i\in V}{\big(y_i-\hat{y}_V\big)^2},\quad RSS(V_L)=\sum_{x_i\in V_L}{\big(y_i-\hat{y}_{V_L}\big)^2},\quad RSS(V_R)=\sum_{x_i\in V_R}{\big(y_i-\hat{y}_{V_R}\big)^2} \] pri čemu su \(\hat{y}_V,\hat{y}_{V_L}\) i \(\hat{y}_{V_R}\) aritmetičke sredine podataka koji pripadaju redom vrhovima \(V,V_L\) i \(V_R\). Tada je \[\mathrm{improve}=1-\frac{RSS(V_L)+RSS(V_R)}{RSS(V)}.\]

U donjem grafičkom prikazu važnosti prediktora, zapravo je prikazana suma svih deviance (RSS) vrijednosti za pojedinu varijablu.

stablo_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli

Važnost varijabli

Također vidimo da u summary ispisu imamo skalirane cjelobrojne važnosti varijabli tako da je njihova suma jednaka 100. Ukoliko želimo, možemo sami prikazati grafički skalirane važnosti varijabli.

(varimpo <- model$variable.importance %>% as_tibble() %>%
  bind_cols(varijabla = names(model$variable.importance)) %>%
  mutate(rel_importance = round(100 * value / sum(value))) %>%
  relocate(varijabla) %>% rename(importance = value))
ggplot(varimpo, aes(x = fct_reorder(varijabla, rel_importance), y = rel_importance)) + 
  geom_col() + coord_flip() + xlab("") + ylab("importance")
skalirana važnost varijabli

skalirana važnost varijabli

Pakiranje stabala

Na poznatoj izreci “Više glava je pametnije od jedne” bazira se pakiranje stabala, na engleskom bagging ili bootstrap aggregation. Pakiranje je jedna općenita metoda koja služi za smanjenje varijance algoritma za učenje bazirana na sljedećoj činjenici:

  • Neka su \(Z_1,Z_2,\dotsc,Z_n\) nezavisne slučajne varijable s varijancom \(\sigma^2\). Tada je varijanca slučajne varijable \[\bar{Z}=\frac{Z_1+Z_2+\dotsc+Z_n}{n}\] jednaka \(\frac{\sigma^2}{n}\). Jednostavno rečeno, usrednjavanjem se smanjuje varijanca.

U praksi zapravo imamo samo jedan skup podataka za treniranje. Stoga na postojećem skupu za treniranje koje ima \(N\) podataka kreiramo nove skupove za treniranje s \(N\) podataka preko bootstrap metode. Svaki novi skup podataka je zapravo neka kombinacija s ponavljanjem od \(N\) elemenata početnog skupa podataka za treniranje s kojim raspolažemo.

  • Generiramo \(B\) različitih skupova podataka za treniranje preko bootstrap metode.
  • Za svaki \(b=1,2,\dotsc,B\) treniramo stablo odlučivanja na \(b\)-tom bootstrap trening skupu i na taj način dobivamo predikciju \(\hat{f}^{\ast b}(x)\) u svakoj točki \(x\).
  • U slučaju regresije napravimo prosječne vrijednosti svih predikcija, tj. za konačnu predikciju uzimamo \[\hat{f}_{\mathrm{bag}}(x)=\frac{1}{B}\sum_{b=1}^B{\hat{f}^{\ast b}(x)}.\]
  • U slučaju klasifikacije svako od \(B\) stabala predviđa određenu klasu. Za konačnu predikciju koristimo većinu glasova, tj. uzimamo onu klasu koja se najviše javlja u dobivenih \(B\) predikcija.

Prilikom treniranja svako od \(B\) stabala koristi samo neke podatke iz početnog trening skupa. Podatke koje pojedino stablo ne koristi prilikom treniranja zovu se out-of-bag (OOB) podaci za promatrano stablo. Dakle, svako od \(B\) stabala ima svoje OOB podatke. Pokazuje se da u prosjeku svako stablo za treniranje koristi oko \(\frac{2}{3}\) početnih podataka, odnosno da je za svako stablo broj njegovih OOB podataka približno jednak \(\frac{1}{3}\) početnih podataka.

Za svaki podatak u trening skupu pogledamo sva stabla za koja je taj podatak OOB podatak. Svako od tih stabala daje predikciju za taj podatak. U slučaju regresije uzmemo prosječnu vrijednosti tih predikcija, a u slučaju klasifikacije klasu koja se javlja najviše puta. Na taj način dobivamo OOB predikciju za svaki podatak.

Primjer klasifikacije

Pogledajmo sada pakiranje stabala na penguins podacima pri čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod. Zapravo pakiramo ukupno 20 stabala jer smo stavili times = 20.

library(tidyverse)
library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = species)

# definiranje modela
stablo_model <- bag_tree() %>%
  set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
  set_engine("rpart", times = 20) %>%
  set_mode("classification")

# recept za transformiranje podataka
stablo_recipe <- recipe(species ~ ., data = training(podaci_split))

# workflow
stablo_work <- workflow() %>% 
  add_model(stablo_model) %>% 
  add_recipe(stablo_recipe)

# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)

# metrike
stablo_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)

stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 50)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>% 
  tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)

saveRDS(stablo_tuning, file = "modeli/bagged_stablo_tuning_klas.rds")

best_stablo_model <- stablo_tuning %>% select_best(metric = 'roc_auc')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)

trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/bagged_stablo_work_klas.rds")
saveRDS(stablo_fit, "modeli/bagged_stablo_fit_klas.rds")

Učitavanje spremljenih podataka iz rds datoteka.

bag_stablo_fit_klas <- readRDS("modeli/bagged_stablo_fit_klas.rds")
bag_stablo_tuning_klas <- readRDS("modeli/bagged_stablo_tuning_klas.rds")
bag_stablo_work_klas <- readRDS("modeli/bagged_stablo_work_klas.rds")
bag_stablo_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za roc_auc metriku

bag_stablo_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)

Metrike na testnom i trening skupu

tab1 <- bag_stablo_fit_klas %>% collect_predictions() %>% 
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- bag_stablo_work_klas %>%
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
bag_stablo_work_klas %>%  
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu

Confusion matrix na trening skupu

bag_stablo_fit_klas %>% collect_predictions() %>%
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu

Confusion matrix na testnom skupu

Pogledajmo ROC i gain krivulje

bag_stablo_fit_klas %>% collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje

ROC krivulje

bag_stablo_fit_klas %>% collect_predictions() %>% 
  gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje

gain krivulje

Kod pakiranja stabala nema nekog velikog smisla prikazivati pojedino stablo, pogotovo ako tih stabala ima jako puno. No, ako želimo, možemo dobiti informacije o svakom stablu u paketu i nacrtati pojedino stablo. No, imajmo na umu da samo jedno stablo ne odlučuje o konačnoj predikciji, već u tome sudjeluju sva stabla u paketu.

model <- bag_stablo_fit_klas %>% extract_workflow() %>% extract_fit_engine()
attributes(model)
## $names
## [1] "model_df"   "control"    "cost"       "imp"        "base_model"
## [6] "blueprint" 
## 
## $class
## [1] "bagger"         "hardhat_model"  "hardhat_scalar"

Informacije o svim stablima nalaze se u atributu model_df. Pogledajmo na primjer informacije o 1. i 8. stablu iz paketa.

model$model_df$model[[1]]
## parsnip model object
## 
## n= 249 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 249 148 1  
##    2) flipper_length_mm< 206.5 148  47 1  
##      4) bill_length_mm< 43.15 105   4 1  
##        8) bill_length_mm< 42.3 96   1 1  
##         16) bill_depth_mm>=16.7 84   0 1 *
##         17) bill_depth_mm< 16.7 12   1 1  
##           34) bill_length_mm< 39.5 11   0 1 *
##           35) bill_length_mm>=39.5 1   0 2 *
##        9) bill_length_mm>=42.3 9   3 1  
##         18) bill_length_mm>=42.6 6   0 1 *
##         19) bill_length_mm< 42.6 3   0 2 *
##      5) bill_length_mm>=43.15 43   1 2  
##       10) bill_depth_mm>=15.4 42   0 2 *
##       11) bill_depth_mm< 15.4 1   0 3 *
##    3) flipper_length_mm>=206.5 101   6 3  
##      6) bill_depth_mm>=18.15 6   0 2 *
##      7) bill_depth_mm< 18.15 95   0 3 *
model$model_df$model[[1]] %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 2, tweak = 1.2)
Prvo stablo iz paketa

Prvo stablo iz paketa

model$model_df$model[[8]]
## parsnip model object
## 
## n= 249 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 249 146 3  
##    2) bill_length_mm< 43.2 105   5 1  
##      4) bill_depth_mm>=15.3 103   3 1  
##        8) bill_length_mm< 42.25 96   1 1  
##         16) bill_depth_mm>=16.8 88   0 1 *
##         17) bill_depth_mm< 16.8 8   1 1  
##           34) bill_length_mm< 39.5 7   0 1 *
##           35) bill_length_mm>=39.5 1   0 2 *
##        9) bill_length_mm>=42.25 7   2 1  
##         18) bill_length_mm>=42.45 5   0 1 *
##         19) bill_length_mm< 42.45 2   0 2 *
##      5) bill_depth_mm< 15.3 2   0 3 *
##    3) bill_length_mm>=43.2 144  43 3  
##      6) island=Dream 41   0 2 *
##      7) island=Biscoe,Torgersen 103   2 3  
##       14) bill_depth_mm>=18.8 2   0 1 *
##       15) bill_depth_mm< 18.8 101   0 3 *
model$model_df$model[[8]] %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 2, tweak = 1.2)
Osmo stablo iz paketa

Osmo stablo iz paketa

Funkcija var_imp iz baguette biblioteke računa važnost prediktora za svako stablo u paketu kako je ranije opisano kod stabla odlučivanja i vraća aritmetičku sredinu tih vrijednosti.

vaznost <- var_imp(model)
ggplot(vaznost, aes(x = fct_reorder(term, value), y = value)) + 
  geom_col() + coord_flip() + xlab("") + ylab("")
Važnost varijabli

Važnost varijabli

Primjer regresije

Pogledajmo sada pakiranje stabala na penguins podacima koje procjenjuje masu pingvina u gramima pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)

# definiranje modela
stablo_model <- bag_tree() %>%
  set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
  set_engine("rpart", times = 20) %>%
  set_mode("regression")

# recept za transformiranje podataka
stablo_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))

# workflow
stablo_work <- workflow() %>% 
  add_model(stablo_model) %>% 
  add_recipe(stablo_recipe)

# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)

# metrike
stablo_metrike <- metric_set(mae, mape, mpe, rmse, rsq)

stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 100)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>% 
  tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)

saveRDS(stablo_tuning, file = "modeli/bagged_stablo_tuning_reg.rds")

best_stablo_model <- stablo_tuning %>% select_best(metric = 'rmse')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)

trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/bagged_stablo_work_reg.rds")
saveRDS(stablo_fit, "modeli/bagged_stablo_fit_reg.rds")

Učitavanje spremljenih podataka iz rds datoteka.

bag_stablo_fit_reg <- readRDS("modeli/bagged_stablo_fit_reg.rds")
bag_stablo_tuning_reg <- readRDS("modeli/bagged_stablo_tuning_reg.rds")
bag_stablo_work_reg <- readRDS("modeli/bagged_stablo_work_reg.rds")
bag_stablo_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za rmse metriku

bag_stablo_tuning_reg %>% show_best(metric = 'rmse', n = 20)

Metrike na testnom i trening skupu

tab1 <- bag_stablo_fit_reg %>% collect_predictions() %>% 
  metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- bag_stablo_work_reg %>%
  metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(bag_stablo_work_reg, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
  ylab("predikcija")
Predikcije na trening skupu

Predikcije na trening skupu

test_podaci <- bag_stablo_fit_reg %>% collect_predictions()

ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5, alpha = 0.7) +
  geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu

Predikcije na testnom skupu

Pogledajmo na primjer informacije o 5. i 11. stablu iz paketa.

model <- bag_stablo_fit_reg %>% extract_workflow() %>% extract_fit_engine()
model$model_df$model[[5]]
## parsnip model object
## 
## n= 248 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 248 155616900.0 4219.052  
##    2) species=Adelie,Chinstrap 163  30085530.0 3749.233  
##      4) sex=female 72   4628672.0 3446.875  
##        8) flipper_length_mm< 197 64   3489023.0 3408.594  
##         16) bill_depth_mm< 16.9 6    197083.3 3141.667 *
##         17) bill_depth_mm>=16.9 58   2820216.0 3436.207  
##           34) bill_length_mm>=36.1 48   1818112.0 3400.521  
##             68) island=Biscoe,Torgersen 18    819444.4 3294.444 *
##             69) island=Dream 30    674604.2 3464.167 *
##           35) bill_length_mm< 36.1 10    647562.5 3607.500 *
##        9) flipper_length_mm>=197 8    295546.9 3753.125 *
##      5) sex=male 91  13666630.0 3988.462  
##       10) bill_depth_mm< 20.6 81   9225386.0 3916.049  
##         20) flipper_length_mm< 189.5 13    725480.8 3582.692 *
##         21) flipper_length_mm>=189.5 68   6779072.0 3979.779  
##           42) species=Chinstrap 28   2009286.0 3814.286  
##             84) bill_length_mm< 50.4 13    758557.7 3621.154 *
##             85) bill_length_mm>=50.4 15    345583.3 3981.667 *
##           43) species=Adelie 40   3466109.0 4095.625  
##             86) bill_length_mm< 43.15 33   2477197.0 4055.303 *
##             87) bill_length_mm>=43.15 7    682321.4 4285.714 *
##       11) bill_depth_mm>=20.6 10    576250.0 4575.000 *
##    3) species=Gentoo 85  20557250.0 5120.000  
##      6) sex=female 42   3047396.0 4729.167  
##       12) bill_length_mm< 43.35 9    865000.0 4466.667 *
##       13) bill_length_mm>=43.35 33   1393106.0 4800.758  
##         26) bill_depth_mm< 14.45 18    567673.6 4715.278 *
##         27) bill_depth_mm>=14.45 15    536083.3 4903.333 *
##      7) sex=male 43   4827994.0 5501.744  
##       14) flipper_length_mm< 228.5 34   3370312.0 5412.500  
##         28) bill_length_mm< 46.6 6    103333.3 5116.667 *
##         29) bill_length_mm>=46.6 28   2629353.0 5475.893  
##           58) bill_length_mm>=49.35 16   1617148.0 5392.188 *
##           59) bill_length_mm< 49.35 12    750625.0 5587.500 *
##       15) flipper_length_mm>=228.5 9    163888.9 5838.889 *

Desnim klikom na sliku možete otvoriti veću sliku u novom tabu.

model$model_df$model[[5]] %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 3, tweak = 1.15)
Peto stablo iz paketa

Peto stablo iz paketa

model$model_df$model[[11]]
## parsnip model object
## 
## n= 248 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 248 167309000.00 4280.847  
##    2) species=Adelie,Chinstrap 154  28888170.00 3741.234  
##      4) sex=female 70   4959295.00 3423.929  
##        8) bill_length_mm< 38 30   1892417.00 3281.667  
##         16) bill_depth_mm< 17.65 14    223258.90 3116.071 *
##         17) bill_depth_mm>=17.65 16    949335.90 3426.562 *
##        9) bill_length_mm>=38 40   2004359.00 3530.625  
##         18) flipper_length_mm< 194.5 27   1237269.00 3457.407  
##           36) bill_depth_mm< 18.35 20    936375.00 3407.500  
##             72) bill_length_mm< 46.15 14    752321.40 3353.571 *
##             73) bill_length_mm>=46.15 6     48333.33 3533.333 *
##           37) bill_depth_mm>=18.35 7    108750.00 3600.000 *
##         19) flipper_length_mm>=194.5 13    321730.80 3682.692 *
##      5) sex=male 84  11007940.00 4005.655  
##       10) flipper_length_mm< 200.5 62   8031855.00 3928.226  
##         20) bill_length_mm>=45.85 19   1127763.00 3734.211  
##           40) bill_depth_mm< 18.85 9    297500.00 3633.333 *
##           41) bill_depth_mm>=18.85 10    656250.00 3825.000 *
##         21) bill_length_mm< 45.85 43   5872878.00 4013.953  
##           42) bill_length_mm< 39.1 12    159166.70 3741.667 *
##           43) bill_length_mm>=39.1 31   4479637.00 4119.355  
##             86) bill_length_mm>=40.15 22   2850369.00 3996.591 *
##             87) bill_length_mm< 40.15 9    487222.20 4419.444 *
##       11) flipper_length_mm>=200.5 22   1556847.00 4223.864  
##         22) bill_depth_mm< 19.85 14    405892.90 4085.714 *
##         23) bill_depth_mm>=19.85 8    416171.90 4465.625 *
##    3) species=Gentoo 94  20114150.00 5164.894  
##      6) sex=female 45   2741111.00 4794.444  
##       12) flipper_length_mm< 211.5 15    869333.30 4576.667 *
##       13) flipper_length_mm>=211.5 30    804666.70 4903.333  
##         26) bill_length_mm>=45.8 12    272291.70 4829.167 *
##         27) bill_length_mm< 45.8 18    422361.10 4952.778 *
##      7) sex=male 49   5526224.00 5505.102  
##       14) bill_length_mm< 48.3 16    444375.00 5256.250 *
##       15) bill_length_mm>=48.3 33   3610606.00 5625.758  
##         30) bill_length_mm>=49.45 22   1400909.00 5513.636  
##           60) bill_length_mm>=50.95 6     77083.33 5358.333 *
##           61) bill_length_mm< 50.95 16   1124844.00 5571.875 *
##         31) bill_length_mm< 49.45 11   1380000.00 5850.000 *

Desnim klikom na sliku možete otvoriti veću sliku u novom tabu.

model$model_df$model[[11]] %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE, digits = 3, tweak = 1.2)
Jedanaesto stablo iz paketa

Jedanaesto stablo iz paketa

Funkcija var_imp iz baguette biblioteke računa važnost prediktora za svako stablo u paketu kako je ranije opisano kod stabla odlučivanja i vraća aritmetičku sredinu tih vrijednosti.

vaznost <- var_imp(model)
ggplot(vaznost, aes(x = fct_reorder(term, value), y = value)) + 
  geom_col() + coord_flip() + xlab("") + ylab("")
Važnost varijabli

Važnost varijabli

Slučajna šuma

Slučajna šuma funkcionira na slični način kao i pakiranje stabala. Jedina je razlika da se prilikom grananja pojedinog stabla ne koriste svi prediktori \(X_1,X_2,\dotsc,X_p\), nego se svaki put prilikom novog grananja na slučajni način bira \(m\) prediktora koji će se razmatrati u trenutnom grananju. Najčešće uzimamo \(m\approx\sqrt{p}\) što je dobro u većini slučajeva. Zapravo je \(m\) hiperparametar za slučajnu šumu čiju optimalnu vrijednost možemo također odabrati preko cross-validation.

Ovdje ćemo preko tidymodels biblioteke koristiti ranger biblioteku u kojoj je implementirana slučajna šuma. U toj biblioteci postoje tri hiperparametra za slučajnu šumu:

  • trees - broj stabala u slučajnoj šumi
  • min_n - minimalni broj podataka potrebnih u vrhu kako bi se on dalje dijelio
  • mtry - broj prediktora koji će se na slučajni način birati kod svakog grananja (ranije opisani parametar \(m\))

Napomena. Ako stavimo da je mtry jednak broju prediktora, tada zapravo dobivamo pakirana stabla. Stoga na taj način možemo pakirati stabla i pomoću ranger biblioteke.

ranger biblioteka nudi dva načina računanja važnosti varijabli.

  1. način je već ranije opisan kod pakiranja stabala i stabla odlučivanja koristeći gini indeks kod klasifikacije i sumu kvadrata reziduala kod regresije.
  2. način je metoda permutacije koja funkcionira na sljedeći način. Za svako stablo u slučajnoj šumi izračuna se greška predikcije na njegovim OOB podacima (postotak pogrešno klasificiranih podataka kod klasifikacije, suma kvadrata reziduala kod regresije). Nakon toga se ponovi postupak, ali tako da se prije toga na slučajni način permutiraju vrijednosti pojedine varijable. Na kraju se uzmu prosječne vrijednosti tih razlika od svih stabala. Dobivene vrijednosti su važnosti pojedinih varijabli. Ukoliko stavimo scale.permutation.importance = TRUE, tada se dobivene vrijednosti još normaliziraju s pripadnim standardnim devijacijama pojedinih razlika. Po defaultu je scale.permutation.importance = FALSE.

Lokalna važnost varijabli. Lokalna važnost nekog prediktora na \(i\)-tom podatku računa se slično kao i gore opisana globalna važnost. Jedina razlika je što se promatraju samo stabla iz slučajne šume koja nisu trenirana na \(i\)-tom podatku.

Primjer klasifikacije

Pogledajmo sada slučajnu šumu na penguins podacima pri čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = species)

# definiranje modela
rf_model <- rand_forest() %>%
  set_args(mtry = tune(), trees = tune(), min_n = tune()) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

# recept za transformiranje podataka
rf_recipe <- recipe(species ~ ., data = training(podaci_split))

# workflow
rf_work <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(rf_recipe)

# cross validation folds
rf_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)

# metrike
rf_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)

rf_grid <- grid_latin_hypercube(mtry(range = c(2,3)), trees(range = c(400,1500)), 
                                min_n(), size = 100)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
rf_tuning <- rf_work %>% 
  tune_grid(resamples = rf_folds, grid = rf_grid, metrics = rf_metrike)
stopCluster(cl)

saveRDS(rf_tuning, file = "modeli/rf_tuning_klas.rds")

best_rf_model <- rf_tuning %>% select_best(metric = 'roc_auc')
final_rf_work <- rf_work %>% finalize_workflow(best_rf_model)
rf_fit <- final_rf_work %>% last_fit(split = podaci_split)

trenirani_model_work <- rf_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_work_klas.rds")
saveRDS(rf_fit, "modeli/rf_fit_klas.rds")

Učitavanje spremljenih podataka iz rds datoteka.

rf_fit_klas <- readRDS("modeli/rf_fit_klas.rds")
rf_tuning_klas <- readRDS("modeli/rf_tuning_klas.rds")
rf_work_klas <- readRDS("modeli/rf_work_klas.rds")
rf_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za roc_auc metriku

rf_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)

Metrike na testnom i trening skupu

tab1 <- rf_fit_klas %>% collect_predictions() %>% 
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- rf_work_klas %>%
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
rf_work_klas %>%  
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu

Confusion matrix na trening skupu

rf_fit_klas %>% collect_predictions() %>%
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu

Confusion matrix na testnom skupu

Pogledajmo ROC i gain krivulje

rf_fit_klas %>% collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje

ROC krivulje

rf_fit_klas %>% collect_predictions() %>% 
  gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje

gain krivulje

Pomoću treeInfo funkcije iz ranger biblioteke možemo dobiti informacije o pojedinom stablu u slučajnoj šumi. Pogledajmo na primjer informacije o 20. i 42. stablu iz slučajne šume.

treeInfo(rf_fit_klas %>% extract_fit_engine(), 20)
treeInfo(rf_fit_klas %>% extract_fit_engine(), 42)

Važnost varijabli temeljena na gini indeksu

rf_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (gini indeks)

Važnost varijabli (gini indeks)

Ako želimo važnost varijabli preko permutacijske metode, tada moramo staviti importance = "permutation". Ukoliko želimo da se izračunaju i lokalne važnosti varijabli, moramo staviti local.importance = TRUE. Vrijednosti hiperparametara su samo preuzete kako smo ih dobili ranije, nisu ispočetka ponovo tražene nove optimalne vrijednosti. Treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidyverse)
library(tidymodels)
library(modeldata)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = species)

# definiranje modela
rf_model <- rand_forest() %>%
  set_args(mtry = 2, trees = 453, min_n = 17) %>%
  set_engine("ranger", importance = "permutation", local.importance = TRUE) %>%
  set_mode("classification")

# recept za transformiranje podataka
rf_recipe <- recipe(species ~ ., data = training(podaci_split))

# workflow
rf_work <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(rf_recipe)

# metrike
rf_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)

rf_fit <- rf_work %>% last_fit(split = podaci_split)

trenirani_model_work <- rf_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_perm_work_klas.rds")
saveRDS(rf_fit, "modeli/rf_perm_fit_klas.rds")

Učitavanje spremljenih podataka iz rds datoteka.

rf_perm_fit_klas <- readRDS("modeli/rf_perm_fit_klas.rds")
rf_perm_work_klas <- readRDS("modeli/rf_perm_work_klas.rds")

Važnost varijabli temeljena na permutacijskoj metodi

rf_perm_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (permutacijska metoda)

Važnost varijabli (permutacijska metoda)

Lokalne važnosti varijabli

model <- rf_perm_fit_klas %>% extract_fit_engine()
as_tibble(model$variable.importance.local)

Različiti grafički prikazi lokalnih važnosti varijabli

tablica <- as_tibble(model$variable.importance.local) %>% 
  mutate(podatak = row_number()) %>%
  pivot_longer(-podatak, names_to = "aktivnost", values_to = "local importance")
ggplot(tablica, aes(x = podatak, y =`local importance`)) + 
  geom_point(alpha = 0.3, color = "black") + facet_wrap(vars(aktivnost))

ggplot(tablica, aes(x = podatak, y =`local importance`)) + 
  geom_line(linewidth = 0.2) + facet_wrap(vars(aktivnost))

ggplot(tablica, aes(x = aktivnost, y = `local importance`)) + geom_boxplot()

ggplot(tablica, aes(x = `local importance`)) + geom_histogram() +
  facet_wrap(vars(aktivnost), scales = "free_y")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Primjer regresije

Pogledajmo sada slučajnu šumu na penguins podacima koja procjenjuje masu pingvina u gramima pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)

# definiranje modela
rf_model <- rand_forest() %>%
  set_args(mtry = tune(), trees = tune(), min_n = tune()) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("regression")

# recept za transformiranje podataka
rf_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))

# workflow
rf_work <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(rf_recipe)

# cross validation folds
rf_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)

# metrike
rf_metrike <- metric_set(mae, mape, mpe, rmse, rsq)

rf_grid <- grid_latin_hypercube(mtry(range = c(2,3)), trees(range = c(400,1500)), 
                                min_n(), size = 100)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
rf_tuning <- rf_work %>% 
  tune_grid(resamples = rf_folds, grid = rf_grid, metrics = rf_metrike)
stopCluster(cl)

saveRDS(rf_tuning, file = "modeli/rf_tuning_reg.rds")

best_rf_model <- rf_tuning %>% select_best(metric = 'rmse')
final_rf_work <- rf_work %>% finalize_workflow(best_rf_model)
rf_fit <- final_rf_work %>% last_fit(split = podaci_split)

trenirani_model_work <- rf_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_work_reg.rds")
saveRDS(rf_fit, "modeli/rf_fit_reg.rds")

Učitavanje spremljenih podataka iz rds datoteka.

rf_fit_reg <- readRDS("modeli/rf_fit_reg.rds")
rf_tuning_reg <- readRDS("modeli/rf_tuning_reg.rds")
rf_work_reg <- readRDS("modeli/rf_work_reg.rds")
rf_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za rmse metriku

rf_tuning_reg %>% show_best(metric = 'rmse', n = 20)

Metrike na testnom i trening skupu

tab1 <- rf_fit_reg %>% collect_predictions() %>% 
  metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- rf_work_reg %>%
  metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(rf_work_reg, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
  ylab("predikcija")
Predikcije na trening skupu

Predikcije na trening skupu

test_podaci <- rf_fit_reg %>% collect_predictions()

ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5, alpha = 0.7) +
  geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu

Predikcije na testnom skupu

Pogledajmo na primjer informacije o 56. i 234. stablu iz slučajne šume.

treeInfo(rf_fit_reg %>% extract_fit_engine(), 56)
treeInfo(rf_fit_reg %>% extract_fit_engine(), 234)

Važnost varijabli temeljena na sumi kvadrata reziduala

rf_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (suma kvadrata reziduala)

Važnost varijabli (suma kvadrata reziduala)

Ako želimo važnost varijabli preko permutacijske metode, tada moramo staviti importance = "permutation". Ukoliko želimo da se izračunaju i lokalne važnosti varijabli, moramo staviti local.importance = TRUE. Vrijednosti hiperparametara su samo preuzete kako smo ih dobili ranije, nisu ispočetka ponovo tražene nove optimalne vrijednosti. Treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidyverse)
library(tidymodels)
library(modeldata)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)

# definiranje modela
rf_model <- rand_forest() %>%
  set_args(mtry = 2, trees = 407, min_n = 6) %>%
  set_engine("ranger", importance = "permutation", local.importance = TRUE) %>%
  set_mode("regression")

# recept za transformiranje podataka
rf_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))

# workflow
rf_work <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(rf_recipe)

# metrike
rf_metrike <- metric_set(mae, mape, mpe, rmse, rsq)

rf_fit <- rf_work %>% last_fit(split = podaci_split)

trenirani_model_work <- rf_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_perm_work_reg.rds")
saveRDS(rf_fit, "modeli/rf_perm_fit_reg.rds")

Učitavanje spremljenih podataka iz rds datoteka.

rf_perm_fit_reg <- readRDS("modeli/rf_perm_fit_reg.rds")
rf_perm_work_reg <- readRDS("modeli/rf_perm_work_reg.rds")

Važnost varijabli temeljena na permutacijskoj metodi

rf_perm_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (permutacijska metoda)

Važnost varijabli (permutacijska metoda)

Lokalne važnosti varijabli

model <- rf_perm_fit_reg %>% extract_fit_engine()
as_tibble(model$variable.importance.local)

Različiti grafički prikazi lokalnih važnosti varijabli

tablica <- as_tibble(model$variable.importance.local) %>% 
  mutate(podatak = row_number()) %>%
  pivot_longer(-podatak, names_to = "aktivnost", values_to = "local importance")
ggplot(tablica, aes(x = podatak, y =`local importance`)) + 
  geom_point(alpha = 0.3, color = "black") + facet_wrap(vars(aktivnost))

ggplot(tablica, aes(x = podatak, y =`local importance`)) + 
  geom_line(linewidth = 0.2) + facet_wrap(vars(aktivnost))

ggplot(tablica, aes(x = aktivnost, y = `local importance`)) + geom_boxplot()

ggplot(tablica, aes(x = `local importance`)) + geom_histogram() +
  facet_wrap(vars(aktivnost), scales = "free_y")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Boruta

Boruta algoritam služi za određivanje važnosti varijabli, a baziran je na slučajnoj šumi. Algoritam funkcionira na sljedeći način:

  1. Svaki prediktor u početnim podacima permutiramo na slučajni način i dodamo kao dodatnu varijablu u početne podatke.
  2. Na takvim podacima (zajedno s dodanim permutacijskim varijablama) treniramo slučajnu šumu.
  3. Među permutacijskim varijablama pronađemo onu koja ima najveću važnost.
  4. Svim prediktorima koji imaju veću važnost od najvažnije permutacijske varijable povećamo uspjeh za 1. Ukoliko neki od prediktora ima puno manju važnost od najvažnije permutacijske varijable, možemo ga odmah eliminirati i smatrati nebitnim.
  5. Makni iz podataka permutacijske varijable i ponovi korake 1. do 4. Po potrebi se napravi barem 20 ili više ponavljanja.
  6. Pomoću binomne distribucije donesi zaključak koji su prediktori zaista važni na temelju dobivenih uspjeha prediktora i broja iteracija.

Za više detalja možete pogledati ovaj članak.

Primijenimo sada Boruta algoritam na penguins podatke pri čemu želimo vidjeti koje su varijable bitne za predikciju vrste pingvina.

boruta_klas <- Boruta(species ~ ., data = podaci, doTrace = 0)

Već nakon 10 iteracija ispada da su svi prediktori za predikciju vrste pingvina važni, a najvažniji je bill_length_mm (duljina kljuna u milimetrima), a najmanje važan je spol pingvina.

par(mar = c(8, 2, 2, 2))
plot(boruta_klas, cex.axis = 0.9, las = 2, xlab = "", main = "Vrsta pingvina",
     colCode = c("green", "orange", "#f6546a", "#2acaea"))

Grafički i tablični prikazi kako su se kroz iteracije mijenjale važnosti varijabli.

plotImpHistory(boruta_klas, colCode = c("green", "orange", "#f6546a", "#2acaea"))

as_tibble(boruta_klas$ImpHistory)

Primijenimo još Boruta algoritam na penguins podatke pri čemu želimo vidjeti koje su varijable bitne za predikciju mase pingvina u gramima.

boruta_reg <- Boruta(body_mass_g ~ ., data = podaci, doTrace = 0)

Već nakon 10 iteracija ispada da su svi prediktori za predikciju mase pingvina važni, uglavnom su svi podjednake važnosti (spol pingvina ispada malo važniji od ostalih).

par(mar = c(8, 2, 2, 2))
plot(boruta_reg, cex.axis = 0.9, las = 2, xlab = "", main = "Masa pingvina",
     colCode = c("green", "orange", "#f6546a", "#2acaea"))

Grafički i tablični prikazi kako su se kroz iteracije mijenjale važnosti varijabli.

plotImpHistory(boruta_reg, colCode = c("green", "orange", "#f6546a", "#2acaea"))

as_tibble(boruta_reg$ImpHistory)

Pojačana stabla

Pojačana stabla, boosted trees, funkcioniraju slično kao i pakirana stabla, s time da stabla rastu sekvencijalno. Svako novo stablo se gradi tako da koristi informacije od prethodno izgrađenog stabla. U slučaju regresije općenita ideja je sljedeća:

  1. Stavi \(\hat{f}(x)=0\) i \(r_i=y_i\) za svaki \(i\) na skupu za treniranje.
  2. Za \(b=1,2,\dotsc,B\) ponavljaj:
    1. Izgradi stablo \(\hat{f}^b\) s \(d\) grananja (\(d+1\) listova) na trening skupu \((X,r)\).
    2. Ažuriraj \(\,\hat{f}\): \(\ \hat{f}(x)\leftarrow\hat{f}(x)+\lambda\hat{f}^b(x)\).
    3. Ažuriraj reziduale: \(r_i\leftarrow r_i-\lambda\hat{f}^b(x_i)\).
  3. Vrati pojačani model \[\hat{f}(x)=\sum_{b=1}^B{\lambda\hat{f}^b(x)}.\]

U slučaju klasifikacije također je slična ideja. Ukratko, \(\hat{f}^b\) se bira tako da se optimizira funkcija cilja. Funkcija cilja je zapravo suma funkcije koja mjeri efikasnost modela na trening skupu i regularizacije.

  • U slučaju regresije najčešća mjera efikasnosti modela je \(l(y_i,\hat{y}_i)=\frac{1}{2}(y_i-\hat{y}_i)^2\) pri čemu je \[g_i=\frac{\partial l}{\partial\hat{y}_i}=-\big(y_i-\hat{y}_i\big),\quad h_i=\frac{\partial^2l}{\partial\hat{y}_i^2}=1.\]
  • U slučaju binarne klasifikacije najčešća mjera efikasnosti modela je
    \[l(y_i,\hat{y}_i)=y_i\ln{\big(1+e^{-\hat{y}_i}\big)}+(1-y_i)\ln{\big(1+e^{\hat{y}_i}\big)}\] pri čemu se laganim računom dobiva \[g_i=\frac{\partial l}{\partial\hat{y}_i}=-\big(y_i-p_i\big),\quad h_i=\frac{\partial^2l}{\partial\hat{y}_i^2}=p_i(1-p_i),\quad p_i=\frac{e^{\hat{y}_i}}{1+e^{\hat{y}_i}}\]

Za više detalja možete pogledati xgboost ili jedan konkretni primjer.

Za dani trenutni model, gradimo novo stablo na rezidualima iz trenutnog modela. Novo stablo dodajemo u \(\hat{f}\) kako bismo ažurirali reziduale. Svako od tih stabla može biti relativno malo sa samo nekoliko listova, što određuje parametar \(d\) u algoritmu. Dodavanjem tih malih stabala polako poboljšavamo \(\hat{f}\) na dijelovima u kojima mu efikasnost nije najbolja. S parametrom \(\lambda\) možemo dodatno usporiti samo napredovanje kako bismo istražili više različito oblikovanih stabala za poboljšanje reziduala.

Ovdje ćemo preko tidymodels biblioteke koristiti xgboost biblioteku u kojoj su implementirana pojačana stabla. Neki od hiperparametara u toj biblioteci su sljedeći:

  • trees - broj stabala koja će se izgraditi
  • min_n - minimalni broj podataka potrebnih u vrhu kako bi se on dalje dijelio
  • mtry - broj prediktora koji će se na slučajni način birati kod svakog grananja
  • tree_depth - maksimalna dubina stabla, tj. broj grananja
  • learn_rate - brzina učenja (ranije spomenuti parametar \(\lambda\))

Prevelika vrijednost hiperparametra trees može ovdje dovesti do overfittinga. Tipične vrijednosti hiperparametra learn_rate su \(0.01\) ili \(0.001\), ali pravi izbor može ovisiti o samom problemu. Međutim, vrlo male vrijednosti od learn_rate zahtijevaju vrlo velike vrijednosti od trees ako želimo postići dobru efikasnost.

Primjer klasifikacije

Pogledajmo sada pojačana stabla na penguins podacima pri čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. Biblioteka xgboost radi samo s numeričkim prediktorima pa su zbog toga preko step_dummy(all_nominal_predictors()) svi nominalni prediktori pretvoreni u numeričke varijable. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = species)

# definiranje modela
boost_model <- boost_tree() %>%
  set_args(mtry = 2, trees = tune(), tree_depth = tune(), learn_rate = tune()) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

# recept za transformiranje podataka
boost_recipe <- recipe(species ~ ., data = training(podaci_split)) %>%
  step_dummy(all_nominal_predictors())

# workflow
boost_work <- workflow() %>% 
  add_model(boost_model) %>% 
  add_recipe(boost_recipe)

# cross validation folds
boost_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)

# metrike
boost_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)

boost_grid <- grid_latin_hypercube(trees(range(1,100)), tree_depth(), learn_rate(), size = 50)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
boost_tuning <- boost_work %>% 
  tune_grid(resamples = boost_folds, grid = boost_grid, metrics = boost_metrike)
stopCluster(cl)

saveRDS(boost_tuning, file = "modeli/boost_tuning_klas.rds")

best_boost_model <- boost_tuning %>% select_best(metric = 'roc_auc')
final_boost_work <- boost_work %>% finalize_workflow(best_boost_model)
boost_fit <- final_boost_work %>% last_fit(split = podaci_split)

trenirani_model_work <- boost_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/boost_work_klas.rds")
saveRDS(boost_fit, "modeli/boost_fit_klas.rds")

Učitavanje spremljenih podataka iz rds datoteka.

boost_fit_klas <- readRDS("modeli/boost_fit_klas.rds")
boost_tuning_klas <- readRDS("modeli/boost_tuning_klas.rds")
boost_work_klas <- readRDS("modeli/boost_work_klas.rds")
boost_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za roc_auc metriku

boost_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)

Metrike na testnom i trening skupu

tab1 <- boost_fit_klas %>% collect_predictions() %>% 
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- boost_work_klas %>%
  metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
boost_work_klas %>%  
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu

Confusion matrix na trening skupu

boost_fit_klas %>% collect_predictions() %>%
  conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
  scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu

Confusion matrix na testnom skupu

Pogledajmo ROC i gain krivulje

boost_fit_klas %>% collect_predictions() %>% 
  roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje

ROC krivulje

boost_fit_klas %>% collect_predictions() %>% 
  gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje

gain krivulje

Nacrtajmo na primjer 9. i 21. stablo u dobivenom modelu.

xgb.plot.tree(model = boost_fit_klas %>% extract_fit_engine(), trees = 9,
              plot_width = 800, plot_height = 500)
xgb.plot.tree(model = boost_fit_klas %>% extract_fit_engine(), trees = 21,
              plot_width = 1000, plot_height = 400)

Možemo vizualizirati distribucije vezane uz dubinu listova stabala.

xgb.ggplot.deepness(boost_fit_klas %>% extract_fit_engine())

Možemo dobiti kompleksnost samog modela tako da se sva stabla opišu preko jednog stablo kako bi se poboljšala interpretacija modela.

gr <- xgb.plot.multi.trees(boost_fit_klas %>% extract_fit_engine(), render = FALSE)
export_graph(gr, 'tree.pdf', width=1000, height=1500)

Možete otvoriti veću PDF sliku.

xgb.plot.multi.trees(boost_fit_klas %>% extract_fit_engine(), 
                     plot_width = 1000, plot_height = 1300)

xgboost biblioteka daje tri različite vrste važnosti varijabli.

  • Gain. Mjeri važnost varijable na temelju toga koliko se smanji neizvjesnost u svakom pojedinom vrhu ako se grana po toj varijabli gledano u svim stablima. Daje relativni broj u odnosu na sve varijable.
  • Cover. Mjeri važnost varijable na temelju pokrivenosti vrha. Pokrivenost nekog vrha jednaka je sumi svih Hesijana \(h_i\) podataka koji pripadaju tom vrhu. U slučaju regresije svi su Hesijani jednaki 1 pa je pokrivenost vrha jednaka broju podataka u tom vrhu. Za svaku varijablu se gledaju pokrivenosti onih vrhova u kojima se granalo po promatranoj varijabli. Daje relativni broj u odnosu na sve varijable.
  • Frequency. Mjeri koliko se puta pojedina varijabla javlja u svakom od stabala. Daje relativni broj u odnosu na sve varijable.
(vaznost <- xgb.importance(model = boost_fit_klas %>% extract_fit_engine()))
xgb.ggplot.importance(vaznost) + ggtitle("Importance (gain)")

xgb.ggplot.importance(vaznost, measure = "Cover") + 
  ggtitle("Importance (cover)")

xgb.ggplot.importance(vaznost, measure = "Frequency") + 
  ggtitle("Importance (frequency)")

Primjer regresije

Pogledajmo sada pojačana stabla na penguins podacima koja procjenjuju masu pingvina u gramima pomoću svih preostalih varijabli koje su prediktori. Biranje optimalnih hiperparametara i treniranje modela je napravljeno u zasebnoj R datoteci i sve potrebne informacije su spremljene u rds datoteke. U spomenutoj R datoteci se nalazi sljedeći kod.

library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)

data("penguins")
podaci <- penguins %>% na.omit()

podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)

# definiranje modela
boost_model <- boost_tree() %>%
  set_args(mtry = 2, trees = tune(), tree_depth = tune(), learn_rate = tune()) %>%
  set_engine("xgboost") %>%
  set_mode("regression")

# recept za transformiranje podataka
boost_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split)) %>%
  step_dummy(all_nominal_predictors())

# workflow
boost_work <- workflow() %>% 
  add_model(boost_model) %>% 
  add_recipe(boost_recipe)

# cross validation folds
boost_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)

# metrike
boost_metrike <- metric_set(mae, mape, mpe, rmse, rsq)

boost_grid <- grid_latin_hypercube(trees(range(1,100)), tree_depth(), learn_rate(), size = 100)

cl <- makePSOCKcluster(10)
registerDoParallel(cl)
boost_tuning <- boost_work %>% 
  tune_grid(resamples = boost_folds, grid = boost_grid, metrics = boost_metrike)
stopCluster(cl)

saveRDS(boost_tuning, file = "modeli/boost_tuning_reg.rds")

best_boost_model <- boost_tuning %>% select_best(metric = 'rmse')
final_boost_work <- boost_work %>% finalize_workflow(best_boost_model)
boost_fit <- final_boost_work %>% last_fit(split = podaci_split)

trenirani_model_work <- boost_fit %>% extract_workflow() %>%
  augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/boost_work_reg.rds")
saveRDS(boost_fit, "modeli/boost_fit_reg.rds")

Učitavanje spremljenih podataka iz rds datoteka.

boost_fit_reg <- readRDS("modeli/boost_fit_reg.rds")
boost_tuning_reg <- readRDS("modeli/boost_tuning_reg.rds")
boost_work_reg <- readRDS("modeli/boost_work_reg.rds")
boost_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika

Vrijednosti hiperparametara i metrika

20 najboljih modela za rmse metriku

boost_tuning_reg %>% show_best(metric = 'rmse', n = 20)

Metrike na testnom i trening skupu

tab1 <- boost_fit_reg %>% collect_predictions() %>% 
  metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- boost_work_reg %>%
  metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
  select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(boost_work_reg, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
  ylab("predikcija")
Predikcije na trening skupu

Predikcije na trening skupu

test_podaci <- boost_fit_reg %>% collect_predictions()

ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
  geom_point(size = 1.5, alpha = 0.7) +
  geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu

Predikcije na testnom skupu

Nacrtajmo na primjer 5. i 33. stablo u dobivenom modelu.

xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(), trees = 5,
              plot_width = 700, plot_height = 450)
gr1 <- xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(), 
                     trees = 33, render = FALSE)
export_graph(gr1, 'tree33.pdf', width=1000, height=900)

Možete otvoriti veću PDF sliku.

xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(), trees = 33,
              plot_width = 1000, plot_height = 900)

Možemo vizualizirati distribucije vezane uz dubinu listova stabala.

xgb.ggplot.deepness(boost_fit_reg %>% extract_fit_engine())

Možemo dobiti kompleksnost samog modela tako da se sva stabla opišu preko jednog stablo kako bi se poboljšala interpretacija modela.

gr3 <- xgb.plot.multi.trees(boost_fit_reg %>% extract_fit_engine(), render = FALSE)
export_graph(gr3, 'tree_reg.pdf', width=900, height=1800)

S obzirom da je stablo dosta veliko, stavljen je samo link na PDF sliku.

Važnost varijabli

(vaznost_reg <- xgb.importance(model = boost_fit_reg %>% extract_fit_engine()))
xgb.ggplot.importance(vaznost_reg) + ggtitle("Importance (gain)")

xgb.ggplot.importance(vaznost_reg, measure = "Cover") + 
  ggtitle("Importance (cover)")

xgb.ggplot.importance(vaznost_reg, measure = "Frequency") + 
  ggtitle("Importance (frequency)")