Introduction

There are \(J\) participants for which we make multiple observations in both the control state and intervention state for all \(J\). A random intercept, random slope model, can be specified as:

\[ \begin{aligned} y_{i} &\sim Normal(\mu_{i}, \sigma) \\ \mu_{i} &= \alpha_{x(i)} + \beta_{j(i),x(i)} \\ \end{aligned} \]

Each \(y_{i}\) records a response for participant \(j(i)\) where \(j = 1..J\), under each treatment group recorded in \(x(i)\); here \(x = 1, 2\) with \(1\) for control. Within a participant, the observations in each group are considered exchangeable.

The model includes a parameter for each of the treatment group mean (the \(\alpha_{x(i)}\)) and each person \(j\) can have a random offset from the control state mean (these would be the \(\beta_{j(i),1}\)) and a random offset to the treatment group mean (the \(\beta_{j(i),2}\)).

We could assume the \(\beta_{j(i),1}\) and \(\beta_{j(i),2}\) to be independent, but we can also assume (and model) them as correlated:

\[ \begin{aligned} \begin{bmatrix} \beta_{j,1} \\ \beta_{j,2} \end{bmatrix} &= MVNormal(0, \Sigma) \\ \Sigma &= \begin{pmatrix} \sigma_{\beta_{j,1}} & 0 \\ 0 & \sigma_{\beta_{j,2}} \\ \end{pmatrix} R \begin{pmatrix} \sigma_{\beta_{j,1}} & 0 \\ 0 & \sigma_{\beta_{j,2}} \\ \end{pmatrix} \\ \end{aligned} \]

For simplicity, assume that we are working on a standardised scale where values between -3 and 3 are the plausible bounds.

get_data <- function(
    mu = c(-1, 2), s = c(2, 2), r = 0.5, J = 250, s_e = 2, reps = 5){
  
  R <- rbind(
    c(1, r),
    c(r, 1)
  )
  S <- diag(s) %*% R %*% diag(s)
  
  # random effects, truth by pt to each trt
  u_p <- mvtnorm::rmvnorm(J, mu, S)
  # expand so that we have mult obs per pt
  id_pt <- rep(1:nrow(u_p),each = reps)
  theta <- u_p[id_pt, ]
  
  d <- data.table(id_pt = id_pt,
                  ctl = theta[, 1], trt = theta[, 2])
  d[, id_obs := 1:.N, keyby = id_pt]
  d <- melt(d, id.vars = c("id_pt", "id_obs"), 
            value.name = "mu", variable.name = "group")
  
  d[group == "ctl", id_trt := 1]
  d[group == "trt", id_trt := 2]
  # simulate the observed data
  d[, y := rnorm(.N, mu, s_e)]
  d[, id := 1:.N]
  d
}
d <- get_data()

Independent random effects means that the responses under the ctl group in no way informs us about what we will see for the response in the trt group. In what is shown below, there is only subject level variation (the residual error, s_e is set to zero). The true population mean is shown in red.

mu <- c(-1, 2)
d <- get_data(mu = mu, r = 0, s_e = 0.0)
dfig <- dcast(d, id_pt ~ group, value.var = "y", fun.aggregate = mean)
ggplot(dfig, aes(x = ctl, y = trt)) +
  geom_point() +
  geom_point(data = data.table(ctl = mu[1], trt = mu[2]),
             col = 2, size = 4)

When the responses under the two groups are strongly positively correlated, their response under ctl gives some insight into their response under trt. Participants that have high responses (relative to the population mean) under ctl might be expected to have high response (relative to the population mean) under the trt group and vice versa.

d <- get_data(mu = mu, r = 0.98, s_e = 0.0)
dfig <- dcast(d, id_pt ~ group, value.var = "y", fun.aggregate = mean)
ggplot(dfig, aes(x = ctl, y = trt)) +
  geom_point()  +
  geom_point(data = data.table(ctl = mu[1], trt = mu[2]),
             col = 2, size = 4)

