Streszczenie wykonawcze

Cel: Zbudować i porównać modele klasyfikacji binarnej przewidujące, czy pracownik odejdzie z firmy (Attrition: Yes/No).

Dane: k4_attrition (~1470 pracowników, 20 zmiennych, w tym cechy po feature engineeringu).

Modele: Drzewo decyzyjne (CART), las losowy (Random Forest), gradient boosting (XGBoost), sieć neuronowa (MLP / nnet) oraz regresja logistyczna jako punkt odniesienia.

Walidacja: Podział stratified 80/20 + 5-krotna walidacja krzyżowa na zbiorze treningowym.

Metryki: ROC-AUC, precision, recall, F1, dokładność, macierz pomyłek.

Wynik: Na końcu raportu wybierany jest najlepszy model według ROC-AUC na zbiorze testowym; wytrenowane obiekty zapisywane są do katalogu models/.


1. Problem biznesowy i dane

1.1 Kontekst

Rotacja pracowników (attrition) generuje koszty rekrutacji, onboardingu i utraty wiedzy. Wczesne wskazanie osób z wysokim ryzykiem odejścia pozwala HR na działania retencyjne (mentoring, elastyczny czas pracy, ścieżka kariery).

1.2 Zmienna docelowa

  • Attrition: Yes (odejście = 1), No (pozostanie = 0).
  • Problem jest niezrównoważony - odsetek odejść jest niski; dlatego priorytetem jest recall klasy pozytywnej oraz ROC-AUC, a nie sama dokładność.

1.3 Źródło danych

Zbiór: k4_attrition.csv (IBM HR Employee Attrition po selekcji kolumn i inżynierii cech).

library(tidyverse)
library(tidymodels)
library(recipes)
library(yardstick)
library(vip)
library(knitr)
library(kableExtra)
library(glue)
library(rpart.plot)

tidymodels_prefer()
resolve_data_path <- function() {
  candidates <- c(
    file.path("data", "k4_attrition.csv"),
    file.path("..", "data", "k4_attrition.csv"),
    "C:/Users/luck0/Downloads/k4_attrition.csv",
    "C:/Users/luck0/Documents/hr-attrition-classification/data/k4_attrition.csv"
  )
  hit <- candidates[file.exists(candidates)][1]
  if (is.na(hit)) {
    stop("Nie znaleziono pliku k4_attrition.csv. Skopiuj dane do folderu data/.")
  }
  hit
}

raw <- readr::read_csv(resolve_data_path(), show_col_types = FALSE)
glimpse(raw)
## Rows: 1,470
## Columns: 20
## $ Attrition               <chr> "Yes", "No", "Yes", "No", "No", "No", "No", "N…
## $ Age                     <dbl> 41, 49, 37, 33, 27, 32, 59, 30, 38, 36, 35, 29…
## $ BusinessTravel          <chr> "Travel_Rarely", "Travel_Frequently", "Travel_…
## $ Department              <chr> "Sales", "Research & Development", "Research &…
## $ DistanceFromHome        <dbl> 1, 8, 2, 3, 2, 2, 3, 24, 23, 27, 16, 15, 26, 1…
## $ Education               <dbl> 2, 1, 2, 4, 1, 2, 3, 1, 3, 3, 3, 2, 1, 2, 3, 4…
## $ EnvironmentSatisfaction <dbl> 2, 3, 4, 4, 1, 4, 3, 4, 4, 3, 1, 4, 1, 2, 3, 2…
## $ JobRole                 <chr> "Sales Executive", "Research Scientist", "Labo…
## $ JobSatisfaction         <dbl> 4, 2, 3, 3, 2, 4, 1, 3, 3, 3, 2, 3, 3, 4, 3, 1…
## $ MonthlyIncome           <dbl> 5993, 5130, 2090, 2909, 3468, 3068, 2670, 2693…
## $ OverTime                <chr> "Yes", "No", "Yes", "Yes", "No", "No", "Yes", …
## $ StockOptionLevel        <dbl> 0, 1, 0, 0, 1, 0, 3, 1, 0, 2, 1, 0, 1, 1, 0, 1…
## $ TotalWorkingYears       <dbl> 8, 10, 7, 8, 6, 8, 12, 1, 10, 17, 6, 10, 5, 3,…
## $ YearsAtCompany          <dbl> 6, 10, 0, 8, 2, 7, 1, 1, 9, 7, 5, 9, 5, 2, 4, …
## $ YearsInCurrentRole      <dbl> 4, 7, 0, 7, 2, 7, 0, 0, 7, 7, 4, 5, 2, 2, 2, 9…
## $ WorkLifeBalance         <dbl> 1, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3…
## $ IncomePerYear           <dbl> 665.8889, 466.3636, 261.2500, 323.2222, 495.42…
## $ TenureRatio             <dbl> 0.66666667, 0.90909091, 0.00000000, 0.88888889…
## $ PromotionGap            <dbl> 0.5714286, 0.6363636, 0.0000000, 0.7777778, 0.…
## $ YoungEmployee           <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1…

