Setup

This section contains the setup and the various utility functions used throughout.

Libraries used:

library(rstan)
library(flexsurv)
library(data.table)
# See https://github.com/maj-biostat/pwexp
library(pwexp)
library(splines2)
library(muhaz)
rstan::rstan_options(auto_write = TRUE)
Sys.setenv(LOCAL_CPPFLAGS = '-march=native')
options(mc.cores = 1)

Compiled code (any models used are shown later):

mod1 <- rstan::stan_model("../stan/surv_mspline4.stan", verbose = FALSE)

Introduction

In an earlier post, I showed how the weibull model has suitability for situations where we are fairly certain the hazard will be constant or monotonic. Flexible survival modeling using splines offers an alternative approach for modelling more complex hazard functions such as the one shown below.

tt <- seq(1e-06, 200, len = 250)
bins = c(0, 20, 40, 100, 150)
rate = c(1/90, 1/40, 1/15, 1/50, 1/100)
hh <- pwexp::hpw(tt, bins, rate)
plot(tt, hh, type = "l", ylim = c(0, max(hh)), ylab = "Hazard", xlab = "Time")

Simulate data having this underlying hazard.

set.seed(2)
n <- 1000
trt <- rbinom(n, size = 1, prob = 0.5)
logb0 <- -3
logb1 <- 1
u <- y <- sapply(1:n, function(i){
  if(trt[i] == 0){
    # shape controls whether hazard is decreasing, constant or increasing
    # over time. scale is used to introduce a linear predictor. 
    pwexp::rpw(1, bins, rate)  
  } else {
    pwexp::rpw(1, bins, 1.2 * rate)  
  }
})
cc <- rexp(n, 1/100)
evt <- as.numeric(y < cc)
u[!evt] <- cc[!evt]
head(data.table(trt, y, cc, u, evt))
##    trt        y        cc         u evt
## 1:   0 46.12594 35.892796 35.892796   0
## 2:   1 68.01862 31.443914 31.443914   0
## 3:   1 21.52448 66.025358 21.524482   1
## 4:   0 39.44779 81.220687 39.447787   1
## 5:   1 47.23519  3.076126  3.076126   0
## 6:   1 55.05666 14.072278 14.072278   0

When modeling with splines, we need to construct a basis that will be used to model the response. You can read an introductory post on splines here. In this case I am pretending that I have insight into where to place the internal knots to our advantage. However, in real life, you typically will not have this knowledge. We create m-spline and i-spline basis with the splines2 package and predict using the observed data.

# set up basis
df <- 6 # assumes intercept included
degree <- 3L # cubic splines
tmax <- 500
bknots <- c(0, tmax)
# internal knots (see get_iknots for interp)
nk <- df - degree - 1 # subtract 1 as we will have an intercept
# knot spacing (somewhat sensibly placed by eye alternatively use 
# a small set of quantiles to select internal knots)
iknots <- c(25, 50, 75, 100, 150)
# the intercept is included in the mspline so that the hazard is not 
# constrained to zero at the intercept.
# in order to ensure identifiability of the mspline coefs and the intercept
# in the linear predictor, a simplex is used for the msplines coefs
tt <- seq(0, tmax, len = 1000)
basis <- splines2::mSpline(tt, 
                           knots = iknots, 
                           Boundary.knots = bknots,
                           degree = degree, 
                           intercept = TRUE)

nvars  <- ncol(basis)  # number of aux parameters, basis terms

ibasis <- basis
class(ibasis) <- c("matrix", "iSpline")
ibasis <- predict(ibasis, tt)

# Now predict the values of the basis at the event times.
# Construct ispline for both observed and censored events.
basis_event <- as.array(predict(basis, u[evt == 1]))
tmpbasis <- basis
class(tmpbasis) <- c("matrix", "iSpline")
ibasis_event <- as.array(predict(tmpbasis, u[evt == 1]))
ibasis_rcens <- as.array(predict(tmpbasis, u[evt == 0]))

In their basic form the basis structure enables complexity for early event times, but less structure later.

par(mfrow = c(1, 2))
matplot(tt, basis, type = "l")
matplot(tt, ibasis, type = "l")

par(mfrow = c(1, 1))

Modelling

The stan model is surprisingly simple. For right censoring, we simply include the likelihood in the standard survival form that you can review here.

