Explainable Survival Modeling

Author
Affiliation

Bongani Ncube(3002164)

University Of the Witwatersrand (School of Public Health)

Published

August 8, 2025

Keywords

Explainable survival analysis, Statistical modeling, Model explainability, Gradient boosting, Proportional hazards

Dataset to be used: veteran from survival package

The veteran dataset contains data from a clinical trial of lung cancer patients. Below is a description of each variable:

Variable Description
trt Treatment group: 1 = standard, 2 = test
celltype Type of lung cancer cell (factor): "squamous", "smallcell", "adeno" (adenocarcinoma), "large"
time Survival time in days
status Censoring status: 1 = censored, 2 = dead
karno Karnofsky performance score (0–100), where higher = better health
diagtime Time from diagnosis to randomization (in months)
age Patient age (in years)
prior Prior therapy: 0 = no, 10 = yes
library(survival)
data(veteran)
surv_data = veteran

1. Standard Cox PH Model

Let us fit a standard Cox PH Model.

cox_model = coxph(Surv(time, status) ~ ., data = surv_data, x = TRUE)
summary(cox_model)
Call:
coxph(formula = Surv(time, status) ~ ., data = surv_data, x = TRUE)

  n= 137, number of events= 128 

                         coef   exp(coef)    se(coef)      z      Pr(>|z|)    
trt                0.29460282  1.34259300  0.20754960  1.419       0.15577    
celltypesmallcell  0.86156046  2.36685120  0.27528447  3.130       0.00175 ** 
celltypeadeno      1.19606637  3.30708248  0.30091699  3.975 0.00007045662 ***
celltypelarge      0.40129165  1.49375286  0.28268864  1.420       0.15574    
karno             -0.03281533  0.96771726  0.00550776 -5.958 0.00000000255 ***
diagtime           0.00008132  1.00008132  0.00913606  0.009       0.99290    
age               -0.00870647  0.99133132  0.00930030 -0.936       0.34920    
prior              0.00715936  1.00718505  0.02323054  0.308       0.75794    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

                  exp(coef) exp(-coef) lower .95 upper .95
trt                  1.3426     0.7448    0.8939    2.0166
celltypesmallcell    2.3669     0.4225    1.3799    4.0597
celltypeadeno        3.3071     0.3024    1.8336    5.9647
celltypelarge        1.4938     0.6695    0.8583    2.5996
karno                0.9677     1.0334    0.9573    0.9782
diagtime             1.0001     0.9999    0.9823    1.0182
age                  0.9913     1.0087    0.9734    1.0096
prior                1.0072     0.9929    0.9624    1.0541

Concordance= 0.736  (se = 0.021 )
Likelihood ratio test= 62.1  on 8 df,   p=0.0000000002
Wald test            = 62.37  on 8 df,   p=0.0000000002
Score (logrank) test = 66.74  on 8 df,   p=0.00000000002
  • n = 137: Number of patients (observations).

  • Number of events = 128: Patients who died (non-censored observations).

  • C-index = 0.705: The model has good discriminative ability—it correctly ranks patient survival ~71% of the time.

  • Overall model significance: All three global tests (Likelihood ratio, Wald, and Score/logrank) are highly significant (p < 0.001), meaning the model as a whole explains survival very well.

We can also get additional concordance measures.

cox_model$concordance
   concordant    discordant        tied.x        tied.y       tied.xy 
6480.00000000 2324.00000000    0.00000000   39.00000000    0.00000000 
  concordance           std 
   0.73602908    0.02117012 
  • concordant = 6,210: These are pairs of patients where the model correctly predicted who survived longer. More is better.

  • discordant = 2,594: Pairs where the model got it wrong — predicted the higher-risk patient to survive longer. Fewer is better.

  • tied.x = 0: Pairs where patients had the same predicted risk. None in this case.

  • tied.y = 39: Pairs where patients had the same survival time (actual time), so no informative comparison.

  • tied.xy = 0: Pairs where both predicted risks and actual survival times are tied.

  • concordance (std) = 0.705 (0.23): The model correctly ranks survival pairs ~70.5% of the time, with a standard error of ~2.3%. A small value of the SE like this indicates a stable estimate.

2. Machine Learning Survival Models

Target function: \(y=f(x)\)

  • \(f\) is explicity specified.

  • ML learns both the function \(f\) and the model parameters.

