References

Background

Survival analysis is at the core of epidemiological data analysis. There are multiple well-known Bayesian data analysis textbooks, but they typically do not cover survival analysis. Here we will showcase some R examples of Bayesian survival analysis.

Load packages

library(tidyverse)
library(survminer)
## library(bayesSurv)
## devtools::install_github('jburos/biostan', build_vignettes = TRUE, dependencies = TRUE)
library(biostan)
## Warning: replacing previous import 'rstan::loo' by 'rstanarm::loo' when loading 'biostan'
library(rstan)
library(bayesplot)
library(tidybayes)

Descriptive analysis example

Load a simple dataset

data(leukemia, package = "survival")
leukemia <- as_data_frame(leukemia)
leukemia
## # A tibble: 23 x 3
##     time status x         
##  * <dbl>  <dbl> <fct>     
##  1     9      1 Maintained
##  2    13      1 Maintained
##  3    13      0 Maintained
##  4    18      1 Maintained
##  5    23      1 Maintained
##  6    28      0 Maintained
##  7    31      1 Maintained
##  8    34      1 Maintained
##  9    45      0 Maintained
## 10    48      1 Maintained
## # ... with 13 more rows
aml                  package:survival                  R Documentation
Acute Myelogenous Leukemia survival data
Description:
     Survival in patients with Acute Myelogenous Leukemia.  The
     question at the time was whether the standard course of
     chemotherapy should be extended ('maintainance') for additional
     cycles.
Usage:
     aml
     leukemia
Format:
       time:    survival or censoring time
       status:  censoring status
       x:       maintenance chemotherapy given? (factor)
Source:
     Rupert G. Miller (1997), _Survival Analysis_.  John Wiley & Sons.
     ISBN: 0-471-25218-2.

Regular Kaplan-Meier plot

km_fit <- survfit(Surv(time, status) ~ x, data = leukemia)
## Error in survfit(Surv(time, status) ~ x, data = leukemia): could not find function "survfit"
km_fit
## Error in eval(expr, envir, enclos): object 'km_fit' not found
## http://www.sthda.com/english/wiki/survminer-0-2-4
ggsurvplot(km_fit,
           conf.int = TRUE,
           break.time.by = 20,
           risk.table = TRUE)
## Error in ggsurvplot(km_fit, conf.int = TRUE, break.time.by = 20, risk.table = TRUE): object 'km_fit' not found

Stan Weibull fit

Here we will use the Weibull model code available in biostan.

