References

Load packages

library(tidyverse)
library(survival)
library(rstan)
set.seed(13960043)

Load and prepare dataset

aml                  package:survival                  R Documentation
Acute Myelogenous Leukemia survival data
Description:
     Survival in patients with Acute Myelogenous Leukemia.  The
     question at the time was whether the standard course of
     chemotherapy should be extended ('maintainance') for additional
     cycles.
Usage:
     aml
     leukemia
Format:
       time:    survival or censoring time
       status:  censoring status
       x:       maintenance chemotherapy given? (factor)
Source:
     Rupert G. Miller (1997), _Survival Analysis_.  John Wiley & Sons.
     ISBN: 0-471-25218-2.
data(leukemia, package = "survival")
leukemia <- as_tibble(leukemia) %>%
    mutate(id = seq_len(n())) %>%
    select(id, everything())
leukemia
## # A tibble: 23 x 4
##       id  time status x         
##    <int> <dbl>  <dbl> <fct>     
##  1     1     9      1 Maintained
##  2     2    13      1 Maintained
##  3     3    13      0 Maintained
##  4     4    18      1 Maintained
##  5     5    23      1 Maintained
##  6     6    28      0 Maintained
##  7     7    31      1 Maintained
##  8     8    34      1 Maintained
##  9     9    45      0 Maintained
## 10    10    48      1 Maintained
## # … with 13 more rows

Check distribution of event times

leukemia_summary <- leukemia %>%
    filter(status == 1) %>%
    summarize(n = n(),
              mean_time = mean(time),
              quantiles = list(quantile(time, probs = seq(from = 0, to = 1, by = 0.2)))) %>%
    unnest()
leukemia_summary
## # A tibble: 6 x 3
##       n mean_time quantiles
##   <int>     <dbl>     <dbl>
## 1    18      23.1       5  
## 2    18      23.1       8.4
## 3    18      23.1      17. 
## 4    18      23.1      27.6
## 5    18      23.1      33.6
## 6    18      23.1      48

Frequentist fit as a references

coxph1 <- coxph(formula = Surv(time, status) ~ as.integer(x == "Maintained"),
                data    = leukemia,
                ties    = c("efron","breslow","exact")[1])
summary(coxph1)
## Call:
## coxph(formula = Surv(time, status) ~ as.integer(x == "Maintained"), 
##     data = leukemia, ties = c("efron", "breslow", "exact")[1])
## 
##   n= 23, number of events= 18 
## 
##                                  coef exp(coef) se(coef)      z Pr(>|z|)  
## as.integer(x == "Maintained") -0.9155    0.4003   0.5119 -1.788   0.0737 .
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
##                               exp(coef) exp(-coef) lower .95 upper .95
## as.integer(x == "Maintained")    0.4003      2.498    0.1468     1.092
## 
## Concordance= 0.619  (se = 0.063 )
## Likelihood ratio test= 3.38  on 1 df,   p=0.07
## Wald test            = 3.2  on 1 df,   p=0.07
## Score (logrank) test = 3.42  on 1 df,   p=0.06

Model fitting

