Predicting Employee Churn Using Organizational Network Analysis

A Comparative Machine Learning Evaluation of Workplace Attrition Risk

Author

Amanuel Gebremariam

Published

June 4, 2026

Packages

We activate packages of interest to perform prediction on the employee data of interest.

Code
# fs to navigate project directory
library(fs)

# tidyverse for data wrangling
library(tidyverse)

# snakecase for naming conventions
library(snakecase)

# scales for scaling variables
library(scales)

# ggthemes for plot themes
library(ggthemes)

# tidymodels for machine learning
library(tidymodels)

# skimr to summarize data
library(skimr)

# ranger for random forest
library(ranger)

# future for future parallel architecture
library(future)

# doFuture for future parallel backedn
library(doFuture)

# parallelly for parallel computation
library(parallelly)

# tictoc for timing computations
library(tictoc)

# vip for predictor importance in models
library(vip)

Import Data

We import the employee_net.rdata data file.

Code
### import the data file
## load
load(
    # file path
    path(
        # go out of the scripts folder first
        "..",
        # folders 
        "data",
        # file name
        "employee_net.rdata"
    )
)

##  list all the data objects in the environment
ls()
[1] "edges" "nodes"

Examine Data

We examine the employee_net.rdata table to become familiar with the data.

Code
### count how many employees churned
## data
nodes %>%
    ## count values
    count(churn) %>%
    ## compute proportion
    mutate(prop = n / sum(n))

Clean Data

We clean the data based on our examination of the data.

Code
### calculate how many connections each employee initiates
## save as object
outbound_connections <- edges %>%
    ## group by the initiating employee
    count(from, name = "outbound_count") %>%
    ## rename column to match node identification
    rename(id = from)

### calculate how many inbound connections each employee receives
## save as intermediate object
inbound_connections <- edges %>%
    ## group by the receiving employee
    count(to, name = "inbound_count") %>%
    ## rename column to match node identification
    rename(id = to)

### create working data table
## save
network_work <- nodes %>%
    ## merge the outbound connection counts
    left_join(outbound_connections, by = "id") %>%
    ## merge the inbound connection counts
    left_join(inbound_connections, by = "id") %>%
    ## update columns
    mutate(
        # ensure id is handled as a characteracter variable
        id = as.character(id),
        # convert churn to a standard factor for classification
        churn = as_factor(churn) %>%
            # recode levels
            fct_recode(
                "No" = "No",
                "Yes" = "Yes"
            ),
        # create integer churn tracking variable
        churn_int = if_else(churn == "Yes", 1, 0),
        # replace missing network values with 0 if an employee has no connections
        outbound_count = replace_na(outbound_count, 0),
        inbound_count = replace_na(inbound_count, 0),
        # create a total connections index metric
        total_connections = outbound_count + inbound_count
    ) %>%
    ## convert to tibble
    as_tibble()

## preview
glimpse(network_work)
Rows: 1,000
Columns: 6
$ id                <chr> "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "…
$ churn             <fct> No, No, No, No, No, No, No, No, No, No, No, No, No, …
$ outbound_count    <int> 460, 136, 445, 431, 447, 418, 419, 427, 412, 431, 42…
$ inbound_count     <int> 0, 0, 1, 2, 1, 2, 3, 1, 6, 4, 6, 6, 6, 5, 6, 9, 9, 1…
$ churn_int         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ total_connections <int> 460, 136, 446, 433, 448, 420, 422, 428, 418, 435, 43…

Variable Relationships

We evaluate the relationships between our aggregated network predictors and our outcome of interest (churn).

Summary of Network Predictors by Churn

We summarize the distributions of outbound, inbound, and total connections for employees who stayed versus those who left the organization.

Code
### summarize numeric network
## data
network_work %>%
    ## select columns
    select(
        # choose numeric network metrics
        outbound_count, inbound_count, total_connections,
        # choose grouping variable
        churn
    ) %>%
    ## define groups
    group_by(churn) %>%
    ## summarize variables
    skim_without_charts()
Data summary
Name Piped data
Number of rows 1000
Number of columns 4
_______________________
Column type frequency:
numeric 3
________________________
Group variables churn

Variable type: numeric

skim_variable churn n_missing complete_rate mean sd p0 p25 p50 p75 p100
outbound_count Yes 0 1 146.95 125.50 2 48.0 100 237.0 440
outbound_count No 0 1 205.72 131.05 0 88.5 198 319.5 460
inbound_count Yes 0 1 171.39 137.04 3 61.0 111 294.0 464
inbound_count No 0 1 201.44 128.96 0 86.0 189 313.0 455
total_connections Yes 0 1 318.34 153.86 106 130.0 428 447.0 477
total_connections No 0 1 407.17 96.74 105 425.5 438 449.0 488

Visualize Churn Probability by Outbound Connections