Random Survival Forest (randomForestSRC, ranger)

  • randomForestSRC: Extensive (competing risks, time-dependent covariates, missing data imputation).

  • ranger: Basic RSF (Kaplan-Meier or Nelson-Aalen only).

library(randomForestSRC)

rsf_model = rfsrc(Surv(time, status) ~ ., data = surv_data, num.trees = 200)
rsf_model
                         Sample size: 137
                    Number of deaths: 128
                     Number of trees: 500
           Forest terminal node size: 15
       Average no. of terminal nodes: 6.464
No. of variables tried at each split: 3
              Total no. of variables: 6
       Resampling used to grow trees: swor
    Resample size used to grow trees: 87
                            Analysis: RSF
                              Family: surv
                      Splitting rule: logrank *random*
       Number of random split points: 10
                          (OOB) CRPS: 61.2483391
                   (OOB) stand. CRPS: 0.06130965
   (OOB) Requested performance error: 0.30176411
  • (OOB) CRPS = 61.26: Continuous Ranked Probability Score — lower is better. Measures prediction accuracy in survival settings.
  • (OOB) standardised CRPS = 0.0613: CRPS scaled by time range or event frequency. Lower is better; used to compare across datasets.
  • (OOB) Requested performance error 0.299: Estimate of prediction error from the OOB samples — around 29.9% in this case.

Model predictions

rsf_pred = predict(rsf_model, newdata=surv_data)
rsf_pred
  Sample size of test (predict) data: 137
                Number of grow trees: 500
  Average no. of grow terminal nodes: 6.464
         Total no. of grow variables: 6
       Resampling used to grow trees: swor
    Resample size used to grow trees: 87
                            Analysis: RSF
                              Family: surv
                                CRPS: 52.25934453
                         stand. CRPS: 0.05231166
         Requested performance error: 0.23651476
Hmisc::rcorr.cens(-rsf_pred$predicted, Surv(surv_data$time, surv_data$status))
       C Index            Dxy           S.D.              n        missing 
    0.76465243     0.52930486     0.04302583   137.00000000     0.00000000 
    uncensored Relevant Pairs     Concordant      Uncertain 
  128.00000000 17608.00000000 13464.00000000   946.00000000 
  • Correctly ranks survival time for ~76% of usable pairs.

  • Has a low uncertainty (SD = 0.045), indicating stable performance.

  • Most events are observed (low censoring), which supports model reliability.

Survival Support Vectors (survivalsvm)

library(survivalsvm)
ssvm_model = survivalsvm(Surv(time, status) ~ ., 
                     data = surv_data,
                     gamma.mu = 0.001,            # required!
                     kernel = "lin_kernel")       # other options: 'lin_kernel', 'add_kernel', 'rbf_kernel', 'poly_kernel'
ssvm_model

survivalsvm result

Call:

 survivalsvm(Surv(time, status) ~ ., data = surv_data, gamma.mu = 0.001, kernel = "lin_kernel") 

Survival svm approach              : regression 
Type of Kernel                     : lin_kernel 
Optimization solver used           : quadprog 
Number of support vectors retained : 100 
survivalsvm version                : 0.0.6 
svm_lp = predict(ssvm_model, surv_data)

Hmisc::rcorr.cens(-svm_lp$predicted, Surv(surv_data$time, surv_data$status))
       C Index            Dxy           S.D.              n        missing 
    0.29327578    -0.41344843     0.04796709   137.00000000     0.00000000 
    uncensored Relevant Pairs     Concordant      Uncertain 
  128.00000000 17608.00000000  5164.00000000   946.00000000 

Gradient Boosting for Survival (gbm)

library(gbm)
gbms_model = gbm(Surv(time, status) ~ ., data = surv_data)
Distribution not specified, assuming coxph ...
gbms_model
gbm(formula = Surv(time, status) ~ ., data = surv_data)
A gradient boosted model with coxph loss function.
100 iterations were performed.
There were 6 predictors of which 6 had non-zero influence.
gbm_lp = predict(gbms_model, surv_data)

Hmisc::rcorr.cens(-gbm_lp, Surv(surv_data$time, surv_data$status))
       C Index            Dxy           S.D.              n        missing 
    0.73540436     0.47080872     0.04537116   137.00000000     0.00000000 
    uncensored Relevant Pairs     Concordant      Uncertain 
  128.00000000 17608.00000000 12949.00000000   946.00000000 