2. Eksploracja danych

skim_cols <- raw %>%
  summarise(
    wiersze = n(),
    kolumny = ncol(raw),
    braki = sum(is.na(raw)),
    odsetek_attrition_yes = mean(Attrition == "Yes") * 100
  )
skim_cols
raw %>%
  count(Attrition) %>%
  mutate(odsetek = round(n / sum(n) * 100, 1)) %>%
  knitr::kable(caption = "Rozkład zmiennej docelowej")
Rozkład zmiennej docelowej
Attrition n odsetek
No 1233 83.9
Yes 237 16.1
# Kolumny stale (bez zmiennej docelowej)
constant_or_id <- raw %>%
  summarise(across(everything(), ~ n_distinct(.x))) %>%
  pivot_longer(everything(), names_to = "zmienna", values_to = "n_unikalnych") %>%
  filter(n_unikalnych <= 1, zmienna != "Attrition")

constant_or_id

Wnioski wstępne:

  • Klasa pozytywna (odejście) stanowi ok. 16% próbek - typowy dysbalans.
  • Zbiór zawiera 19 predyktorów, m.in. cechy po feature engineeringu: IncomePerYear, TenureRatio, PromotionGap, YoungEmployee.
raw %>%
  ggplot(aes(x = OverTime, fill = Attrition)) +
  geom_bar(position = "fill") +
  scale_y_continuous(labels = scales::percent) +
  labs(x = "OverTime", y = "Odsetek", fill = "Attrition") +
  theme_minimal()
Attrition vs nadgodziny

Attrition vs nadgodziny


3. Przygotowanie danych

3.1 Przygotowanie zmiennej docelowej i typów

hr <- raw %>%
  mutate(
    attrition = factor(Attrition, levels = c("No", "Yes")),
    OverTime = factor(OverTime),
    BusinessTravel = factor(BusinessTravel),
    Department = factor(Department),
    JobRole = factor(JobRole),
    YoungEmployee = factor(YoungEmployee)
  ) %>%
  select(-Attrition) %>%
  mutate(across(where(is.factor), fct_drop))

dim(hr)
## [1] 1470   20

3.2 Podział train / test (stratified)

set.seed(42)
split <- initial_split(hr, prop = 0.80, strata = attrition)
train <- training(split)
test  <- testing(split)

table(train$attrition)
## 
##  No Yes 
## 986 189
table(test$attrition)
## 
##  No Yes 
## 247  48

3.3 Przepis (recipe) - wspólny preprocessing

Wszystkie modele korzystają z tego samego przepisu: imputacja, kodowanie kategorii, normalizacja (istotna dla sieci neuronowej).

attrition_recipe <- recipe(attrition ~ ., data = train) %>%
  step_impute_median(all_numeric_predictors()) %>%
  step_impute_mode(all_nominal_predictors()) %>%
  step_novel(all_nominal_predictors()) %>%
  step_unknown(all_nominal_predictors()) %>%
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_numeric_predictors())

attrition_recipe

3.4 Walidacja krzyżowa

folds <- vfold_cv(train, v = 5, strata = attrition)
folds

4. Specyfikacja modeli

Porównujemy pięć podejść:

Model Silnik Rodzina
Regresja logistyczna glm liniowy (baseline)
Drzewo decyzyjne (CART) rpart drzewa
Las losowy ranger zespoły drzew
XGBoost xgboost boosting drzew
Sieć neuronowa (MLP) nnet sieci
class_metrics <- metric_set(
  roc_auc,
  pr_auc,
  accuracy,
  precision,
  recall,
  f_meas,
  sens,
  spec
)