stan_weibull_survival_model_file <- system.file('stan', 'weibull_survival_model.stan', package =  'biostan')
biostan::print_stan_file(stan_weibull_survival_model_file)
## /*  Variable naming: 
##  obs       = observed 
##  cen       = (right) censored 
##  N         = number of samples 
##  M         = number of covariates 
##  bg        = established risk (or protective) factors 
##  tau       = scale parameter 
## */ 
## // Tomi Peltola, tomi.peltola@aalto.fi 
##  
## functions { 
##   vector sqrt_vec(vector x) { 
##     vector[dims(x)[1]] res; 
##  
##     for (m in 1:dims(x)[1]){ 
##       res[m] = sqrt(x[m]); 
##     } 
##  
##     return res; 
##   } 
##  
##   vector bg_prior_lp(real r_global, vector r_local) { 
##     r_global ~ normal(0.0, 10.0); 
##     r_local ~ inv_chi_square(1.0); 
##  
##     return r_global * sqrt_vec(r_local); 
##   } 
## } 
##  
## data { 
##   int<lower=0> Nobs; 
##   int<lower=0> Ncen; 
##   int<lower=0> M_bg; 
##   vector[Nobs] yobs; 
##   vector[Ncen] ycen; 
##   matrix[Nobs, M_bg] Xobs_bg; 
##   matrix[Ncen, M_bg] Xcen_bg; 
## } 
##  
## transformed data { 
##   real<lower=0> tau_mu; 
##   real<lower=0> tau_al; 
##  
##   tau_mu = 10.0; 
##   tau_al = 10.0; 
## } 
##  
## parameters { 
##   real<lower=0> tau_s_bg_raw; 
##   vector<lower=0>[M_bg] tau_bg_raw; 
##  
##   real alpha_raw; 
##   vector[M_bg] beta_bg_raw; 
##  
##   real mu; 
## } 
##  
## transformed parameters { 
##   vector[M_bg] beta_bg; 
##   real alpha; 
##  
##   beta_bg = bg_prior_lp(tau_s_bg_raw, tau_bg_raw) .* beta_bg_raw; 
##   alpha = exp(tau_al * alpha_raw); 
## } 
##  
## model { 
##   yobs ~ weibull(alpha, exp(-(mu + Xobs_bg * beta_bg)/alpha)); 
##   target += weibull_lccdf(ycen | alpha, exp(-(mu + Xcen_bg * beta_bg)/alpha)); 
##  
##   beta_bg_raw ~ normal(0.0, 1.0); 
##   alpha_raw ~ normal(0.0, 1.0); 
##  
##   mu ~ normal(0.0, tau_mu); 
## } 
##  
## generated quantities { 
##     real yhat_uncens[Nobs + Ncen]; 
##     real log_lik[Nobs + Ncen]; 
##     real lp[Nobs + Ncen]; 
##  
##     for (i in 1:Nobs) { 
##         lp[i] = mu + Xobs_bg[i,] * beta_bg; 
##         yhat_uncens[i] = weibull_rng(alpha, exp(-(mu + Xobs_bg[i,] * beta_bg)/alpha)); 
##         log_lik[i] = weibull_lpdf(yobs[i] | alpha, exp(-(mu + Xobs_bg[i,] * beta_bg)/alpha)); 
##     } 
##     for (i in 1:Ncen) { 
##         lp[Nobs + i] = mu + Xcen_bg[i,] * beta_bg; 
##         yhat_uncens[Nobs + i] = weibull_rng(alpha, exp(-(mu + Xcen_bg[i,] * beta_bg)/alpha)); 
##         log_lik[Nobs + i] = weibull_lccdf(ycen[i] | alpha, exp(-(mu + Xcen_bg[i,] * beta_bg)/alpha)); 
##     } 
## } 
## 

Stan parameterizes this probability density function for the survival time \(y\) as follows.

\[f(y|\alpha,\sigma) = \frac{\alpha}{\sigma}\left(\frac{y}{\sigma}\right)^{\alpha-1}e^{-(y/\sigma)^{\alpha}}\]

where \(\alpha\) is the shape parameter and \(\sigma\) is the scale parameter. The average survival time increases with an increasing \(\sigma\). To incorporate covariates, the scale parameter is defined as follows in the Stan program used here.

\[\sigma_{i} = \exp{\left( - \frac{\mu + X_{i}^{T}\beta}{\alpha} \right)}\]

As \(\sigma\) is a decreasing function of \(\beta\), a positive \(\beta\) means a shorter average survival time with a unit increase in \(X_i\) and a negative \(\beta\) means a longer average survival time with a unit increase in \(X_i\).

From the data block, observations where events were observed and censored are handled separately. No hyperparameters for the priors are specified here. They are hard-coded.

stan_weibull_survival_model_code <- biostan::read_stan_file(stan_weibull_survival_model_file)
biostan::print_stan_code(stan_weibull_survival_model_code, section = "data")
## data { 
##   int<lower=0> Nobs; 
##   int<lower=0> Ncen; 
##   int<lower=0> M_bg; 
##   vector[Nobs] yobs; 
##   vector[Ncen] ycen; 
##   matrix[Nobs, M_bg] Xobs_bg; 
##   matrix[Ncen, M_bg] Xcen_bg; 
## }

Here we structure the data accordingly.

stan_weibull_survival_model_data <-
    list(
        ## Number of event individuals
        Nobs = sum(leukemia$status == 1),
        ## Number of censored individuals
        Ncen = sum(leukemia$status == 0),
        ## Number of covariates
        M_bg = 1,
        ## Times for event individuals
        yobs = leukemia$time[leukemia$status == 1],
        ## Times for censored individuals
        ycen = leukemia$time[leukemia$status == 0],
        ## Covariates for event individuals as a matrix
        Xobs_bg = matrix(as.numeric(leukemia$x == "Maintained")[leukemia$status == 1]),
        ## Covariates for censored individuals as a matrix
        Xcen_bg = matrix(as.numeric(leukemia$x == "Maintained")[leukemia$status == 0])
        )