3. Explainable ML for Survival Models (survex)

  • Explainability is the degree to which we can understand and communicate how a model, system, or process makes its decisions.

  • It means making the model’s behavior transparent and interpretable to humans.

  • It involves showing which features influenced a prediction, how much, and in what direction.

Explainability vs. Interpretability

  • Interpretability: How well a human can understand the internal mechanics of the model (e.g., linear regression is inherently interpretable).

  • Explainability: How well we can communicate the reasoning behind predictions, even if the model is complex (like a deep neural network).

survex is an R package developed to explain machine‑learning survival models using explainable AI (XAI) techniques.

Many survival models—especially tree‑based (e.g. random survival forest)— produce nonlinear and time-varying effects that are hard to interpret.

survex brings interpretability by explaining both individual predicted survival curves and global model behavior, assessing variable importance, bias, and reliability.

Particularly useful in healthcare, clinical research, and other sensitive domains where transparency matters.

  • More about the package, you can visit this link
library(survex)

3.1. Creating Explainer

General Purpose Explainer - explain()

  • It wraps any predictive model (regression, classification, or survival) in a standardized interface for explainability.
explainer = explain(model)

Now let’s create an explainer

cph_model <- coxph(Surv(time, status)~., data = surv_data, model = TRUE, x = TRUE)

