Análisis Jerárquico

En este trabajo práctico vamos a aprender a escribir y ajustar modelos jerárquicos en JAGS. Para entender de qué se trata esto de las estructuras jerárquicas vamos a empezar con datos simulados.

Objetivos:

  1. Simular datos con estructura jerárquica
  2. Realizar análisis de complete pooling y no-pooling
  3. Comparar los análisis anteriores con un modelo jerárquico (partial pooling) usando JAGS

Suponiendo que estamos estudiando \(10\) sub-poblaciones de Chinchillones en la estepa Patagónica, y queremos saber cuál es la probabilidad de que estos bichos sobrevivan al invierno. Vamos a simular datos provenientes de individuos de cada una de las sub-poblaciones. Además, vamos a considerar que todas las sub-poblaciones son parte de una misma “meta-población”. Esto hace que esperemos que los parámetros de supervivencia sean parecidos entre las sub-poblaciones, pero que sin embargo tengan cierto grado de diferencias. Para modelar la variabilidad entre sub-poblaciones podemos usar una distribución Beta, por ejemplo con parámetros \(a_s = 2\) y \(b_s = 10\). Si repasamos el Bestiario vemos que esta combinación de parámetros resulta en un valor esperado de \(\frac{a_s}{(a_s+b_s)}\) = 0.167 y una varianza de \(\frac{(a_s \times b_s)}{(a_s + b_s)^2 \times (a_s + b_s + 1)}\) = 0.011. Entonces, cada sub-población tendrá su propia tasa de supervivencia. Por otro lado, podemos imaginar que logramos seguir un número variable de individuos en cada población para ver si sobreviven o no. Para simular los datos hacemos:

set.seed(1234)
n <- 10  # sub-poblaciones
m <- c(30, 28, 20, 30, 30, 26, 29, 5, 3, 27)  # número de individuos muestraeados por sub-poblacion

a_s <- 2
b_s <- 10

theta <- rbeta(n, a_s, b_s)  # generamos n tasas de mortalidad, una por cada sub-poblacion    
y <- rbinom(n, size = m, prob = theta)  # simulamos número de muertes por grupo

op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")

plot(table(y), xlab = "Muertes por sub-población", ylab = "Frecuencia")

par(op)

Estos datos nos muestran el número de muertes por sub-población. Hay cuatro sub-poblaciones que no tuvieron muertes, dos que tuvieron dos muertes, etc. Además, tenemos que recordar que cuando simulamos estos datos cada sub-población tenia su propia tasa de mortalidad.

Vamos a modelar estos datos suponiendo que:

  1. todas las sub-poblaciones tienen la misma mortalidad (complete pooling), es decir que ignoramos que cada sub-población tiene su propia tasa de mortalidad

  2. cada sub-población tiene una tasa de mortalidad independiente de las otras (no pooling), es decir que ignoramos que todas las sub-poblaciones son parte de una misma “meta-población”.

  3. cada sub-población es diferente pero parecida a las otras (partial pooling)

Vamos por partes:

(1) complete pooling

En este caso el modelo de datos (cuántos individuos mueren) es una Binomial:

cat(file = "completepooling.bug", "
  model{
    for( i in 1 : n ) {
      y[i] ~ dbin(theta, m[i]) #Likelihood\t
\t\t}
\t theta ~ dbeta(1, 1) # previa no-informativa para la tasa de mortalidad
  }
")

Como de costumbre, definimos los datos que le vamos a pasar a JAGS, una función para generar valores iniciales para las cadenas Markovianas, los parámetros que queremos guardar, cuántas iteraciones vamos a correr, etc.

data <- list("y", "m", "n")
inits <- function() list(theta = runif(1, 0, 1))
params <- c("theta")

ni <- 1000
nc <- 3
nt <- 1
nb <- 500

Ahora llamamos a JAGS

library(jagsUI)
cp.sim <- jags(data, inits, params, model.file = "completepooling.bug", n.chains = nc, 
    n.iter = ni, n.burnin = nb, n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 1
##    Total graph size: 24
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(cp.sim)
## JAGS output for model 'completepooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior. 
## MCMC ran for 0.004 minutes at time 2019-03-01 14:17:46.
## 
##            mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## theta     0.135 0.023  0.093  0.134  0.183    FALSE 1 1.001  1133
## deviance 47.443 1.490 46.403 46.888 51.386    FALSE 1 1.009  1500
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 1.1 and DIC = 48.553 
## DIC is an estimate of expected predictive error (lower is better).

Ejercicios:

1- ¿Cuál es el valor esperado de la posterior de la tasa de mortalidad bajo el supuesto de “complete pooling”?

2- Comparar el valor esperado de la tasa de mortalidad con el valor promedio de las tasas de mortalidad usadas para simular los datos.

(2) No-pooling

De nuevo usando previas no informativas:

cat(file = "nopooling.bug", "
  model   {
    for( i in 1 : n ) {
        y[i] ~ dbin(theta[i], m[i]) # Likelihood
        theta[i] ~ dbeta(1, 1) # previas no-informativas para la tasa de mortalidad
\t\t}
}
")
library(jagsUI)
data <- list("y", "m", "n")
inits <- function() list(theta = runif(n, 0, 1))
params <- c("theta")

ni <- 1000
nc <- 3
nt <- 1
nb <- 500
np.sim <- jags(data, inits, params, model.file = "nopooling.bug", n.chains = nc, 
    n.iter = ni, n.burnin = nb, n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 10
##    Total graph size: 51
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(np.sim)
## JAGS output for model 'nopooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior. 
## MCMC ran for 0.002 minutes at time 2019-03-01 14:17:47.
## 
##             mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## theta[1]   0.031 0.030  0.001  0.022  0.110    FALSE 1 1.001  1500
## theta[2]   0.100 0.053  0.021  0.092  0.225    FALSE 1 1.000  1500
## theta[3]   0.043 0.042  0.001  0.030  0.158    FALSE 1 1.006   753
## theta[4]   0.313 0.080  0.164  0.311  0.479    FALSE 1 1.000  1500
## theta[5]   0.252 0.076  0.123  0.246  0.412    FALSE 1 0.999  1500
## theta[6]   0.214 0.077  0.088  0.206  0.385    FALSE 1 1.005   658
## theta[7]   0.194 0.070  0.073  0.187  0.344    FALSE 1 1.001  1095
## theta[8]   0.137 0.117  0.004  0.106  0.414    FALSE 1 1.003   934
## theta[9]   0.195 0.157  0.006  0.152  0.560    FALSE 1 0.999  1500
## theta[10]  0.103 0.056  0.023  0.094  0.232    FALSE 1 1.002  1216
## deviance  31.534 4.757 24.037 30.865 42.234    FALSE 1 1.000  1500
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 11.3 and DIC = 42.855 
## DIC is an estimate of expected predictive error (lower is better).

Veamos un gráfico para comprar los dos análisis (complete-pooling y no-pooling)

op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")

plot(density(cp.sim$sims.list$theta), type = "l", lwd = 3, ylab = "Density", 
    xlab = "Tasa de mortalidad", xlim = c(0, 0.5), ylim = c(0, 30), main = "")  #Posterior de theta con complete pooling

for (i in 1:10) {
    lines(density(np.sim$sims.list$theta[, i]), type = "l", col = "gray", lwd = 3)  #Posteriores de theta con no pooling
}

par(op)

Ejercicios

1- ¿Qué se puede decir de las distintas posteriores?

2- ¿Alguno de estos dos enfoques les parece más adecuado? ¿Cuáles serían los pro y contras de cada uno?

(3) Partial pooling

Finalmente, podemos hacer un análisis jerárquico o “multi-nivel” (partial pooling), reconociendo explícitamente que cada sub-población es potencialmente distinta de las otras, pero que todas son parte de la misma meta-población. Esto último implica que las observaciones en las sub-poblaciones no son del todo independientes.

cat(file = "hier.bug", "
    model{
      #Likelihood
      for(i in 1: n ) {
\t\t    y[i] ~ dbin(theta[i], m[i])\t
\t\t    theta[i] ~ dbeta(a, b) 
      }
    #Previas
\t\ta ~ dnorm(0,0.01)T(0,)
\t\tb ~ dnorm(0,0.01)T(0,)
    mean_pobl <- a/(a + b) 
}
")

Noten que en la última línea del modelo BUGS, registramos el valor esperado de la tasa de mortalidad para la meta-población. Este es un ejemplo de cómo podemos calcular valores de interés (en la jerga estadística serían “derived quantities”) dentro del modelo BUGS. En este caso, usamos esta fórmula que corresponde a la media de la distribución Beta (ver Bestiario de distribuciones). Entonces, vamos a obtener una posterior para esta “derived quantity”.

library(jagsUI)
data <- list("y", "m", "n")
inits <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
params <- c("a", "b", "theta", "mean_pobl")
ni <- 5000
nc <- 3
nt <- 4
nb <- 2500

hier.sim <- jags(data, inits, params, model.file = "hier.bug", n.chains = nc, 
    n.iter = ni, n.burnin = nb, n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 12
##    Total graph size: 41
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 2500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 2500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(hier.sim)
## JAGS output for model 'hier.bug', generated by jagsUI.
## Estimates based on 3 chains of 5000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 2500 iterations and thin rate = 4,
## yielding 1875 total samples from the joint posterior. 
## MCMC ran for 0.028 minutes at time 2019-03-01 14:17:49.
## 
##             mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## a          1.676 0.977  0.395  1.466  4.161    FALSE 1 1.014   313
## b         11.340 5.769  2.794 10.449 24.406    FALSE 1 1.010   315
## theta[1]   0.037 0.032  0.000  0.030  0.117    FALSE 1 1.005   568
## theta[2]   0.090 0.045  0.020  0.083  0.190    FALSE 1 1.001  1555
## theta[3]   0.048 0.042  0.001  0.038  0.155    FALSE 1 1.004   681
## theta[4]   0.254 0.069  0.138  0.249  0.405    FALSE 1 1.000  1875
## theta[5]   0.204 0.062  0.100  0.199  0.338    FALSE 1 1.002   788
## theta[6]   0.172 0.062  0.072  0.166  0.313    FALSE 1 1.001  1420
## theta[7]   0.160 0.058  0.063  0.154  0.289    FALSE 1 1.003   568
## theta[8]   0.087 0.072  0.001  0.071  0.275    FALSE 1 1.001  1875
## theta[9]   0.098 0.077  0.001  0.083  0.286    FALSE 1 1.000  1851
## theta[10]  0.091 0.048  0.020  0.084  0.203    FALSE 1 1.002   880
## mean_pobl  0.130 0.038  0.066  0.128  0.213    FALSE 1 1.001  1626
## deviance  30.207 4.513 22.803 29.582 40.411    FALSE 1 1.008   303
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 10.1 and DIC = 40.336 
## DIC is an estimate of expected predictive error (lower is better).

Ahora podemos ver la posterior de la tasa de mortalidad a nivel de meta-población y compararla con el valor verdadero que usamos en las simulaciones (en rojo).

op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de  mortalidad", 
    lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)

par(op)

Finalmente podemos comparar gráficamente los tres enfoques:

op <- par(mfrow = c(1, 2), cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, 
    bty = "n")

plot(density(cp.sim$sims.list$theta), type = "l", lwd = 3, ylab = "Density", 
    xlab = "Tasa de mortalidad", xlim = c(0, 1), ylim = c(0, 25), main = "")  #Complete pooling

for (i in 1:10) {
    lines(density(np.sim$sims.list$theta[, i]), type = "l", col = "gray", lwd = 2)
}  #No pooling

curve(dbeta(x, hier.sim$mean$a, hier.sim$mean$b), lwd = 3, col = 2, xlab = "Tasa de mortalidad", 
    ylab = "", ylim = c(0, 25))  #Tasa de mortalidad meta-poblacional estimada con partial pooling

for (i in 1:10) {
    lines(density(hier.sim$sims.list$theta[, i]), col = "blue", lwd = 2, main = "")  #Tasa de mortalidad de cada subpoblación estimadas con partial pooling
}

par(op)

En el gráfico la linea roja representa la posterior para la tasa de mortalidad de la meta-población y cada linea azul representa la posterior para la tasa de mortalidad de cada sub-población.

¿Cómo comparan las posteriores para las distintas sub-poblaciones con no-pooling y partial pooling?


Opcional: Regresiones con Stan y brms

Veamos cómo hacer el análisis jerárquico con Stan

cat(file = "hier.stan", "
data { 
  int<lower=1> n;
  int<lower=0> y[n];
  int<lower=0> m[n];
}

parameters{
  real<lower=0> a;
  real<lower=0> b;
  real<lower=0,upper=1> theta[n];
}

model{
  
  a ~ normal(0,10);
  b ~ normal(0,10);
  theta ~ beta(a,b);
  for(i in 1: n ) {
    y[i] ~ binomial(m[i], theta[i]);\t
\t}
}

generated quantities {
  real mean_pobl;
  mean_pobl = a/(a + b);
}
")

Ahora llamamos a Stan

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

hier_dat <- list(n = n, m = m, y = y)

inits <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))

fit <- stan(file = "hier.stan", data = hier_dat, init = inits, iter = 1000, 
    thin = 1, chains = 3)

Podemos comparar con el ajuste de JAGS y con el valore verdadero a nivel poblacional

pos <- as.data.frame(fit)

op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de  mortalidad", 
    lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)

par(op)

Ahora veamos como hacer el análisis jerárquico con brms. La gran diferencia es que con brms formulamos en modelo es escala logit

library(brms)

priors = c(prior(normal(0, 1), class = Intercept), prior(cauchy(0, 1), class = sd))

fit_brms = brm(y | trials(m) ~ 1 + (1 | site), data = data.frame(y = y, m = m, 
    site = 1:n), family = binomial(), prior = priors)

Para ver las estimaciones para cada sub-población podemos hacer

plogis(coef(fit_brms, robust = TRUE)$site[, , ])
##      Estimate Est.Error        Q2.5     Q97.5
## 1  0.04435602 0.7048795 0.002918079 0.1425000
## 2  0.08897024 0.6395497 0.023207112 0.1964227
## 3  0.05440197 0.7121480 0.003329889 0.1667172
## 4  0.25140673 0.6027072 0.126697361 0.4188041
## 5  0.19888246 0.6040404 0.100137723 0.3507151
## 6  0.16407096 0.6147617 0.067465115 0.3218250
## 7  0.15406484 0.6088120 0.065749125 0.2966351
## 8  0.08784399 0.7031740 0.006358927 0.2943714
## 9  0.09791512 0.7095773 0.005933390 0.3689950
## 10 0.09060340 0.6352696 0.024233502 0.2002980

Podemos graficar la estimación a nivel poblacional y compararla con las estimaciones de JAGS y Stan

pos <- as.data.frame(fit)
pos_brms <- posterior_samples(fit_brms)

op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de  mortalidad", 
    lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
lines(density(plogis(pos_brms$b_Intercept)), lty = 3, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)
Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.

Figure 1: Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.

par(op)

Juan Manuel Morales. 6 de Septiembre del 2015. última actualización: 2019-03-01