The same is true when the responses are strongly negatively correlated. Participants that have low responses (relative to the population mean) under ctl might be expected to have high response (relative to the population mean) under the trt group and vice versa.

d <- get_data(mu = mu, r = -0.98, s_e = 0.0)
dfig <- dcast(d, id_pt ~ group, value.var = "y", fun.aggregate = mean)
ggplot(dfig, aes(x = ctl, y = trt)) +
  geom_point()  +
  geom_point(data = data.table(ctl = mu[1], trt = mu[2]),
             col = 2, size = 4)

Another way is to look at how this plays out is for a random sample of participants, but it isn’t quite as obvious as what is shown above.

High positive correlation between individual level ctl and trt response.

d <- get_data(mu = mu, r = 0.98, s_e = 0.4)
dfig <- d[id_pt %in% sample(unique(d$id_pt), size = 40)]
ggplot(dfig, aes(x = group, y = y, group = id_obs)) +
  geom_line() +
  geom_point()  +
  facet_wrap(~id_pt)

And independence between individual level ctl and trt response.

d <- get_data(mu = mu, r = 0.0, s_e = 0.4)
dfig <- d[id_pt %in% sample(unique(d$id_pt), size = 40)]
ggplot(dfig, aes(x = group, y = y, group = id_obs)) +
  geom_line() +
  geom_point()  +
  facet_wrap(~id_pt)

Below is a model implementation. Much of it has an obvious interpretation, but some comments are in order:

  1. A Cholesky factor is a lower triangle matrix, from which the original correlation matrix can be restored by \(L L^\prime\), i.e \(L\) multiplied by its transpose.
  2. The Cholesky decomopostion allows Stan models to ensure support on all valid constrained parameters, see:
  3. diag_pre_multiply(s_p, L_Rho_p) combines the variance elements and the representation of the correlation matrix (the Cholesky factor) into a covariance matrix, see:
  4. Once we do the multiplication by z_p we transform the independent standard normal variates into a set of offsets that have the correct scale and correlations.
  5. The multiply_lower_tri_self_transpose reconstructs the correlation matrix from the Cholesky factors.
  6. The quad_form_diag reconstructs the covariance matrix from the reconstructed correlation matrix and random effects scale parameters.
m1 <- cmdstanr::cmdstan_model(paste0(path_pref, "/stan/mvn_01.stan"))
m1
## data {
##   int N;
##   // num pt
##   int J;
##   // num trt
##   int I;
##   // resp
##   vector[N] y;
##   array[N] int id_pt;
##   array[N] int id_trt;
##   // priors
##   int prionly;
##   vector[I] pri_a_m;
##   vector[I] pri_a_s;
##   real pri_rho;
##   real pri_s;
## }
## parameters {
##   // treatment group means
##   vector[I] a;
##   // this needs to be num trt x num pt
##   // so that the matrix mult works
##   matrix[I,J] z_p;
##   vector<lower=0>[I] s_p;
##   cholesky_factor_corr[I] L_Rho_p;
##   real<lower=0> s_r;
## }
## transformed parameters{
##   matrix[J,I] u_p;
##   // transform so that we can index u_p as id_pt, id_trt
##   u_p = (diag_pre_multiply(s_p, L_Rho_p) * z_p)';
## }
## model {
##   for(i in 1:I){
##     // different prior by grp
##     target += normal_lpdf(a[i] | pri_a_m[i], pri_a_s[i]);
##   }
##   target += lkj_corr_cholesky_lpdf(L_Rho_p | pri_rho);
##   target += exponential_lpdf(s_p | pri_s);
##   target += normal_lpdf(to_vector(z_p) | 0, 1);
##   target += exponential_lpdf(s_r | 1);
##   if(!prionly){
##     vector[N] mu;
##     for(i in 1:N){
##       mu[i] = a[id_trt[i]] + u_p[id_pt[i],id_trt[i]];
##     }
##     target += normal_lpdf(y | mu, s_r);
##   }
## }
## generated quantities{
##   vector[N] log_lik;
##   vector[N] y_rep;
##   matrix[I,I] Rho_p;
##   matrix[I,I] S_p;
##   matrix[J,I] mu_pt;
##   vector[I] mu_marg;
##   // reconstruct the correl and cov matrices
##   Rho_p = multiply_lower_tri_self_transpose(L_Rho_p);
##   S_p = quad_form_diag(Rho_p, s_p);
##   {
##     real theta;
##     for (i in 1:N){
##       theta = a[id_trt[i]] + u_p[id_pt[i],id_trt[i]];
##       log_lik[i] = normal_lpdf(y[i] | theta , s_r);
##       y_rep[i] = normal_rng(theta, s_r);
##     } 
##     // conditional estimates
##     for(i in 1:I){
##       for(j in 1:J){
##         mu_pt[j,i] = a[i] + u_p[j,i];
##       }
##     }
##     mu_marg = multi_normal_rng(a,S_p); 
##   }
## }
# Moderate positive correlation b/w individual level resp
set.seed(2)
d <- get_data(r = 0.7, J = 100)
lsd <- list(
  N = nrow(d),
  J = length(unique(d$id_pt)),
  I = length(unique(d$id_trt)),
  y = d$y,
  id_pt = d$id_pt,
  id_trt = d$id_trt,
  prionly = 1,
  # treatment group prior means and sd
  pri_a_m = c(0, 0), pri_a_s = c(2,2),
  # ljkcor conc, the higher this value, the 
  # distribution is more concentrated 
  # around zero
  pri_rho = 3, 
  # resid sd (this is actually a rate that 
  # gets passed into the exponential prior)
  pri_s = 1
)