## Load and compile
piecewise_ph_model <- rstan::stan_model("./bayesianideas_piecewise2.stan")
piecewise_ph_model
## S4 class stanmodel 'bayesianideas_piecewise2' coded as follows:
## data {
##     // Hypeparameters for lambda[1]
##     real<lower=0> lambda1_mean;
##     real<lower=0> lambda1_length_w;
##     // Hyperparameter for lambda[k]
##     real<lower=0> w;
##     real<lower=0> lambda_star;
##     // Hyperparameter for beta
##     real beta_mean;
##     real<lower=0> beta_sd;
##     // Number of pieces
##     int<lower=0> K;
##     // Cutopoints on time
##     //  cutpoints[1] = 0
##     //  max(event time) < cutpoints[K+1] < Inf
##     //  K+1 elements
##     real cutpoints[K+1];
##     //
##     int<lower=0> N;
##     int<lower=0,upper=1> cens[N];
##     real y[N];
##     int<lower=0,upper=1> x[N];
##     //
##     // grids for evaluating posterior predictions
##     int<lower=0> grid_size;
##     real grid[grid_size];
## }
## 
## transformed data {
## }
## 
## parameters {
##     // Baseline hazards
##     real<lower=0> lambda[K];
##     // Effect of group
##     real beta;
## }
## 
## transformed parameters {
## }
## 
## model {
##     // Prior on beta
##     target += normal_lpdf(beta | beta_mean, beta_sd);
## 
##     // Loop over pieces of time
##     for (k in 1:K) {
##         // k = 1,2,...,K
##         // cutpoints[1] = 0
##         // cutpoints[K+1] > max event time
##         real length = cutpoints[k+1] - cutpoints[k];
## 
##         // Prior on lambda
##         // BIDA 13.2.5 Priors for lambda
##         if (k == 1) {
##             // The first interval requires special handling.
##             target += gamma_lpdf(lambda[1] | lambda1_mean * lambda1_length_w, lambda1_length_w);
##         } else {
##             // Mean lambda_star
##             target += gamma_lpdf(lambda[k] | lambda_star * length * w, length * w);
##         }
## 
##         // Likelihood contribution
##         // BIDA 13.2.3 Likelihood for piecewise hazard PH model
##         for (i in 1:N) {
##             // Linear predictor
##             real lp = beta * x[i];
##             // Everyone will contribute to the survival part.
##             if (y[i] >= cutpoints[k+1]) {
##                 // If surviving beyond the end of the interval,
##                 // contribute survival throughout the interval.
##                 target += -exp(lp) * (lambda[k] * length);
##                 //
##             } else if (cutpoints[k] <= y[i] && y[i] < cutpoints[k+1]) {
##                 // If ending follow up during the interval,
##                 // contribute survival until the end of follow up.
##                 target += -exp(lp) * (lambda[k] * (y[i] - cutpoints[k]));
##                 //
##                 // Event individuals also contribute to the hazard part.
##                 if (cens[i] == 1) {
##                     target += lp + log(lambda[k]);
##                 }
##             } else {
##                 // If having ended follow up before this interval,
##                 // no contribution in this interval.
##             }
##         }
##     }
## }
## 
## generated quantities {
##     // Hazard function evaluated at grid points
##     real<lower=0> h_grid[grid_size];
##     // Cumulative hazard function at grid points
##     real<lower=0> H_grid[grid_size];
##     // Survival function at grid points
##     real<lower=0> S_grid[grid_size];
##     // Time zero cumulative hazard should be zero.
##     H_grid[1] = 0;
## 
##     // Loop over grid points
##     for (g in 1:grid_size) {
##         // Loop over cutpoints
##         for (k in 1:K) {
##             // At each k, hazard is constant at lambda[k]
##             if (cutpoints[k] <= grid[g] && grid[g] < cutpoints[k+1]) {
##                 h_grid[g] = lambda[k];
##                 break;
##             }
##         }
##         // Set grid points beyond the last time cutoff to zeros.
##         if (grid[g] >= cutpoints[K+1]) {
##             h_grid[g] = 0;
##         }
##         // Cumulative hazard
##         if (g > 1) {
##             // This double loop is very inefficient.
##             // Index starts at 2!
##             for (gg in 2:g) {
##                 // Width between current grid points
##                 real width = grid[gg] - grid[gg-1];
##                 // Width x hazard value at first grid point.
##                 // This is approximation and is incorrect for grid points
##                 // between which the hazard changes.
##                 // Previous cumulative + current contribution.
##                 H_grid[g] = H_grid[g-1] + (width * h_grid[gg-1]);
##             }
##         }
##         // Survival
##         S_grid[g] = exp(-H_grid[g]);
##     }
## }

Time cutoffs at 20% quantiles

## Cutpoints
cutpoints_20 <- as.numeric(leukemia_summary$quantiles)
## First cutpoint should be time 0.
cutpoints_20[1] <- 0
## Last cutpoint should be larger than the maximum failure time.
cutpoints_20[length(cutpoints_20)] <- cutpoints_20[length(cutpoints_20)] + 1
## Show
cutpoints_20
## [1]  0.0  8.4 17.0 27.6 33.6 49.0
## Evaluation grid
grid <- seq(from = 0, to = max(leukemia_summary$quantiles), by = 0.1)
piecewise_ph_sample <-
    rstan::sampling(object = piecewise_ph_model,
                    data = list(lambda1_mean = 0.01,
                                lambda1_length_w = 10^4,
                                w = 0.01,
                                lambda_star = 0.05,
                                beta_mean = 0,
                                beta_sd = 100,
                                K = length(cutpoints_20) - 1,
                                cutpoints = cutpoints_20,
                                N = length(leukemia$time),
                                cens = leukemia$status,
                                y = leukemia$time,
                                x = as.integer(leukemia$x == "Maintained"),
                                grid_size = length(grid),
                                grid = grid))