stan_weibull_survival_model_data
## $Nobs
## [1] 18
## 
## $Ncen
## [1] 5
## 
## $M_bg
## [1] 1
## 
## $yobs
##  [1]  9 13 18 23 31 34 48  5  5  8  8 12 23 27 30 33 43 45
## 
## $ycen
## [1]  13  28  45 161  16
## 
## $Xobs_bg
##       [,1]
##  [1,]    1
##  [2,]    1
##  [3,]    1
##  [4,]    1
##  [5,]    1
##  [6,]    1
##  [7,]    1
##  [8,]    0
##  [9,]    0
## [10,]    0
## [11,]    0
## [12,]    0
## [13,]    0
## [14,]    0
## [15,]    0
## [16,]    0
## [17,]    0
## [18,]    0
## 
## $Xcen_bg
##      [,1]
## [1,]    1
## [2,]    1
## [3,]    1
## [4,]    1
## [5,]    0

Here we fit the model.

stan_weibull_survival_model_fit <-
    rstan::stan(file = stan_weibull_survival_model_file,
                data = stan_weibull_survival_model_data)
## Warning: There were 198 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems

Here we check the results. The Rhat values are all around 1, indicating reasonable results for all parameters. Each element of yhat_uncens (a vector of 23 elements) is MCMC samples of event times for each individual based on \((\alpha, \sigma_{i})\) MCMC samples.