cph_explainer = explain(cph_model)
Preparation of a new explainer is initiated 
  -> model label       :  coxph (  default  ) 
  -> data              :  137  rows  6  cols (  extracted from the model  ) 
  -> target variable   :  137  values ( 128 events and 9 censored , censoring rate = 0.066 ) (  extracted from the model  ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  predict.coxph with type = 'risk' will be used (  default  ) 
  -> predict survival function  :  predictSurvProb.coxph will be used (  default  ) 
  -> predict cumulative hazard function  :  -log(predict_survival_function) will be used (  default  ) 
  -> model_info        :  package survival , ver. 3.8.3 , task survival (  default  ) 
  A new explainer has been created!  

The only mandatory input for the explain() function is the proportional hazards model itself. However, when creating the model, it is essential to set model = TRUE and x = TRUE. Omitting these arguments will result in an error.

rsf_model <- rfsrc(Surv(time, status)~., data = surv_data)

rsf_explainer = explain(rsf_model)
Preparation of a new explainer is initiated 
  -> model label       :  rfsrc (  default  ) 
  -> data              :  137  rows  6  cols (  extracted from the model  ) 
  -> target variable   :  137  values ( 128 events and 9 censored , censoring rate = 0.066 ) (  extracted from the model  ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
  -> predict survival function  :  stepfun based on predict.rfsrc()$survival will be used (  default  ) 
  -> predict cumulative hazard function  :  stepfun based on predict.rfsrc()$chf will be used (  default  ) 
  -> model_info        :  package randomForestSRC , ver. 3.4.1 , task survival (  default  ) 
  A new explainer has been created!  

For some models, data cannot be extracted automatically. For instance, when creating an explainer for a Random Survival Forest using the ranger package, we must manually provide the data X and the target variable y.

library(ranger)

ranger_rsf <- ranger(Surv(time, status)~., data = surv_data)

ranger_rsf_exp <- explain(ranger_rsf)
Preparation of a new explainer is initiated 
  -> model label       :  ranger (  default  ) 
  -> no data available! (  WARNING  ) 
  -> target variable   :  not specified! (  WARNING  ) 
  -> times   :  not specified and automatic generation is impossible ('y' is NULL)! (  WARNING  ) 
  -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
  -> predict survival function  :  stepfun based on predict.ranger()$survival will be used (  default  ) 
  -> predict cumulative hazard function  :  stepfun based on predict.ranger()$chf will be used (  default  ) 
  -> model_info        :  package ranger , ver. 0.17.0 , task survival (  default  ) 
  -> model_info        :  survival task detected but 'y' is a NULL   (  WARNING  ) 
  -> model_info        :  by deafult survival tasks supports only 'y' parameter of 'survival::Surv' class 
  A new explainer has been created!  

We should supply the data parameter X without the columns containing survival information.

X = surv_data[, -c(3,4)]
y = Surv(surv_data$time, surv_data$status)

rang_rf_explainer <- explain(ranger_rsf, data = X, y = y)
Preparation of a new explainer is initiated 
  -> model label       :  ranger (  default  ) 
  -> data              :  137  rows  6  cols 
  -> target variable   :  137  values ( 128 events and 9 censored ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
  -> predict survival function  :  stepfun based on predict.ranger()$survival will be used (  default  ) 
  -> predict cumulative hazard function  :  stepfun based on predict.ranger()$chf will be used (  default  ) 
  -> model_info        :  package ranger , ver. 0.17.0 , task survival (  default  ) 
  A new explainer has been created!  

3.2. Making Predictions

Now we are going to work exclusively with explainer objects, which serve as standardized wrappers around the models. The key advantage of using an explainer is its ability to generate predictions—whether risk scores, survival probabilities, or cumulative hazard functions—in a consistent manner, regardless of the underlying model.

  • predict(): Computes predicted survival probabilities, cumulative hazard, or risk scores from the model wrapped in the explainer.
predict(rsf_explainer, newdata = X_new)

Survival Function

pred_surv = predict(rsf_explainer, newdata = surv_data, output_type='survival')
#pred_surv

So the survival output is essentially a matrix where each element \((i,j)\) is the survival probability of individual \(j\) surviving past time \(i\). - The probability that an individual survives beyond a certain time \(t\). - Eg: If \(S(5)=0.8\), there’s an 80% chance the person survives past 5 units of time.

Cumulative Hazard Function

new_obs = surv_data[1, -c(3,4)]
pred_chf = predict(rsf_explainer, newdata = new_obs, output_type='chf')
pred_chf
          [,1]       [,2]       [,3]       [,4]       [,5]       [,6]
[1,] 0.0110632 0.01130605 0.01705345 0.02697063 0.03075422 0.05468718
           [,7]       [,8]       [,9]      [,10]      [,11]      [,12]
[1,] 0.05612141 0.07083183 0.07311615 0.07527041 0.07740642 0.08649309
          [,13]   [,14]     [,15]     [,16]     [,17]     [,18]     [,19]
[1,] 0.09797097 0.10067 0.1215816 0.1296288 0.1308685 0.2275782 0.2339819
         [,20]     [,21]     [,22]     [,23]     [,24]     [,25]     [,26]
[1,] 0.2436291 0.2627047 0.3044986 0.3053727 0.3801407 0.4114505 0.4431012
         [,27]     [,28]     [,29]     [,30]     [,31]     [,32]     [,33]
[1,] 0.4499907 0.4516038 0.4902727 0.5422882 0.6690722 0.6971438 0.7308551
         [,34]     [,35]     [,36]     [,37]    [,38]    [,39]    [,40]
[1,] 0.7742135 0.8083247 0.8470295 0.9379728 0.977424 1.128483 1.168013
        [,41]    [,42]    [,43]    [,44]    [,45]    [,46]    [,47]    [,48]
[1,] 1.266984 1.580142 1.717938 1.783009 1.955323 2.169437 2.281171 2.559837
        [,49]    [,50]
[1,] 2.950171 3.327171
  • The accumulated risk of the event occurring up to time \(t\).
  • Non-negative and generally increases over time.
  • The survival probability is the exponential of the negative cumulative hazard.

Risk Score / Prognostic Index

pred_risk = predict(cph_explainer, newdata = surv_data[1:10,], output_type='risk')
pred_risk
        1         2         3         4         5         6         7         8 
0.7354128 0.5942155 0.9629558 0.8324949 0.5893519 3.2519560 1.5232113 0.3855308 
        9        10 
1.2815790 0.5250487 
  • A model-derived score indicating relative risk compared to a baseline.
  • Usually any real number (can be positive or negative).
  • Higher risk score → higher hazard (higher chance of event sooner).
  • In PH models, the hazard for an individual is proportional to \(exp(\text{risk score})\).

3.3. Model Performance - model_performance()

  • model_performance(): Evaluates the overall quality of the model on a survival prediction task. It computes metrics like Concordance index (C-index), Time-dependent AUC, Integrated Brier Score.

  • C-index (Concordance Index)

    • Measures How well the model ranks survival times.
    • The probability that, for a randomly selected pair of individuals, the one who died earlier had a higher predicted risk.
    • Higher is better.
  • Integrated C/D AUC (Concordance / Discrimination AUC)

    • Measures time-dependent version of AUC (Area Under the Curve), averaged over all follow-up times.
    • Higher (strong discrimination) is better.
  • Integrated Brier Score (IBS)

    • Measures prediction error over time, averaged across all time points.
    • Combines calibration and discrimination.
    • Brier score is the mean squared error between predicted survival probability and actual outcome (0 = dead, 1 = survived).
    • Lower (lower calibration) is better (0 = perfect prediction, 0.1-0.2 = very good, 0.25+ = Poor)
model_performance(explainer)
cox_model <- coxph(Surv(time, status) ~ ., data = surv_data, model=TRUE, x = TRUE)
cox_explainer <- explain(cox_model, label = "Cox PH")
Preparation of a new explainer is initiated 
  -> model label       :  Cox PH 
  -> data              :  137  rows  6  cols (  extracted from the model  ) 
  -> target variable   :  137  values ( 128 events and 9 censored , censoring rate = 0.066 ) (  extracted from the model  ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  predict.coxph with type = 'risk' will be used (  default  ) 
  -> predict survival function  :  predictSurvProb.coxph will be used (  default  ) 
  -> predict cumulative hazard function  :  -log(predict_survival_function) will be used (  default  ) 
  -> model_info        :  package survival , ver. 3.8.3 , task survival (  default  ) 
  A new explainer has been created!  
rsf_model <- rfsrc(Surv(time, status)~., data = surv_data)
rsf_explainer <- explain(rsf_model, data = X, y = y, label = "RSF I")
Preparation of a new explainer is initiated 
  -> model label       :  RSF I 
  -> data              :  137  rows  6  cols 
  -> target variable   :  137  values ( 128 events and 9 censored ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
  -> predict survival function  :  stepfun based on predict.rfsrc()$survival will be used (  default  ) 
  -> predict cumulative hazard function  :  stepfun based on predict.rfsrc()$chf will be used (  default  ) 
  -> model_info        :  package randomForestSRC , ver. 3.4.1 , task survival (  default  ) 
  A new explainer has been created!  
ranger_rsf <- ranger(Surv(time, status)~., data = surv_data)
ranger_rsf_explainer <- explain(ranger_rsf, data = X, y = y, label = "RSF II")
Preparation of a new explainer is initiated 
  -> model label       :  RSF II 
  -> data              :  137  rows  6  cols 
  -> target variable   :  137  values ( 128 events and 9 censored ) 
  -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
  -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
  -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
  -> predict survival function  :  stepfun based on predict.ranger()$survival will be used (  default  ) 
  -> predict cumulative hazard function  :  stepfun based on predict.ranger()$chf will be used (  default  ) 
  -> model_info        :  package ranger , ver. 0.17.0 , task survival (  default  ) 
  A new explainer has been created!  
cox_perf = model_performance(cox_explainer)
rsf_perf = model_performance(rsf_explainer)
ran_perf = model_performance(ranger_rsf_explainer)
plot(cox_perf, rsf_perf, ran_perf)

We can also plot the scalar metrics in the form of bar plots.

plot(cox_perf, rsf_perf, ran_perf, metrics_type = 'scalar')

  • If our priority is better overall prediction accuracy (calibration), lower IBS is more important, so the RSF with lower IBS is preferable.

  • If our priority is better ranking of risk or discrimination, then higher C-index and AUC matter more.

3.4. Global Explanations

We check how each variable influences the models’ predictions on a global level. Which variables are important to the model.

Variable Importance

  • model_parts(): It tells which variables most affect predictions by shuffling them and measuring performance drop. It calculates permutational variable importance with the difference being that the loss function is time-dependent (by default it is loss_brier_score()), so the influence of each variable can be different at each considered time point.
model_parts(explainer)
cph_m_parts <- model_parts(cox_explainer)
rsf_m_parts <- model_parts(rsf_explainer)
ran_m_parts <- model_parts(ranger_rsf_explainer)

plot(cph_m_parts, rsf_m_parts, ran_m_parts)

Accross the three models, the permutation of the karno variable leads to the highest increase in the loss function, with the second being celltype. These two variables are the most important for models making predictions.

Partial Dependence

Partial dependence plots are generated using the model_profile() function. These plots illustrate how changing the value of a single variable affects the model’s prediction on average.

To prevent nonsensical values—like a treatment value of 0.5—we must specify the categorical_variables parameter. While all factors are automatically recognized as categorical, if we want to treat a numeric variable as categorical, we need to include it explicitly in this parameter.

  • model_profile(): How does a variable affect the average prediction?
model_profile(explainer)
cox_m_profile = model_profile(cox_explainer, categorical_variables=c("trt", "prior"))
plot(cox_m_profile, numerical_plot_type='lines')

From the plot, we observe that in the proportional hazards model, the prior and diagtime variables have minimal impact—the prediction bands are narrow and nearly overlapping, indicating that changes in these variables have little effect on the predicted survival function. In contrast, the karno variable shows a much wider band, suggesting that even small changes in its value lead to significant differences in the predicted survival. Lower karno values are associated with poorer survival prospects, as evidenced by a faster decline in the survival function.

We can also plot the same information for the random survival forest models.

rsf_m_profile = model_profile(rsf_explainer, categorical_variables=c("trt", "prior"))
plot(rsf_m_profile, numerical_plot_type = "contours") #facet_ncol = 2

This type of plot also gives us valuable insight that is easy to overlook in the other type. For example, we see a sharp drop in survival function values around diagtime=25. We also observe that the most significant influence of the karno variable is consistent across the proportional hazards and random survival forest.

rang_m_profile = model_profile(ranger_rsf_explainer, categorical_variables=c("trt", "prior"))
plot(rang_m_profile, numerical_plot_type='lines')

3.5. Local Explanations

  • Local explanations explain model’s predicitons for a single observation. The predict_parts() function can be used to assess the importance of variables while making predictions for a selected observation. This can be done by two methods, SurvSHAP(t) and SurvLIME.

Variable Attributions: Which variables contribute for the prediction?

  • predict_parts(): Provides local explanations for a single prediction. Breaks down how each feature contributes to an individual’s predicted survival curve or risk. Uses techniques like SurvSHAP(t) — time-dependent SHAP values for survival, SurvLIME — local linear approximation.

SurvSHAP(t)

predict_parts(explainer, new_observation)
new_obs1 = surv_data[32,]
cox_p_parts1 = predict_parts(cox_explainer, new_obs1)
new_obs2 = surv_data[12,]
cox_p_parts2 = predict_parts(cox_explainer, new_obs2)
plot(cox_p_parts1)

plot(cox_p_parts2)

On the first plot for observation 32, the value of karno variable improves the chances of survival of this individual. In contrast, the value of the celltype variable decreases them.

On the second plot, the situation is flipped for observation 12, celltype increases the chances of survival while karno decreases them.

For the random survival forest:

rsf_p_parts = predict_parts(rsf_explainer, new_obs2)
rang_p_parts = predict_parts(ranger_rsf_explainer, new_obs2)
plot(rsf_p_parts)

plot(rang_p_parts)

SurvLIME

A different way of attributing variable importance is provided by the SurvLIME method. It works by finding a surrogate proportional hazards model that approximates the survival model at the local area around an observation of interest. Variable importance is then attributed using the coefficients of the found model.

new_obs = surv_data[12,]
cox_p_parts = predict_parts(cox_explainer, new_obs, type="survlime")
rsf_p_parts = predict_parts(rsf_explainer, new_obs, type="survlime")
rang_p_parts = predict_parts(ranger_rsf_explainer, new_obs, type="survlime")
plot(cox_p_parts)

plot(rsf_p_parts)

plot(rang_p_parts)

The left part of the plot shows which variables are most important and if their value increases or lowers the chances of survival, whereas the right shows the black-box model prediction together with the one from the found surrogate model. This is useful information because the closer these functions are, the more accurate the explanation can be.

Ceteris paribus

Another explanation technique provided by this package is ceteris paribus profiles. They show how the prediction changes when we change the value of one variable at a time. We can think of them as the equivalent of partial dependence plots but applied to a single observation. The predict_profile() function is used to make these explanations.

new_obs = surv_data[12,]
cox_p_profile = predict_profile(cox_explainer, new_obs, categorical_variables=c("trt", "prior"))
rsf_p_profile = predict_profile(rsf_explainer, new_obs, categorical_variables=c("trt", "prior"))
rang_p_profile = predict_profile(ranger_rsf_explainer, new_obs, categorical_variables=c("trt", "prior"))
plot(cox_p_profile)

plot(rsf_p_profile)

plot(rang_p_profile)

These plots also give a lot of valuable insight. For example, we see that the prior variable does not influence the predictions very much, as the lines representing survival functions for its different values almost overlap. We also observe that the celltype values of “small” and “adeno”, as well as “large” and “squamous” result in almost the same prediction for this observation. It can be seen that the most important variable for this observation is karno, as the differences in the survival function are the greatest, with low values indicating lower chances of survival. In contrast, high values of the diagtime variable seem to indicate lower chances of survival.