Survival analysis is at the core of epidemiological data analysis. There are multiple well-known Bayesian data analysis textbooks, but they typically do not cover survival analysis. Here we will showcase some R examples of Bayesian survival analysis.
library(tidyverse)
library(survminer)
## library(bayesSurv)
## devtools::install_github('jburos/biostan', build_vignettes = TRUE, dependencies = TRUE)
library(biostan)
## Warning: replacing previous import 'rstan::loo' by 'rstanarm::loo' when loading 'biostan'
library(rstan)
library(bayesplot)
library(tidybayes)
data(leukemia, package = "survival")
leukemia <- as_data_frame(leukemia)
leukemia
## # A tibble: 23 x 3
## time status x
## * <dbl> <dbl> <fct>
## 1 9 1 Maintained
## 2 13 1 Maintained
## 3 13 0 Maintained
## 4 18 1 Maintained
## 5 23 1 Maintained
## 6 28 0 Maintained
## 7 31 1 Maintained
## 8 34 1 Maintained
## 9 45 0 Maintained
## 10 48 1 Maintained
## # ... with 13 more rows
aml package:survival R Documentation
Acute Myelogenous Leukemia survival data
Description:
Survival in patients with Acute Myelogenous Leukemia. The
question at the time was whether the standard course of
chemotherapy should be extended ('maintainance') for additional
cycles.
Usage:
aml
leukemia
Format:
time: survival or censoring time
status: censoring status
x: maintenance chemotherapy given? (factor)
Source:
Rupert G. Miller (1997), _Survival Analysis_. John Wiley & Sons.
ISBN: 0-471-25218-2.
km_fit <- survfit(Surv(time, status) ~ x, data = leukemia)
## Error in survfit(Surv(time, status) ~ x, data = leukemia): could not find function "survfit"
km_fit
## Error in eval(expr, envir, enclos): object 'km_fit' not found
## http://www.sthda.com/english/wiki/survminer-0-2-4
ggsurvplot(km_fit,
conf.int = TRUE,
break.time.by = 20,
risk.table = TRUE)
## Error in ggsurvplot(km_fit, conf.int = TRUE, break.time.by = 20, risk.table = TRUE): object 'km_fit' not found
Here we will use the Weibull model code available in biostan.
stan_weibull_survival_model_file <- system.file('stan', 'weibull_survival_model.stan', package = 'biostan')
biostan::print_stan_file(stan_weibull_survival_model_file)
## /* Variable naming:
## obs = observed
## cen = (right) censored
## N = number of samples
## M = number of covariates
## bg = established risk (or protective) factors
## tau = scale parameter
## */
## // Tomi Peltola, tomi.peltola@aalto.fi
##
## functions {
## vector sqrt_vec(vector x) {
## vector[dims(x)[1]] res;
##
## for (m in 1:dims(x)[1]){
## res[m] = sqrt(x[m]);
## }
##
## return res;
## }
##
## vector bg_prior_lp(real r_global, vector r_local) {
## r_global ~ normal(0.0, 10.0);
## r_local ~ inv_chi_square(1.0);
##
## return r_global * sqrt_vec(r_local);
## }
## }
##
## data {
## int<lower=0> Nobs;
## int<lower=0> Ncen;
## int<lower=0> M_bg;
## vector[Nobs] yobs;
## vector[Ncen] ycen;
## matrix[Nobs, M_bg] Xobs_bg;
## matrix[Ncen, M_bg] Xcen_bg;
## }
##
## transformed data {
## real<lower=0> tau_mu;
## real<lower=0> tau_al;
##
## tau_mu = 10.0;
## tau_al = 10.0;
## }
##
## parameters {
## real<lower=0> tau_s_bg_raw;
## vector<lower=0>[M_bg] tau_bg_raw;
##
## real alpha_raw;
## vector[M_bg] beta_bg_raw;
##
## real mu;
## }
##
## transformed parameters {
## vector[M_bg] beta_bg;
## real alpha;
##
## beta_bg = bg_prior_lp(tau_s_bg_raw, tau_bg_raw) .* beta_bg_raw;
## alpha = exp(tau_al * alpha_raw);
## }
##
## model {
## yobs ~ weibull(alpha, exp(-(mu + Xobs_bg * beta_bg)/alpha));
## target += weibull_lccdf(ycen | alpha, exp(-(mu + Xcen_bg * beta_bg)/alpha));
##
## beta_bg_raw ~ normal(0.0, 1.0);
## alpha_raw ~ normal(0.0, 1.0);
##
## mu ~ normal(0.0, tau_mu);
## }
##
## generated quantities {
## real yhat_uncens[Nobs + Ncen];
## real log_lik[Nobs + Ncen];
## real lp[Nobs + Ncen];
##
## for (i in 1:Nobs) {
## lp[i] = mu + Xobs_bg[i,] * beta_bg;
## yhat_uncens[i] = weibull_rng(alpha, exp(-(mu + Xobs_bg[i,] * beta_bg)/alpha));
## log_lik[i] = weibull_lpdf(yobs[i] | alpha, exp(-(mu + Xobs_bg[i,] * beta_bg)/alpha));
## }
## for (i in 1:Ncen) {
## lp[Nobs + i] = mu + Xcen_bg[i,] * beta_bg;
## yhat_uncens[Nobs + i] = weibull_rng(alpha, exp(-(mu + Xcen_bg[i,] * beta_bg)/alpha));
## log_lik[Nobs + i] = weibull_lccdf(ycen[i] | alpha, exp(-(mu + Xcen_bg[i,] * beta_bg)/alpha));
## }
## }
##
Stan parameterizes this probability density function for the survival time \(y\) as follows.
\[f(y|\alpha,\sigma) = \frac{\alpha}{\sigma}\left(\frac{y}{\sigma}\right)^{\alpha-1}e^{-(y/\sigma)^{\alpha}}\]
where \(\alpha\) is the shape parameter and \(\sigma\) is the scale parameter. The average survival time increases with an increasing \(\sigma\). To incorporate covariates, the scale parameter is defined as follows in the Stan program used here.
\[\sigma_{i} = \exp{\left( - \frac{\mu + X_{i}^{T}\beta}{\alpha} \right)}\]
As \(\sigma\) is a decreasing function of \(\beta\), a positive \(\beta\) means a shorter average survival time with a unit increase in \(X_i\) and a negative \(\beta\) means a longer average survival time with a unit increase in \(X_i\).
From the data block, observations where events were observed and censored are handled separately. No hyperparameters for the priors are specified here. They are hard-coded.
stan_weibull_survival_model_code <- biostan::read_stan_file(stan_weibull_survival_model_file)
biostan::print_stan_code(stan_weibull_survival_model_code, section = "data")
## data {
## int<lower=0> Nobs;
## int<lower=0> Ncen;
## int<lower=0> M_bg;
## vector[Nobs] yobs;
## vector[Ncen] ycen;
## matrix[Nobs, M_bg] Xobs_bg;
## matrix[Ncen, M_bg] Xcen_bg;
## }
Here we structure the data accordingly.
stan_weibull_survival_model_data <-
list(
## Number of event individuals
Nobs = sum(leukemia$status == 1),
## Number of censored individuals
Ncen = sum(leukemia$status == 0),
## Number of covariates
M_bg = 1,
## Times for event individuals
yobs = leukemia$time[leukemia$status == 1],
## Times for censored individuals
ycen = leukemia$time[leukemia$status == 0],
## Covariates for event individuals as a matrix
Xobs_bg = matrix(as.numeric(leukemia$x == "Maintained")[leukemia$status == 1]),
## Covariates for censored individuals as a matrix
Xcen_bg = matrix(as.numeric(leukemia$x == "Maintained")[leukemia$status == 0])
)
stan_weibull_survival_model_data
## $Nobs
## [1] 18
##
## $Ncen
## [1] 5
##
## $M_bg
## [1] 1
##
## $yobs
## [1] 9 13 18 23 31 34 48 5 5 8 8 12 23 27 30 33 43 45
##
## $ycen
## [1] 13 28 45 161 16
##
## $Xobs_bg
## [,1]
## [1,] 1
## [2,] 1
## [3,] 1
## [4,] 1
## [5,] 1
## [6,] 1
## [7,] 1
## [8,] 0
## [9,] 0
## [10,] 0
## [11,] 0
## [12,] 0
## [13,] 0
## [14,] 0
## [15,] 0
## [16,] 0
## [17,] 0
## [18,] 0
##
## $Xcen_bg
## [,1]
## [1,] 1
## [2,] 1
## [3,] 1
## [4,] 1
## [5,] 0
Here we fit the model.
stan_weibull_survival_model_fit <-
rstan::stan(file = stan_weibull_survival_model_file,
data = stan_weibull_survival_model_data)
## Warning: There were 198 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
Here we check the results. The Rhat
values are all around 1, indicating reasonable results for all parameters. Each element of yhat_uncens
(a vector of 23 elements) is MCMC samples of event times for each individual based on \((\alpha, \sigma_{i})\) MCMC samples.
stan_weibull_survival_model_fit
## Inference for Stan model: weibull_survival_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## tau_s_bg_raw 4.05 0.16 3.86 0.18 1.26 2.70 5.81 14.22 576 1
## tau_bg_raw[1] 10.41 4.17 221.57 0.15 0.43 0.92 2.42 29.65 2826 1
## alpha_raw 0.02 0.00 0.02 -0.02 0.01 0.02 0.03 0.05 930 1
## beta_bg_raw[1] -0.52 0.02 0.49 -1.81 -0.76 -0.36 -0.17 -0.01 730 1
## mu -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## beta_bg[1] -1.07 0.01 0.55 -2.23 -1.43 -1.06 -0.70 -0.02 1786 1
## alpha 1.23 0.01 0.22 0.85 1.07 1.21 1.37 1.69 920 1
## yhat_uncens[1] 63.73 1.05 63.24 2.64 22.38 46.93 85.89 214.35 3651 1
## yhat_uncens[2] 62.50 1.05 65.60 2.64 21.21 44.82 81.67 224.04 3939 1
## yhat_uncens[3] 62.30 1.19 64.83 2.01 20.80 45.44 83.24 219.83 2987 1
## yhat_uncens[4] 63.81 1.05 66.20 2.71 21.88 47.17 84.74 221.33 4003 1
## yhat_uncens[5] 62.70 1.03 65.79 2.01 21.25 44.56 83.08 223.17 4076 1
## yhat_uncens[6] 63.49 1.08 68.60 2.49 21.20 46.32 83.08 227.42 4012 1
## yhat_uncens[7] 62.56 1.11 66.09 2.56 21.43 45.74 81.67 227.46 3561 1
## yhat_uncens[8] 25.81 0.42 25.37 0.96 9.32 19.16 35.18 87.81 3594 1
## yhat_uncens[9] 25.58 0.40 24.49 1.00 8.89 19.43 33.99 89.68 3736 1
## yhat_uncens[10] 25.31 0.40 24.36 1.11 8.93 18.75 33.87 90.30 3652 1
## yhat_uncens[11] 25.16 0.38 24.10 1.13 8.91 18.47 34.05 85.33 4007 1
## yhat_uncens[12] 25.58 0.39 24.43 1.04 8.85 19.00 34.65 88.01 3916 1
## yhat_uncens[13] 25.82 0.45 26.16 0.96 8.96 19.72 34.44 87.60 3448 1
## yhat_uncens[14] 26.12 0.40 25.22 0.94 8.87 19.32 34.94 92.70 3906 1
## yhat_uncens[15] 25.81 0.41 25.33 1.12 9.23 19.08 34.55 89.76 3839 1
## yhat_uncens[16] 25.29 0.39 23.94 0.98 8.98 18.94 34.37 84.33 3740 1
## yhat_uncens[17] 25.90 0.42 24.78 1.23 9.06 19.42 35.28 88.64 3540 1
## yhat_uncens[18] 25.48 0.40 24.33 1.02 8.87 19.00 34.59 85.89 3773 1
## yhat_uncens[19] 63.92 1.13 69.44 2.33 21.72 44.73 84.32 232.87 3751 1
## yhat_uncens[20] 61.76 1.05 62.75 2.28 20.75 44.03 81.93 227.45 3589 1
## yhat_uncens[21] 65.00 1.06 66.04 2.63 22.26 46.27 85.18 241.88 3855 1
## yhat_uncens[22] 63.52 1.07 67.36 2.58 21.44 44.77 83.45 227.18 3938 1
## yhat_uncens[23] 26.29 0.43 24.85 1.15 9.47 19.71 35.30 90.71 3355 1
## log_lik[1] -4.49 0.01 0.42 -5.51 -4.74 -4.45 -4.19 -3.81 1169 1
## log_lik[2] -4.46 0.01 0.36 -5.35 -4.67 -4.42 -4.21 -3.88 1417 1
## log_lik[3] -4.47 0.01 0.31 -5.23 -4.64 -4.43 -4.24 -3.98 1860 1
## log_lik[4] -4.49 0.01 0.28 -5.18 -4.64 -4.45 -4.30 -4.06 2128 1
## log_lik[5] -4.56 0.01 0.24 -5.13 -4.70 -4.53 -4.39 -4.18 2079 1
## log_lik[6] -4.60 0.01 0.23 -5.14 -4.73 -4.56 -4.43 -4.22 1978 1
## log_lik[7] -4.78 0.01 0.21 -5.24 -4.91 -4.77 -4.64 -4.42 1296 1
## log_lik[8] -3.59 0.01 0.31 -4.30 -3.78 -3.56 -3.36 -3.07 1488 1
## log_lik[9] -3.59 0.01 0.31 -4.30 -3.78 -3.56 -3.36 -3.07 1488 1
## log_lik[10] -3.59 0.01 0.25 -4.15 -3.74 -3.57 -3.42 -3.18 1886 1
## log_lik[11] -3.59 0.01 0.25 -4.15 -3.74 -3.57 -3.42 -3.18 1886 1
## log_lik[12] -3.66 0.01 0.21 -4.13 -3.79 -3.64 -3.50 -3.31 1711 1
## log_lik[13] -4.00 0.01 0.19 -4.40 -4.13 -3.99 -3.87 -3.66 795 1
## log_lik[14] -4.16 0.01 0.20 -4.57 -4.30 -4.15 -4.02 -3.81 770 1
## log_lik[15] -4.29 0.01 0.21 -4.72 -4.42 -4.28 -4.15 -3.92 829 1
## log_lik[16] -4.42 0.01 0.22 -4.92 -4.56 -4.41 -4.27 -4.03 971 1
## log_lik[17] -4.91 0.01 0.33 -5.73 -5.08 -4.87 -4.69 -4.41 2188 1
## log_lik[18] -5.02 0.01 0.36 -5.90 -5.19 -4.96 -4.77 -4.48 2427 1
## log_lik[19] -0.17 0.00 0.09 -0.38 -0.21 -0.15 -0.11 -0.04 1544 1
## log_lik[20] -0.41 0.00 0.17 -0.81 -0.51 -0.39 -0.29 -0.14 2015 1
## log_lik[21] -0.73 0.01 0.28 -1.38 -0.89 -0.69 -0.53 -0.28 2541 1
## log_lik[22] -3.52 0.03 1.48 -6.96 -4.32 -3.32 -2.44 -1.29 2350 1
## log_lik[23] -0.58 0.00 0.20 -1.04 -0.70 -0.56 -0.44 -0.27 1794 1
## lp[1] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[2] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[3] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[4] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[5] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[6] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[7] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[8] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[9] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[10] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[11] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[12] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[13] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[14] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[15] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[16] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[17] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[18] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp[19] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[20] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[21] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[22] -5.07 0.04 1.02 -7.36 -5.71 -5.00 -4.33 -3.29 817 1
## lp[23] -4.00 0.03 0.80 -5.68 -4.53 -3.96 -3.42 -2.57 978 1
## lp__ -82.61 0.08 2.00 -87.38 -83.72 -82.20 -81.14 -79.89 668 1
##
## Samples were drawn using NUTS(diag_e) at Sun Oct 28 07:26:15 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
The traceplots for the parameters of interest appear to indicate reasonable mixing.
rstan::traceplot(stan_weibull_survival_model_fit, par = c("alpha","mu","beta_bg"))
Some auto-correlation is seen for the parameters of interest.
bayesplot::mcmc_acf(as.matrix(stan_weibull_survival_model_fit), pars = c("alpha","mu","beta_bg[1]"))
95% credible intervals for the parameters. The effect of interest beta_bg[1]
seems to have most of its posterior probability in the negative range (survival benefit with Maintained treatment).
bayesplot::mcmc_areas(as.matrix(stan_weibull_survival_model_fit), pars = c("alpha","mu","beta_bg[1]"), prob = 0.95)
The parameter values do not give intuitive understanding of the survival time distributions for each group. As the Stan code sampled the event times for each individual, we can examine these directly.
stan_weibull_survival_model_draws <- tidybayes::tidy_draws(stan_weibull_survival_model_fit)
stan_weibull_survival_model_draws
## # A tibble: 4,000 x 86
## .chain .iteration .draw tau_s_bg_raw `tau_bg_raw[1]` alpha_raw `beta_bg_raw[1]` mu `beta_bg[1]` alpha
## <int> <int> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 1 1 1 7.02 2.60 0.0192 -0.0192 -4.10 -0.217 1.21
## 2 1 2 2 7.43 2.48 0.0176 -0.0292 -4.09 -0.342 1.19
## 3 1 3 3 2.92 7.22 0.0244 -0.162 -4.19 -1.27 1.28
## 4 1 4 4 1.78 2.14 0.00269 -0.270 -3.58 -0.706 1.03
## 5 1 5 5 6.27 0.926 0.00673 -0.181 -3.49 -1.09 1.07
## 6 1 6 6 3.67 0.559 -0.00409 -0.100 -3.42 -0.275 0.960
## 7 1 7 7 3.06 0.436 -0.00534 -0.231 -3.56 -0.467 0.948
## 8 1 8 8 2.48 0.418 -0.00642 -0.271 -3.55 -0.434 0.938
## 9 1 9 9 0.464 4.59 0.00696 0.0338 -3.74 0.0336 1.07
## 10 1 10 10 0.198 1.35 0.0115 -0.254 -3.73 -0.0585 1.12
## # ... with 3,990 more rows, and 76 more variables: `yhat_uncens[1]` <dbl>, `yhat_uncens[2]` <dbl>,
## # `yhat_uncens[3]` <dbl>, `yhat_uncens[4]` <dbl>, `yhat_uncens[5]` <dbl>, `yhat_uncens[6]` <dbl>,
## # `yhat_uncens[7]` <dbl>, `yhat_uncens[8]` <dbl>, `yhat_uncens[9]` <dbl>, `yhat_uncens[10]` <dbl>,
## # `yhat_uncens[11]` <dbl>, `yhat_uncens[12]` <dbl>, `yhat_uncens[13]` <dbl>, `yhat_uncens[14]` <dbl>,
## # `yhat_uncens[15]` <dbl>, `yhat_uncens[16]` <dbl>, `yhat_uncens[17]` <dbl>, `yhat_uncens[18]` <dbl>,
## # `yhat_uncens[19]` <dbl>, `yhat_uncens[20]` <dbl>, `yhat_uncens[21]` <dbl>, `yhat_uncens[22]` <dbl>,
## # `yhat_uncens[23]` <dbl>, `log_lik[1]` <dbl>, `log_lik[2]` <dbl>, `log_lik[3]` <dbl>, `log_lik[4]` <dbl>,
## # `log_lik[5]` <dbl>, `log_lik[6]` <dbl>, `log_lik[7]` <dbl>, `log_lik[8]` <dbl>, `log_lik[9]` <dbl>,
## # `log_lik[10]` <dbl>, `log_lik[11]` <dbl>, `log_lik[12]` <dbl>, `log_lik[13]` <dbl>, `log_lik[14]` <dbl>,
## # `log_lik[15]` <dbl>, `log_lik[16]` <dbl>, `log_lik[17]` <dbl>, `log_lik[18]` <dbl>, `log_lik[19]` <dbl>,
## # `log_lik[20]` <dbl>, `log_lik[21]` <dbl>, `log_lik[22]` <dbl>, `log_lik[23]` <dbl>, `lp[1]` <dbl>,
## # `lp[2]` <dbl>, `lp[3]` <dbl>, `lp[4]` <dbl>, `lp[5]` <dbl>, `lp[6]` <dbl>, `lp[7]` <dbl>, `lp[8]` <dbl>,
## # `lp[9]` <dbl>, `lp[10]` <dbl>, `lp[11]` <dbl>, `lp[12]` <dbl>, `lp[13]` <dbl>, `lp[14]` <dbl>, `lp[15]` <dbl>,
## # `lp[16]` <dbl>, `lp[17]` <dbl>, `lp[18]` <dbl>, `lp[19]` <dbl>, `lp[20]` <dbl>, `lp[21]` <dbl>, `lp[22]` <dbl>,
## # `lp[23]` <dbl>, lp__ <dbl>, accept_stat__ <dbl>, stepsize__ <dbl>, treedepth__ <dbl>, n_leapfrog__ <dbl>,
## # divergent__ <dbl>, energy__ <dbl>
The ordering of yhat_uncens
does not respect the original data ordering, but is in the observed-then-censored ordering of data fed to Stan. We need to create the corresponding treatment vector.
treatment_assignment <- c(as.numeric(leukemia$x == "Maintained")[leukemia$status == 1],
as.numeric(leukemia$x == "Maintained")[leukemia$status == 0])
treatment_assignment_df <-
data_frame(obs = 1:23,
treatment = treatment_assignment)
treatment_assignment_df
## # A tibble: 23 x 2
## obs treatment
## <int> <dbl>
## 1 1 1
## 2 2 1
## 3 3 1
## 4 4 1
## 5 5 1
## 6 6 1
## 7 7 1
## 8 8 0
## 9 9 0
## 10 10 0
## # ... with 13 more rows
The draws have to be reorganized into the long format and combined with the treatment assignment.
stan_weibull_survival_model_draws_yhat_uncens <-
stan_weibull_survival_model_draws %>%
select(.chain, .iteration, .draw, starts_with("yhat_uncens")) %>%
gather(key = key, value = yhat_uncens, starts_with("yhat_uncens")) %>%
separate(col = key, sep = "uncens", into = c("key","obs")) %>%
select(-key) %>%
## Avoid using regular expressions with square brackets (syntax highlighter broke).
## https://stringr.tidyverse.org/articles/stringr.html
mutate(obs = as.integer(str_sub(obs, 2, -2))) %>%
left_join(y = treatment_assignment_df)
stan_weibull_survival_model_draws_yhat_uncens
## # A tibble: 92,000 x 6
## .chain .iteration .draw obs yhat_uncens treatment
## <int> <int> <int> <int> <dbl> <dbl>
## 1 1 1 1 1 45.1 1
## 2 1 2 2 1 90.2 1
## 3 1 3 3 1 63.2 1
## 4 1 4 4 1 29.2 1
## 5 1 5 5 1 13.3 1
## 6 1 6 6 1 18.9 1
## 7 1 7 7 1 2.16 1
## 8 1 8 8 1 345. 1
## 9 1 9 9 1 19.0 1
## 10 1 10 10 1 118. 1
## # ... with 91,990 more rows
Now we can plot the posterior predictive distributions of survival times for each group. This plot is comparing two conditional empirical density functions \(\hat{f}_{Y|X}(t|1)\) and \(\hat{f}_{Y|X}(t|0)\).
ggplot(data = stan_weibull_survival_model_draws_yhat_uncens,
mapping = aes(x = yhat_uncens, color = factor(treatment))) +
geom_density(n = 512*10) +
coord_cartesian(xlim = c(0,160)) +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
We can see the event times are shifted to the right (longer survival times) for the treated group (Maintained group).
To compare more familiar conditional survival functions we can use the Weibull survival function.
\[S(t | x) = e^{- \left( \frac{y}{\sigma_{i}} \right)^{\alpha}}\]
where \(\sigma_{i}\) is a function of \(x\).
\[\sigma_{i} = \exp{\left( - \frac{\mu + x_{i}^{T}\beta}{\alpha} \right)}\]
Therefore, for each MCMC sample of the \((\alpha, \mu, \beta)\) triplet, we will have two random survival functions. To plot these functions, we then need to evaluate these functions at various time points in [0,160].
## Constructor for treatment-specific survival function
construct_survival_function <- function(alpha, mu, beta, x) {
function(t) {
sigma_i <- exp(-1 * (mu + beta * x) / alpha)
exp(- (t / sigma_i)^alpha)
}
}
## Random functions
stan_weibull_survival_model_survival_functins <-
stan_weibull_survival_model_draws %>%
select(.chain, .iteration, .draw, alpha, mu, `beta_bg[1]`) %>%
## Simplify name
rename(beta = `beta_bg[1]`) %>%
## Construct realization of random functions
mutate(`S(t|1)` = pmap(list(alpha, mu, beta), function(a,m,b) {construct_survival_function(a,m,b,1)}),
`S(t|0)` = pmap(list(alpha, mu, beta), function(a,m,b) {construct_survival_function(a,m,b,0)}))
stan_weibull_survival_model_survival_functins
## # A tibble: 4,000 x 8
## .chain .iteration .draw alpha mu beta `S(t|1)` `S(t|0)`
## <int> <int> <int> <dbl> <dbl> <dbl> <list> <list>
## 1 1 1 1 1.21 -4.10 -0.217 <fn> <fn>
## 2 1 2 2 1.19 -4.09 -0.342 <fn> <fn>
## 3 1 3 3 1.28 -4.19 -1.27 <fn> <fn>
## 4 1 4 4 1.03 -3.58 -0.706 <fn> <fn>
## 5 1 5 5 1.07 -3.49 -1.09 <fn> <fn>
## 6 1 6 6 0.960 -3.42 -0.275 <fn> <fn>
## 7 1 7 7 0.948 -3.56 -0.467 <fn> <fn>
## 8 1 8 8 0.938 -3.55 -0.434 <fn> <fn>
## 9 1 9 9 1.07 -3.74 0.0336 <fn> <fn>
## 10 1 10 10 1.12 -3.73 -0.0585 <fn> <fn>
## # ... with 3,990 more rows
times <- seq(from = 0, to = 160, by = 0.1)
times_df <- data_frame(t = times)
## Try first realizations
stan_weibull_survival_model_survival_functins$`S(t|1)`[[1]](times[1:10])
## [1] 1.0000000 0.9991831 0.9981083 0.9969093 0.9956225 0.9942667 0.9928537 0.9913917 0.9898868 0.9883437
stan_weibull_survival_model_survival_functins$`S(t|0)`[[1]](times[1:10])
## [1] 1.0000000 0.9989850 0.9976498 0.9961607 0.9945630 0.9928802 0.9911270 0.9893137 0.9874478 0.9855353
## Apply all realizations
stan_weibull_survival_model_survival <-
stan_weibull_survival_model_survival_functins %>%
mutate(times_df = list(times_df)) %>%
mutate(times_df = pmap(list(times_df, `S(t|1)`, `S(t|0)`),
function(df, s1, s0) {df %>% mutate(s1 = s1(t),
s0 = s0(t))})) %>%
select(-`S(t|1)`, -`S(t|0)`) %>%
unnest() %>%
gather(key = treatment, value = survival, s1, s0) %>%
mutate(treatment = factor(treatment,
levels = c("s1","s0"),
labels = c("Maintained","Nonmaintained")))
## Average on survival scale
stan_weibull_survival_model_survival_mean <-
stan_weibull_survival_model_survival %>%
group_by(treatment, t) %>%
summarize(survival_mean = mean(survival),
survival_95upper = quantile(survival, probs = 0.975),
survival_95lower = quantile(survival, probs = 0.025))
ggplot(data = stan_weibull_survival_model_survival,
mapping = aes(x = t, y = survival, color = treatment, group = interaction(.chain,.draw,treatment))) +
geom_line(size = 0.1, alpha = 0.02) +
geom_line(data = stan_weibull_survival_model_survival_mean,
mapping = aes(y = survival_mean, group = treatment)) +
geom_line(data = stan_weibull_survival_model_survival_mean,
mapping = aes(y = survival_95upper, group = treatment),
linetype = "dotted") +
geom_line(data = stan_weibull_survival_model_survival_mean,
mapping = aes(y = survival_95lower, group = treatment),
linetype = "dotted") +
facet_grid(. ~ treatment) +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
The space on which the average is taken can be the parameter space \((\alpha, \mu, \beta)\) or the survival space. Here we will calculate average parameter vector, and construct corresponding survival functions.
## Average on parameter space
stan_weibull_survival_model_average_parameters <-
stan_weibull_survival_model_draws %>%
summarize(alpha = mean(alpha),
mu = mean(mu),
beta = mean(`beta_bg[1]`))
stan_weibull_survival_model_average_parameters
## # A tibble: 1 x 3
## alpha mu beta
## <dbl> <dbl> <dbl>
## 1 1.23 -4.00 -1.07
stan_weibull_average_params_survival1 <- with(stan_weibull_survival_model_average_parameters,
construct_survival_function(alpha, mu, beta, 1))
stan_weibull_average_params_survival0 <- with(stan_weibull_survival_model_average_parameters,
construct_survival_function(alpha, mu, beta, 0))
stan_weibull_average_params_survival <-
data_frame(t = seq(from = 0, to = 160, by = 0.1),
s1 = stan_weibull_average_params_survival1(t),
s0 = stan_weibull_average_params_survival0(t)) %>%
gather(key = treatment, value = survival, -t) %>%
mutate(treatment = factor(treatment,
levels = c("s1","s0"),
labels = c("Maintained","Nonmaintained")))
stan_weibull_average_params_survival %>%
ggplot(mapping = aes(x = t, y = survival, color = treatment, group = treatment)) +
geom_line() +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())
Plot both of them to compare. The dotted lines are averaged in the parameter space. The solid lines are averaged on the survival scale.
ggplot(data = stan_weibull_survival_model_survival,
mapping = aes(x = t, y = survival, color = treatment, group = interaction(.chain,.draw,treatment))) +
geom_line(size = 0.1, alpha = 0.02) +
geom_line(data = stan_weibull_survival_model_survival_mean,
mapping = aes(y = survival_mean, group = treatment)) +
geom_line(data = stan_weibull_average_params_survival,
mapping = aes(group = treatment),
linetype = "dotted") +
facet_grid(. ~ treatment) +
theme_bw() +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
legend.key = element_blank(),
plot.title = element_text(hjust = 0.5),
strip.background = element_blank())