We visualize the relationship between the number of outbound connections initiated by an employee and their probability of churning.

Code
### visualize relationship between outbound connections and churn
## call plot

ggplot(
    # data
    network_work,
    # map the data
    aes(
        # x-axis
        x = outbound_count, 
        # y-axis
        y = churn_int
    )
) +
    ## plots each employee as dots
    geom_jitter(
        height = 0.05, 
        alpha = 0.3,
        color = "gray30"
    ) +
    ## estimate logistic regression line
    geom_smooth(
        method = "glm", 
        method.args = list(family = "binomial"), 
        se = FALSE,
        color = "royalblue",
        linewidth = 1.2
    ) +
    ## labels
    labs(
        title = "Churn Probability by Outbound Connections",
        subtitle = "Does initiating more connections reduce an employee's likelihood of leaving?",
        y = "Probability of Churn (0 to 1)",
        x = "Number of Outbound Connections"
    ) +
    ## update theme
    theme_minimal()

Visualize Churn Probability by Total Network Connection

We visualize the relationship between the total number of connections (inbound + outbound) an employee has and their probability of churning.

Code
### visualize churn across total connection counts
## call plot
ggplot(
    # data
    network_work,
    # map the data
    aes(
        # x-axis
        x = total_connections, 
        # y-axis
        y = churn_int
    )
) +
    ## plots each employee as dots
    geom_jitter(
        height = 0.05, 
        alpha = 0.3,
        color = "gray30"
    ) +
    ## estimate logistic regression line
    geom_smooth(
        method = "glm", 
        method.args = list(family = "binomial"), 
        se = FALSE,
        color = "firebrick",
        linewidth = 1.2
    ) +
    ## labels
    labs(
        title = "Churn Probability by Total Network Connections",
        subtitle = "Does an employee's total network size impact their retention?",
        y = "Probability of Churn (0 to 1)",
        x = "Total Network Connections (Inbound + Outbound)"
    ) +
    ## update theme
    theme_minimal()

Training Predictive Models

In this section, we train several models to make predictions of employee who are at risk of leaving using our network_work data table.

Split Data

We split the original network_work data into subsets for training and testing. We additionally split the training data into cross-validation subsets.

Code
### generate random seed
## call function
set.seed(1001)

### split data into training and testing subsets
## save
network_split <- initial_split(
    # working data table
    network_work, 
    # proportion of data allocated for training
    prop = 0.75, 
    # stratification by outcome variable to handle class imbalance
    strata = churn
)

### extract training data subset
## save
network_train <- training(network_split)

### extract testing data subset
## save
network_test <- testing(network_split)

### create cross-validation folds for model tuning
## save
network_folds <- vfold_cv(
    # baseline training data
    network_train, 
    # number of structural splits
    v = 5, 
    # evaluation repeats
    repeats = 2,
    # target optimization stratification
    strata = churn
)

Create Model Recipe

We create a base recipe for our various predictive models.

Code
### create modeling recipe
## save
network_recipe <- recipe(
    # formula
    churn ~ .,
    # baseline dataset
    data = network_train
) %>%
    ## remove undesired variables
    step_rm(
        # exclude row index and integer visual placeholder variables
        id, churn_int, total_connections
    ) %>%
    ## center and scale all continuous network inputs
    step_normalize(
        # choose all numeric predictors
        all_numeric_predictors()
    )

### preview preprocessed training data
## call recipe
network_recipe %>%
    ## prepare the recipe
    prep() %>%
    ## compute trainin data
    bake(
        # use training data
        new_data = NULL
    ) %>%
    ## print table wide
    print(width = Inf)
# A tibble: 749 × 3
   outbound_count inbound_count churn
            <dbl>         <dbl> <fct>
 1          2.03          -1.51 No   
 2         -0.464         -1.51 No   
 3          1.91          -1.51 No   
 4          1.80          -1.50 No   
 5          1.93          -1.51 No   
 6          1.70          -1.50 No   
 7          1.71          -1.49 No   
 8          1.77          -1.51 No   
 9          1.77          -1.47 No   
10          1.88          -1.47 No   
# ℹ 739 more rows

Metrics

We specify the metrics to track as we evaluate the quality of the models to make accurate predictions on the outcome of interest.

Code
### specify metrics to track
## save
class_metrics <- metric_set(
    # classification metrics
    accuracy, bal_accuracy,
    detection_prevalence, f_meas,
    j_index, mcc, npv, spec,
    precision, recall,
    # probability metrics
    pr_auc, roc_auc
)

Train Logistic Regression

We train a logistic regression model to predict employee churn in our training data.

Code
### specify logistic regression model
## save
log_reg_spec <- logistic_reg() %>%
    ## set the computational engine
    set_engine("glm") %>%
    ## type of prediction problem
    set_mode("classification")