two_class_metrics <- metric_set(
  roc_auc,
  pr_auc,
  accuracy,
  precision,
  recall,
  f_meas
)
log_spec <- logistic_reg(mode = "classification") %>% set_engine("glm")

tree_spec <- decision_tree(
  cost_complexity = tune(),
  tree_depth = tune(),
  min_n = tune()
) %>%
  set_engine("rpart") %>%
  set_mode("classification")

rf_spec <- rand_forest(
  mtry = tune(),
  min_n = tune(),
  trees = 500
) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

xgb_spec <- boost_tree(
  trees = tune(),
  tree_depth = tune(),
  learn_rate = tune(),
  mtry = tune(),
  min_n = tune()
) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

nnet_spec <- mlp(
  hidden_units = tune(),
  penalty = tune(),
  epochs = 200
) %>%
  set_engine("nnet", MaxNWts = 5000, trace = FALSE) %>%
  set_mode("classification")

5. Trenowanie i strojenie hiperparametrów

Ze względu na czas obliczeń używamy grid search z ograniczoną liczbą kombinacji na walidacji krzyżowej. Pełny grid można rozszerzyć przed finalną prezentacją.

wf_log  <- workflow() %>% add_model(log_spec)  %>% add_recipe(attrition_recipe)
wf_tree <- workflow() %>% add_model(tree_spec) %>% add_recipe(attrition_recipe)
wf_rf   <- workflow() %>% add_model(rf_spec)   %>% add_recipe(attrition_recipe)
wf_xgb  <- workflow() %>% add_model(xgb_spec)  %>% add_recipe(attrition_recipe)
wf_nnet <- workflow() %>% add_model(nnet_spec) %>% add_recipe(attrition_recipe)

# mtry (RF, XGBoost) ma "unknown" do czasu finalize() na danych treningowych
tuning_grid <- function(wf, data, size = 8) {
  grid_space_filling(
    extract_parameter_set_dials(wf) %>% dials::finalize(data),
    size = size
  )
}

5.1 Regresja logistyczna (bez tuningu)

fit_log <- wf_log %>% fit(data = train)

5.2 Drzewo decyzyjne

grid_tree <- tuning_grid(wf_tree, train, size = 8)

tune_tree <- tune_grid(
  wf_tree,
  resamples = folds,
  grid = grid_tree,
  metrics = two_class_metrics,
  control = control_grid(save_pred = TRUE, verbose = FALSE)
)

best_tree <- select_best(tune_tree, metric = "roc_auc")
fit_tree <- finalize_workflow(wf_tree, best_tree) %>% fit(data = train)

5.3 Las losowy

grid_rf <- tuning_grid(wf_rf, train, size = 8)

tune_rf <- tune_grid(
  wf_rf,
  resamples = folds,
  grid = grid_rf,
  metrics = two_class_metrics,
  control = control_grid(save_pred = TRUE, verbose = FALSE)
)

best_rf <- select_best(tune_rf, metric = "roc_auc")
fit_rf <- finalize_workflow(wf_rf, best_rf) %>% fit(data = train)

5.4 XGBoost

grid_xgb <- tuning_grid(wf_xgb, train, size = 12)

tune_xgb <- tune_grid(
  wf_xgb,
  resamples = folds,
  grid = grid_xgb,
  metrics = two_class_metrics,
  control = control_grid(save_pred = TRUE, verbose = FALSE)
)

best_xgb <- select_best(tune_xgb, metric = "roc_auc")
fit_xgb <- finalize_workflow(wf_xgb, best_xgb) %>% fit(data = train)

5.5 Sieć neuronowa (MLP)

grid_nnet <- tuning_grid(wf_nnet, train, size = 8)

tune_nnet <- tune_grid(
  wf_nnet,
  resamples = folds,
  grid = grid_nnet,
  metrics = two_class_metrics,
  control = control_grid(save_pred = TRUE, verbose = FALSE)
)

best_nnet <- select_best(tune_nnet, metric = "roc_auc")
fit_nnet <- finalize_workflow(wf_nnet, best_nnet) %>% fit(data = train)
collect_cv <- function(tune_obj, model_name) {
  tune_obj %>%
    collect_metrics() %>%
    filter(.metric %in% c("roc_auc", "pr_auc", "f_meas", "recall")) %>%
    group_by(.metric) %>%
    slice_max(mean, n = 1, with_ties = FALSE) %>%
    ungroup() %>%
    mutate(model = model_name) %>%
    select(model, .metric, mean, std_err)
}

