Koristit ćemo tidyverse biblioteku koja sadrži
R pakete za obradu i prikazivanje podataka. Detaljnije
informacije o svim paketima možete pronaći na tidyverse.
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::group_rows() masks kableExtra::group_rows()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
Za učenje pod nadzorom koristit ćemo tidymodels
biblioteku koja sadrži R pakete za statističku analizu i
modeliranje. Detaljnije informacije o svim paketima možete pronaći na tidymodels.
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
## ✔ broom 1.0.6 ✔ rsample 1.2.1
## ✔ dials 1.3.0 ✔ tune 1.2.1
## ✔ infer 1.0.7 ✔ workflows 1.1.4
## ✔ modeldata 1.4.0 ✔ workflowsets 1.1.0
## ✔ parsnip 1.2.1 ✔ yardstick 1.3.1
## ✔ recipes 1.1.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::group_rows() masks kableExtra::group_rows()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
Dodatne biblioteke koje ćemo trebati.
library(rpart.plot)
library(vip)
library(baguette)
library(ranger)
library(Boruta)
library(xgboost)
library(DiagrammeR)
update_geom_defaults(geom = "point", new = list(color = "#F8766D", alpha = 0.5))
update_geom_defaults(geom = "tile", new = list(color = "black"))
Ovdje ćemo koristiti modele koji su bazirani na stablima. Modeli bazirani na stablima nemaju parametara, ali imaju hiperparametre. Hiperparametre algoritam za učenje ne uči, nego se njihove vrijednosti moraju postaviti prije samog procesa učenja. Vrijednosti hiperparametara mogu bitno utjecati na efikasnost modela pa ih treba pažljivo odabrati. Možemo koristiti defaultne vrijednosti koje su postavljene u određenoj implementaciji, možemo sami odabrati njihove vrijednosti ili pak možemo npr. preko cross-validation testirati ponašanje algoritma za učenje na dovoljnom broju kombinacija vrijednosti određenih hiperparametara (tuning hyperparameters) i pritom mjeriti za svaku takvu kombinaciju efikasnost promatranog algoritma za učenje. Na kraju odabiremo onu kombinaciju hiperparametara za koju algoritam za učenje ima najbolju efikasnost s obzirom na zadanu (željenu) metriku.
Za kreiranje mreže vrijednosti hiperparametara
tidymodels nudi gotove funkcije grid_regular,
grid_random, grid_max_entropy i
grid_latin_hypercube. Detalje o tim funkcijama možete
vidjeti na dials
reference. Ovdje ćemo koristiti funkciju
grid_latin_hypercube koja radi uzorkovanje na latinskoj
hiperkocki.
tidymodels biblioteka također nudi implementirane
metrike za mjerenje efikasnosti algoritama za učenje u slučaju
klasifikacije i u slučaju regresije. Popis dostupnih metrika možete
vidjeti na yardstick
metrike. Metrike koje ćemo ovdje koristiti dane su u donjoj
tablici.
| metrika | tip | info |
|---|---|---|
mae
|
regresija | mean absolute error |
mape
|
regresija | mean absolute percentage error |
mpe
|
regresija | mean percentage error |
rmse
|
regresija | root mean square error |
rsq
|
regresija | R squared |
sens
|
klasifikacija | sensitivity |
spec
|
klasifikacija | specificity |
precision
|
klasifikacija | precision |
accuracy
|
klasifikacija | accuracy |
f_meas
|
klasifikacija | F-score |
roc_auc
|
klasifikacija | ROC AUC1, ROC AUC2, Hand-Till |
kap
|
klasifikacija | Cohen’s kappa |
U tidymodels biblioteki možemo definirati koje metrike
želimo koristiti pomoću metric_set naredbe.
metrike_klas <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
metrike_reg <- metric_set(mae, mape, mpe, rmse, rsq)
Za određivanje optimalne kombinacije hiperparametara u slučaju
regresije koristit ćemo rmse metriku koja daje prosječnu
udaljenost između predviđenih i stvarnih vrijednosti.
Sve klasifikacijske metrike su implementirane za binarni slučaj i
slučaj s više od dvije klase. Detalje o tome možete pogledati na multiclass
averaging. Za određivanje optimalne kombinacije hiperparametara u
slučaju klasifikacije koristit ćemo roc_auc metriku i Hand-Till
metodu.
U binarnom slučaju ROC krivulja pokazuje odnos između dviju mjera,
sensitivity i 1-specificity. Sensitivity mjeri
u kojem postotku je algoritam za učenje točan na pozitivnoj klasi, a
specificity mjeri u kojem postotku je algoritam za učenje točan na
negativnoj klasi. Algoritam za učenje daje vjerojatnosti pripadanja
određenoj klasi. Kako bi se odredilo kojoj klasi promatrani podatak
pripada, mora se zadati određena granica (threshold) za pozitivnu klasu.
Standardno se koristi granica od \(0.5\): ako je dobivena vjerojatnost za
pozitivnu klasu veća ili jednaka od \(0.5\), algoritam će promatrani podatak
klasificirati u pozitivnu klasu, a u protivnom u negativnu klasu.
Međutim, može se zadati proizvoljna granica \(p\in[0,1]\) za pozitivnu klasu. U tom
slučaju, ako je dobivena vjerojatnost za pozitivnu klasu veća ili
jednaka od \(p\), algoritam će
promatrani podatak klasificirati u pozitivnu klasu, a u protivnom u
negativnu klasu. Upravo to je poanta ROC krivulje: za različite
vrijednosti granice \(p\in[0,1]\)
prikazati koliko je algoritam dobar na pozitivnoj klasi u odnosu na to
koliko je loš na negativnoj klasi. Površina ispod ROC krivulje je
zapravo roc_auc metrika.
U slučaju više klasa, roc_auc metriku možemo odrediti
preko macro
averaging. Zapravo se koristi one-vs-all pristup gdje se svaka klasa
gleda u odnosu na ostale pa se na taj način svodi na više binarnih
slučaja. Međutim, takva metrika je osjetljiva na raspodjelu klasa. Zbog
toga je implementirana hand_till metoda koja čuva
neosjetljivost na raspodjelu klasa, ali nema jednostavnu vizualnu
interpretaciju. Za detalje pogledajte članak Hand-Till,
a ovdje ćemo samo ukratko intuitivno objasniti glavnu ideju.
roc_auc s hand_till metodom
računa se po formuli \[M=\frac{2}{c(c-1)}\sum_{i<j}{\hat{A}(i,j)}\]
pri čemu je \(c\) ukupni broj klasa.
Drugim riječima, \(M\) je zapravo
aritmetička sredina svih \(\hat{A}(i,j)\). Uočite da je \[\frac{2}{c(c-1)}=\frac{1}{\binom{c}{2}},\]
a \(\binom{c}{2}\) jednak je ukupnom
broju svih neuređenih parova postojećih \(c\) klasa.Iz navedenih razmatranja možemo zaključiti da je roc_auc
dobiven hand_till metodom zapravo jednak prosječnoj
vjerojatnosti da slučajno odabrani član iz neke klase ima manju
procijenjenu vjerojatnost da pripada nekoj drugoj klasi od slučajno
odabranog člana iz te druge klase.
Koristit ćemo penguins podatke o pingvinima iz
modeldata biblioteke.
data("penguins")
podaci <- penguins %>% na.omit()
podaci %>% glimpse()
## Rows: 333
## Columns: 7
## $ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel…
## $ island <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgerse…
## $ bill_length_mm <dbl> 39.1, 39.5, 40.3, 36.7, 39.3, 38.9, 39.2, 41.1, 38.6…
## $ bill_depth_mm <dbl> 18.7, 17.4, 18.0, 19.3, 20.6, 17.8, 19.6, 17.6, 21.2…
## $ flipper_length_mm <int> 181, 186, 195, 193, 190, 181, 195, 182, 191, 198, 18…
## $ body_mass_g <int> 3750, 3800, 3250, 3450, 3650, 3625, 4675, 3200, 3800…
## $ sex <fct> male, female, female, female, male, female, male, fe…
Podaci sadrže tri faktorske varijable i četiri numeričke varijable.
podaci %>% pivot_longer(c(species, island, sex), names_to = "varijabla", values_to = "vrijednost") %>%
ggplot(aes(x = vrijednost)) + geom_bar(fill = "skyblue3") +
facet_wrap(vars(varijabla), scales = "free_x") + xlab("") + ylab("")
Stupčasti dijagrami faktorskih varijabli
podaci %>% pivot_longer(where(is.numeric), names_to = "varijabla", values_to = "vrijednost") %>%
ggplot(aes(x = vrijednost)) + geom_histogram(fill = "skyblue", color = "violet") +
facet_wrap(vars(varijabla), scales = "free_x") + xlab("") + ylab("")
Histogrami numeričkih varijabli
podaci %>% pivot_longer(where(is.numeric), names_to = "varijabla", values_to = "vrijednost") %>%
ggplot(aes(x = varijabla, y = vrijednost)) + geom_boxplot(color = "skyblue3") +
stat_summary(fun = mean, geom = "point", shape = 10, size=3, color="sienna2") +
facet_wrap(vars(varijabla), scales = "free") + xlab("") + ylab("") +
theme(axis.text.x = element_blank(),
axis.ticks.x = element_blank())
Brkate kutije numeričkih varijabli zajedno s aritmetičkom sredinom
Regresija. Neka su \(X_1,X_2,\dotsc,X_p\) prediktorske varijable, a \(Y\) numerička response varijabla.
Klasifikacija. Neka su \(X_1,X_2,\dotsc,X_p\) prediktorske varijable, a \(Y\) kvalitativna response varijabla. Klasifikacijsko stablo odlučivanja funkcionira na sličan način kao i regresijsko stablo odlučivanja.
Obrezivanje stabla (tree pruning). Postupak izgradnje stabla pomoću rekurzivne binarne podjele može dovesti do prenaučenosti (overfitting), tj. stablo odlučivanja može davati dobre predikcije na skupu na kojem je trenirano, ali loše na novom testnom skupu. S druge strane, stablo odlučivanja s manjim brojem kutija (listova) može smanjiti prenaučenost i omogućiti jednostavniju interpretaciju uz malu žrtvu da se lošije ponaša na podacima na kojima je trenirano.
Opisane gornje strategije daju manja stabla, ali uz jedan veliki nedostatak. Preranim prekidanjem grananja u nekom vrhu možemo biti “oštećeni” za neko dobro grananje koje bi slijedilo kasnije da nismo ranije prekinuli. Drugim riječima, ukoliko u nekom vrhu nema značajnog smanjenja sume kvadrata reziduala ili Gini indeksa, to ne znači da se u kasnijim grananjima ne mogu javiti značajnija smanjenja.
Bolja strategija je prvo napraviti vrlo veliko stablo \(T_0\), a nakon toga ga obrezati kako bismo dobili optimalno podstablo. Jedna takva poznata strategija je cost complexity pruning. Promatramo niz stabala indeksiranih s nenegativnim parametrom podešavanja \(\alpha\).
Pritom je \(|T|\) oznaka za ukupni broj listova u stablu \(T\). Pomoću parametra \(\alpha\) kontroliramo koliko će biti kompleksno stablo odlučivanja, odnosno koliko ćemo “kažnjavati” stablo s obzirom na broj listova koje ima. Ako je \(\alpha=0\), onda uopće nema kazne i u tom slučaju zapravo dobivamo najkompleksnije stablo. Ako je \(\alpha=\infty\), tada je kazna prevelika jer dobivamo prazno stablo, tj. stablo koje ima samo jedan vrh (korijen). Pomoću cross-validation možemo pronaći optimalnu vrijednost za parametar \(\alpha\). Za više detalja možete pogledati npr. linkove prune1 i prune2.
U tidymodels biblioteki postoje tri hiperparametra kod
stabla odlučivanja:
tree_depth - prirodni broj za najveću dozvoljenu dubinu
stablamin_n - minimalni broj podataka potrebnih u vrhu kako
bi se on dalje dijeliocost_complexity - ranije opisani parametar \(\alpha\) za obrezivanje stablaSvaki od tih hipeparametara ima već postavljene defaultne intervale unutar kojih se ti parametri biraju.
tree_depth()
## Tree Depth (quantitative)
## Range: [1, 15]
tree_depth() %>% range_get()
## $lower
## [1] 1
##
## $upper
## [1] 15
min_n()
## Minimal Node Size (quantitative)
## Range: [2, 40]
min_n() %>% range_get()
## $lower
## [1] 2
##
## $upper
## [1] 40
cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-10, -1]
cost_complexity() %>% range_get()
## $lower
## [1] 1e-10
##
## $upper
## [1] 0.1
Možemo po želji postaviti i druge intervale za pojedini hiperparametar.
cost_complexity(range(-20,1))
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10 [1e-100, Inf]
## Range (transformed scale): [-20, 1]
cost_complexity(range(-20,1)) %>% range_get()
## $lower
## [1] 1e-20
##
## $upper
## [1] 10
Pogledajmo sada klasifikacijsko stablo na penguins
podacima koje klasificira vrstu pingvina pomoću svih preostalih
varijabli koje su prediktori. Biranje optimalnih hiperparametara i
treniranje modela je napravljeno u zasebnoj R datoteci i
sve potrebne informacije su spremljene u rds datoteke. U
spomenutoj R datoteci se nalazi sljedeći kod.
library(tidymodels)
library(modeldata)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = species)
# definiranje modela
stablo_model <- decision_tree() %>%
set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
# recept za transformiranje podataka
stablo_recipe <- recipe(species ~ ., data = training(podaci_split))
# workflow
stablo_work <- workflow() %>%
add_model(stablo_model) %>%
add_recipe(stablo_recipe)
# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)
# metrike
stablo_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 5000)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>%
tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)
saveRDS(stablo_tuning, file = "modeli/stablo_tuning_klas.rds")
best_stablo_model <- stablo_tuning %>% select_best(metric = 'roc_auc')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)
trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/stablo_work_klas.rds")
saveRDS(stablo_fit, "modeli/stablo_fit_klas.rds")
Učitavanje spremljenih podataka iz rds datoteka.
stablo_fit_klas <- readRDS("modeli/stablo_fit_klas.rds")
stablo_tuning_klas <- readRDS("modeli/stablo_tuning_klas.rds")
stablo_work_klas <- readRDS("modeli/stablo_work_klas.rds")
stablo_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za roc_auc metriku
stablo_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)
Metrike na testnom i trening skupu
tab1 <- stablo_fit_klas %>% collect_predictions() %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- stablo_work_klas %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
stablo_work_klas %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu
stablo_fit_klas %>% collect_predictions() %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu
Pogledajmo ROC i gain krivulje
stablo_fit_klas %>% collect_predictions() %>%
roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje
stablo_fit_klas %>% collect_predictions() %>%
gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje
Pogledajmo i dobiveno stablo odlučivanja.
stablo_fit_klas %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 3)
klasifikacijsko stablo odlučivanja
Na primjer, iz dobivenog stabla možemo zaključiti ukoliko pingvin ima peraje dugačke barem 208 milimetara i nalazi se na Biscoe otocima, tada se radi o Gentoo vrsti pingvina.
Možemo ispisati stablo u tekstualnom formatu.
stablo_fit_klas %>%
extract_fit_engine()
## n= 249
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 249 140 Adelie (0.437751004 0.204819277 0.357429719)
## 2) flipper_length_mm< 207.5 156 48 Adelie (0.692307692 0.301282051 0.006410256)
## 4) bill_length_mm< 43.35 108 2 Adelie (0.981481481 0.018518519 0.000000000) *
## 5) bill_length_mm>=43.35 48 3 Chinstrap (0.041666667 0.937500000 0.020833333)
## 10) island=Biscoe,Torgersen 3 1 Adelie (0.666666667 0.000000000 0.333333333) *
## 11) island=Dream 45 0 Chinstrap (0.000000000 1.000000000 0.000000000) *
## 3) flipper_length_mm>=207.5 93 5 Gentoo (0.010752688 0.043010753 0.946236559)
## 6) island=Dream,Torgersen 5 1 Chinstrap (0.200000000 0.800000000 0.000000000) *
## 7) island=Biscoe 88 0 Gentoo (0.000000000 0.000000000 1.000000000) *
Također, možemo dobiti detaljnije informacije o gradnji stabla kod svakog vrha.
model <- stablo_fit_klas %>% extract_fit_engine()
summary(model)
## Call:
## rpart::rpart(formula = ..y ~ ., data = data, cp = ~1.60719499929497e-10,
## maxdepth = ~3, minsplit = min_rows(6, data))
## n= 249
##
## CP nsplit rel error xerror xstd
## 1 6.214286e-01 0 1.00000000 1.00000000 0.05591773
## 2 3.071429e-01 1 0.37857143 0.40000000 0.04705925
## 3 2.857143e-02 2 0.07142857 0.11428571 0.02763823
## 4 1.428571e-02 3 0.04285714 0.05714286 0.01987585
## 5 1.607195e-10 4 0.02857143 0.05000000 0.01863069
##
## Variable importance
## flipper_length_mm bill_length_mm bill_depth_mm body_mass_g
## 24 23 20 17
## island
## 16
##
## Node number 1: 249 observations, complexity param=0.6214286
## predicted class=Adelie expected loss=0.562249 P(node) =1
## class counts: 109 51 89
## probabilities: 0.438 0.205 0.357
## left son=2 (156 obs) right son=3 (93 obs)
## Primary splits:
## flipper_length_mm < 207.5 to the left, improve=82.41562, (0 missing)
## bill_length_mm < 43.25 to the left, improve=77.82187, (0 missing)
## bill_depth_mm < 16.45 to the right, improve=75.79562, (0 missing)
## body_mass_g < 4525 to the left, improve=60.35847, (0 missing)
## island splits as RLL, improve=47.71573, (0 missing)
## Surrogate splits:
## bill_depth_mm < 16.35 to the right, agree=0.940, adj=0.839, (0 split)
## body_mass_g < 4687.5 to the left, agree=0.908, adj=0.753, (0 split)
## island splits as RLL, agree=0.831, adj=0.548, (0 split)
## bill_length_mm < 43.25 to the left, agree=0.783, adj=0.419, (0 split)
##
## Node number 2: 156 observations, complexity param=0.3071429
## predicted class=Adelie expected loss=0.3076923 P(node) =0.626506
## class counts: 108 47 1
## probabilities: 0.692 0.301 0.006
## left son=4 (108 obs) right son=5 (48 obs)
## Primary splits:
## bill_length_mm < 43.35 to the left, improve=57.4298400, (0 missing)
## island splits as LRL, improve=21.8747000, (0 missing)
## flipper_length_mm < 196.5 to the left, improve= 9.5811760, (0 missing)
## body_mass_g < 3225 to the left, improve= 2.3040520, (0 missing)
## bill_depth_mm < 19.55 to the left, improve= 0.8794872, (0 missing)
## Surrogate splits:
## flipper_length_mm < 196.5 to the left, agree=0.763, adj=0.229, (0 split)
## island splits as LRL, agree=0.712, adj=0.063, (0 split)
##
## Node number 3: 93 observations, complexity param=0.02857143
## predicted class=Gentoo expected loss=0.05376344 P(node) =0.373494
## class counts: 1 4 88
## probabilities: 0.011 0.043 0.946
## left son=6 (5 obs) right son=7 (88 obs)
## Primary splits:
## island splits as RLL, improve=7.9483870, (0 missing)
## bill_depth_mm < 17.65 to the right, improve=7.9483870, (0 missing)
## body_mass_g < 4125 to the left, improve=3.1382750, (0 missing)
## flipper_length_mm < 212.5 to the left, improve=1.1039430, (0 missing)
## bill_length_mm < 48.9 to the right, improve=0.5234619, (0 missing)
## Surrogate splits:
## bill_depth_mm < 17.65 to the right, agree=1.000, adj=1.0, (0 split)
## body_mass_g < 4125 to the left, agree=0.968, adj=0.4, (0 split)
##
## Node number 4: 108 observations
## predicted class=Adelie expected loss=0.01851852 P(node) =0.4337349
## class counts: 106 2 0
## probabilities: 0.981 0.019 0.000
##
## Node number 5: 48 observations, complexity param=0.01428571
## predicted class=Chinstrap expected loss=0.0625 P(node) =0.1927711
## class counts: 2 45 1
## probabilities: 0.042 0.937 0.021
## left son=10 (3 obs) right son=11 (45 obs)
## Primary splits:
## island splits as LRL, improve=4.3750000, (0 missing)
## body_mass_g < 4575 to the left, improve=2.7518120, (0 missing)
## bill_depth_mm < 16.5 to the right, improve=0.8822464, (0 missing)
## bill_length_mm < 45.85 to the left, improve=0.7583333, (0 missing)
## flipper_length_mm < 202.5 to the left, improve=0.1891696, (0 missing)
## Surrogate splits:
## body_mass_g < 4575 to the right, agree=0.979, adj=0.667, (0 split)
##
## Node number 6: 5 observations
## predicted class=Chinstrap expected loss=0.2 P(node) =0.02008032
## class counts: 1 4 0
## probabilities: 0.200 0.800 0.000
##
## Node number 7: 88 observations
## predicted class=Gentoo expected loss=0 P(node) =0.3534137
## class counts: 0 0 88
## probabilities: 0.000 0.000 1.000
##
## Node number 10: 3 observations
## predicted class=Adelie expected loss=0.3333333 P(node) =0.01204819
## class counts: 2 0 1
## probabilities: 0.667 0.000 0.333
##
## Node number 11: 45 observations
## predicted class=Chinstrap expected loss=0 P(node) =0.1807229
## class counts: 0 45 0
## probabilities: 0.000 1.000 0.000
model$splits
## count ncat improve index adj
## flipper_length_mm 249 -1 82.4156228 207.50 0.0000000
## bill_length_mm 249 -1 77.8218746 43.25 0.0000000
## bill_depth_mm 249 1 75.7956167 16.45 0.0000000
## body_mass_g 249 -1 60.3584658 4525.00 0.0000000
## island 249 3 47.7157254 1.00 0.0000000
## bill_depth_mm 0 1 0.9397590 16.35 0.8387097
## body_mass_g 0 -1 0.9076305 4687.50 0.7526882
## island 0 3 0.8313253 2.00 0.5483871
## bill_length_mm 0 -1 0.7831325 43.25 0.4193548
## bill_length_mm 156 -1 57.4298433 43.35 0.0000000
## island 156 3 21.8746973 3.00 0.0000000
## flipper_length_mm 156 -1 9.5811757 196.50 0.0000000
## body_mass_g 156 -1 2.3040518 3225.00 0.0000000
## bill_depth_mm 156 -1 0.8794872 19.55 0.0000000
## flipper_length_mm 0 -1 0.7628205 196.50 0.2291667
## island 0 3 0.7115385 4.00 0.0625000
## island 48 3 4.3750000 5.00 0.0000000
## body_mass_g 48 -1 2.7518116 4575.00 0.0000000
## bill_depth_mm 48 1 0.8822464 16.50 0.0000000
## bill_length_mm 48 -1 0.7583333 45.85 0.0000000
## flipper_length_mm 48 -1 0.1891696 202.50 0.0000000
## body_mass_g 0 1 0.9791667 4575.00 0.6666667
## island 93 3 7.9483871 6.00 0.0000000
## bill_depth_mm 93 1 7.9483871 17.65 0.0000000
## body_mass_g 93 -1 3.1382747 4125.00 0.0000000
## flipper_length_mm 93 -1 1.1039427 212.50 0.0000000
## bill_length_mm 93 1 0.5234619 48.90 0.0000000
## bill_depth_mm 0 1 1.0000000 17.65 1.0000000
## body_mass_g 0 -1 0.9677419 4125.00 0.4000000
Objasnimo ukratko kako se u gornjem ispisu računa
improve. Neka je \(V\)
promatrani vrh, a \(V_L\) i \(V_R\) redom njegovo lijevo i desno dijete
nakon grananja po nekoj varijabli. Neka su \(G(V), G(V_L)\) i \(G(V_R)\) redom Gini indeksi vrhova \(V,V_L\) i \(V_R\). Nadalje, neka je s \(|V|, |V_L|\) i \(|V_R|\) označen ukupni broj podataka koji
se nalaze redom u vrhovima \(V,V_L\) i
\(V_R\). Tada je \(|V|=|V_L|+|V_R|\) i vrijedi \[\mathrm{improve}=|V_L|\cdot\big(G(V)-G(V_L)\big)+|V_R|\cdot\big(G(V)-G(V_R)\big).\]
U donjem grafičkom prikazu važnosti prediktora, zapravo je prikazana
suma svih improve vrijednosti za pojedinu varijablu.
stablo_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli
Također vidimo da u summary ispisu imamo skalirane
cjelobrojne važnosti varijabli tako da je njihova suma jednaka 100.
Ukoliko želimo, možemo sami prikazati grafički skalirane važnosti
varijabli.
(varimpo <- model$variable.importance %>% as_tibble() %>%
bind_cols(varijabla = names(model$variable.importance)) %>%
mutate(rel_importance = round(100 * value / sum(value))) %>%
relocate(varijabla) %>% rename(importance = value))
ggplot(varimpo, aes(x = fct_reorder(varijabla, rel_importance), y = rel_importance)) +
geom_col() + coord_flip() + xlab("") + ylab("importance")
skalirana važnost varijabli
Pogledajmo sada regresijsko stablo na penguins podacima
koje procjenjuje masu pingvina u gramima pomoću svih preostalih
varijabli koje su prediktori. Biranje optimalnih hiperparametara i
treniranje modela je napravljeno u zasebnoj R datoteci i
sve potrebne informacije su spremljene u rds datoteke. U
spomenutoj R datoteci se nalazi sljedeći kod.
library(tidymodels)
library(modeldata)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)
# definiranje modela
stablo_model <- decision_tree() %>%
set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
set_engine("rpart") %>%
set_mode("regression")
# recept za transformiranje podataka
stablo_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))
# workflow
stablo_work <- workflow() %>%
add_model(stablo_model) %>%
add_recipe(stablo_recipe)
# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)
# metrike
stablo_metrike <- metric_set(mae, mape, mpe, rmse, rsq)
stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 5000)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>%
tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)
saveRDS(stablo_tuning, file = "modeli/stablo_tuning_reg.rds")
best_stablo_model <- stablo_tuning %>% select_best(metric = 'rmse')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)
trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/stablo_work_reg.rds")
saveRDS(stablo_fit, "modeli/stablo_fit_reg.rds")
Učitavanje spremljenih podataka iz rds datoteka.
stablo_fit_reg <- readRDS("modeli/stablo_fit_reg.rds")
stablo_tuning_reg <- readRDS("modeli/stablo_tuning_reg.rds")
stablo_work_reg <- readRDS("modeli/stablo_work_reg.rds")
stablo_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za rmse metriku
stablo_tuning_reg %>% show_best(metric = 'rmse', n = 20)
Metrike na testnom i trening skupu
tab1 <- stablo_fit_reg %>% collect_predictions() %>%
metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- stablo_work_reg %>%
metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(stablo_work_reg, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
ylab("predikcija")
Predikcije na trening skupu
test_podaci <- stablo_fit_reg %>% collect_predictions()
ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5, alpha = 0.7) +
geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu
Pogledajmo dobiveno regresijsko stablo odlučivanja.
stablo_fit_reg %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 3, tweak = 1.2)
regresijsko stablo odlučivanja
Na primjer, iz regresijskog stabla možemo vidjeti da Gentoo muški pingvini s duljinom kljuna većom od 47.7 milimetara imaju u prosjeku težinu od 5542 grama.
Ispis stabla u tekstualnom formatu.
stablo_fit_reg %>%
extract_fit_engine()
## n= 248
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 248 151430000.0 4203.528
## 2) species=Adelie,Chinstrap 160 29311870.0 3725.156
## 4) sex=female 79 6079747.0 3438.608
## 8) bill_depth_mm< 18.75 69 4804366.0 3399.638
## 16) flipper_length_mm< 195.5 60 3937958.0 3364.167
## 32) bill_depth_mm< 16.65 9 592222.2 3119.444 *
## 33) bill_depth_mm>=16.65 51 2711618.0 3407.353 *
## 17) flipper_length_mm>=195.5 9 287638.9 3636.111 *
## 9) bill_depth_mm>=18.75 10 447562.5 3707.500 *
## 5) sex=male 81 10418890.0 4004.630
## 10) flipper_length_mm< 194.5 37 3661419.0 3880.405 *
## 11) flipper_length_mm>=194.5 44 5706364.0 4109.091
## 22) bill_length_mm>=49.25 18 2527812.0 3970.833
## 44) flipper_length_mm< 199.5 7 230535.7 3603.571 *
## 45) flipper_length_mm>=199.5 11 752272.7 4204.545 *
## 23) bill_length_mm< 49.25 26 2596274.0 4204.808 *
## 3) species=Gentoo 88 18932240.0 5073.295
## 6) sex=female 45 3109250.0 4698.333
## 12) flipper_length_mm< 211.5 18 1221944.0 4530.556 *
## 13) flipper_length_mm>=211.5 27 1042824.0 4810.185 *
## 7) sex=male 43 2875029.0 5465.698
## 14) bill_length_mm< 47.7 10 170250.0 5215.000 *
## 15) bill_length_mm>=47.7 33 1885833.0 5541.667 *
Detaljnije informacije o gradnji stabla kod svakog vrha.
model <- stablo_fit_reg %>% extract_fit_engine()
summary(model)
## Call:
## rpart::rpart(formula = ..y ~ ., data = data, cp = ~0.0037551782899252,
## maxdepth = ~14, minsplit = min_rows(9, data))
## n= 248
##
## CP nsplit rel error xerror xstd
## 1 0.681409870 0 1.0000000 1.0038092 0.06866008
## 2 0.085504603 1 0.3185901 0.3230474 0.02492202
## 3 0.084614886 2 0.2330855 0.2693748 0.02386278
## 4 0.006996383 3 0.1484706 0.1529744 0.01321357
## 5 0.005576710 6 0.1274815 0.1705304 0.01498297
## 6 0.005466672 7 0.1219048 0.1721455 0.01544124
## 7 0.005408080 8 0.1164381 0.1711754 0.01540275
## 8 0.004004777 9 0.1110300 0.1649165 0.01544877
## 9 0.003755178 11 0.1030205 0.1647426 0.01542931
##
## Variable importance
## bill_depth_mm flipper_length_mm species island
## 24 24 22 14
## bill_length_mm sex
## 11 5
##
## Node number 1: 248 observations, complexity param=0.6814099
## mean=4203.528, MSE=610605
## left son=2 (160 obs) right son=3 (88 obs)
## Primary splits:
## species splits as LLR, improve=0.6814099, (0 missing)
## flipper_length_mm < 208.5 to the left, improve=0.6649007, (0 missing)
## bill_depth_mm < 16.4 to the right, improve=0.4916466, (0 missing)
## island splits as RLL, improve=0.3908572, (0 missing)
## bill_length_mm < 42.55 to the left, improve=0.3111583, (0 missing)
## Surrogate splits:
## flipper_length_mm < 206.5 to the left, agree=0.972, adj=0.920, (0 split)
## bill_depth_mm < 16.4 to the right, agree=0.960, adj=0.886, (0 split)
## island splits as RLL, agree=0.863, adj=0.614, (0 split)
## bill_length_mm < 43.25 to the left, agree=0.774, adj=0.364, (0 split)
##
## Node number 2: 160 observations, complexity param=0.08461489
## mean=3725.156, MSE=183199.2
## left son=4 (79 obs) right son=5 (81 obs)
## Primary splits:
## sex splits as LR, improve=0.437134700, (0 missing)
## bill_depth_mm < 18.05 to the left, improve=0.236826100, (0 missing)
## flipper_length_mm < 194.5 to the left, improve=0.184746400, (0 missing)
## bill_length_mm < 39.05 to the left, improve=0.118219000, (0 missing)
## island splits as RRL, improve=0.003803564, (0 missing)
## Surrogate splits:
## bill_depth_mm < 17.95 to the left, agree=0.806, adj=0.608, (0 split)
## bill_length_mm < 38.95 to the left, agree=0.675, adj=0.342, (0 split)
## flipper_length_mm < 195.5 to the left, agree=0.669, adj=0.329, (0 split)
## island splits as RRL, agree=0.525, adj=0.038, (0 split)
## species splits as RL-, agree=0.512, adj=0.013, (0 split)
##
## Node number 3: 88 observations, complexity param=0.0855046
## mean=5073.295, MSE=215139.1
## left son=6 (45 obs) right son=7 (43 obs)
## Primary splits:
## sex splits as LR, improve=0.6839107, (0 missing)
## bill_depth_mm < 14.75 to the left, improve=0.4953340, (0 missing)
## bill_length_mm < 48.3 to the left, improve=0.4320779, (0 missing)
## flipper_length_mm < 212.5 to the left, improve=0.3651048, (0 missing)
## Surrogate splits:
## bill_depth_mm < 14.85 to the left, agree=0.909, adj=0.814, (0 split)
## bill_length_mm < 47.9 to the left, agree=0.807, adj=0.605, (0 split)
## flipper_length_mm < 217.5 to the left, agree=0.807, adj=0.605, (0 split)
##
## Node number 4: 79 observations, complexity param=0.005466672
## mean=3438.608, MSE=76958.82
## left son=8 (69 obs) right son=9 (10 obs)
## Primary splits:
## bill_depth_mm < 18.75 to the left, improve=0.13616000, (0 missing)
## flipper_length_mm < 188.5 to the left, improve=0.09458126, (0 missing)
## bill_length_mm < 44.35 to the left, improve=0.07698016, (0 missing)
## species splits as LR-, improve=0.03663995, (0 missing)
## island splits as LRL, improve=0.00384916, (0 missing)
##
## Node number 5: 81 observations, complexity param=0.006996383
## mean=4004.63, MSE=128628.3
## left son=10 (37 obs) right son=11 (44 obs)
## Primary splits:
## flipper_length_mm < 194.5 to the left, improve=0.100884700, (0 missing)
## bill_depth_mm < 20.65 to the left, improve=0.058103760, (0 missing)
## bill_length_mm < 49.45 to the right, improve=0.022226350, (0 missing)
## species splits as RL-, improve=0.017807420, (0 missing)
## island splits as RLL, improve=0.003979764, (0 missing)
## Surrogate splits:
## bill_length_mm < 41.2 to the left, agree=0.753, adj=0.459, (0 split)
## species splits as LR-, agree=0.667, adj=0.270, (0 split)
## bill_depth_mm < 18.95 to the left, agree=0.667, adj=0.270, (0 split)
## island splits as LRR, agree=0.630, adj=0.189, (0 split)
##
## Node number 6: 45 observations, complexity param=0.00557671
## mean=4698.333, MSE=69094.44
## left son=12 (18 obs) right son=13 (27 obs)
## Primary splits:
## flipper_length_mm < 211.5 to the left, improve=0.2716030, (0 missing)
## bill_depth_mm < 14.65 to the left, improve=0.1697643, (0 missing)
## bill_length_mm < 46.45 to the left, improve=0.1080713, (0 missing)
## Surrogate splits:
## bill_length_mm < 43.35 to the left, agree=0.644, adj=0.111, (0 split)
## bill_depth_mm < 13.65 to the left, agree=0.644, adj=0.111, (0 split)
##
## Node number 7: 43 observations, complexity param=0.00540808
## mean=5465.698, MSE=66861.14
## left son=14 (10 obs) right son=15 (33 obs)
## Primary splits:
## bill_length_mm < 47.7 to the left, improve=0.2848478, (0 missing)
## bill_depth_mm < 15.85 to the left, improve=0.1380727, (0 missing)
## flipper_length_mm < 228.5 to the left, improve=0.1374721, (0 missing)
## Surrogate splits:
## flipper_length_mm < 215.5 to the left, agree=0.814, adj=0.2, (0 split)
##
## Node number 8: 69 observations, complexity param=0.004004777
## mean=3399.638, MSE=69628.49
## left son=16 (60 obs) right son=17 (9 obs)
## Primary splits:
## flipper_length_mm < 195.5 to the left, improve=0.12046720, (0 missing)
## bill_depth_mm < 16.75 to the left, improve=0.11065450, (0 missing)
## bill_length_mm < 44.35 to the left, improve=0.10740950, (0 missing)
## species splits as LR-, improve=0.05913125, (0 missing)
## island splits as LRL, improve=0.01884782, (0 missing)
## Surrogate splits:
## bill_length_mm < 48.7 to the left, agree=0.884, adj=0.111, (0 split)
##
## Node number 9: 10 observations
## mean=3707.5, MSE=44756.25
##
## Node number 10: 37 observations
## mean=3880.405, MSE=98957.27
##
## Node number 11: 44 observations, complexity param=0.006996383
## mean=4109.091, MSE=129690.1
## left son=22 (18 obs) right son=23 (26 obs)
## Primary splits:
## bill_length_mm < 49.25 to the right, improve=0.10204000, (0 missing)
## bill_depth_mm < 20.4 to the left, improve=0.08950952, (0 missing)
## island splits as RLL, improve=0.08655960, (0 missing)
## species splits as RL-, improve=0.07489047, (0 missing)
## flipper_length_mm < 202.5 to the left, improve=0.05269237, (0 missing)
## Surrogate splits:
## species splits as RL-, agree=0.932, adj=0.833, (0 split)
## island splits as RLR, agree=0.750, adj=0.389, (0 split)
## flipper_length_mm < 200.5 to the right, agree=0.705, adj=0.278, (0 split)
## bill_depth_mm < 19.75 to the right, agree=0.636, adj=0.111, (0 split)
##
## Node number 12: 18 observations
## mean=4530.556, MSE=67885.8
##
## Node number 13: 27 observations
## mean=4810.185, MSE=38623.11
##
## Node number 14: 10 observations
## mean=5215, MSE=17025
##
## Node number 15: 33 observations
## mean=5541.667, MSE=57146.46
##
## Node number 16: 60 observations, complexity param=0.004004777
## mean=3364.167, MSE=65632.64
## left son=32 (9 obs) right son=33 (51 obs)
## Primary splits:
## bill_depth_mm < 16.65 to the left, improve=0.16102720, (0 missing)
## bill_length_mm < 45.3 to the left, improve=0.10953280, (0 missing)
## flipper_length_mm < 188.5 to the left, improve=0.05126286, (0 missing)
## species splits as LR-, improve=0.04810778, (0 missing)
## island splits as LRR, improve=0.04608987, (0 missing)
##
## Node number 17: 9 observations
## mean=3636.111, MSE=31959.88
##
## Node number 22: 18 observations, complexity param=0.006996383
## mean=3970.833, MSE=140434
## left son=44 (7 obs) right son=45 (11 obs)
## Primary splits:
## flipper_length_mm < 199.5 to the left, improve=0.61120200, (0 missing)
## bill_length_mm < 50.4 to the left, improve=0.13577520, (0 missing)
## bill_depth_mm < 19.85 to the left, improve=0.06731364, (0 missing)
## Surrogate splits:
## bill_length_mm < 50.4 to the left, agree=0.722, adj=0.286, (0 split)
##
## Node number 23: 26 observations
## mean=4204.808, MSE=99856.69
##
## Node number 32: 9 observations
## mean=3119.444, MSE=65802.47
##
## Node number 33: 51 observations
## mean=3407.353, MSE=53168.97
##
## Node number 44: 7 observations
## mean=3603.571, MSE=32933.67
##
## Node number 45: 11 observations
## mean=4204.545, MSE=68388.43
model$splits
## count ncat improve index adj
## species 248 3 0.681409870 1.00 0.00000000
## flipper_length_mm 248 -1 0.664900737 208.50 0.00000000
## bill_depth_mm 248 1 0.491646586 16.40 0.00000000
## island 248 3 0.390857150 2.00 0.00000000
## bill_length_mm 248 -1 0.311158277 42.55 0.00000000
## flipper_length_mm 0 -1 0.971774194 206.50 0.92045455
## bill_depth_mm 0 1 0.959677419 16.40 0.88636364
## island 0 3 0.862903226 3.00 0.61363636
## bill_length_mm 0 -1 0.774193548 43.25 0.36363636
## sex 160 2 0.437134679 4.00 0.00000000
## bill_depth_mm 160 -1 0.236826057 18.05 0.00000000
## flipper_length_mm 160 -1 0.184746367 194.50 0.00000000
## bill_length_mm 160 -1 0.118218983 39.05 0.00000000
## island 160 3 0.003803564 5.00 0.00000000
## bill_depth_mm 0 -1 0.806250000 17.95 0.60759494
## bill_length_mm 0 -1 0.675000000 38.95 0.34177215
## flipper_length_mm 0 -1 0.668750000 195.50 0.32911392
## island 0 3 0.525000000 6.00 0.03797468
## species 0 3 0.512500000 7.00 0.01265823
## bill_depth_mm 79 -1 0.136160011 18.75 0.00000000
## flipper_length_mm 79 -1 0.094581258 188.50 0.00000000
## bill_length_mm 79 -1 0.076980156 44.35 0.00000000
## species 79 3 0.036639950 8.00 0.00000000
## island 79 3 0.003849160 9.00 0.00000000
## flipper_length_mm 69 -1 0.120467243 195.50 0.00000000
## bill_depth_mm 69 -1 0.110654501 16.75 0.00000000
## bill_length_mm 69 -1 0.107409451 44.35 0.00000000
## species 69 3 0.059131250 10.00 0.00000000
## island 69 3 0.018847823 11.00 0.00000000
## bill_length_mm 0 -1 0.884057971 48.70 0.11111111
## bill_depth_mm 60 -1 0.161027215 16.65 0.00000000
## bill_length_mm 60 -1 0.109532815 45.30 0.00000000
## flipper_length_mm 60 -1 0.051262862 188.50 0.00000000
## species 60 3 0.048107779 12.00 0.00000000
## island 60 3 0.046089873 13.00 0.00000000
## flipper_length_mm 81 -1 0.100884686 194.50 0.00000000
## bill_depth_mm 81 -1 0.058103756 20.65 0.00000000
## bill_length_mm 81 1 0.022226346 49.45 0.00000000
## species 81 3 0.017807420 14.00 0.00000000
## island 81 3 0.003979764 15.00 0.00000000
## bill_length_mm 0 -1 0.753086420 41.20 0.45945946
## species 0 3 0.666666667 16.00 0.27027027
## bill_depth_mm 0 -1 0.666666667 18.95 0.27027027
## island 0 3 0.629629630 17.00 0.18918919
## bill_length_mm 44 1 0.102039957 49.25 0.00000000
## bill_depth_mm 44 -1 0.089509524 20.40 0.00000000
## island 44 3 0.086559601 18.00 0.00000000
## species 44 3 0.074890469 19.00 0.00000000
## flipper_length_mm 44 -1 0.052692369 202.50 0.00000000
## species 0 3 0.931818182 20.00 0.83333333
## island 0 3 0.750000000 21.00 0.38888889
## flipper_length_mm 0 1 0.704545455 200.50 0.27777778
## bill_depth_mm 0 1 0.636363636 19.75 0.11111111
## flipper_length_mm 18 -1 0.611202001 199.50 0.00000000
## bill_length_mm 18 -1 0.135775215 50.40 0.00000000
## bill_depth_mm 18 -1 0.067313636 19.85 0.00000000
## bill_length_mm 0 -1 0.722222222 50.40 0.28571429
## sex 88 2 0.683910741 22.00 0.00000000
## bill_depth_mm 88 -1 0.495333985 14.75 0.00000000
## bill_length_mm 88 -1 0.432077925 48.30 0.00000000
## flipper_length_mm 88 -1 0.365104843 212.50 0.00000000
## bill_depth_mm 0 -1 0.909090909 14.85 0.81395349
## bill_length_mm 0 -1 0.806818182 47.90 0.60465116
## flipper_length_mm 0 -1 0.806818182 217.50 0.60465116
## flipper_length_mm 45 -1 0.271602953 211.50 0.00000000
## bill_depth_mm 45 -1 0.169764261 14.65 0.00000000
## bill_length_mm 45 -1 0.108071346 46.45 0.00000000
## bill_length_mm 0 -1 0.644444444 43.35 0.11111111
## bill_depth_mm 0 -1 0.644444444 13.65 0.11111111
## bill_length_mm 43 -1 0.284847811 47.70 0.00000000
## bill_depth_mm 43 -1 0.138072693 15.85 0.00000000
## flipper_length_mm 43 -1 0.137472108 228.50 0.00000000
## flipper_length_mm 0 -1 0.813953488 215.50 0.20000000
Objasnimo ukratko kako se u gornjem ispisu računa
improve. Neka je \(V\)
promatrani vrh, a \(V_L\) i \(V_R\) redom njegovo lijevo i desno dijete
nakon grananja po nekoj varijabli. Neka je \[
RSS(V)=\sum_{x_i\in V}{\big(y_i-\hat{y}_V\big)^2},\quad
RSS(V_L)=\sum_{x_i\in V_L}{\big(y_i-\hat{y}_{V_L}\big)^2},\quad
RSS(V_R)=\sum_{x_i\in V_R}{\big(y_i-\hat{y}_{V_R}\big)^2}
\] pri čemu su \(\hat{y}_V,\hat{y}_{V_L}\) i \(\hat{y}_{V_R}\) aritmetičke sredine
podataka koji pripadaju redom vrhovima \(V,V_L\) i \(V_R\). Tada je \[\mathrm{improve}=1-\frac{RSS(V_L)+RSS(V_R)}{RSS(V)}.\]
U donjem grafičkom prikazu važnosti prediktora, zapravo je prikazana
suma svih deviance (RSS) vrijednosti za pojedinu
varijablu.
stablo_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli
Također vidimo da u summary ispisu imamo skalirane
cjelobrojne važnosti varijabli tako da je njihova suma jednaka 100.
Ukoliko želimo, možemo sami prikazati grafički skalirane važnosti
varijabli.
(varimpo <- model$variable.importance %>% as_tibble() %>%
bind_cols(varijabla = names(model$variable.importance)) %>%
mutate(rel_importance = round(100 * value / sum(value))) %>%
relocate(varijabla) %>% rename(importance = value))
ggplot(varimpo, aes(x = fct_reorder(varijabla, rel_importance), y = rel_importance)) +
geom_col() + coord_flip() + xlab("") + ylab("importance")
skalirana važnost varijabli
Na poznatoj izreci “Više glava je pametnije od jedne” bazira se pakiranje stabala, na engleskom bagging ili bootstrap aggregation. Pakiranje je jedna općenita metoda koja služi za smanjenje varijance algoritma za učenje bazirana na sljedećoj činjenici:
U praksi zapravo imamo samo jedan skup podataka za treniranje. Stoga na postojećem skupu za treniranje koje ima \(N\) podataka kreiramo nove skupove za treniranje s \(N\) podataka preko bootstrap metode. Svaki novi skup podataka je zapravo neka kombinacija s ponavljanjem od \(N\) elemenata početnog skupa podataka za treniranje s kojim raspolažemo.
Prilikom treniranja svako od \(B\) stabala koristi samo neke podatke iz početnog trening skupa. Podatke koje pojedino stablo ne koristi prilikom treniranja zovu se out-of-bag (OOB) podaci za promatrano stablo. Dakle, svako od \(B\) stabala ima svoje OOB podatke. Pokazuje se da u prosjeku svako stablo za treniranje koristi oko \(\frac{2}{3}\) početnih podataka, odnosno da je za svako stablo broj njegovih OOB podataka približno jednak \(\frac{1}{3}\) početnih podataka.
Za svaki podatak u trening skupu pogledamo sva stabla za koja je taj podatak OOB podatak. Svako od tih stabala daje predikciju za taj podatak. U slučaju regresije uzmemo prosječnu vrijednosti tih predikcija, a u slučaju klasifikacije klasu koja se javlja najviše puta. Na taj način dobivamo OOB predikciju za svaki podatak.
Pogledajmo sada pakiranje stabala na penguins podacima
pri čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli
koje su prediktori. Biranje optimalnih hiperparametara i treniranje
modela je napravljeno u zasebnoj R datoteci i sve potrebne
informacije su spremljene u rds datoteke. U spomenutoj
R datoteci se nalazi sljedeći kod. Zapravo pakiramo ukupno
20 stabala jer smo stavili times = 20.
library(tidyverse)
library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = species)
# definiranje modela
stablo_model <- bag_tree() %>%
set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
set_engine("rpart", times = 20) %>%
set_mode("classification")
# recept za transformiranje podataka
stablo_recipe <- recipe(species ~ ., data = training(podaci_split))
# workflow
stablo_work <- workflow() %>%
add_model(stablo_model) %>%
add_recipe(stablo_recipe)
# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)
# metrike
stablo_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 50)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>%
tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)
saveRDS(stablo_tuning, file = "modeli/bagged_stablo_tuning_klas.rds")
best_stablo_model <- stablo_tuning %>% select_best(metric = 'roc_auc')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)
trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/bagged_stablo_work_klas.rds")
saveRDS(stablo_fit, "modeli/bagged_stablo_fit_klas.rds")
Učitavanje spremljenih podataka iz rds datoteka.
bag_stablo_fit_klas <- readRDS("modeli/bagged_stablo_fit_klas.rds")
bag_stablo_tuning_klas <- readRDS("modeli/bagged_stablo_tuning_klas.rds")
bag_stablo_work_klas <- readRDS("modeli/bagged_stablo_work_klas.rds")
bag_stablo_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za roc_auc metriku
bag_stablo_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)
Metrike na testnom i trening skupu
tab1 <- bag_stablo_fit_klas %>% collect_predictions() %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- bag_stablo_work_klas %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
bag_stablo_work_klas %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu
bag_stablo_fit_klas %>% collect_predictions() %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu
Pogledajmo ROC i gain krivulje
bag_stablo_fit_klas %>% collect_predictions() %>%
roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje
bag_stablo_fit_klas %>% collect_predictions() %>%
gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje
Kod pakiranja stabala nema nekog velikog smisla prikazivati pojedino stablo, pogotovo ako tih stabala ima jako puno. No, ako želimo, možemo dobiti informacije o svakom stablu u paketu i nacrtati pojedino stablo. No, imajmo na umu da samo jedno stablo ne odlučuje o konačnoj predikciji, već u tome sudjeluju sva stabla u paketu.
model <- bag_stablo_fit_klas %>% extract_workflow() %>% extract_fit_engine()
attributes(model)
## $names
## [1] "model_df" "control" "cost" "imp" "base_model"
## [6] "blueprint"
##
## $class
## [1] "bagger" "hardhat_model" "hardhat_scalar"
Informacije o svim stablima nalaze se u atributu
model_df. Pogledajmo na primjer informacije o 1. i 8.
stablu iz paketa.
model$model_df$model[[1]]
## parsnip model object
##
## n= 249
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 249 148 1
## 2) flipper_length_mm< 206.5 148 47 1
## 4) bill_length_mm< 43.15 105 4 1
## 8) bill_length_mm< 42.3 96 1 1
## 16) bill_depth_mm>=16.7 84 0 1 *
## 17) bill_depth_mm< 16.7 12 1 1
## 34) bill_length_mm< 39.5 11 0 1 *
## 35) bill_length_mm>=39.5 1 0 2 *
## 9) bill_length_mm>=42.3 9 3 1
## 18) bill_length_mm>=42.6 6 0 1 *
## 19) bill_length_mm< 42.6 3 0 2 *
## 5) bill_length_mm>=43.15 43 1 2
## 10) bill_depth_mm>=15.4 42 0 2 *
## 11) bill_depth_mm< 15.4 1 0 3 *
## 3) flipper_length_mm>=206.5 101 6 3
## 6) bill_depth_mm>=18.15 6 0 2 *
## 7) bill_depth_mm< 18.15 95 0 3 *
model$model_df$model[[1]] %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 2, tweak = 1.2)
Prvo stablo iz paketa
model$model_df$model[[8]]
## parsnip model object
##
## n= 249
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 249 146 3
## 2) bill_length_mm< 43.2 105 5 1
## 4) bill_depth_mm>=15.3 103 3 1
## 8) bill_length_mm< 42.25 96 1 1
## 16) bill_depth_mm>=16.8 88 0 1 *
## 17) bill_depth_mm< 16.8 8 1 1
## 34) bill_length_mm< 39.5 7 0 1 *
## 35) bill_length_mm>=39.5 1 0 2 *
## 9) bill_length_mm>=42.25 7 2 1
## 18) bill_length_mm>=42.45 5 0 1 *
## 19) bill_length_mm< 42.45 2 0 2 *
## 5) bill_depth_mm< 15.3 2 0 3 *
## 3) bill_length_mm>=43.2 144 43 3
## 6) island=Dream 41 0 2 *
## 7) island=Biscoe,Torgersen 103 2 3
## 14) bill_depth_mm>=18.8 2 0 1 *
## 15) bill_depth_mm< 18.8 101 0 3 *
model$model_df$model[[8]] %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 2, tweak = 1.2)
Osmo stablo iz paketa
Funkcija var_imp iz baguette biblioteke
računa važnost prediktora za svako stablo u paketu kako je ranije
opisano kod stabla odlučivanja i vraća aritmetičku sredinu tih
vrijednosti.
vaznost <- var_imp(model)
ggplot(vaznost, aes(x = fct_reorder(term, value), y = value)) +
geom_col() + coord_flip() + xlab("") + ylab("")
Važnost varijabli
Pogledajmo sada pakiranje stabala na penguins podacima
koje procjenjuje masu pingvina u gramima pomoću svih preostalih
varijabli koje su prediktori. Biranje optimalnih hiperparametara i
treniranje modela je napravljeno u zasebnoj R datoteci i
sve potrebne informacije su spremljene u rds datoteke. U
spomenutoj R datoteci se nalazi sljedeći kod.
library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)
# definiranje modela
stablo_model <- bag_tree() %>%
set_args(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
set_engine("rpart", times = 20) %>%
set_mode("regression")
# recept za transformiranje podataka
stablo_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))
# workflow
stablo_work <- workflow() %>%
add_model(stablo_model) %>%
add_recipe(stablo_recipe)
# cross validation folds
stablo_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)
# metrike
stablo_metrike <- metric_set(mae, mape, mpe, rmse, rsq)
stablo_grid <- grid_latin_hypercube(tree_depth(), min_n(), cost_complexity(), size = 100)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
stablo_tuning <- stablo_work %>%
tune_grid(resamples = stablo_folds, grid = stablo_grid, metrics = stablo_metrike)
stopCluster(cl)
saveRDS(stablo_tuning, file = "modeli/bagged_stablo_tuning_reg.rds")
best_stablo_model <- stablo_tuning %>% select_best(metric = 'rmse')
final_stablo_work <- stablo_work %>% finalize_workflow(best_stablo_model)
stablo_fit <- final_stablo_work %>% last_fit(split = podaci_split)
trenirani_model_work <- stablo_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/bagged_stablo_work_reg.rds")
saveRDS(stablo_fit, "modeli/bagged_stablo_fit_reg.rds")
Učitavanje spremljenih podataka iz rds datoteka.
bag_stablo_fit_reg <- readRDS("modeli/bagged_stablo_fit_reg.rds")
bag_stablo_tuning_reg <- readRDS("modeli/bagged_stablo_tuning_reg.rds")
bag_stablo_work_reg <- readRDS("modeli/bagged_stablo_work_reg.rds")
bag_stablo_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za rmse metriku
bag_stablo_tuning_reg %>% show_best(metric = 'rmse', n = 20)
Metrike na testnom i trening skupu
tab1 <- bag_stablo_fit_reg %>% collect_predictions() %>%
metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- bag_stablo_work_reg %>%
metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(bag_stablo_work_reg, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
ylab("predikcija")
Predikcije na trening skupu
test_podaci <- bag_stablo_fit_reg %>% collect_predictions()
ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5, alpha = 0.7) +
geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu
Pogledajmo na primjer informacije o 5. i 11. stablu iz paketa.
model <- bag_stablo_fit_reg %>% extract_workflow() %>% extract_fit_engine()
model$model_df$model[[5]]
## parsnip model object
##
## n= 248
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 248 155616900.0 4219.052
## 2) species=Adelie,Chinstrap 163 30085530.0 3749.233
## 4) sex=female 72 4628672.0 3446.875
## 8) flipper_length_mm< 197 64 3489023.0 3408.594
## 16) bill_depth_mm< 16.9 6 197083.3 3141.667 *
## 17) bill_depth_mm>=16.9 58 2820216.0 3436.207
## 34) bill_length_mm>=36.1 48 1818112.0 3400.521
## 68) island=Biscoe,Torgersen 18 819444.4 3294.444 *
## 69) island=Dream 30 674604.2 3464.167 *
## 35) bill_length_mm< 36.1 10 647562.5 3607.500 *
## 9) flipper_length_mm>=197 8 295546.9 3753.125 *
## 5) sex=male 91 13666630.0 3988.462
## 10) bill_depth_mm< 20.6 81 9225386.0 3916.049
## 20) flipper_length_mm< 189.5 13 725480.8 3582.692 *
## 21) flipper_length_mm>=189.5 68 6779072.0 3979.779
## 42) species=Chinstrap 28 2009286.0 3814.286
## 84) bill_length_mm< 50.4 13 758557.7 3621.154 *
## 85) bill_length_mm>=50.4 15 345583.3 3981.667 *
## 43) species=Adelie 40 3466109.0 4095.625
## 86) bill_length_mm< 43.15 33 2477197.0 4055.303 *
## 87) bill_length_mm>=43.15 7 682321.4 4285.714 *
## 11) bill_depth_mm>=20.6 10 576250.0 4575.000 *
## 3) species=Gentoo 85 20557250.0 5120.000
## 6) sex=female 42 3047396.0 4729.167
## 12) bill_length_mm< 43.35 9 865000.0 4466.667 *
## 13) bill_length_mm>=43.35 33 1393106.0 4800.758
## 26) bill_depth_mm< 14.45 18 567673.6 4715.278 *
## 27) bill_depth_mm>=14.45 15 536083.3 4903.333 *
## 7) sex=male 43 4827994.0 5501.744
## 14) flipper_length_mm< 228.5 34 3370312.0 5412.500
## 28) bill_length_mm< 46.6 6 103333.3 5116.667 *
## 29) bill_length_mm>=46.6 28 2629353.0 5475.893
## 58) bill_length_mm>=49.35 16 1617148.0 5392.188 *
## 59) bill_length_mm< 49.35 12 750625.0 5587.500 *
## 15) flipper_length_mm>=228.5 9 163888.9 5838.889 *
Desnim klikom na sliku možete otvoriti veću sliku u novom tabu.
model$model_df$model[[5]] %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 3, tweak = 1.15)
Peto stablo iz paketa
model$model_df$model[[11]]
## parsnip model object
##
## n= 248
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 248 167309000.00 4280.847
## 2) species=Adelie,Chinstrap 154 28888170.00 3741.234
## 4) sex=female 70 4959295.00 3423.929
## 8) bill_length_mm< 38 30 1892417.00 3281.667
## 16) bill_depth_mm< 17.65 14 223258.90 3116.071 *
## 17) bill_depth_mm>=17.65 16 949335.90 3426.562 *
## 9) bill_length_mm>=38 40 2004359.00 3530.625
## 18) flipper_length_mm< 194.5 27 1237269.00 3457.407
## 36) bill_depth_mm< 18.35 20 936375.00 3407.500
## 72) bill_length_mm< 46.15 14 752321.40 3353.571 *
## 73) bill_length_mm>=46.15 6 48333.33 3533.333 *
## 37) bill_depth_mm>=18.35 7 108750.00 3600.000 *
## 19) flipper_length_mm>=194.5 13 321730.80 3682.692 *
## 5) sex=male 84 11007940.00 4005.655
## 10) flipper_length_mm< 200.5 62 8031855.00 3928.226
## 20) bill_length_mm>=45.85 19 1127763.00 3734.211
## 40) bill_depth_mm< 18.85 9 297500.00 3633.333 *
## 41) bill_depth_mm>=18.85 10 656250.00 3825.000 *
## 21) bill_length_mm< 45.85 43 5872878.00 4013.953
## 42) bill_length_mm< 39.1 12 159166.70 3741.667 *
## 43) bill_length_mm>=39.1 31 4479637.00 4119.355
## 86) bill_length_mm>=40.15 22 2850369.00 3996.591 *
## 87) bill_length_mm< 40.15 9 487222.20 4419.444 *
## 11) flipper_length_mm>=200.5 22 1556847.00 4223.864
## 22) bill_depth_mm< 19.85 14 405892.90 4085.714 *
## 23) bill_depth_mm>=19.85 8 416171.90 4465.625 *
## 3) species=Gentoo 94 20114150.00 5164.894
## 6) sex=female 45 2741111.00 4794.444
## 12) flipper_length_mm< 211.5 15 869333.30 4576.667 *
## 13) flipper_length_mm>=211.5 30 804666.70 4903.333
## 26) bill_length_mm>=45.8 12 272291.70 4829.167 *
## 27) bill_length_mm< 45.8 18 422361.10 4952.778 *
## 7) sex=male 49 5526224.00 5505.102
## 14) bill_length_mm< 48.3 16 444375.00 5256.250 *
## 15) bill_length_mm>=48.3 33 3610606.00 5625.758
## 30) bill_length_mm>=49.45 22 1400909.00 5513.636
## 60) bill_length_mm>=50.95 6 77083.33 5358.333 *
## 61) bill_length_mm< 50.95 16 1124844.00 5571.875 *
## 31) bill_length_mm< 49.45 11 1380000.00 5850.000 *
Desnim klikom na sliku možete otvoriti veću sliku u novom tabu.
model$model_df$model[[11]] %>%
extract_fit_engine() %>%
rpart.plot(roundint = FALSE, digits = 3, tweak = 1.2)
Jedanaesto stablo iz paketa
Funkcija var_imp iz baguette biblioteke
računa važnost prediktora za svako stablo u paketu kako je ranije
opisano kod stabla odlučivanja i vraća aritmetičku sredinu tih
vrijednosti.
vaznost <- var_imp(model)
ggplot(vaznost, aes(x = fct_reorder(term, value), y = value)) +
geom_col() + coord_flip() + xlab("") + ylab("")
Važnost varijabli
Slučajna šuma funkcionira na slični način kao i pakiranje stabala. Jedina je razlika da se prilikom grananja pojedinog stabla ne koriste svi prediktori \(X_1,X_2,\dotsc,X_p\), nego se svaki put prilikom novog grananja na slučajni način bira \(m\) prediktora koji će se razmatrati u trenutnom grananju. Najčešće uzimamo \(m\approx\sqrt{p}\) što je dobro u većini slučajeva. Zapravo je \(m\) hiperparametar za slučajnu šumu čiju optimalnu vrijednost možemo također odabrati preko cross-validation.
Ovdje ćemo preko tidymodels biblioteke koristiti
ranger biblioteku u kojoj je implementirana slučajna šuma.
U toj biblioteci postoje tri hiperparametra za slučajnu šumu:
trees - broj stabala u slučajnoj šumimin_n - minimalni broj podataka potrebnih u vrhu kako
bi se on dalje dijeliomtry - broj prediktora koji će se na slučajni način
birati kod svakog grananja (ranije opisani parametar \(m\))Napomena. Ako stavimo da je mtry jednak
broju prediktora, tada zapravo dobivamo pakirana stabla. Stoga na taj
način možemo pakirati stabla i pomoću ranger
biblioteke.
ranger biblioteka nudi dva načina računanja važnosti
varijabli.
scale.permutation.importance = TRUE, tada se dobivene
vrijednosti još normaliziraju s pripadnim standardnim devijacijama
pojedinih razlika. Po defaultu je
scale.permutation.importance = FALSE.Lokalna važnost varijabli. Lokalna važnost nekog prediktora na \(i\)-tom podatku računa se slično kao i gore opisana globalna važnost. Jedina razlika je što se promatraju samo stabla iz slučajne šume koja nisu trenirana na \(i\)-tom podatku.
Pogledajmo sada slučajnu šumu na penguins podacima pri
čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli koje
su prediktori. Biranje optimalnih hiperparametara i treniranje modela je
napravljeno u zasebnoj R datoteci i sve potrebne
informacije su spremljene u rds datoteke. U spomenutoj
R datoteci se nalazi sljedeći kod.
library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = species)
# definiranje modela
rf_model <- rand_forest() %>%
set_args(mtry = tune(), trees = tune(), min_n = tune()) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
# recept za transformiranje podataka
rf_recipe <- recipe(species ~ ., data = training(podaci_split))
# workflow
rf_work <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rf_recipe)
# cross validation folds
rf_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)
# metrike
rf_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
rf_grid <- grid_latin_hypercube(mtry(range = c(2,3)), trees(range = c(400,1500)),
min_n(), size = 100)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
rf_tuning <- rf_work %>%
tune_grid(resamples = rf_folds, grid = rf_grid, metrics = rf_metrike)
stopCluster(cl)
saveRDS(rf_tuning, file = "modeli/rf_tuning_klas.rds")
best_rf_model <- rf_tuning %>% select_best(metric = 'roc_auc')
final_rf_work <- rf_work %>% finalize_workflow(best_rf_model)
rf_fit <- final_rf_work %>% last_fit(split = podaci_split)
trenirani_model_work <- rf_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_work_klas.rds")
saveRDS(rf_fit, "modeli/rf_fit_klas.rds")
Učitavanje spremljenih podataka iz rds datoteka.
rf_fit_klas <- readRDS("modeli/rf_fit_klas.rds")
rf_tuning_klas <- readRDS("modeli/rf_tuning_klas.rds")
rf_work_klas <- readRDS("modeli/rf_work_klas.rds")
rf_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za roc_auc metriku
rf_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)
Metrike na testnom i trening skupu
tab1 <- rf_fit_klas %>% collect_predictions() %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- rf_work_klas %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
rf_work_klas %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu
rf_fit_klas %>% collect_predictions() %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu
Pogledajmo ROC i gain krivulje
rf_fit_klas %>% collect_predictions() %>%
roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje
rf_fit_klas %>% collect_predictions() %>%
gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje
Pomoću treeInfo funkcije iz ranger
biblioteke možemo dobiti informacije o pojedinom stablu u slučajnoj
šumi. Pogledajmo na primjer informacije o 20. i 42. stablu iz slučajne
šume.
treeInfo(rf_fit_klas %>% extract_fit_engine(), 20)
treeInfo(rf_fit_klas %>% extract_fit_engine(), 42)
Važnost varijabli temeljena na gini indeksu
rf_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (gini indeks)
Ako želimo važnost varijabli preko permutacijske metode, tada moramo
staviti importance = "permutation". Ukoliko želimo da se
izračunaju i lokalne važnosti varijabli, moramo staviti
local.importance = TRUE. Vrijednosti hiperparametara su
samo preuzete kako smo ih dobili ranije, nisu ispočetka ponovo tražene
nove optimalne vrijednosti. Treniranje modela je napravljeno u zasebnoj
R datoteci i sve potrebne informacije su spremljene u
rds datoteke. U spomenutoj R datoteci se
nalazi sljedeći kod.
library(tidyverse)
library(tidymodels)
library(modeldata)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = species)
# definiranje modela
rf_model <- rand_forest() %>%
set_args(mtry = 2, trees = 453, min_n = 17) %>%
set_engine("ranger", importance = "permutation", local.importance = TRUE) %>%
set_mode("classification")
# recept za transformiranje podataka
rf_recipe <- recipe(species ~ ., data = training(podaci_split))
# workflow
rf_work <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rf_recipe)
# metrike
rf_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
rf_fit <- rf_work %>% last_fit(split = podaci_split)
trenirani_model_work <- rf_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_perm_work_klas.rds")
saveRDS(rf_fit, "modeli/rf_perm_fit_klas.rds")
Učitavanje spremljenih podataka iz rds datoteka.
rf_perm_fit_klas <- readRDS("modeli/rf_perm_fit_klas.rds")
rf_perm_work_klas <- readRDS("modeli/rf_perm_work_klas.rds")
Važnost varijabli temeljena na permutacijskoj metodi
rf_perm_fit_klas %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (permutacijska metoda)
Lokalne važnosti varijabli
model <- rf_perm_fit_klas %>% extract_fit_engine()
as_tibble(model$variable.importance.local)
Različiti grafički prikazi lokalnih važnosti varijabli
tablica <- as_tibble(model$variable.importance.local) %>%
mutate(podatak = row_number()) %>%
pivot_longer(-podatak, names_to = "aktivnost", values_to = "local importance")
ggplot(tablica, aes(x = podatak, y =`local importance`)) +
geom_point(alpha = 0.3, color = "black") + facet_wrap(vars(aktivnost))
ggplot(tablica, aes(x = podatak, y =`local importance`)) +
geom_line(linewidth = 0.2) + facet_wrap(vars(aktivnost))
ggplot(tablica, aes(x = aktivnost, y = `local importance`)) + geom_boxplot()
ggplot(tablica, aes(x = `local importance`)) + geom_histogram() +
facet_wrap(vars(aktivnost), scales = "free_y")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
Pogledajmo sada slučajnu šumu na penguins podacima koja
procjenjuje masu pingvina u gramima pomoću svih preostalih varijabli
koje su prediktori. Biranje optimalnih hiperparametara i treniranje
modela je napravljeno u zasebnoj R datoteci i sve potrebne
informacije su spremljene u rds datoteke. U spomenutoj
R datoteci se nalazi sljedeći kod.
library(tidymodels)
library(modeldata)
library(baguette)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)
# definiranje modela
rf_model <- rand_forest() %>%
set_args(mtry = tune(), trees = tune(), min_n = tune()) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("regression")
# recept za transformiranje podataka
rf_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))
# workflow
rf_work <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rf_recipe)
# cross validation folds
rf_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)
# metrike
rf_metrike <- metric_set(mae, mape, mpe, rmse, rsq)
rf_grid <- grid_latin_hypercube(mtry(range = c(2,3)), trees(range = c(400,1500)),
min_n(), size = 100)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
rf_tuning <- rf_work %>%
tune_grid(resamples = rf_folds, grid = rf_grid, metrics = rf_metrike)
stopCluster(cl)
saveRDS(rf_tuning, file = "modeli/rf_tuning_reg.rds")
best_rf_model <- rf_tuning %>% select_best(metric = 'rmse')
final_rf_work <- rf_work %>% finalize_workflow(best_rf_model)
rf_fit <- final_rf_work %>% last_fit(split = podaci_split)
trenirani_model_work <- rf_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_work_reg.rds")
saveRDS(rf_fit, "modeli/rf_fit_reg.rds")
Učitavanje spremljenih podataka iz rds datoteka.
rf_fit_reg <- readRDS("modeli/rf_fit_reg.rds")
rf_tuning_reg <- readRDS("modeli/rf_tuning_reg.rds")
rf_work_reg <- readRDS("modeli/rf_work_reg.rds")
rf_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za rmse metriku
rf_tuning_reg %>% show_best(metric = 'rmse', n = 20)
Metrike na testnom i trening skupu
tab1 <- rf_fit_reg %>% collect_predictions() %>%
metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- rf_work_reg %>%
metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(rf_work_reg, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
ylab("predikcija")
Predikcije na trening skupu
test_podaci <- rf_fit_reg %>% collect_predictions()
ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5, alpha = 0.7) +
geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu
Pogledajmo na primjer informacije o 56. i 234. stablu iz slučajne šume.
treeInfo(rf_fit_reg %>% extract_fit_engine(), 56)
treeInfo(rf_fit_reg %>% extract_fit_engine(), 234)
Važnost varijabli temeljena na sumi kvadrata reziduala
rf_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (suma kvadrata reziduala)
Ako želimo važnost varijabli preko permutacijske metode, tada moramo
staviti importance = "permutation". Ukoliko želimo da se
izračunaju i lokalne važnosti varijabli, moramo staviti
local.importance = TRUE. Vrijednosti hiperparametara su
samo preuzete kako smo ih dobili ranije, nisu ispočetka ponovo tražene
nove optimalne vrijednosti. Treniranje modela je napravljeno u zasebnoj
R datoteci i sve potrebne informacije su spremljene u
rds datoteke. U spomenutoj R datoteci se
nalazi sljedeći kod.
library(tidyverse)
library(tidymodels)
library(modeldata)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)
# definiranje modela
rf_model <- rand_forest() %>%
set_args(mtry = 2, trees = 407, min_n = 6) %>%
set_engine("ranger", importance = "permutation", local.importance = TRUE) %>%
set_mode("regression")
# recept za transformiranje podataka
rf_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split))
# workflow
rf_work <- workflow() %>%
add_model(rf_model) %>%
add_recipe(rf_recipe)
# metrike
rf_metrike <- metric_set(mae, mape, mpe, rmse, rsq)
rf_fit <- rf_work %>% last_fit(split = podaci_split)
trenirani_model_work <- rf_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/rf_perm_work_reg.rds")
saveRDS(rf_fit, "modeli/rf_perm_fit_reg.rds")
Učitavanje spremljenih podataka iz rds datoteka.
rf_perm_fit_reg <- readRDS("modeli/rf_perm_fit_reg.rds")
rf_perm_work_reg <- readRDS("modeli/rf_perm_work_reg.rds")
Važnost varijabli temeljena na permutacijskoj metodi
rf_perm_fit_reg %>% extract_fit_parsnip() %>% vip()
Važnost varijabli (permutacijska metoda)
Lokalne važnosti varijabli
model <- rf_perm_fit_reg %>% extract_fit_engine()
as_tibble(model$variable.importance.local)
Različiti grafički prikazi lokalnih važnosti varijabli
tablica <- as_tibble(model$variable.importance.local) %>%
mutate(podatak = row_number()) %>%
pivot_longer(-podatak, names_to = "aktivnost", values_to = "local importance")
ggplot(tablica, aes(x = podatak, y =`local importance`)) +
geom_point(alpha = 0.3, color = "black") + facet_wrap(vars(aktivnost))
ggplot(tablica, aes(x = podatak, y =`local importance`)) +
geom_line(linewidth = 0.2) + facet_wrap(vars(aktivnost))
ggplot(tablica, aes(x = aktivnost, y = `local importance`)) + geom_boxplot()
ggplot(tablica, aes(x = `local importance`)) + geom_histogram() +
facet_wrap(vars(aktivnost), scales = "free_y")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
Boruta algoritam služi za određivanje važnosti varijabli, a baziran je na slučajnoj šumi. Algoritam funkcionira na sljedeći način:
Za više detalja možete pogledati ovaj članak.
Primijenimo sada Boruta algoritam na penguins podatke
pri čemu želimo vidjeti koje su varijable bitne za predikciju vrste
pingvina.
boruta_klas <- Boruta(species ~ ., data = podaci, doTrace = 0)
Već nakon 10 iteracija ispada da su svi prediktori za predikciju
vrste pingvina važni, a najvažniji je bill_length_mm
(duljina kljuna u milimetrima), a najmanje važan je spol pingvina.
par(mar = c(8, 2, 2, 2))
plot(boruta_klas, cex.axis = 0.9, las = 2, xlab = "", main = "Vrsta pingvina",
colCode = c("green", "orange", "#f6546a", "#2acaea"))
Grafički i tablični prikazi kako su se kroz iteracije mijenjale važnosti varijabli.
plotImpHistory(boruta_klas, colCode = c("green", "orange", "#f6546a", "#2acaea"))
as_tibble(boruta_klas$ImpHistory)
Primijenimo još Boruta algoritam na penguins podatke pri
čemu želimo vidjeti koje su varijable bitne za predikciju mase pingvina
u gramima.
boruta_reg <- Boruta(body_mass_g ~ ., data = podaci, doTrace = 0)
Već nakon 10 iteracija ispada da su svi prediktori za predikciju mase pingvina važni, uglavnom su svi podjednake važnosti (spol pingvina ispada malo važniji od ostalih).
par(mar = c(8, 2, 2, 2))
plot(boruta_reg, cex.axis = 0.9, las = 2, xlab = "", main = "Masa pingvina",
colCode = c("green", "orange", "#f6546a", "#2acaea"))
Grafički i tablični prikazi kako su se kroz iteracije mijenjale važnosti varijabli.
plotImpHistory(boruta_reg, colCode = c("green", "orange", "#f6546a", "#2acaea"))
as_tibble(boruta_reg$ImpHistory)
Pojačana stabla, boosted trees, funkcioniraju slično kao i pakirana stabla, s time da stabla rastu sekvencijalno. Svako novo stablo se gradi tako da koristi informacije od prethodno izgrađenog stabla. U slučaju regresije općenita ideja je sljedeća:
U slučaju klasifikacije također je slična ideja. Ukratko, \(\hat{f}^b\) se bira tako da se optimizira funkcija cilja. Funkcija cilja je zapravo suma funkcije koja mjeri efikasnost modela na trening skupu i regularizacije.
Za više detalja možete pogledati xgboost ili jedan konkretni primjer.
Za dani trenutni model, gradimo novo stablo na rezidualima iz trenutnog modela. Novo stablo dodajemo u \(\hat{f}\) kako bismo ažurirali reziduale. Svako od tih stabla može biti relativno malo sa samo nekoliko listova, što određuje parametar \(d\) u algoritmu. Dodavanjem tih malih stabala polako poboljšavamo \(\hat{f}\) na dijelovima u kojima mu efikasnost nije najbolja. S parametrom \(\lambda\) možemo dodatno usporiti samo napredovanje kako bismo istražili više različito oblikovanih stabala za poboljšanje reziduala.
Ovdje ćemo preko tidymodels biblioteke koristiti
xgboost biblioteku u kojoj su implementirana pojačana
stabla. Neki od hiperparametara u toj biblioteci su sljedeći:
trees - broj stabala koja će se izgraditimin_n - minimalni broj podataka potrebnih u vrhu kako
bi se on dalje dijeliomtry - broj prediktora koji će se na slučajni način
birati kod svakog grananjatree_depth - maksimalna dubina stabla, tj. broj
grananjalearn_rate - brzina učenja (ranije spomenuti parametar
\(\lambda\))Prevelika vrijednost hiperparametra trees može ovdje
dovesti do overfittinga. Tipične vrijednosti hiperparametra
learn_rate su \(0.01\) ili
\(0.001\), ali pravi izbor može ovisiti
o samom problemu. Međutim, vrlo male vrijednosti od
learn_rate zahtijevaju vrlo velike vrijednosti od
trees ako želimo postići dobru efikasnost.
Pogledajmo sada pojačana stabla na penguins podacima pri
čemu klasificiramo vrstu pingvina pomoću svih preostalih varijabli koje
su prediktori. Biranje optimalnih hiperparametara i treniranje modela je
napravljeno u zasebnoj R datoteci i sve potrebne
informacije su spremljene u rds datoteke. Biblioteka
xgboost radi samo s numeričkim prediktorima pa su zbog toga
preko step_dummy(all_nominal_predictors()) svi nominalni
prediktori pretvoreni u numeričke varijable. U spomenutoj R
datoteci se nalazi sljedeći kod.
library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = species)
# definiranje modela
boost_model <- boost_tree() %>%
set_args(mtry = 2, trees = tune(), tree_depth = tune(), learn_rate = tune()) %>%
set_engine("xgboost") %>%
set_mode("classification")
# recept za transformiranje podataka
boost_recipe <- recipe(species ~ ., data = training(podaci_split)) %>%
step_dummy(all_nominal_predictors())
# workflow
boost_work <- workflow() %>%
add_model(boost_model) %>%
add_recipe(boost_recipe)
# cross validation folds
boost_folds <- vfold_cv(training(podaci_split), v = 10, strata = species)
# metrike
boost_metrike <- metric_set(roc_auc, sens, precision, spec, accuracy, f_meas, kap)
boost_grid <- grid_latin_hypercube(trees(range(1,100)), tree_depth(), learn_rate(), size = 50)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
boost_tuning <- boost_work %>%
tune_grid(resamples = boost_folds, grid = boost_grid, metrics = boost_metrike)
stopCluster(cl)
saveRDS(boost_tuning, file = "modeli/boost_tuning_klas.rds")
best_boost_model <- boost_tuning %>% select_best(metric = 'roc_auc')
final_boost_work <- boost_work %>% finalize_workflow(best_boost_model)
boost_fit <- final_boost_work %>% last_fit(split = podaci_split)
trenirani_model_work <- boost_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/boost_work_klas.rds")
saveRDS(boost_fit, "modeli/boost_fit_klas.rds")
Učitavanje spremljenih podataka iz rds datoteka.
boost_fit_klas <- readRDS("modeli/boost_fit_klas.rds")
boost_tuning_klas <- readRDS("modeli/boost_tuning_klas.rds")
boost_work_klas <- readRDS("modeli/boost_work_klas.rds")
boost_tuning_klas %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za roc_auc metriku
boost_tuning_klas %>% show_best(metric = 'roc_auc', n = 20)
Metrike na testnom i trening skupu
tab1 <- boost_fit_klas %>% collect_predictions() %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab2 <- boost_work_klas %>%
metrike_klas(truth = species, estimate = .pred_class, .pred_Adelie:.pred_Gentoo)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
boost_work_klas %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na trening skupu
boost_fit_klas %>% collect_predictions() %>%
conf_mat(truth = species, estimate = .pred_class) %>% autoplot("heatmap") +
scale_fill_gradient(low = "#87DEE7", high = "#FFFFCC")
Confusion matrix na testnom skupu
Pogledajmo ROC i gain krivulje
boost_fit_klas %>% collect_predictions() %>%
roc_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
ROC krivulje
boost_fit_klas %>% collect_predictions() %>%
gain_curve(species, .pred_Adelie:.pred_Gentoo) %>% autoplot()
gain krivulje
Nacrtajmo na primjer 9. i 21. stablo u dobivenom modelu.
xgb.plot.tree(model = boost_fit_klas %>% extract_fit_engine(), trees = 9,
plot_width = 800, plot_height = 500)
xgb.plot.tree(model = boost_fit_klas %>% extract_fit_engine(), trees = 21,
plot_width = 1000, plot_height = 400)
Možemo vizualizirati distribucije vezane uz dubinu listova stabala.
xgb.ggplot.deepness(boost_fit_klas %>% extract_fit_engine())
Možemo dobiti kompleksnost samog modela tako da se sva stabla opišu preko jednog stablo kako bi se poboljšala interpretacija modela.
gr <- xgb.plot.multi.trees(boost_fit_klas %>% extract_fit_engine(), render = FALSE)
export_graph(gr, 'tree.pdf', width=1000, height=1500)
Možete otvoriti veću PDF sliku.
xgb.plot.multi.trees(boost_fit_klas %>% extract_fit_engine(),
plot_width = 1000, plot_height = 1300)
xgboost biblioteka daje tri različite vrste važnosti
varijabli.
(vaznost <- xgb.importance(model = boost_fit_klas %>% extract_fit_engine()))
xgb.ggplot.importance(vaznost) + ggtitle("Importance (gain)")
xgb.ggplot.importance(vaznost, measure = "Cover") +
ggtitle("Importance (cover)")
xgb.ggplot.importance(vaznost, measure = "Frequency") +
ggtitle("Importance (frequency)")
Pogledajmo sada pojačana stabla na penguins podacima
koja procjenjuju masu pingvina u gramima pomoću svih preostalih
varijabli koje su prediktori. Biranje optimalnih hiperparametara i
treniranje modela je napravljeno u zasebnoj R datoteci i
sve potrebne informacije su spremljene u rds datoteke. U
spomenutoj R datoteci se nalazi sljedeći kod.
library(tidyverse)
library(tidymodels)
library(modeldata)
library(doParallel)
data("penguins")
podaci <- penguins %>% na.omit()
podaci_split <- initial_split(podaci, prop = 0.75, strata = body_mass_g)
# definiranje modela
boost_model <- boost_tree() %>%
set_args(mtry = 2, trees = tune(), tree_depth = tune(), learn_rate = tune()) %>%
set_engine("xgboost") %>%
set_mode("regression")
# recept za transformiranje podataka
boost_recipe <- recipe(body_mass_g ~ ., data = training(podaci_split)) %>%
step_dummy(all_nominal_predictors())
# workflow
boost_work <- workflow() %>%
add_model(boost_model) %>%
add_recipe(boost_recipe)
# cross validation folds
boost_folds <- vfold_cv(training(podaci_split), v = 10, strata = body_mass_g)
# metrike
boost_metrike <- metric_set(mae, mape, mpe, rmse, rsq)
boost_grid <- grid_latin_hypercube(trees(range(1,100)), tree_depth(), learn_rate(), size = 100)
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
boost_tuning <- boost_work %>%
tune_grid(resamples = boost_folds, grid = boost_grid, metrics = boost_metrike)
stopCluster(cl)
saveRDS(boost_tuning, file = "modeli/boost_tuning_reg.rds")
best_boost_model <- boost_tuning %>% select_best(metric = 'rmse')
final_boost_work <- boost_work %>% finalize_workflow(best_boost_model)
boost_fit <- final_boost_work %>% last_fit(split = podaci_split)
trenirani_model_work <- boost_fit %>% extract_workflow() %>%
augment(training(podaci_split))
saveRDS(trenirani_model_work, "modeli/boost_work_reg.rds")
saveRDS(boost_fit, "modeli/boost_fit_reg.rds")
Učitavanje spremljenih podataka iz rds datoteka.
boost_fit_reg <- readRDS("modeli/boost_fit_reg.rds")
boost_tuning_reg <- readRDS("modeli/boost_tuning_reg.rds")
boost_work_reg <- readRDS("modeli/boost_work_reg.rds")
boost_tuning_reg %>% autoplot()
Vrijednosti hiperparametara i metrika
20 najboljih modela za rmse metriku
boost_tuning_reg %>% show_best(metric = 'rmse', n = 20)
Metrike na testnom i trening skupu
tab1 <- boost_fit_reg %>% collect_predictions() %>%
metrike_reg(truth = body_mass_g, estimate = .pred,)
tab2 <- boost_work_reg %>%
metrike_reg(truth = body_mass_g, estimate = .pred)
tab1 %>% inner_join(tab2, by = ".metric") %>%
select(.metric, .estimator = .estimator.x, trening = .estimate.y, test = .estimate.x)
ggplot(boost_work_reg, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5) + geom_abline(color = "#00BFC4") +
ylab("predikcija")
Predikcije na trening skupu
test_podaci <- boost_fit_reg %>% collect_predictions()
ggplot(test_podaci, aes(x = body_mass_g, y = .pred)) +
geom_point(size = 1.5, alpha = 0.7) +
geom_abline(color = "#00BFC4") + ylab("predikcija")
Predikcije na testnom skupu
Nacrtajmo na primjer 5. i 33. stablo u dobivenom modelu.
xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(), trees = 5,
plot_width = 700, plot_height = 450)
gr1 <- xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(),
trees = 33, render = FALSE)
export_graph(gr1, 'tree33.pdf', width=1000, height=900)
Možete otvoriti veću PDF sliku.
xgb.plot.tree(model = boost_fit_reg %>% extract_fit_engine(), trees = 33,
plot_width = 1000, plot_height = 900)
Možemo vizualizirati distribucije vezane uz dubinu listova stabala.
xgb.ggplot.deepness(boost_fit_reg %>% extract_fit_engine())
Možemo dobiti kompleksnost samog modela tako da se sva stabla opišu preko jednog stablo kako bi se poboljšala interpretacija modela.
gr3 <- xgb.plot.multi.trees(boost_fit_reg %>% extract_fit_engine(), render = FALSE)
export_graph(gr3, 'tree_reg.pdf', width=900, height=1800)
S obzirom da je stablo dosta veliko, stavljen je samo link na PDF sliku.
Važnost varijabli
(vaznost_reg <- xgb.importance(model = boost_fit_reg %>% extract_fit_engine()))
xgb.ggplot.importance(vaznost_reg) + ggtitle("Importance (gain)")
xgb.ggplot.importance(vaznost_reg, measure = "Cover") +
ggtitle("Importance (cover)")
xgb.ggplot.importance(vaznost_reg, measure = "Frequency") +
ggtitle("Importance (frequency)")