stan_weibull_survival_model_fit
## Inference for Stan model: weibull_survival_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                   mean se_mean     sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
## tau_s_bg_raw      4.05    0.16   3.86   0.18   1.26   2.70   5.81  14.22   576    1
## tau_bg_raw[1]    10.41    4.17 221.57   0.15   0.43   0.92   2.42  29.65  2826    1
## alpha_raw         0.02    0.00   0.02  -0.02   0.01   0.02   0.03   0.05   930    1
## beta_bg_raw[1]   -0.52    0.02   0.49  -1.81  -0.76  -0.36  -0.17  -0.01   730    1
## mu               -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## beta_bg[1]       -1.07    0.01   0.55  -2.23  -1.43  -1.06  -0.70  -0.02  1786    1
## alpha             1.23    0.01   0.22   0.85   1.07   1.21   1.37   1.69   920    1
## yhat_uncens[1]   63.73    1.05  63.24   2.64  22.38  46.93  85.89 214.35  3651    1
## yhat_uncens[2]   62.50    1.05  65.60   2.64  21.21  44.82  81.67 224.04  3939    1
## yhat_uncens[3]   62.30    1.19  64.83   2.01  20.80  45.44  83.24 219.83  2987    1
## yhat_uncens[4]   63.81    1.05  66.20   2.71  21.88  47.17  84.74 221.33  4003    1
## yhat_uncens[5]   62.70    1.03  65.79   2.01  21.25  44.56  83.08 223.17  4076    1
## yhat_uncens[6]   63.49    1.08  68.60   2.49  21.20  46.32  83.08 227.42  4012    1
## yhat_uncens[7]   62.56    1.11  66.09   2.56  21.43  45.74  81.67 227.46  3561    1
## yhat_uncens[8]   25.81    0.42  25.37   0.96   9.32  19.16  35.18  87.81  3594    1
## yhat_uncens[9]   25.58    0.40  24.49   1.00   8.89  19.43  33.99  89.68  3736    1
## yhat_uncens[10]  25.31    0.40  24.36   1.11   8.93  18.75  33.87  90.30  3652    1
## yhat_uncens[11]  25.16    0.38  24.10   1.13   8.91  18.47  34.05  85.33  4007    1
## yhat_uncens[12]  25.58    0.39  24.43   1.04   8.85  19.00  34.65  88.01  3916    1
## yhat_uncens[13]  25.82    0.45  26.16   0.96   8.96  19.72  34.44  87.60  3448    1
## yhat_uncens[14]  26.12    0.40  25.22   0.94   8.87  19.32  34.94  92.70  3906    1
## yhat_uncens[15]  25.81    0.41  25.33   1.12   9.23  19.08  34.55  89.76  3839    1
## yhat_uncens[16]  25.29    0.39  23.94   0.98   8.98  18.94  34.37  84.33  3740    1
## yhat_uncens[17]  25.90    0.42  24.78   1.23   9.06  19.42  35.28  88.64  3540    1
## yhat_uncens[18]  25.48    0.40  24.33   1.02   8.87  19.00  34.59  85.89  3773    1
## yhat_uncens[19]  63.92    1.13  69.44   2.33  21.72  44.73  84.32 232.87  3751    1
## yhat_uncens[20]  61.76    1.05  62.75   2.28  20.75  44.03  81.93 227.45  3589    1
## yhat_uncens[21]  65.00    1.06  66.04   2.63  22.26  46.27  85.18 241.88  3855    1
## yhat_uncens[22]  63.52    1.07  67.36   2.58  21.44  44.77  83.45 227.18  3938    1
## yhat_uncens[23]  26.29    0.43  24.85   1.15   9.47  19.71  35.30  90.71  3355    1
## log_lik[1]       -4.49    0.01   0.42  -5.51  -4.74  -4.45  -4.19  -3.81  1169    1
## log_lik[2]       -4.46    0.01   0.36  -5.35  -4.67  -4.42  -4.21  -3.88  1417    1
## log_lik[3]       -4.47    0.01   0.31  -5.23  -4.64  -4.43  -4.24  -3.98  1860    1
## log_lik[4]       -4.49    0.01   0.28  -5.18  -4.64  -4.45  -4.30  -4.06  2128    1
## log_lik[5]       -4.56    0.01   0.24  -5.13  -4.70  -4.53  -4.39  -4.18  2079    1
## log_lik[6]       -4.60    0.01   0.23  -5.14  -4.73  -4.56  -4.43  -4.22  1978    1
## log_lik[7]       -4.78    0.01   0.21  -5.24  -4.91  -4.77  -4.64  -4.42  1296    1
## log_lik[8]       -3.59    0.01   0.31  -4.30  -3.78  -3.56  -3.36  -3.07  1488    1
## log_lik[9]       -3.59    0.01   0.31  -4.30  -3.78  -3.56  -3.36  -3.07  1488    1
## log_lik[10]      -3.59    0.01   0.25  -4.15  -3.74  -3.57  -3.42  -3.18  1886    1
## log_lik[11]      -3.59    0.01   0.25  -4.15  -3.74  -3.57  -3.42  -3.18  1886    1
## log_lik[12]      -3.66    0.01   0.21  -4.13  -3.79  -3.64  -3.50  -3.31  1711    1
## log_lik[13]      -4.00    0.01   0.19  -4.40  -4.13  -3.99  -3.87  -3.66   795    1
## log_lik[14]      -4.16    0.01   0.20  -4.57  -4.30  -4.15  -4.02  -3.81   770    1
## log_lik[15]      -4.29    0.01   0.21  -4.72  -4.42  -4.28  -4.15  -3.92   829    1
## log_lik[16]      -4.42    0.01   0.22  -4.92  -4.56  -4.41  -4.27  -4.03   971    1
## log_lik[17]      -4.91    0.01   0.33  -5.73  -5.08  -4.87  -4.69  -4.41  2188    1
## log_lik[18]      -5.02    0.01   0.36  -5.90  -5.19  -4.96  -4.77  -4.48  2427    1
## log_lik[19]      -0.17    0.00   0.09  -0.38  -0.21  -0.15  -0.11  -0.04  1544    1
## log_lik[20]      -0.41    0.00   0.17  -0.81  -0.51  -0.39  -0.29  -0.14  2015    1
## log_lik[21]      -0.73    0.01   0.28  -1.38  -0.89  -0.69  -0.53  -0.28  2541    1
## log_lik[22]      -3.52    0.03   1.48  -6.96  -4.32  -3.32  -2.44  -1.29  2350    1
## log_lik[23]      -0.58    0.00   0.20  -1.04  -0.70  -0.56  -0.44  -0.27  1794    1
## lp[1]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[2]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[3]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[4]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[5]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[6]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[7]            -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[8]            -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[9]            -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[10]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[11]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[12]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[13]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[14]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[15]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[16]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[17]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[18]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp[19]           -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[20]           -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[21]           -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[22]           -5.07    0.04   1.02  -7.36  -5.71  -5.00  -4.33  -3.29   817    1
## lp[23]           -4.00    0.03   0.80  -5.68  -4.53  -3.96  -3.42  -2.57   978    1
## lp__            -82.61    0.08   2.00 -87.38 -83.72 -82.20 -81.14 -79.89   668    1
## 
## Samples were drawn using NUTS(diag_e) at Sun Oct 28 07:26:15 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