f1 <- m1$sample(lsd,
    iter_warmup = 1000,
    iter_sampling = 2000,
    parallel_chains = 3,
    chains = 3,
    refresh = 0)
## Running MCMC with 3 parallel chains...
## 
## Chain 1 finished in 1.5 seconds.
## Chain 2 finished in 1.5 seconds.
## Chain 3 finished in 1.5 seconds.
## 
## All 3 chains finished successfully.
## Mean chain execution time: 1.5 seconds.
## Total execution time: 1.5 seconds.

The prior predictive distribution (simulated data based on our a-priori assumptions) can be generated by running the model without conditioning on the data.

Each row has \(N\) values, the first J correspond to the control state and the last J to the trt state.

The black shows samples from the prior predictive (i.e. without conditioning on the observed data).

The utility in this is to get a sense of whether these replicates are within the realms of possibility; we do not care whether they align with what was actually observed. It is an exercise about whether our assumptions align with reality as it is presently understood.

Given that I have prior knowledge that the response will most likely be in the range of -3 to 3, these seem ok for now as they more than cover those kind of values and are not producing anything crazy such as -100 to 100.

d_yrep <- data.table(f1$draws(variable = "y_rep", format = "matrix"))
d_yrep[, id_rep := 1:.N]
d_yrep <- melt(d_yrep, id.vars = "id_rep", value.name = "y_rep")
d_yrep[, id := gsub("y_rep[", "", variable, fixed = T)]
d_yrep[, id := as.integer(gsub("]", "", id, fixed = T))]
d_yrep[, group := d[id, group]]

ggplot(d_yrep[id_rep %in% sample(1:max(d_yrep$id_rep), size = 20)], 
       aes(x = group, y = y_rep)) +
  geom_jitter(height = 0, width = 0.1, size = 0.3, alpha = 0.4) +
  facet_wrap(~id_rep)

And these are the posterior distributions for the terms that we are interested in. They should reflect what we have specified for the prior, because we haven’t conditioned on anything.

Reminder: a corresponds to the treatment group means, the s_p are the standard deviations that are used to construct the covariance matrix for the random effects, the Rho_p[1,2] is the correlation between the individual level responses and the s_r is the standard deviation used for the residual error.

d_par <- data.table(f1$draws(variable = c("a", "s_p", "s_r", 
                                          "Rho_p[1,2]"), 
                             format = "matrix"))