print(piecewise_ph_sample, pars = c("lambda","beta","lp__"))
## Inference for Stan model: bayesianideas_piecewise2.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##              mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
## lambda[1]    0.01    0.00 0.00    0.01    0.01    0.01    0.01    0.01  5420    1
## lambda[2]    0.03    0.00 0.02    0.01    0.02    0.02    0.04    0.07  5482    1
## lambda[3]    0.04    0.00 0.02    0.01    0.02    0.04    0.05    0.10  4491    1
## lambda[4]    0.08    0.00 0.05    0.02    0.05    0.07    0.11    0.21  4673    1
## lambda[5]    0.09    0.00 0.05    0.02    0.05    0.08    0.12    0.22  4389    1
## beta        -0.55    0.01 0.52   -1.55   -0.89   -0.56   -0.19    0.43  3285    1
## lp__      -107.10    0.04 1.72 -111.22 -108.07 -106.80 -105.83 -104.67  1577    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Jun 10 22:56:14 2019.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).
traceplot(piecewise_ph_sample, inc_warmup = TRUE, pars = c("lambda","beta","lp__"))

traceplot(piecewise_ph_sample, inc_warmup = FALSE, pars = c("lambda","beta","lp__"))

Survival estimate for the Nonmaintained group.

piecewise_ph_S_sample <- piecewise_ph_sample %>%
    as.matrix(pars = "S_grid") %>%
    as_tibble()
names(piecewise_ph_S_sample) <- as.character(grid)
piecewise_ph_S_sample %>%
    mutate(iter = seq_len(n())) %>%
    gather(key = time, value = survival, -iter) %>%
    mutate(time = as.numeric(time)) %>%
    filter(iter %in% sample(1:max(iter), size = 500)) %>%
    ##
    ggplot(mapping = aes(x = time, y = survival, group = iter)) +
    geom_line(alpha = 0.1) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())

Time cutoffs at all event times

cutpoints_all <- unique(sort(leukemia$time[leukemia$status == 1]))
cutpoints_all <- c(0, cutpoints_all, max(cutpoints_all)+1)
cutpoints_all
##  [1]  0  5  8  9 12 13 18 23 27 30 31 33 34 43 45 48 49
piecewise_ph_all_sample <-
    sampling(object = piecewise_ph_model,
             data = list(lambda1_mean = 0.01,
                         lambda1_length_w = 10^4,
                         w = 0.01,
                         lambda_star = 0.05,
                         beta_mean = 0,
                         beta_sd = 100,
                         K = length(cutpoints_all) - 1,
                         cutpoints = cutpoints_all,
                         N = length(leukemia$time),
                         cens = leukemia$status,
                         y = leukemia$time,
                         x = as.integer(leukemia$x == "Maintained"),
                         grid_size = length(grid),
                         grid = grid))
print(piecewise_ph_all_sample, pars = c("lambda","beta","lp__"))
## Inference for Stan model: bayesianideas_piecewise2.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##               mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
## lambda[1]     0.01    0.00 0.00    0.01    0.01    0.01    0.01    0.01  4540    1
## lambda[2]     0.05    0.00 0.04    0.01    0.02    0.04    0.07    0.15  4946    1
## lambda[3]     0.18    0.00 0.13    0.02    0.08    0.15    0.23    0.52  4178    1
## lambda[4]     0.03    0.00 0.03    0.00    0.01    0.02    0.04    0.12  4926    1
## lambda[5]     0.10    0.00 0.10    0.00    0.03    0.07    0.14    0.38  5763    1
## lambda[6]     0.02    0.00 0.02    0.00    0.01    0.02    0.03    0.09  5823    1
## lambda[7]     0.02    0.00 0.02    0.00    0.01    0.02    0.03    0.09  5276    1
## lambda[8]     0.07    0.00 0.05    0.01    0.03    0.06    0.10    0.21  4601    1
## lambda[9]     0.06    0.00 0.06    0.00    0.02    0.04    0.08    0.21  4772    1
## lambda[10]    0.22    0.00 0.23    0.00    0.06    0.14    0.30    0.84  5274    1
## lambda[11]    0.11    0.00 0.12    0.00    0.03    0.08    0.15    0.43  4249    1
## lambda[12]    0.31    0.00 0.31    0.01    0.08    0.21    0.43    1.16  4620    1
## lambda[13]    0.04    0.00 0.04    0.00    0.01    0.03    0.05    0.14  4592    1
## lambda[14]    0.26    0.00 0.27    0.01    0.07    0.18    0.37    1.00  4562    1
## lambda[15]    0.66    0.01 0.89    0.01    0.14    0.37    0.83    2.95  4005    1
## lambda[16]    3.87    0.09 5.12    0.08    0.89    2.23    4.96   17.37  3494    1
## beta         -1.23    0.01 0.55   -2.35   -1.59   -1.20   -0.86   -0.17  3203    1
## lp__       -182.23    0.09 3.33 -189.80 -184.24 -181.79 -179.81 -176.98  1372    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Jun 10 22:56:52 2019.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).
traceplot(piecewise_ph_all_sample, inc_warmup = TRUE, pars = c("lambda","beta","lp__"))