The traceplots for the parameters of interest appear to indicate reasonable mixing.

rstan::traceplot(stan_weibull_survival_model_fit, par = c("alpha","mu","beta_bg"))

Some auto-correlation is seen for the parameters of interest.

bayesplot::mcmc_acf(as.matrix(stan_weibull_survival_model_fit), pars = c("alpha","mu","beta_bg[1]"))

95% credible intervals for the parameters. The effect of interest beta_bg[1] seems to have most of its posterior probability in the negative range (survival benefit with Maintained treatment).

bayesplot::mcmc_areas(as.matrix(stan_weibull_survival_model_fit), pars = c("alpha","mu","beta_bg[1]"), prob = 0.95)

The parameter values do not give intuitive understanding of the survival time distributions for each group. As the Stan code sampled the event times for each individual, we can examine these directly.

stan_weibull_survival_model_draws <- tidybayes::tidy_draws(stan_weibull_survival_model_fit)
stan_weibull_survival_model_draws
## # A tibble: 4,000 x 86
##    .chain .iteration .draw tau_s_bg_raw `tau_bg_raw[1]` alpha_raw `beta_bg_raw[1]`    mu `beta_bg[1]` alpha
##     <int>      <int> <int>        <dbl>           <dbl>     <dbl>            <dbl> <dbl>        <dbl> <dbl>
##  1      1          1     1        7.02            2.60    0.0192           -0.0192 -4.10      -0.217  1.21 
##  2      1          2     2        7.43            2.48    0.0176           -0.0292 -4.09      -0.342  1.19 
##  3      1          3     3        2.92            7.22    0.0244           -0.162  -4.19      -1.27   1.28 
##  4      1          4     4        1.78            2.14    0.00269          -0.270  -3.58      -0.706  1.03 
##  5      1          5     5        6.27            0.926   0.00673          -0.181  -3.49      -1.09   1.07 
##  6      1          6     6        3.67            0.559  -0.00409          -0.100  -3.42      -0.275  0.960
##  7      1          7     7        3.06            0.436  -0.00534          -0.231  -3.56      -0.467  0.948
##  8      1          8     8        2.48            0.418  -0.00642          -0.271  -3.55      -0.434  0.938
##  9      1          9     9        0.464           4.59    0.00696           0.0338 -3.74       0.0336 1.07 
## 10      1         10    10        0.198           1.35    0.0115           -0.254  -3.73      -0.0585 1.12 
## # ... with 3,990 more rows, and 76 more variables: `yhat_uncens[1]` <dbl>, `yhat_uncens[2]` <dbl>,
## #   `yhat_uncens[3]` <dbl>, `yhat_uncens[4]` <dbl>, `yhat_uncens[5]` <dbl>, `yhat_uncens[6]` <dbl>,
## #   `yhat_uncens[7]` <dbl>, `yhat_uncens[8]` <dbl>, `yhat_uncens[9]` <dbl>, `yhat_uncens[10]` <dbl>,
## #   `yhat_uncens[11]` <dbl>, `yhat_uncens[12]` <dbl>, `yhat_uncens[13]` <dbl>, `yhat_uncens[14]` <dbl>,
## #   `yhat_uncens[15]` <dbl>, `yhat_uncens[16]` <dbl>, `yhat_uncens[17]` <dbl>, `yhat_uncens[18]` <dbl>,
## #   `yhat_uncens[19]` <dbl>, `yhat_uncens[20]` <dbl>, `yhat_uncens[21]` <dbl>, `yhat_uncens[22]` <dbl>,
## #   `yhat_uncens[23]` <dbl>, `log_lik[1]` <dbl>, `log_lik[2]` <dbl>, `log_lik[3]` <dbl>, `log_lik[4]` <dbl>,
## #   `log_lik[5]` <dbl>, `log_lik[6]` <dbl>, `log_lik[7]` <dbl>, `log_lik[8]` <dbl>, `log_lik[9]` <dbl>,
## #   `log_lik[10]` <dbl>, `log_lik[11]` <dbl>, `log_lik[12]` <dbl>, `log_lik[13]` <dbl>, `log_lik[14]` <dbl>,
## #   `log_lik[15]` <dbl>, `log_lik[16]` <dbl>, `log_lik[17]` <dbl>, `log_lik[18]` <dbl>, `log_lik[19]` <dbl>,
## #   `log_lik[20]` <dbl>, `log_lik[21]` <dbl>, `log_lik[22]` <dbl>, `log_lik[23]` <dbl>, `lp[1]` <dbl>,
## #   `lp[2]` <dbl>, `lp[3]` <dbl>, `lp[4]` <dbl>, `lp[5]` <dbl>, `lp[6]` <dbl>, `lp[7]` <dbl>, `lp[8]` <dbl>,
## #   `lp[9]` <dbl>, `lp[10]` <dbl>, `lp[11]` <dbl>, `lp[12]` <dbl>, `lp[13]` <dbl>, `lp[14]` <dbl>, `lp[15]` <dbl>,
## #   `lp[16]` <dbl>, `lp[17]` <dbl>, `lp[18]` <dbl>, `lp[19]` <dbl>, `lp[20]` <dbl>, `lp[21]` <dbl>, `lp[22]` <dbl>,
## #   `lp[23]` <dbl>, lp__ <dbl>, accept_stat__ <dbl>, stepsize__ <dbl>, treedepth__ <dbl>, n_leapfrog__ <dbl>,
## #   divergent__ <dbl>, energy__ <dbl>