d_par <- melt(d_par, measure.vars = names(d_par))
ggplot(d_par, aes(x = value)) +
  geom_density() +
  facet_wrap(~variable, scales = "free")

Fit the model based on a data set with the default means, and moderate correlation for the random effects. You can never say from a single run whether things worked or not, but based on the estimates, this doesn’t look like there are any major problems.

set.seed(9)
mu = c(-1, 1); s = c(0.4, 0.4); r = 0.4; J = 100; s_e = 0.6; reps = 2
d <- get_data(mu, s, r, J, s_e, reps)

l1 <- lme4::lmer(y ~ group + (1 + group | id_pt), data = d)
mu_freq_trt <- lme4::fixef(l1)["grouptrt"]
mu_freq_ci <- as.numeric(confint(l1, method = "Wald")["grouptrt", ])
# if re.form is NA or ~0, include no random effects 
# predict(l1, newdata = data.table(group = c("ctl", "trt")), re.form = NA)
# confint(l1,method="boot")

# d[, mean(y), keyby = id_trt]
lsd <- list(
  N = nrow(d),
  J = length(unique(d$id_pt)),
  I = length(unique(d$id_trt)),
  y = d$y,
  id_pt = d$id_pt,
  id_trt = d$id_trt,
  prionly = 0,
  # treatment group prior means and sd
  pri_a_m = c(0, 0), pri_a_s = c(10,10),
  # ljkcor conc, the higher this value, the 
  # distribution is more concentrated 
  # around zero
  pri_rho = 1, 
  # resid sd (this is actually a rate that 
  # gets passed into the exponential prior)
  pri_s = 1
)

f1 <- m1$sample(lsd,
    iter_warmup = 1000,
    iter_sampling = 3000,
    parallel_chains = 3,
    chains = 3,
    refresh = 0, 
    adapt_delta = 0.95, max_treedepth = 10)
## Running MCMC with 3 parallel chains...
## 
## Chain 1 finished in 2.5 seconds.
## Chain 2 finished in 2.6 seconds.
## Chain 3 finished in 2.8 seconds.
## 
## All 3 chains finished successfully.
## Mean chain execution time: 2.6 seconds.
## Total execution time: 2.9 seconds.
f1$summary(variables = c("a", "s_p", "s_r", "Rho_p[1,2]"))
## # A tibble: 6 × 10
##   variable     mean median     sd    mad     q5    q95  rhat ess_bulk ess_tail
##   <chr>       <num>  <num>  <num>  <num>  <num>  <num> <num>    <num>    <num>
## 1 a[1]       -0.974 -0.973 0.0581 0.0585 -1.07  -0.878  1.00    8557.    7075.
## 2 a[2]        0.943  0.943 0.0546 0.0545  0.852  1.03   1.00    9649.    7208.
## 3 s_p[1]      0.357  0.360 0.0780 0.0724  0.225  0.477  1.00    2050.    2045.
## 4 s_p[2]      0.295  0.300 0.0815 0.0769  0.155  0.422  1.00    1877.    2178.
## 5 s_r         0.633  0.632 0.0319 0.0321  0.582  0.687  1.00    2104.    3556.
## 6 Rho_p[1,2]  0.553  0.571 0.261  0.264   0.105  0.935  1.00    2453.    3525.

The diagnostics plus the caterpillar plots suggests that the chains converged ok. Basically, the fact that they are stable and flat and there is a substantial amount of alignment in each of the chains are heuristics that suggest convergence. If the chains are wandering all over the place or they are not, more or less, sitting on top of one another, then your estimates are probably going to be unreliable and you need to figure what has gone wrong. Similarly, if your parameter values look odd in any way, then investigate until you fully understand why they look odd.

There are many ways to approach debugging. First thing is to carefully check all your data definitions and the data that is being passed to stan. For example, have you declared something in data as an int, when in fact it should be declared as real? In such a case, stan will silently take your real value data and round it to an integer. The print(whatever) statement can be useful to track things down. If that doesn’t highlight anything, then simplify your model as much as you can until things are working as you expect, and then progressively add complexity until you hit the point where thing go wrong.