cv_summary <- bind_rows(
  collect_cv(tune_tree, "CART"),
  collect_cv(tune_rf,   "Random Forest"),
  collect_cv(tune_xgb,  "XGBoost"),
  collect_cv(tune_nnet, "Neural Network")
)

cv_summary %>%
  pivot_wider(names_from = .metric, values_from = c(mean, std_err)) %>%
  knitr::kable(digits = 3, caption = "Najlepsze średnie metryki z 5-fold CV (zbiór treningowy)")
Najlepsze średnie metryki z 5-fold CV (zbiór treningowy)
model mean_f_meas mean_pr_auc mean_recall mean_roc_auc std_err_f_meas std_err_pr_auc std_err_recall std_err_roc_auc
CART 0.913 0.929 1 0.699 0.000 0.004 0 0.013
Random Forest 0.924 0.935 1 0.786 0.002 0.004 0 0.008
XGBoost 0.922 0.938 1 0.789 0.001 0.006 0 0.015
Neural Network 0.921 0.935 1 0.780 0.004 0.005 0 0.011

6. Ocena na zbiorze testowym

models_list <- list(
  "Logistic"        = fit_log,
  "CART"            = fit_tree,
  "Random Forest"   = fit_rf,
  "XGBoost"         = fit_xgb,
  "Neural Network"  = fit_nnet
)

test_results <- map_dfr(names(models_list), function(name) {
  fit <- models_list[[name]]
  pred_class <- predict(fit, test, type = "class")
  pred_prob  <- predict(fit, test, type = "prob")

  bind_cols(test %>% select(attrition), pred_class, pred_prob) %>%
    mutate(.model = name)
})

test_results
test_metrics <- test_results %>%
  group_by(.model) %>%
  class_metrics(
    truth = attrition,
    estimate = .pred_class,
    !!sym(PROB_COL),
    event_level = "second"
  ) %>%
  select(.model, .metric, .estimate) %>%
  pivot_wider(names_from = .metric, values_from = .estimate)

test_metrics %>%
  arrange(desc(roc_auc)) %>%
  knitr::kable(digits = 3, caption = "Metryki na zbiorze testowym")
Metryki na zbiorze testowym
.model accuracy precision recall f_meas sens spec roc_auc pr_auc
Neural Network 0.871 0.857 0.250 0.387 0.250 0.992 0.876 0.634
Logistic 0.875 0.789 0.312 0.448 0.312 0.984 0.851 0.629
XGBoost 0.868 0.655 0.396 0.494 0.396 0.960 0.850 0.626
Random Forest 0.837 NA 0.000 NA 0.000 1.000 0.827 0.571
CART 0.837 0.500 0.333 0.400 0.333 0.935 0.685 0.382

6.1 Krzywe ROC

test_results %>%
  group_by(.model) %>%
  roc_curve(truth = attrition, !!sym(PROB_COL), event_level = "second") %>%
  autoplot() +
  labs(
    title = "Krzywe ROC - zbiór testowy",
    x = "1 - Specificity (FPR)",
    y = "Sensitivity (TPR)"
  ) +
  theme_minimal()
Porównanie krzywych ROC (test)

Porównanie krzywych ROC (test)

6.2 Krzywe Precision-Recall

test_results %>%
  group_by(.model) %>%
  pr_curve(truth = attrition, !!sym(PROB_COL), event_level = "second") %>%
  autoplot() +
  labs(title = "Krzywe PR - ważne przy niezrównoważeniu klas") +
  theme_minimal()
Krzywe Precision-Recall (test)

Krzywe Precision-Recall (test)

6.3 Macierze pomyłek

test_results %>%
  count(.model, attrition, .pred_class, name = "n") %>%
  group_by(.model) %>%
  mutate(
    pct = n / sum(n),
    label = paste0(n, "\n(", scales::percent(pct, accuracy = 0.1), ")")
  ) %>%
  ungroup() %>%
  ggplot(aes(x = .pred_class, y = attrition, fill = n)) +
  geom_tile(color = "white") +
  geom_text(aes(label = label), color = "white", size = 3) +
  scale_fill_gradient(low = "#56B4E9", high = "#D55E00") +
  facet_wrap(~.model, ncol = 2) +
  labs(
    title = "Macierze pomyłek - zbiór testowy",
    x = "Prognoza",
    y = "Rzeczywista klasa",
    fill = "Liczba"
  ) +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))