### create workflow pairing model and recipe
## save
log_reg_wf <- workflow() %>%
    ## add model specification
    add_model(log_reg_spec) %>%
    ## add model recipe
    add_recipe(network_recipe)

### fit the model across our cross-validation folds
## save
log_reg_fit_rs <- fit_resamples(
    # workflow object
    log_reg_wf,
    # cross-validation splits
    resamples = network_folds,
    # target evaluation metrics
    metrics = class_metrics,
    # save layout predictions for analysis
    control = control_resamples(save_pred = TRUE)
)

### view the collected cross-validation performance metrics
## call object
log_reg_fit_rs %>%
    ## view metrics
    collect_metrics()
Code
### fit the finalized logistic regression model to the complete training partition
## save
log_reg_fit <- fit(
    # baseline workflow
    log_reg_wf, 
    # complete training data
    data = network_train
)

### print the final fitted model
## call object
log_reg_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_rm()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)

Coefficients:
   (Intercept)  outbound_count   inbound_count  
        1.8831          0.7484          0.6153  

Degrees of Freedom: 748 Total (i.e. Null);  746 Residual
Null Deviance:      628.5 
Residual Deviance: 581.5    AIC: 587.5

Train Random Forest

We train a random forest on the cross-validation training data to predict employee churn in our training data.

Code
### specify random forest model with tuning parameters
## save
rf_spec <- rand_forest(
    # sample size of predictors at each split node
    mtry = tune(),
    # number of operational trees in ensemble
    trees = tune(),
    # minimum observations within terminal nodes
    min_n = tune()
) %>%
    ## specify computational engine
    set_engine("ranger", importance = "impurity") %>%
    ## classification problem mode
    set_mode("classification")

### create workflow pairing the model and recipe
## save
rf_wf <- workflow() %>%
    ## add model specification
    add_model(rf_spec) %>%
    ## add model recipe
    add_recipe(network_recipe)

### create tuning grid for hyperparameters
## save
rf_grid <- grid_regular(
    # sample range based on our 3 network predictors
    mtry(range = c(1, 2)),
    # number of trees
    trees(range = c(500, 1500)),
    # minimum observation in node
    min_n(range = c(5, 25)),
    # grid size
    levels = 2
)

### view the tuning grid
rf_grid
Code
### start timer
## call
tic()

### set up parallel processing with future
## use all available cores minus 1 to keep system responsive
plan(
    # parallel computation
    multisession, 
    # number  of cores
    workers = availableCores() - 1
)

### tune random forest hyperparameters using parallel processing
## save
rf_tune_rs <- tune_grid(
    # workflow object
    rf_wf,
    # cross-validation fold
    resamples = network_folds,
    # hyperparameter grid
    grid = rf_grid,
    # metrics to evaluate
    metrics = class_metrics,
    # parallel processing unit
    control = control_grid(
        save_pred = TRUE,
        parallel_over = "everything",
        allow_par = TRUE
    )
)

### convert to single session
## call
plan(sequential)

### stop timer
## call
toc()
40.46 sec elapsed
Code
### evaluate performance of all random forests
rf_metrics <- collect_metrics(rf_tune_rs)

### show the best models by different metrics
## call function
show_best(rf_tune_rs, metric = "roc_auc")
Code
### finalize random forest workflow
## save as object
rf_wf_final <- finalize_workflow(
    # existing workflow
    rf_wf,
    # pchoose the best random forest
    select_best(
        # training results
        rf_tune_rs,
        # metric
        metric = "roc_auc"
    )
)

### estimate best random forest on entire training data
## save as object
rf_fit <- fit(rf_wf_final, data = network_train)

### extract predictor importance
## save as object
rf_vip <- rf_fit %>%
    ## extract model results
    extract_fit_parsnip() %>%
    ## extract top predictors
    vip(num_features = 3)

## show results
rf_vip

Testing Predictive Models

We evaluate the logistic regression model and the best-performing random forest model on the testing dataset.

Evaluate Against Metrics

We evaluate the logistic regression and random forest models on the testing data against the classification and probability metrics we specified.

Code
### logistic regression predictions
## save as object
log_reg_pred <- predict(
    # model
    log_reg_fit, 
    # data
    new_data = network_test, 
    # type of prediction
    type = "prob"
) %>%
    ## rename columns
    rename(
        # new = old
        lr_prob_yes = .pred_Yes,
        # new = old
        lr_prob_no = .pred_No
    )    

### random forest predictions  
## save as object
rf_pred <- predict(
    # model
    rf_fit, 
    # data
    new_data = network_test, 
    # type of prediction
    type = "prob"
) %>%
    ## rename columns
    rename(
        # new = old
        rf_prob_yes = .pred_Yes,
        # new = old
        rf_prob_no = .pred_No
    )  