data {
    int<lower=0> N_uncen;                                   
    int<lower=0> N_cen;                                        
    int<lower=1> m;   // number cols in basis                                               
    int<lower=1> NC;  // number of covariates excl intercept                                              
    matrix[N_cen,NC] X_cen;     // design matrix for censored events                          
    matrix[N_uncen,NC] X_uncen; // design matrix for observed events
    matrix[N_uncen,m] mspline_basis_uncen;                  
    matrix[N_uncen,m] ispline_basis_uncen;   
    matrix[N_cen,m] ispline_basis_cen;
    // hyper parameter for dirichlet 
    vector<lower=0>[m] conc; 
}
parameters {
    simplex[m] gammas;       
    vector[NC] betas;                                            
    real intercept;   
}
model {
    target += normal_lpdf(betas | 0, 1);
    target += normal_lpdf(intercept | 0, 5);
    target += dirichlet_lpdf(gammas | conc);
    
    // The likelihood has the form L(F) = prod h(u)^d S(u) so we need to add all these
    // components to the log posterior.
    // Another way to write the likelihood is L(F) = prod f(u)^d (1-F(u))^d. We can get to
    // the other version by using the fact that f(u) = h(u)S(u) and S(u) = 1-F(u).
    // see ~/Documents/lib/edu/lutian_2020_survival
    
    // mspline_log_surv - censored
    // for these obs, the h(u)^d part equals zero so we just add the survival part
    target += -(ispline_basis_cen*gammas) .* exp(X_cen*betas + intercept);
    // mspline_log_surv - events
    // for these obs, the h(u)^d part does not equal zero so we need to add both the 
    // hazard and survival parts
    target += -(ispline_basis_uncen*gammas) .* exp(X_uncen*betas + intercept);
    // mspline_log_haz - events
    target +=  log(mspline_basis_uncen*gammas) + X_uncen*betas + intercept;
}

Sample from the model.

# for a simple single predictor model (plus intercept)
# the event times are encapsulated in the splines
l <- list(
  N_uncen = sum(evt),                           
  N_cen = length(y) - sum(evt),                                
  m = nvars, # cols in basis
  NC = 1, # number of terms in lp excl intercept   
  # design matrix for censored and uncensored events
  X_cen =  matrix(trt[which(evt == 0)] , ncol = 1),                            
  X_uncen =  matrix(trt[which(evt == 1)] , ncol = 1),        
  # informs likelihood for the observed events
  mspline_basis_uncen = basis_event,               
  ispline_basis_uncen = ibasis_event,
  # informs likelihood for the censored events
  ispline_basis_cen = ibasis_rcens,
  conc = rep(1, nvars)
)


f1 <- rstan::sampling(object  = mod1,
                      data    = l,
                      chains  = 1,
                      thin    = 1,
                      iter    = 10000,
                      warmup  = 2000,
                      refresh = 0)
  
print(f1, pars = c("intercept", "betas", "gammas"))
## Inference for Stan model: surv_mspline4.
## 1 chains, each with iter=10000; warmup=2000; thin=1; 
## post-warmup draws per chain=8000, total post-warmup draws=8000.
## 
##           mean se_mean   sd  2.5%  25%  50%  75% 97.5% n_eff Rhat
## intercept 2.30    0.01 0.31  1.79 2.08 2.27 2.48  3.00  1605    1
## betas[1]  0.11    0.00 0.08 -0.04 0.06 0.11 0.16  0.26  5625    1
## gammas[1] 0.01    0.00 0.00  0.00 0.01 0.01 0.01  0.02  2286    1
## gammas[2] 0.02    0.00 0.01  0.00 0.01 0.01 0.02  0.03  2525    1
## gammas[3] 0.02    0.00 0.01  0.00 0.01 0.02 0.03  0.04  3419    1
## gammas[4] 0.18    0.00 0.06  0.08 0.14 0.18 0.22  0.31  2018    1
## gammas[5] 0.26    0.00 0.08  0.12 0.20 0.26 0.31  0.43  2534    1
## gammas[6] 0.12    0.00 0.08  0.01 0.05 0.10 0.17  0.31  6847    1
## gammas[7] 0.17    0.00 0.13  0.01 0.07 0.14 0.24  0.46  4267    1
## gammas[8] 0.12    0.00 0.10  0.00 0.04 0.09 0.17  0.38  4269    1
## gammas[9] 0.11    0.00 0.10  0.00 0.03 0.08 0.16  0.38  4894    1
## 
## Samples were drawn using NUTS(diag_e) at Tue May 25 11:38:30 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Extract the posterior to visualise the implied hazard function. Note that while the hazard we obtained does not match the discontinuities present in the original hazard function, the overall form is captured quite well. The dashed lines are the estimates from the muhaz r package.

post <- rstan::extract(f1)
tt <- seq(0, 190, len = 300)
bb <- as.array(predict(basis, tt))
  
h0 <- exp(log(t(colMeans(post$gammas) %*% t(bb))) + mean(post$intercept))
h1 <- exp(log(t(colMeans(post$gammas) %*% t(bb))) + mean(post$betas) + mean(post$intercept))
  