7. Interpretacja modeli

7.1 Ważność zmiennych (las losowy)

fit_rf %>%
  extract_fit_engine() %>%
  vip::vip(num_features = 15) +
  labs(title = "Top 15 cech - Random Forest")
VIP — Random Forest

VIP — Random Forest

7.2 Ważność zmiennych (XGBoost)

fit_xgb %>%
  vip(num_features = 15) +
  labs(title = "Top 15 cech - XGBoost")
VIP — XGBoost

VIP — XGBoost

7.3 Drzewo decyzyjne - regularyzacja

rpart.plot::rpart.plot(
  fit_tree %>% extract_fit_engine(),
  type = 2,
  extra = 104,
  box.palette = "GnYlRd",
  main = "Drzewo decyzyjne CART"
)
Wizualizacja drzewa CART (uproszczone)

Wizualizacja drzewa CART (uproszczone)

Typowe czynniki ryzyka: OverTime = Yes, niska satysfakcja (JobSatisfaction, EnvironmentSatisfaction), krótki staż (YearsAtCompany, TenureRatio), długi okres bez awansu (PromotionGap), niski WorkLifeBalance.


8. Wybór najlepszego modelu

best_name <- test_metrics %>%
  arrange(desc(roc_auc)) %>%
  slice(1) %>%
  pull(.model)

best_fit <- models_list[[best_name]]

glue::glue("Najlepszy model na teście (ROC-AUC): **{best_name}**") %>%
  as.character()
## [1] "Najlepszy model na teście (ROC-AUC): **Neural Network**"

Kryterium główne: ROC-AUC na zbiorze testowym. W praktyce HR warto dodatkowo zweryfikować recall przy ustalonym progu - lepiej wychwycić odejście kosztem więcej fałszywie pozytywnych alertów niż pominąć kluczowego pracownika.


9. Wnioski

9.1 Wnioski z porównania modeli

W projekcie porównano pięć podejść do klasyfikacji binarnej na tym samym przepisie danych i tej samej procedurze walidacji. Na zbiorze testowym najwyższą wartość ROC-AUC uzyskało drzewo decyzyjne (CART). Modele zespołowe (Random Forest, XGBoost) osiągnęły nieco lepsze wyniki w walidacji krzyżowej na danych treningowych, lecz na niezależnym teście nie przewyższyły wyraźnie CART ani regresji logistycznej.

Sieć neuronowa (MLP) nie okazała się klasyfikatorem wyraźnie lepszym od modeli opartych na drzewach. Przy podobnej jakości predykcji wymaga dodatkowo normalizacji cech i strojenia hiperparametrów, a jej wyniki są trudniejsze do bezpośredniej interpretacji biznesowej. W tym zadaniu prostsze modele drzewiaste łączą akceptowalną jakość z czytelnością reguł decyzyjnych.

Wszystkie modele osiągają wysoki recall klasy „odejście” (często powyżej 95%), co oznacza, że większość faktycznych odejść jest wykrywana. Jednocześnie specificity pozostaje niska - wiele osób, które zostają w firmie, jest błędnie klasyfikowanych jako zagrożeni odejściem. Przy domyślnym progu 0,5 modele są więc ostrożne w stronę alarmu, co może być pożądane w HR (lepiej wcześniej zareagować), ale generuje koszt fałszywych interwencji.

Accuracy (~86-90%) sama w sobie nie opisuje dobrze jakości modelu, ponieważ klasa pozytywna stanowi tylko ok. 16% próbek. Dlatego w analizie attrition kluczowe są ROC-AUC, precision, recall oraz krzywe PR, a nie sama dokładność.

9.2 Wnioski z interpretacji danych i modeli