bayesplot::mcmc_trace(f1$draws(
  variables = c("a", "s_p", "Rho_p[1,2]", "s_r")))

And the posterior predictive distribution, at least superficially, looks ok. The posterior predictive are replicates (simulated data) conditional on your observed data. They represent other experiments, and in this case, of the same number of participants and observations per participant. If there are major deviations, it indicates that something in the model structure is not aligned with the data.

d_yrep <- data.table(f1$draws(variable = "y_rep", format = "matrix"))
d_yrep[, id_rep := 1:.N]
d_yrep <- melt(d_yrep, id.vars = "id_rep", value.name = "y_rep")
d_yrep[, id := gsub("y_rep[", "", variable, fixed = T)]
d_yrep[, id := as.integer(gsub("]", "", id, fixed = T))]
d_yrep[, group := d[id, group]]

ggplot(d_yrep[id_rep %in% sample(1:max(d_yrep$id_rep), size = 50)], 
       aes(x = y_rep)) +
  geom_line(aes(group = id_rep), stat="density", alpha = 0.1) +
  geom_density(data = d,
               aes(x = y), col = 2) +
  facet_wrap(~group)

The conditional estimates for the group means under both the ctl and trt state can be obtained for each of the \(J\) participants. The black lines show the parameters used to simulate the data. We shouldn’t expect to be able to recover these from any given dataset. The red show the participant level means; the relatively strong regularisation is apparent for those participants that had the more extreme observations.

d_pt <- data.table(f1$draws(variable = c("mu_pt"), 
                              format = "matrix"))
d_pt[, id_smpl := 1:.N]
d_pt <- melt(d_pt, id.vars = "id_smpl")
d_pt[, id_pt := gsub("mu_pt[", "", variable, fixed = T)]
d_pt[, id_pt := as.integer(gsub(",.*", "", id_pt))]
d_pt[, id_trt := gsub(".*,", "", variable)]
d_pt[, id_trt := as.integer(gsub("]", "", id_trt, fixed = T))]
dfig1 <- d_pt[, .(
  mu_mod = mean(value),
  q_025 = quantile(value, prob = 0.025),
  q_975 = quantile(value, prob = 0.975)
  ), keyby = .(id_pt, id_trt)]
dfig2 <- d[, .(mu_obs = mean(y)), keyby = .(id_pt, id_trt)]
dfig <- merge(dfig1, dfig2, by = c("id_pt", "id_trt"))
setkey(dfig, id_trt, mu_mod)
dfig[, id_plot := 1:.N, keyby = id_trt]
dfig[, group := levels(d$group)[id_trt]]

ggplot(dfig, aes(x = id_plot, y = mu_mod)) +
  geom_linerange(aes(ymin = q_025, ymax = q_975),
                 linewidth = 0.1) +
  geom_point() +
  geom_point(aes(y = mu_obs), col = 2, size = 0.4) +
  geom_hline(data = data.table(group = "ctl", mu = mu[1]),
             aes(yintercept = mu), 
             linewidth = 0.2) +
  geom_hline(data = data.table(group = "trt", mu = mu[2]),
             aes(yintercept = mu), 
             linewidth = 0.2) +
  scale_x_continuous("Ordered index") +
  scale_y_continuous("Posterior mean") +
  facet_wrap(~group)

Similarly, the conditional estimates for the treatment effects can be obtained for each of the \(J\) participants. I have included a dashed line to show the observed average treatment effect.

d_del <- dcast(
  d_pt,
  id_smpl + id_pt ~ id_trt, value.var = "value",
  )
d_del[, rd := `2` - `1`]

dfig1 <- d_del[, .(
  mu_mod = mean(rd),
  q_025 = quantile(rd, prob = 0.025),
  q_975 = quantile(rd, prob = 0.975)
  ), keyby = .(id_pt)]

dfig2 <- dcast(
  d[, .(mu = mean(y)), keyby = .(id_pt, id_trt)],
  id_pt ~ id_trt, value.var = "mu")[, mu_obs := .(`2` - `1`)]