traceplot(piecewise_ph_all_sample, inc_warmup = FALSE, pars = c("lambda","beta","lp__"))

Survival estimate for the Nonmaintained group.

piecewise_ph_all_S_sample <- piecewise_ph_all_sample %>%
    as.matrix(pars = "S_grid") %>%
    as_tibble()
names(piecewise_ph_all_S_sample) <- as.character(grid)
piecewise_ph_all_S_sample %>%
    mutate(iter = seq_len(n())) %>%
    gather(key = time, value = survival, -iter) %>%
    mutate(time = as.numeric(time)) %>%
    filter(iter %in% sample(1:max(iter), size = 500)) %>%
    ##
    ggplot(mapping = aes(x = time, y = survival, group = iter)) +
    geom_line(alpha = 0.1) +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5),
          legend.key = element_blank(),
          plot.title = element_text(hjust = 0.5),
          strip.background = element_blank())


print(sessionInfo())
## R version 3.6.0 (2019-04-26)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS Mojave 10.14.5
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] parallel  stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] rstan_2.18.2       StanHeaders_2.18.1 survival_2.44-1.1  forcats_0.4.0      stringr_1.4.0     
##  [6] dplyr_0.8.1        purrr_0.3.2        readr_1.3.1        tidyr_0.8.3        tibble_2.1.3      
## [11] ggplot2_3.1.1      tidyverse_1.2.1    doRNG_1.7.1        rngtools_1.3.1.1   pkgmaker_0.27     
## [16] registry_0.5-1     doParallel_1.0.14  iterators_1.0.10   foreach_1.4.4      knitr_1.23        
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.0         jsonlite_1.6       splines_3.6.0      modelr_0.1.4       assertthat_0.2.1  
##  [6] stats4_3.6.0       cellranger_1.1.0   yaml_2.2.0         pillar_1.4.1       backports_1.1.4   
## [11] lattice_0.20-38    glue_1.3.1         digest_0.6.19      rvest_0.3.4        colorspace_1.4-1  
## [16] htmltools_0.3.6    Matrix_1.2-17      plyr_1.8.4         pkgconfig_2.0.2    bibtex_0.4.2      
## [21] broom_0.5.2        haven_2.1.0        xtable_1.8-4       scales_1.0.0       processx_3.3.1    
## [26] generics_0.0.2     withr_2.1.2        lazyeval_0.2.2     cli_1.1.0          magrittr_1.5      
## [31] crayon_1.3.4       readxl_1.3.1       evaluate_0.14      ps_1.3.0           fansi_0.4.0       
## [36] nlme_3.1-140       xml2_1.2.0         pkgbuild_1.0.3     tools_3.6.0        loo_2.1.0         
## [41] prettyunits_1.0.2  hms_0.4.2          matrixStats_0.54.0 munsell_0.5.0      callr_3.2.0       
## [46] compiler_3.6.0     rlang_0.3.4        grid_3.6.0         rstudioapi_0.10    labeling_0.3      
## [51] rmarkdown_1.13     gtable_0.3.0       codetools_0.2-16   inline_0.3.15      R6_2.4.0          
## [56] gridExtra_2.3      lubridate_1.7.4    zeallot_0.1.0      utf8_1.1.4         stringi_1.4.3     
## [61] Rcpp_1.0.1         vctrs_0.1.0        tidyselect_0.2.5   xfun_0.7
## Record execution time and multicore use
end_time <- Sys.time()
diff_time <- difftime(end_time, start_time, units = "auto")
cat("Started  ", as.character(start_time), "\n",
    "Finished ", as.character(end_time), "\n",
    "Time difference of ", diff_time, " ", attr(diff_time, "units"), "\n",
    "Used ", foreach::getDoParWorkers(), " cores\n",
    "Used ", foreach::getDoParName(), " as backend\n",
    sep = "")
## Started  2019-06-10 22:54:32
## Finished 2019-06-10 22:57:27
## Time difference of 2.922999 mins
## Used 4 cores
## Used doParallelMC as backend