# km fit is terrible
fit0 <- muhaz(u[trt == 0], evt[trt == 0])
fit1 <- muhaz(u[trt == 1], evt[trt == 1])
h0_mh <- fit0$haz.est
t0_mh <- fit0$est.grid
h1_mh <- fit1$haz.est
t1_mh <- fit1$est.grid

plot(1, type = "n",
     xlim = range(c(0, tt, t0_mh, t1_mh)), 
     ylim = range(c(0, h0, h1, h0_mh, h1_mh)),
     xlab = "Time",
     ylab = "Hazard rate",
     main = "Hazard function")
lines(tt, h0, col = 1)
lines(tt, h1, col = 2)  
lines(t0_mh, h0_mh, col = 1, lty = 2)
lines(t1_mh, h1_mh, col = 2, lty = 2)
for(i in seq_along(iknots)){
  abline(v = iknots[i], col = "grey", lty = 2)  
}
hh <- pwexp::hpw(tt, bins, rate)
lines(tt, hh, lty = 1, lwd = 0.5)
legend("topright", legend = c("Control", "Treatment"), lty = 1, col = 1:2)

Predictive checks

As shown earlier it is a good idea to check the posterior predictive distribution in order to determine if your underlying model is reasonable. We can simulate using the basis we defined and a generic root finding procedure.

# This controls the range to be searched. Needs to be > 0.
interval = c(1E-8, tmax)
# To get sensible values.
return_finite <- function(x){
  x <- min(x, .Machine$double.xmax)
  x <- max(x, -.Machine$double.xmax)
  x
}
# The function to be solved. Note that the log transform is used
# to give numerical stability.
rootfn <- function(t, u, ibasis, gammas, beta = NULL, intercept){
  itmp <- as.array(predict(ibasis, t))
  if(is.null(beta)){
    # control arm
    SS <- exp(-t(gammas %*% t(itmp)) * exp(intercept))
  } else {
    # treatment arm
    # Note only considering single treatment covariate. This would need
    # to be extended for more complex models.
    SS <- exp(-t(gammas %*% t(itmp)) * exp(beta + intercept))
  }
  return_finite(log(SS) - log(u))
}
# Creates random draws
r_ispline <- function(ibasis, gammas, beta = NULL, intercept){
  u_i <- runif(1)
  at_limit <- rootfn(t = interval[2], 
                     u = u_i,
                     ibasis = ibasis, 
                     gammas = gammas, 
                     beta = beta, 
                     intercept = intercept)
  if(at_limit > 0){
    return(interval[2])
  } else {
    # trying to find the value for t at which S(t) - u = 0
    t_i <- stats::uniroot(rootfn, 
                          u = u_i, 
                          interval = interval, 
                          check.conv = TRUE,
                          ibasis = ibasis, 
                          gammas = gammas, 
                          beta = beta, 
                          intercept = intercept,
                          tol = 0.001)
    
    # our random sample from the surv dist
    return(t_i$root)
  }
}

This function simply wraps the random number generating functions and places the posterior predictive draws into two matrices, one for the control group and one for the treatment group.

post_predictive <- function(ff, nsim = 20){
  post <- rstan::extract(ff)
  idx_smpl <- sample(1:length(post$intercept), replace = FALSE, size = nsim)
  m0 <- matrix(NA, nrow = n, ncol = nsim)
  m1 <- matrix(NA, nrow = n, ncol = nsim)
  for(i in 1:nsim){
    m0[, i] <- replicate(n, r_ispline(ibasis, 
                                      gammas = post$gammas[idx_smpl[i], ], 
                                      beta = NULL, 
                                      intercept = post$intercept[idx_smpl[i]]
                                      )
                         )
    
    m1[, i] <- replicate(n, r_ispline(ibasis, 
                                      gammas = post$gammas[idx_smpl[i], ], 
                                      beta = post$beta[idx_smpl[i], 1],
                                      intercept = post$intercept[idx_smpl[i]]
                                      )
                         )
  }
  return(list(m0 = m0, m1 = m1))
}

After generating the posterior predictive distribution, we can visually inspect how well the model fits the data. As can be seen, the spline model does a better job at representing the underlying data generating mechanism than the weibull model discussed earlier.

nsim <- 20
pp <- post_predictive(f1, nsim)
xlim <- c(0, 300)
par(mfrow = c(1, 2))
hist(y[trt == 0], prob = TRUE, main = "Control group", xlim = xlim, ylim = c(0, 0.025))
for(i in 1:nsim){
  lines(density(pp$m0[, i]), col = 2, lwd = 0.3)
}
hist(y[trt == 1], prob = TRUE, main = "Treatment group", xlim = xlim, ylim = c(0, 0.025))
for(i in 1:nsim){
  lines(density(pp$m1[, i]), col = 2, lwd = 0.3)
}

par(mfrow = c(1, 1))