dfig <- merge(dfig1, dfig2[, .SD, .SDcols = !c("1", "2")], by = c("id_pt"))
setkey(dfig, mu_mod)
dfig[, id_plot := 1:.N]

mu_trt_effect <- mean(dfig$mu_obs)

ggplot(dfig, aes(x = id_plot, y = mu_mod)) +
  geom_linerange(aes(ymin = q_025, ymax = q_975),
                 linewidth = 0.1) +
  geom_point() +
  geom_point(aes(y = mu_obs), col = 2, size = 0.4) +
  geom_hline(yintercept = mu[2] - mu[1],
             linewidth = 0.2) +
  geom_hline(yintercept = mu_trt_effect,
             linewidth = 0.2, lty = 2) +
  scale_x_continuous("Ordered index") +
  scale_y_continuous("Treatment effect") 

In addition to the conditional treatment effects, the treatment effect associated with the ‘average’ within-sample participant might be of interest (LHS panel below). This is the usual focus of a clinical trial. It parallels what you would get out of a frequentist approach using the lme4 package when you run confint on the fixed effect for the treatment group parameter. I have shown that 95% confidence interval in red and then the 95% credible interval associated with the posterior distribution for the treatment effect is bounded by the dotted vertical lines. The true value used to simulate the data is represented by the black vertical line.

An out of sample estimate of the treatment effect, then we would need to integrate out the random effects and therefore, much greater uncertainty is apparent (RHS panel below).

d_a <- data.table(f1$draws(variable = c("a"), 
                              format = "matrix"))
d_a[, id_smpl := 1:.N]
d_a <- melt(d_a, id.vars = "id_smpl")
d_a[, id_trt := gsub("a[", "", variable, fixed = T)]
d_a[, id_trt := as.integer(gsub("]", "", id_trt, fixed = T))]

dfig1 <- dcast(
  d_a, id_smpl ~ id_trt, value.var = 'value')[, mu_mod := `2` - `1`]

cri_q025 <- quantile(dfig1$mu_mod, prob = 0.025)
cri_q975 <- quantile(dfig1$mu_mod, prob = 0.975)

d_marg <- data.table(f1$draws(variable = c("mu_marg"), 
                              format = "matrix"))
d_marg[, id_smpl := 1:.N]
d_marg <- melt(d_marg, id.vars = "id_smpl")
d_marg[, id_trt := gsub("mu_marg[", "", variable, fixed = T)]
d_marg[, id_trt := as.integer(gsub("]", "", id_trt, fixed = T))]

dfig2 <- dcast(
  d_marg, id_smpl ~ id_trt, value.var = 'value')[, mu_mod := `2` - `1`]

p1 <- ggplot(dfig1, aes(x = mu_mod)) +
  geom_density() +
  geom_vline(xintercept = mu[2] - mu[1],
             linewidth = 0.2, lty = 1) + 
  geom_segment(aes(x = mu_freq_ci[1], 
                   y = 2, 
                   xend = mu_freq_ci[2], 
                   yend = 2), col = 2) +
  geom_vline(xintercept = c(cri_q025, cri_q975),
             linewidth = 0.2, lty = 2) 


p2 <- ggplot(dfig2, aes(x = mu_mod)) +
  geom_density() +
  geom_vline(xintercept = mu_trt_effect,
             linewidth = 0.2, lty = 3) +
  geom_vline(xintercept = mu[2] - mu[1],
             linewidth = 0.2, lty = 1)

pg1 <- ggplot2::ggplotGrob(p1)
pg2 <- ggplot2::ggplotGrob(p2)

grid::grid.newpage()
vp <- viewport(x = 0.0, y = 0.0, 
               width = 0.5, height = 1,
               just = c("left", "bottom"))
pushViewport(vp)
grid.draw(pg1)
popViewport()
  
vp <- viewport(x = 0.5, y = 0.0, 
               width = 0.5, height = 1,
               just = c("left", "bottom"))
pushViewport(vp)
grid.draw(pg2)
popViewport()