The ordering of yhat_uncens does not respect the original data ordering, but is in the observed-then-censored ordering of data fed to Stan. We need to create the corresponding treatment vector.

treatment_assignment <- c(as.numeric(leukemia$x == "Maintained")[leukemia$status == 1],
                          as.numeric(leukemia$x == "Maintained")[leukemia$status == 0])
treatment_assignment_df <-
    data_frame(obs = 1:23,
               treatment = treatment_assignment)
treatment_assignment_df
## # A tibble: 23 x 2
##      obs treatment
##    <int>     <dbl>
##  1     1         1
##  2     2         1
##  3     3         1
##  4     4         1
##  5     5         1
##  6     6         1
##  7     7         1
##  8     8         0
##  9     9         0
## 10    10         0
## # ... with 13 more rows

The draws have to be reorganized into the long format and combined with the treatment assignment.

stan_weibull_survival_model_draws_yhat_uncens <-
    stan_weibull_survival_model_draws %>%
    select(.chain, .iteration, .draw, starts_with("yhat_uncens")) %>%
    gather(key = key, value = yhat_uncens, starts_with("yhat_uncens")) %>%
    separate(col = key, sep = "uncens", into = c("key","obs")) %>%
    select(-key) %>%
    ## Avoid using regular expressions with square brackets (syntax highlighter broke).
    ## https://stringr.tidyverse.org/articles/stringr.html
    mutate(obs = as.integer(str_sub(obs, 2, -2))) %>%
    left_join(y = treatment_assignment_df)