### combine predictions into one table
## save as object
test_pred <- network_test %>%
    ## select columns
    select(churn) %>%
    ## bind columns
    bind_cols(
        # logistic regression predictions
        log_reg_pred,
        # random forest predictions
        rf_pred 
    )

### compute metrics for the models on the testing data
## save as object
test_pred_metrics <- test_pred %>%
    ## add identifier column
    rowid_to_column(var = "row_id") %>%
    ## pivot table long
    pivot_longer(
        # columns
        cols = c(lr_prob_yes, rf_prob_yes),
        # names
        names_to = "model",
        # values
        values_to = "prob_yes"
    ) %>%
    ## update columns
    mutate(
        #  align the actual data column factor levels
        churn = as_factor(churn) %>% 
          fct_relevel(
            "No", 
            "Yes"
            ),
        # binary predictions
        pred_yes = if_else(
            # condition
            prob_yes < 0.5,
            # true
            "No",
            # false
            "Yes"
        ) %>%
            # convert to factor
            as_factor() %>%
            # relevel factor
            fct_relevel(
                "No",
                "Yes"
            )
    ) %>%
    ## group by model
    group_by(model) %>% 
    ## map function to groups
    group_map(
        # function
        ~ class_metrics(
            # data
            data = .x,
            # truth
            truth = churn,
            # estimate
            estimate = pred_yes,
            # probabilities
            prob_yes, 
            # event level for binary classification
            event_level = "second"
        )
    ) %>%
    ## set element names
    set_names(
        # vector
        c("LR", "RF")
    ) %>%
    ## combine tables
    list_rbind(
        # names
        names_to = "model"
    ) %>%
    ## pivot wide
    pivot_wider(
        # identifier columns
        id_cols = model,
        # names
        names_from = .metric,
        # values
        values_from = .estimate
    )

Visualize ROC Curves

We visualize the ROC (receiver-operator characteristic) curves for the logistic regression and random forest models on the testing data.

Code
### compute metrics for the models on the testing data
## save as object
test_roc_curve <- test_pred %>%
    ## add identifier column
    rowid_to_column(var = "row_id") %>%
    ## pivot table long
    pivot_longer(
        # columns
        cols = c(lr_prob_yes, rf_prob_yes),
        # names
        names_to = "model",
        # values
        values_to = "prob_yes"
    ) %>%
    ## group by model
    group_by(model) %>% 
    ## map function to groups
    group_map(
        # function
        ~ roc_curve(
            # data
            data = .x,
            # truth
            truth = churn,
            # probabilities
            prob_yes
        )
    ) %>%
    ## set element names
    set_names(
        # vector
        c("LR", "RF")
    ) %>%
    ## combine tables
    list_rbind(
        # names
        names_to = "model"
    )

### plot roc curves
## call function
roc_curve_plot <- ggplot(
    # data
    test_roc_curve, 
    # aesthetics
    aes(
        # x-axis
        x = 1 - specificity, 
        # y-axis
        y = sensitivity, 
        # models
        color = model
    )
) +
    ## line geometry
    geom_line(linewidth = 1.5) +
    ## line of random prediction
    geom_abline(
        # parameters
        slope = 1, intercept = 0, 
        # line type
        linetype = "dashed", 
        # line color
        color = "gray50"
    ) +
    ## labels
    labs(
        # title
        title = "ROC Curves: Logistic Regression vs. Random Forest",
        # x-axis
        x = "False Positive Rate",
        # y-axis
        y = "True Positive Rate",
        # legend
        color = "Model"
    ) +
    ## manual color scale
    scale_color_manual(
        # colors
        values = c("LR" = "red", "RF" = "blue")
    ) +
    ## equal coordinates
    coord_equal() +
    ## change the theme
    theme_bw() +
    ## update theme
    theme(
        # legend position
        legend.position = "bottom"
    )

roc_curve_plot

Save Objects

We save objects of interest.

Code
### save testing data
## call function
write_rds(
    # data
    network_train, 
    # path
    path(
       "..",
       # folder
        "data",
        # file
        "network_train.rds"
    )
)

### save testing data
## call function
write_rds(
    # data
    network_test, 
    # path
    path(
        "..",
        # folder
        "data",
        # file
        "network_test.rds"
    )
)

### save both models in one file
## call function
save(
    # list estimated model objects
    log_reg_fit, rf_fit, 
    # path
    file = path(
        "..",
        # folder
        "data",
        # file
        "model_fits.rdata"
    )
)

### save roc curves
## call function
ggsave(
    # file path
    filename = path(
        "..",
        # folder
        "plots",
        # file
        "roc_curves.png"
    ),
    # plot object
    plot = roc_curve_plot,
    # width in inches
    width = 8,
    # height in inches
    height = 8,
    # dots per inch
    dpi = 300
)