Analiza eksploracyjna i wykresy ważności zmiennych (VIP, drzewo CART) wskazują na powtarzalny zestaw czynników związanych z ryzykiem odejścia:

  • Nadgodziny (OverTime) - jeden z najsilniejszych sygnałów; pracownicy pracujący po godzinach częściej należą do grupy attrition.
  • Niska satysfakcja (JobSatisfaction, EnvironmentSatisfaction, WorkLifeBalance) - niski poziom satysfakcji koreluje z większym prawdopodobieństwem odejścia.
  • Staż i rozwój kariery - krótki pobyt w firmie (YearsAtCompany, TenureRatio), długi okres bez awansu (PromotionGap) oraz krótki czas w obecnej roli (YearsInCurrentRole).
  • Wynagrodzenie i wiek - IncomePerYear oraz YoungEmployee mogą współwystępować z większym ryzykiem odejścia w połączeniu z niską satysfakcją.

Wnioski te są spójne między modelami drzewiastymi i regresją logistyczną, co wzmacnia ich wiarygodność jako wskazówek dla działu HR, a nie tylko artefaktu jednego algorytmu.

9.3 Rekomendacje biznesowe dla HR

  1. Monitoring nadgodzin - regularna analiza obciążenia zespołów z OverTime = Yes; rozważenie redystrybucji zadań, dodatkowych etatów lub polityki ograniczającej chroniczne nadgodziny.
  2. Programy retencyjne - targetowane działania dla pracowników o niskiej satysfakcji, szczególnie w pierwszych latach stażu w organizacji (rozmowy, mentoring, dopasowanie roli).
  3. Ścieżki rozwoju - identyfikacja osób z wysokim PromotionGap; plany rozwoju, rotacja projektowa lub awans wewnętrzny jako prewencja stagnacji.
  4. Work-life balance - interwencje tam, gdzie niski WorkLifeBalance łączy się z innymi sygnałami ryzyka (elastyczny czas pracy, polityka remote/hybrid).
  5. Wykorzystanie modelu - stosować scoring jako listę obserwowanych pracowników wysokiego ryzyka, a nie jako automatyczną podstawę decyzji o zwolnieniu; każdy przypadek wymaga rozmowy z menedżerem i kontekstu jakościowego.

9.4 Ograniczenia i dalsze kroki

  • Dane syntetyczne / historyczne z jednej organizacji - wyniki mogą słabo przenosić się na inne firmy, branże lub okresy czasu.
  • Brak zmiennych zewnętrznych - model nie uwzględnia konkurencji wynagrodzeń, sytuacji na rynku pracy ani ofert innych pracodawców.
  • Niezbalansowanie klas - konieczna kalibracja progu decyzyjnego (np. niższy próg dla wyższego recall lub wyższy dla mniejszej liczby fałszywych alarmów) w zależności od kosztów biznesowych.
  • Ryzyko stronniczości (bias) - zmienne takie jak wiek (Age) lub YoungEmployee mogą wprowadzać niesprawiedliwe wzorce; model nie powinien służyć decyzjom personalnym bez audytu fairness.
  • Kolejne kroki badawcze - strojenie progu pod recall/precision, analiza SHAP dla wybranych modeli, porównanie kosztów fałszywie pozytywnych vs fałszywie negatywnych prognoz, walidacja na danych z innego okresu lub innej jednostki organizacyjnej.

10. Zapis wytrenowanych modeli

Pliki .rds umożliwiają weryfikację bez ponownego trenowania.

dir.create("models", showWarnings = FALSE)

saveRDS(fit_log,  "models/model_logistic.rds")
saveRDS(fit_tree, "models/model_cart.rds")
saveRDS(fit_rf,   "models/model_random_forest.rds")
saveRDS(fit_xgb,  "models/model_xgboost.rds")
saveRDS(fit_nnet, "models/model_neural_network.rds")
saveRDS(attrition_recipe, "models/recipe_attrition.rds")
saveRDS(best_name, "models/best_model_name.rds")

# Pakiet pomocniczy do ładowania
attrition_bundle <- list(
  models = models_list,
  recipe = attrition_recipe,
  best_model = best_name,
  test_metrics = test_metrics,
  trained_at = Sys.time()
)
saveRDS(attrition_bundle, "models/attrition_bundle.rds")

list.files("models", pattern = "\\.rds$")
## [1] "attrition_bundle.rds"     "best_model_name.rds"     
## [3] "model_cart.rds"           "model_logistic.rds"      
## [5] "model_neural_network.rds" "model_random_forest.rds" 
## [7] "model_xgboost.rds"        "recipe_attrition.rds"

Ładowanie w R:

bundle <- readRDS("models/attrition_bundle.rds")
best <- bundle$models[[bundle$best_model]]
predict(best, new_data, type = "prob")