stan_weibull_survival_model_draws_yhat_uncens
## # A tibble: 92,000 x 6
##    .chain .iteration .draw   obs yhat_uncens treatment
##     <int>      <int> <int> <int>       <dbl>     <dbl>
##  1      1          1     1     1       45.1          1
##  2      1          2     2     1       90.2          1
##  3      1          3     3     1       63.2          1
##  4      1          4     4     1       29.2          1
##  5      1          5     5     1       13.3          1
##  6      1          6     6     1       18.9          1
##  7      1          7     7     1        2.16         1
##  8      1          8     8     1      345.           1
##  9      1          9     9     1       19.0          1
## 10      1         10    10     1      118.           1
## # ... with 91,990 more rows

Now we can plot the posterior predictive distributions of survival times for each group. This plot is comparing two conditional empirical density functions \(\hat{f}_{Y|X}(t|1)\) and \(\hat{f}_{Y|X}(t|0)\).

ggplot(data = stan_weibull_survival_model_draws_yhat_uncens,
       mapping = aes(x = yhat_uncens, color = factor(treatment))) +
    geom_density(n = 512*10) +
    coord_cartesian(xlim = c(0,160)) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

We can see the event times are shifted to the right (longer survival times) for the treated group (Maintained group).

To compare more familiar conditional survival functions we can use the Weibull survival function.

\[S(t | x) = e^{- \left( \frac{y}{\sigma_{i}} \right)^{\alpha}}\]

where \(\sigma_{i}\) is a function of \(x\).

\[\sigma_{i} = \exp{\left( - \frac{\mu + x_{i}^{T}\beta}{\alpha} \right)}\]

Therefore, for each MCMC sample of the \((\alpha, \mu, \beta)\) triplet, we will have two random survival functions. To plot these functions, we then need to evaluate these functions at various time points in [0,160].

## Constructor for treatment-specific survival function
construct_survival_function <- function(alpha, mu, beta, x) {
    function(t) {
        sigma_i <- exp(-1 * (mu + beta * x) / alpha)
        exp(- (t / sigma_i)^alpha)
    }
}

## Random functions
stan_weibull_survival_model_survival_functins <-
    stan_weibull_survival_model_draws %>%
    select(.chain, .iteration, .draw, alpha, mu, `beta_bg[1]`) %>%
    ## Simplify name
    rename(beta = `beta_bg[1]`) %>%
    ## Construct realization of random functions
    mutate(`S(t|1)` = pmap(list(alpha, mu, beta), function(a,m,b) {construct_survival_function(a,m,b,1)}),
           `S(t|0)` = pmap(list(alpha, mu, beta), function(a,m,b) {construct_survival_function(a,m,b,0)}))
stan_weibull_survival_model_survival_functins
## # A tibble: 4,000 x 8
##    .chain .iteration .draw alpha    mu    beta `S(t|1)` `S(t|0)`
##     <int>      <int> <int> <dbl> <dbl>   <dbl> <list>   <list>  
##  1      1          1     1 1.21  -4.10 -0.217  <fn>     <fn>    
##  2      1          2     2 1.19  -4.09 -0.342  <fn>     <fn>    
##  3      1          3     3 1.28  -4.19 -1.27   <fn>     <fn>    
##  4      1          4     4 1.03  -3.58 -0.706  <fn>     <fn>    
##  5      1          5     5 1.07  -3.49 -1.09   <fn>     <fn>    
##  6      1          6     6 0.960 -3.42 -0.275  <fn>     <fn>    
##  7      1          7     7 0.948 -3.56 -0.467  <fn>     <fn>    
##  8      1          8     8 0.938 -3.55 -0.434  <fn>     <fn>    
##  9      1          9     9 1.07  -3.74  0.0336 <fn>     <fn>    
## 10      1         10    10 1.12  -3.73 -0.0585 <fn>     <fn>    
## # ... with 3,990 more rows
times <- seq(from = 0, to = 160, by = 0.1)
times_df <- data_frame(t = times)

