📌 Load Libraries
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tidyr)
library(ggplot2)
library(xgboost)
##
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
##
## slice
🔧 Simulate Discrete-Time Survival Data
set.seed(123)
n <- 100
max_time <- 10
# Static covariates
data <- data.frame(
id = 1:n,
age = rnorm(n, 50, 10),
gender = sample(0:1, n, replace = TRUE),
income = rnorm(n, 50000, 10000)
)
# Simulate true event time and censoring
data$true_time <- sample(1:max_time, n, replace = TRUE)
data$event <- rbinom(n, 1, 0.7)
🚀 Train XGBoost Model
# Prepare data for XGBoost
X <- as.matrix(long_data[, c("time", "age", "gender", "income")])
y <- long_data$event
dtrain <- xgb.DMatrix(data = X, label = y)
params <- list(
objective = "binary:logistic",
eval_metric = "logloss"
)
xgb_model <- xgboost(data = dtrain,
params = params,
nrounds = 50,
verbose = 0)
📈 Predict Hazards
long_data$pred_xgb <- predict(xgb_model, newdata = X)
📊 Plot Predicted Hazards Over Time
hazard_plot_data <- long_data %>%
group_by(time) %>%
summarise(mean_hazard = mean(pred_xgb, na.rm = TRUE))
ggplot(hazard_plot_data, aes(x = time, y = mean_hazard)) +
geom_line(color = "darkgreen", size = 1.2) +
geom_point(color = "black", size = 2) +
labs(
title = "Predicted Hazard Over Time (XGBoost)",
x = "Time Interval",
y = "Predicted Hazard"
) +
theme_minimal()
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## i Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
