library(survival)
data(veteran)
= veteran surv_data
Dataset to be used: veteran
from survival
package
The veteran
dataset contains data from a clinical trial of lung cancer patients. Below is a description of each variable:
Variable | Description |
---|---|
trt |
Treatment group: 1 = standard , 2 = test |
celltype |
Type of lung cancer cell (factor): "squamous" , "smallcell" , "adeno" (adenocarcinoma), "large" |
time |
Survival time in days |
status |
Censoring status: 1 = censored , 2 = dead |
karno |
Karnofsky performance score (0–100), where higher = better health |
diagtime |
Time from diagnosis to randomization (in months) |
age |
Patient age (in years) |
prior |
Prior therapy: 0 = no , 10 = yes |
1. Standard Cox PH Model
Let us fit a standard Cox PH Model.
= coxph(Surv(time, status) ~ ., data = surv_data, x = TRUE)
cox_model summary(cox_model)
Call:
coxph(formula = Surv(time, status) ~ ., data = surv_data, x = TRUE)
n= 137, number of events= 128
coef exp(coef) se(coef) z Pr(>|z|)
trt 0.29460282 1.34259300 0.20754960 1.419 0.15577
celltypesmallcell 0.86156046 2.36685120 0.27528447 3.130 0.00175 **
celltypeadeno 1.19606637 3.30708248 0.30091699 3.975 0.00007045662 ***
celltypelarge 0.40129165 1.49375286 0.28268864 1.420 0.15574
karno -0.03281533 0.96771726 0.00550776 -5.958 0.00000000255 ***
diagtime 0.00008132 1.00008132 0.00913606 0.009 0.99290
age -0.00870647 0.99133132 0.00930030 -0.936 0.34920
prior 0.00715936 1.00718505 0.02323054 0.308 0.75794
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
exp(coef) exp(-coef) lower .95 upper .95
trt 1.3426 0.7448 0.8939 2.0166
celltypesmallcell 2.3669 0.4225 1.3799 4.0597
celltypeadeno 3.3071 0.3024 1.8336 5.9647
celltypelarge 1.4938 0.6695 0.8583 2.5996
karno 0.9677 1.0334 0.9573 0.9782
diagtime 1.0001 0.9999 0.9823 1.0182
age 0.9913 1.0087 0.9734 1.0096
prior 1.0072 0.9929 0.9624 1.0541
Concordance= 0.736 (se = 0.021 )
Likelihood ratio test= 62.1 on 8 df, p=0.0000000002
Wald test = 62.37 on 8 df, p=0.0000000002
Score (logrank) test = 66.74 on 8 df, p=0.00000000002
n = 137: Number of patients (observations).
Number of events = 128: Patients who died (non-censored observations).
C-index = 0.705: The model has good discriminative ability—it correctly ranks patient survival ~71% of the time.
Overall model significance: All three global tests (Likelihood ratio, Wald, and Score/logrank) are highly significant (p < 0.001), meaning the model as a whole explains survival very well.
We can also get additional concordance measures.
$concordance cox_model
concordant discordant tied.x tied.y tied.xy
6480.00000000 2324.00000000 0.00000000 39.00000000 0.00000000
concordance std
0.73602908 0.02117012
concordant = 6,210: These are pairs of patients where the model correctly predicted who survived longer. More is better.
discordant = 2,594: Pairs where the model got it wrong — predicted the higher-risk patient to survive longer. Fewer is better.
tied.x = 0: Pairs where patients had the same predicted risk. None in this case.
tied.y = 39: Pairs where patients had the same survival time (actual time), so no informative comparison.
tied.xy = 0: Pairs where both predicted risks and actual survival times are tied.
concordance (std) = 0.705 (0.23): The model correctly ranks survival pairs ~70.5% of the time, with a standard error of ~2.3%. A small value of the SE like this indicates a stable estimate.
2. Machine Learning Survival Models
Target function: \(y=f(x)\)
\(f\) is explicity specified.
ML learns both the function \(f\) and the model parameters.
Random Survival Forest (randomForestSRC
, ranger
)
randomForestSRC
: Extensive (competing risks, time-dependent covariates, missing data imputation).ranger
: Basic RSF (Kaplan-Meier or Nelson-Aalen only).
library(randomForestSRC)
= rfsrc(Surv(time, status) ~ ., data = surv_data, num.trees = 200)
rsf_model rsf_model
Sample size: 137
Number of deaths: 128
Number of trees: 500
Forest terminal node size: 15
Average no. of terminal nodes: 6.464
No. of variables tried at each split: 3
Total no. of variables: 6
Resampling used to grow trees: swor
Resample size used to grow trees: 87
Analysis: RSF
Family: surv
Splitting rule: logrank *random*
Number of random split points: 10
(OOB) CRPS: 61.2483391
(OOB) stand. CRPS: 0.06130965
(OOB) Requested performance error: 0.30176411
- (OOB) CRPS = 61.26: Continuous Ranked Probability Score — lower is better. Measures prediction accuracy in survival settings.
- (OOB) standardised CRPS = 0.0613: CRPS scaled by time range or event frequency. Lower is better; used to compare across datasets.
- (OOB) Requested performance error 0.299: Estimate of prediction error from the OOB samples — around 29.9% in this case.
Model predictions
= predict(rsf_model, newdata=surv_data)
rsf_pred rsf_pred
Sample size of test (predict) data: 137
Number of grow trees: 500
Average no. of grow terminal nodes: 6.464
Total no. of grow variables: 6
Resampling used to grow trees: swor
Resample size used to grow trees: 87
Analysis: RSF
Family: surv
CRPS: 52.25934453
stand. CRPS: 0.05231166
Requested performance error: 0.23651476
::rcorr.cens(-rsf_pred$predicted, Surv(surv_data$time, surv_data$status)) Hmisc
C Index Dxy S.D. n missing
0.76465243 0.52930486 0.04302583 137.00000000 0.00000000
uncensored Relevant Pairs Concordant Uncertain
128.00000000 17608.00000000 13464.00000000 946.00000000
Correctly ranks survival time for ~76% of usable pairs.
Has a low uncertainty (SD = 0.045), indicating stable performance.
Most events are observed (low censoring), which supports model reliability.
Survival Support Vectors (survivalsvm
)
library(survivalsvm)
= survivalsvm(Surv(time, status) ~ .,
ssvm_model data = surv_data,
gamma.mu = 0.001, # required!
kernel = "lin_kernel") # other options: 'lin_kernel', 'add_kernel', 'rbf_kernel', 'poly_kernel'
ssvm_model
survivalsvm result
Call:
survivalsvm(Surv(time, status) ~ ., data = surv_data, gamma.mu = 0.001, kernel = "lin_kernel")
Survival svm approach : regression
Type of Kernel : lin_kernel
Optimization solver used : quadprog
Number of support vectors retained : 100
survivalsvm version : 0.0.6
= predict(ssvm_model, surv_data)
svm_lp
::rcorr.cens(-svm_lp$predicted, Surv(surv_data$time, surv_data$status)) Hmisc
C Index Dxy S.D. n missing
0.29327578 -0.41344843 0.04796709 137.00000000 0.00000000
uncensored Relevant Pairs Concordant Uncertain
128.00000000 17608.00000000 5164.00000000 946.00000000
Gradient Boosting for Survival (gbm
)
library(gbm)
= gbm(Surv(time, status) ~ ., data = surv_data) gbms_model
Distribution not specified, assuming coxph ...
gbms_model
gbm(formula = Surv(time, status) ~ ., data = surv_data)
A gradient boosted model with coxph loss function.
100 iterations were performed.
There were 6 predictors of which 6 had non-zero influence.
= predict(gbms_model, surv_data)
gbm_lp
::rcorr.cens(-gbm_lp, Surv(surv_data$time, surv_data$status)) Hmisc
C Index Dxy S.D. n missing
0.73540436 0.47080872 0.04537116 137.00000000 0.00000000
uncensored Relevant Pairs Concordant Uncertain
128.00000000 17608.00000000 12949.00000000 946.00000000
3. Explainable ML for Survival Models (survex
)
Explainability is the degree to which we can understand and communicate how a model, system, or process makes its decisions.
It means making the model’s behavior transparent and interpretable to humans.
It involves showing which features influenced a prediction, how much, and in what direction.
Explainability vs. Interpretability
Interpretability: How well a human can understand the internal mechanics of the model (e.g., linear regression is inherently interpretable).
Explainability: How well we can communicate the reasoning behind predictions, even if the model is complex (like a deep neural network).
survex
is an R package developed to explain machine‑learning survival models using explainable AI (XAI) techniques.
Many survival models—especially tree‑based (e.g. random survival forest)— produce nonlinear and time-varying effects that are hard to interpret.
survex
brings interpretability by explaining both individual predicted survival curves and global model behavior, assessing variable importance, bias, and reliability.
Particularly useful in healthcare, clinical research, and other sensitive domains where transparency matters.
- More about the package, you can visit this link
library(survex)
3.1. Creating Explainer
General Purpose Explainer - explain()
- It wraps any predictive model (regression, classification, or survival) in a standardized interface for explainability.
= explain(model) explainer
Now let’s create an explainer
<- coxph(Surv(time, status)~., data = surv_data, model = TRUE, x = TRUE)
cph_model
= explain(cph_model) cph_explainer
Preparation of a new explainer is initiated
-> model label : coxph ( [33m default [39m )
-> data : 137 rows 6 cols ( extracted from the model )
-> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : predict.coxph with type = 'risk' will be used ( [33m default [39m )
-> predict survival function : predictSurvProb.coxph will be used ( [33m default [39m )
-> predict cumulative hazard function : -log(predict_survival_function) will be used ( [33m default [39m )
-> model_info : package survival , ver. 3.8.3 , task survival ( [33m default [39m )
A new explainer has been created!
The only mandatory input for the explain()
function is the proportional hazards model itself. However, when creating the model, it is essential to set model = TRUE
and x = TRUE
. Omitting these arguments will result in an error.
<- rfsrc(Surv(time, status)~., data = surv_data)
rsf_model
= explain(rsf_model) rsf_explainer
Preparation of a new explainer is initiated
-> model label : rfsrc ( [33m default [39m )
-> data : 137 rows 6 cols ( extracted from the model )
-> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : sum over the predict_cumulative_hazard_function will be used ( [33m default [39m )
-> predict survival function : stepfun based on predict.rfsrc()$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : stepfun based on predict.rfsrc()$chf will be used ( [33m default [39m )
-> model_info : package randomForestSRC , ver. 3.4.1 , task survival ( [33m default [39m )
A new explainer has been created!
For some models, data cannot be extracted automatically. For instance, when creating an explainer for a Random Survival Forest using the ranger
package, we must manually provide the data X
and the target variable y
.
library(ranger)
<- ranger(Surv(time, status)~., data = surv_data)
ranger_rsf
<- explain(ranger_rsf) ranger_rsf_exp
Preparation of a new explainer is initiated
-> model label : ranger ( [33m default [39m )
-> no data available! ( WARNING )
-> target variable : not specified! ( WARNING )
-> times : not specified and automatic generation is impossible ('y' is NULL)! ( WARNING )
-> predict function : sum over the predict_cumulative_hazard_function will be used ( [33m default [39m )
-> predict survival function : stepfun based on predict.ranger()$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : stepfun based on predict.ranger()$chf will be used ( [33m default [39m )
-> model_info : package ranger , ver. 0.17.0 , task survival ( [33m default [39m )
-> model_info : survival task detected but 'y' is a NULL ( WARNING )
-> model_info : by deafult survival tasks supports only 'y' parameter of 'survival::Surv' class
A new explainer has been created!
We should supply the data parameter X
without the columns containing survival information.
= surv_data[, -c(3,4)]
X = Surv(surv_data$time, surv_data$status)
y
<- explain(ranger_rsf, data = X, y = y) rang_rf_explainer
Preparation of a new explainer is initiated
-> model label : ranger ( [33m default [39m )
-> data : 137 rows 6 cols
-> target variable : 137 values ( 128 events and 9 censored )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : sum over the predict_cumulative_hazard_function will be used ( [33m default [39m )
-> predict survival function : stepfun based on predict.ranger()$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : stepfun based on predict.ranger()$chf will be used ( [33m default [39m )
-> model_info : package ranger , ver. 0.17.0 , task survival ( [33m default [39m )
A new explainer has been created!
3.2. Making Predictions
Now we are going to work exclusively with explainer objects, which serve as standardized wrappers around the models. The key advantage of using an explainer is its ability to generate predictions—whether risk scores, survival probabilities, or cumulative hazard functions—in a consistent manner, regardless of the underlying model.
predict()
: Computes predicted survival probabilities, cumulative hazard, or risk scores from the model wrapped in the explainer.
predict(rsf_explainer, newdata = X_new)
Survival Function
= predict(rsf_explainer, newdata = surv_data, output_type='survival')
pred_surv #pred_surv
So the survival output is essentially a matrix where each element \((i,j)\) is the survival probability of individual \(j\) surviving past time \(i\). - The probability that an individual survives beyond a certain time \(t\). - Eg: If \(S(5)=0.8\), there’s an 80% chance the person survives past 5 units of time.
Cumulative Hazard Function
= surv_data[1, -c(3,4)]
new_obs = predict(rsf_explainer, newdata = new_obs, output_type='chf')
pred_chf pred_chf
[,1] [,2] [,3] [,4] [,5] [,6]
[1,] 0.0110632 0.01130605 0.01705345 0.02697063 0.03075422 0.05468718
[,7] [,8] [,9] [,10] [,11] [,12]
[1,] 0.05612141 0.07083183 0.07311615 0.07527041 0.07740642 0.08649309
[,13] [,14] [,15] [,16] [,17] [,18] [,19]
[1,] 0.09797097 0.10067 0.1215816 0.1296288 0.1308685 0.2275782 0.2339819
[,20] [,21] [,22] [,23] [,24] [,25] [,26]
[1,] 0.2436291 0.2627047 0.3044986 0.3053727 0.3801407 0.4114505 0.4431012
[,27] [,28] [,29] [,30] [,31] [,32] [,33]
[1,] 0.4499907 0.4516038 0.4902727 0.5422882 0.6690722 0.6971438 0.7308551
[,34] [,35] [,36] [,37] [,38] [,39] [,40]
[1,] 0.7742135 0.8083247 0.8470295 0.9379728 0.977424 1.128483 1.168013
[,41] [,42] [,43] [,44] [,45] [,46] [,47] [,48]
[1,] 1.266984 1.580142 1.717938 1.783009 1.955323 2.169437 2.281171 2.559837
[,49] [,50]
[1,] 2.950171 3.327171
- The accumulated risk of the event occurring up to time \(t\).
- Non-negative and generally increases over time.
- The survival probability is the exponential of the negative cumulative hazard.
Risk Score / Prognostic Index
= predict(cph_explainer, newdata = surv_data[1:10,], output_type='risk')
pred_risk pred_risk
1 2 3 4 5 6 7 8
0.7354128 0.5942155 0.9629558 0.8324949 0.5893519 3.2519560 1.5232113 0.3855308
9 10
1.2815790 0.5250487
- A model-derived score indicating relative risk compared to a baseline.
- Usually any real number (can be positive or negative).
- Higher risk score → higher hazard (higher chance of event sooner).
- In PH models, the hazard for an individual is proportional to \(exp(\text{risk score})\).
3.3. Model Performance - model_performance()
model_performance()
: Evaluates the overall quality of the model on a survival prediction task. It computes metrics like Concordance index (C-index), Time-dependent AUC, Integrated Brier Score.C-index (Concordance Index)
- Measures How well the model ranks survival times.
- The probability that, for a randomly selected pair of individuals, the one who died earlier had a higher predicted risk.
- Higher is better.
Integrated C/D AUC (Concordance / Discrimination AUC)
- Measures time-dependent version of AUC (Area Under the Curve), averaged over all follow-up times.
- Higher (strong discrimination) is better.
Integrated Brier Score (IBS)
- Measures prediction error over time, averaged across all time points.
- Combines calibration and discrimination.
- Brier score is the mean squared error between predicted survival probability and actual outcome (0 = dead, 1 = survived).
- Lower (lower calibration) is better (0 = perfect prediction, 0.1-0.2 = very good, 0.25+ = Poor)
model_performance(explainer)
<- coxph(Surv(time, status) ~ ., data = surv_data, model=TRUE, x = TRUE)
cox_model <- explain(cox_model, label = "Cox PH") cox_explainer
Preparation of a new explainer is initiated
-> model label : Cox PH
-> data : 137 rows 6 cols ( extracted from the model )
-> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : predict.coxph with type = 'risk' will be used ( [33m default [39m )
-> predict survival function : predictSurvProb.coxph will be used ( [33m default [39m )
-> predict cumulative hazard function : -log(predict_survival_function) will be used ( [33m default [39m )
-> model_info : package survival , ver. 3.8.3 , task survival ( [33m default [39m )
A new explainer has been created!
<- rfsrc(Surv(time, status)~., data = surv_data)
rsf_model <- explain(rsf_model, data = X, y = y, label = "RSF I") rsf_explainer
Preparation of a new explainer is initiated
-> model label : RSF I
-> data : 137 rows 6 cols
-> target variable : 137 values ( 128 events and 9 censored )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : sum over the predict_cumulative_hazard_function will be used ( [33m default [39m )
-> predict survival function : stepfun based on predict.rfsrc()$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : stepfun based on predict.rfsrc()$chf will be used ( [33m default [39m )
-> model_info : package randomForestSRC , ver. 3.4.1 , task survival ( [33m default [39m )
A new explainer has been created!
<- ranger(Surv(time, status)~., data = surv_data)
ranger_rsf <- explain(ranger_rsf, data = X, y = y, label = "RSF II") ranger_rsf_explainer
Preparation of a new explainer is initiated
-> model label : RSF II
-> data : 137 rows 6 cols
-> target variable : 137 values ( 128 events and 9 censored )
-> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : sum over the predict_cumulative_hazard_function will be used ( [33m default [39m )
-> predict survival function : stepfun based on predict.ranger()$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : stepfun based on predict.ranger()$chf will be used ( [33m default [39m )
-> model_info : package ranger , ver. 0.17.0 , task survival ( [33m default [39m )
A new explainer has been created!
= model_performance(cox_explainer)
cox_perf = model_performance(rsf_explainer)
rsf_perf = model_performance(ranger_rsf_explainer) ran_perf
plot(cox_perf, rsf_perf, ran_perf)
We can also plot the scalar metrics in the form of bar plots.
plot(cox_perf, rsf_perf, ran_perf, metrics_type = 'scalar')
If our priority is better overall prediction accuracy (calibration), lower IBS is more important, so the RSF with lower IBS is preferable.
If our priority is better ranking of risk or discrimination, then higher C-index and AUC matter more.
3.4. Global Explanations
We check how each variable influences the models’ predictions on a global level. Which variables are important to the model.
Variable Importance
model_parts()
: It tells which variables most affect predictions by shuffling them and measuring performance drop. It calculates permutational variable importance with the difference being that the loss function is time-dependent (by default it isloss_brier_score()
), so the influence of each variable can be different at each considered time point.
model_parts(explainer)
<- model_parts(cox_explainer)
cph_m_parts <- model_parts(rsf_explainer)
rsf_m_parts <- model_parts(ranger_rsf_explainer)
ran_m_parts
plot(cph_m_parts, rsf_m_parts, ran_m_parts)
Accross the three models, the permutation of the karno
variable leads to the highest increase in the loss function, with the second being celltype
. These two variables are the most important for models making predictions.
Partial Dependence
Partial dependence plots are generated using the model_profile()
function. These plots illustrate how changing the value of a single variable affects the model’s prediction on average.
To prevent nonsensical values—like a treatment value of 0.5—we must specify the categorical_variables
parameter. While all factors are automatically recognized as categorical, if we want to treat a numeric variable as categorical, we need to include it explicitly in this parameter.
model_profile()
: How does a variable affect the average prediction?
model_profile(explainer)
= model_profile(cox_explainer, categorical_variables=c("trt", "prior"))
cox_m_profile plot(cox_m_profile, numerical_plot_type='lines')
From the plot, we observe that in the proportional hazards model, the prior
and diagtime
variables have minimal impact—the prediction bands are narrow and nearly overlapping, indicating that changes in these variables have little effect on the predicted survival function. In contrast, the karno
variable shows a much wider band, suggesting that even small changes in its value lead to significant differences in the predicted survival. Lower karno
values are associated with poorer survival prospects, as evidenced by a faster decline in the survival function.
We can also plot the same information for the random survival forest models.
= model_profile(rsf_explainer, categorical_variables=c("trt", "prior"))
rsf_m_profile plot(rsf_m_profile, numerical_plot_type = "contours") #facet_ncol = 2
This type of plot also gives us valuable insight that is easy to overlook in the other type. For example, we see a sharp drop in survival function values around diagtime=25
. We also observe that the most significant influence of the karno
variable is consistent across the proportional hazards and random survival forest.
= model_profile(ranger_rsf_explainer, categorical_variables=c("trt", "prior"))
rang_m_profile plot(rang_m_profile, numerical_plot_type='lines')
3.5. Local Explanations
- Local explanations explain model’s predicitons for a single observation. The
predict_parts()
function can be used to assess the importance of variables while making predictions for a selected observation. This can be done by two methods,SurvSHAP(t)
andSurvLIME
.
Variable Attributions: Which variables contribute for the prediction?
predict_parts()
: Provides local explanations for a single prediction. Breaks down how each feature contributes to an individual’s predicted survival curve or risk. Uses techniques likeSurvSHAP(t)
— time-dependent SHAP values for survival,SurvLIME
— local linear approximation.
SurvSHAP(t)
predict_parts(explainer, new_observation)
= surv_data[32,]
new_obs1 = predict_parts(cox_explainer, new_obs1)
cox_p_parts1 = surv_data[12,]
new_obs2 = predict_parts(cox_explainer, new_obs2)
cox_p_parts2 plot(cox_p_parts1)
plot(cox_p_parts2)
On the first plot for observation 32, the value of karno
variable improves the chances of survival of this individual. In contrast, the value of the celltype
variable decreases them.
On the second plot, the situation is flipped for observation 12, celltype
increases the chances of survival while karno
decreases them.
For the random survival forest:
= predict_parts(rsf_explainer, new_obs2)
rsf_p_parts = predict_parts(ranger_rsf_explainer, new_obs2)
rang_p_parts plot(rsf_p_parts)
plot(rang_p_parts)
SurvLIME
A different way of attributing variable importance is provided by the SurvLIME
method. It works by finding a surrogate proportional hazards model that approximates the survival model at the local area around an observation of interest. Variable importance is then attributed using the coefficients of the found model.
= surv_data[12,]
new_obs = predict_parts(cox_explainer, new_obs, type="survlime")
cox_p_parts = predict_parts(rsf_explainer, new_obs, type="survlime")
rsf_p_parts = predict_parts(ranger_rsf_explainer, new_obs, type="survlime")
rang_p_parts plot(cox_p_parts)
plot(rsf_p_parts)
plot(rang_p_parts)
The left part of the plot shows which variables are most important and if their value increases or lowers the chances of survival, whereas the right shows the black-box model prediction together with the one from the found surrogate model. This is useful information because the closer these functions are, the more accurate the explanation can be.
Ceteris paribus
Another explanation technique provided by this package is ceteris paribus
profiles. They show how the prediction changes when we change the value of one variable at a time. We can think of them as the equivalent of partial dependence plots but applied to a single observation. The predict_profile()
function is used to make these explanations.
= surv_data[12,]
new_obs = predict_profile(cox_explainer, new_obs, categorical_variables=c("trt", "prior"))
cox_p_profile = predict_profile(rsf_explainer, new_obs, categorical_variables=c("trt", "prior"))
rsf_p_profile = predict_profile(ranger_rsf_explainer, new_obs, categorical_variables=c("trt", "prior"))
rang_p_profile plot(cox_p_profile)
plot(rsf_p_profile)
plot(rang_p_profile)
These plots also give a lot of valuable insight. For example, we see that the prior
variable does not influence the predictions very much, as the lines representing survival functions for its different values almost overlap. We also observe that the celltype
values of “small” and “adeno”, as well as “large” and “squamous” result in almost the same prediction for this observation. It can be seen that the most important variable for this observation is karno
, as the differences in the survival function are the greatest, with low values indicating lower chances of survival. In contrast, high values of the diagtime
variable seem to indicate lower chances of survival.