## Try first realizations
stan_weibull_survival_model_survival_functins$`S(t|1)`[[1]](times[1:10])
##  [1] 1.0000000 0.9991831 0.9981083 0.9969093 0.9956225 0.9942667 0.9928537 0.9913917 0.9898868 0.9883437
stan_weibull_survival_model_survival_functins$`S(t|0)`[[1]](times[1:10])
##  [1] 1.0000000 0.9989850 0.9976498 0.9961607 0.9945630 0.9928802 0.9911270 0.9893137 0.9874478 0.9855353
## Apply all realizations
stan_weibull_survival_model_survival <-
    stan_weibull_survival_model_survival_functins %>%
    mutate(times_df = list(times_df)) %>%
    mutate(times_df = pmap(list(times_df, `S(t|1)`, `S(t|0)`),
                           function(df, s1, s0) {df %>% mutate(s1 = s1(t),
                                                               s0 = s0(t))})) %>%
    select(-`S(t|1)`, -`S(t|0)`) %>%
    unnest() %>%
    gather(key = treatment, value = survival, s1, s0) %>%
    mutate(treatment = factor(treatment,
                              levels = c("s1","s0"),
                              labels = c("Maintained","Nonmaintained")))

## Average on survival scale
stan_weibull_survival_model_survival_mean <-
    stan_weibull_survival_model_survival %>%
    group_by(treatment, t) %>%
    summarize(survival_mean = mean(survival),
              survival_95upper = quantile(survival, probs = 0.975),
              survival_95lower = quantile(survival, probs = 0.025))

ggplot(data = stan_weibull_survival_model_survival,
       mapping = aes(x = t, y = survival, color = treatment, group = interaction(.chain,.draw,treatment))) +
    geom_line(size = 0.1, alpha = 0.02) +
    geom_line(data = stan_weibull_survival_model_survival_mean,
              mapping = aes(y = survival_mean, group = treatment)) +
    geom_line(data = stan_weibull_survival_model_survival_mean,
              mapping = aes(y = survival_95upper, group = treatment),
              linetype = "dotted") +
    geom_line(data = stan_weibull_survival_model_survival_mean,
              mapping = aes(y = survival_95lower, group = treatment),
              linetype = "dotted") +
    facet_grid(. ~ treatment) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

The space on which the average is taken can be the parameter space \((\alpha, \mu, \beta)\) or the survival space. Here we will calculate average parameter vector, and construct corresponding survival functions.

## Average on parameter space
stan_weibull_survival_model_average_parameters <-
    stan_weibull_survival_model_draws %>%
    summarize(alpha = mean(alpha),
              mu = mean(mu),
              beta = mean(`beta_bg[1]`))
stan_weibull_survival_model_average_parameters
## # A tibble: 1 x 3
##   alpha    mu  beta
##   <dbl> <dbl> <dbl>
## 1  1.23 -4.00 -1.07
stan_weibull_average_params_survival1 <- with(stan_weibull_survival_model_average_parameters,
                                              construct_survival_function(alpha, mu, beta, 1))
stan_weibull_average_params_survival0 <- with(stan_weibull_survival_model_average_parameters,
                                              construct_survival_function(alpha, mu, beta, 0))
stan_weibull_average_params_survival <-
    data_frame(t = seq(from = 0, to = 160, by = 0.1),
               s1 = stan_weibull_average_params_survival1(t),
               s0 = stan_weibull_average_params_survival0(t)) %>%
    gather(key = treatment, value = survival, -t) %>%
    mutate(treatment = factor(treatment,
                              levels = c("s1","s0"),
                              labels = c("Maintained","Nonmaintained")))

stan_weibull_average_params_survival %>%
    ggplot(mapping = aes(x = t, y = survival, color = treatment, group = treatment)) +
    geom_line() +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

Plot both of them to compare. The dotted lines are averaged in the parameter space. The solid lines are averaged on the survival scale.

ggplot(data = stan_weibull_survival_model_survival,
       mapping = aes(x = t, y = survival, color = treatment, group = interaction(.chain,.draw,treatment))) +
    geom_line(size = 0.1, alpha = 0.02) +
    geom_line(data = stan_weibull_survival_model_survival_mean,
              mapping = aes(y = survival_mean, group = treatment)) +
    geom_line(data = stan_weibull_average_params_survival,
              mapping = aes(group = treatment),
              linetype = "dotted") +
    facet_grid(. ~ treatment) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())