Análisis Jerárquico

En este práctico vamos a escribir y ajustar modelos jerárquicos en JAGS. Para entender de qué se trata esto de las estructuras jerárquicas vamos a empezar con datos simulados.

Objetivos:

  1. Simular datos con estructura jerárquica
  2. Realizar análisis de complete pooling y no-pooling
  3. Comparar los análisis anteriores con un modelo jerárquico (partial pooling) usando JAGS

Suponiendo que estamos estudiando \(10\) sub-poblaciones de Chinchillones en la estepa Patagónica, y queremos saber cuál es la probabilidad de que estos bichos sobrevivan al invierno. Vamos a simular datos provenientes de individuos de cada una de las sub-poblaciones. Además, vamos a considerar que todas las sub-poblaciones son parte de una misma “meta-población”. Esto hace que esperemos que los parámetros de supervivencia sean parecidos entre las sub-poblaciones, pero que sin embargo tengan cierto grado de diferencias. Para modelar la variabilidad entre sub-poblaciones podemos usar una distribución Beta, por ejemplo con parámetros \(a_s = 2\) y \(b_s = 10\). Si repasamos el Bestiario de distribuciones vemos que esta combinación de parámetros resulta en un valor esperado de \(\frac{a_s}{(a_s+b_s)}\) = 0.167 y una varianza de \(\frac{(a_s \times b_s)}{(a_s + b_s)^2 \times (a_s + b_s + 1)}\) = 0.011. Entonces, cada sub-población tendrá su propia tasa de supervivencia. Por otro lado, podemos imaginar que logramos seguir un número variable de individuos en cada población para ver si sobreviven o no. Para simular los datos hacemos:

set.seed(1234)
n <- 10 # sub-poblaciones
m <- c(30, 28, 20, 30, 30, 26, 29, 5, 3, 27) # número de individuos muestraeados por sub-poblacion

a_s <- 2
b_s <- 10
 
theta <- rbeta(n, a_s, b_s) # generamos n tasas de mortalidad, una por cada sub-poblacion    
y <- rbinom(n, size = m, prob = theta)  # simulamos número de muertes por grupo

op <- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")

plot(table(y), xlab = "Muertes por sub-población", ylab = "Frecuencia")

par(op)

Estos datos nos muestran el número de muertes por sub-población. Hay cuatro sub-poblaciones que no tuvieron muertes, dos que tuvieron dos muertes, etc. Además, tenemos que recordar que cuando simulamos estos datos cada sub-población tenia su propia tasa de mortalidad.

Vamos a modelar estos datos suponiendo que:

  1. todas las sub-poblaciones tienen la misma mortalidad (complete pooling), es decir que ignoramos variabilidad entre sub-poblaciones

  2. cada sub-población tiene una tasa de mortalidad independiente de las otras (no pooling), es decir que ignoramos que todas las sub-poblaciones son parte de una misma “meta-población”, la misma región, la misma especie…

  3. cada sub-población es diferente pero parecida a las otras (partial pooling)

Vamos por partes:

(1) complete pooling

En este caso el modelo de datos (cuántos individuos mueren) es una Binomial:

cat(file = "completepooling.bug",
"
  model{
    for( i in 1 : n ) {
      y[i] ~ dbin(theta, m[i]) #Likelihood  
    }
     theta ~ dbeta(1, 1) # previa no-informativa para la tasa de mortalidad
  }
")

Como de costumbre, definimos los datos que le vamos a pasar a JAGS, una función para generar valores iniciales para las cadenas Markovianas, los parámetros que queremos guardar, cuántas iteraciones vamos a correr, etc.

data   <- list(y = y, 
               m = m,
               n = n)
inits  <- function() list(theta = runif(1, 0, 1))
params <- c("theta", "pred")

ni <- 1000
nc <- 3
nt <- 1
nb <- 500

Ahora llamamos a JAGS

library(jagsUI)
cp.sim <- jags(data, 
               inits, 
               params, 
               model.file = "completepooling.bug",
               n.chains = nc, 
               n.iter = ni, 
               n.burnin = nb, 
               n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 1
##    Total graph size: 24
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(cp.sim)
## JAGS output for model 'completepooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior. 
## MCMC ran for 0 minutes at time 2021-12-07 19:40:14.
## 
##            mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## theta     0.135 0.023  0.095  0.134  0.183    FALSE 1 1.003  1323
## deviance 47.417 1.425 46.403 46.853 51.326    FALSE 1 1.013   491
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 1 and DIC = 48.43 
## DIC is an estimate of expected predictive error (lower is better).

Preguntas:

1- ¿Cuál es el valor esperado de la posterior de la tasa de mortalidad bajo el supuesto de “complete pooling”?

2- Comparar el valor esperado de la tasa de mortalidad con el valor promedio de las tasas de mortalidad usadas para simular los datos.

(2) No-pooling

De nuevo usando previas no informativas:

cat(file="nopooling.bug",
"
  model   {
    for( i in 1 : n ) {
        y[i] ~ dbin(theta[i], m[i]) # Likelihood
        theta[i] ~ dbeta(1, 1) # previas no-informativas para la tasa de mortalidad
        }
}
")
library(jagsUI)
data   <- list(y = y, 
               m = m,
               n = n)
inits  <- function() list(theta = runif(n, 0, 1))
params <- c("theta")

ni <- 1000
nc <- 3
nt <- 1
nb <- 500
np.sim <- jags(data, 
               inits,
               params, 
               model.file = "nopooling.bug",
               n.chains = nc, 
               n.iter = ni, 
               n.burnin = nb, 
               n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 10
##    Total graph size: 51
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(np.sim)
## JAGS output for model 'nopooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior. 
## MCMC ran for 0 minutes at time 2021-12-07 19:40:14.
## 
##             mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## theta[1]   0.032 0.032  0.001  0.022  0.115    FALSE 1 1.006   813
## theta[2]   0.099 0.054  0.021  0.089  0.228    FALSE 1 1.000  1500
## theta[3]   0.046 0.044  0.001  0.032  0.163    FALSE 1 1.002  1061
## theta[4]   0.314 0.080  0.169  0.312  0.479    FALSE 1 1.000  1500
## theta[5]   0.248 0.076  0.119  0.245  0.411    FALSE 1 1.004   473
## theta[6]   0.213 0.074  0.088  0.206  0.382    FALSE 1 0.999  1500
## theta[7]   0.195 0.070  0.079  0.188  0.355    FALSE 1 1.001  1500
## theta[8]   0.148 0.128  0.005  0.110  0.468    FALSE 1 1.000  1500
## theta[9]   0.205 0.166  0.008  0.166  0.621    FALSE 1 1.001  1500
## theta[10]  0.103 0.055  0.022  0.094  0.228    FALSE 1 1.001  1500
## deviance  31.919 5.126 24.251 31.250 43.971    FALSE 1 0.999  1500
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 13.2 and DIC = 45.077 
## DIC is an estimate of expected predictive error (lower is better).

Veamos un gráfico para comprar los dos análisis (complete-pooling y no-pooling)

op <- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")

plot(density(cp.sim$sims.list$theta), 
     type = "l", 
     lwd = 3, 
     ylab = "Density", 
     xlab = "Tasa de mortalidad", 
     xlim = c(0, 0.5), 
     ylim = c(0, 30), main = "")

for (i in 1: 10){
 lines(density(np.sim$sims.list$theta[,i ]), 
       col = "gray", 
       lwd = 3) 
}

par(op)

Preguntas

1- ¿Qué se puede decir de las distintas posteriores?

2- ¿Alguno de estos dos enfoques les parece más adecuado? ¿Cuáles serían los pro y contras de cada uno?

(3) Partial pooling

Finalmente, podemos hacer un análisis jerárquico o “multi-nivel” (partial pooling), reconociendo explícitamente que cada sub-población es potencialmente distinta de las otras, pero que todas son parte de la misma meta-población, la misma región, la misma especie, etc. Esto último implica que las observaciones en las sub-poblaciones no son del todo independientes.

cat(file = "hier.bug",
    "
    model{
      #Likelihood
      for(i in 1: n ) {
            y[i] ~ dbin(theta[i], m[i]) 
            theta[i] ~ dbeta(a, b) 
      }
    #Previas
        a ~ dnorm(0,0.01)T(0,)
        b ~ dnorm(0,0.01)T(0,)
    mean_pobl <- a/(a + b) 
}
")

Noten que en la última línea del modelo BUGS, registramos el valor esperado de la tasa de mortalidad para la meta-población. Este es un ejemplo de cómo podemos calcular valores de interés (en la jerga estadística serían “derived quantities”) dentro del modelo BUGS. En este caso, usamos esta fórmula que corresponde a la media de la distribución Beta (ver Bestiario de distribuciones). Entonces, vamos a obtener una posterior para esta “derived quantity”.

library(jagsUI)
data   <- list(y = y, 
               m = m,
               n = n)
inits  <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
params <- c("a", "b", "theta", "mean_pobl")
ni <- 5000
nc <- 3
nt <- 4
nb <- 2500

hier.sim <- jags(data, 
                 inits, 
                 params, 
                 model.file = "hier.bug", 
                 n.chains = nc, 
                 n.iter = ni, 
                 n.burnin = nb, 
                 n.thin = nt)
## 
## Processing function input....... 
## 
## Done. 
##  
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 10
##    Unobserved stochastic nodes: 12
##    Total graph size: 41
## 
## Initializing model
## 
## Adaptive phase..... 
## Adaptive phase complete 
##  
## 
##  Burn-in phase, 2500 iterations x 3 chains 
##  
## 
## Sampling from joint posterior, 2500 iterations x 3 chains 
##  
## 
## Calculating statistics....... 
## 
## Done.
print(hier.sim)
## JAGS output for model 'hier.bug', generated by jagsUI.
## Estimates based on 3 chains of 5000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 2500 iterations and thin rate = 4,
## yielding 1875 total samples from the joint posterior. 
## MCMC ran for 0.024 minutes at time 2021-12-07 19:40:15.
## 
##             mean    sd   2.5%    50%  97.5% overlap0 f  Rhat n.eff
## a          1.731 0.945  0.451  1.537  4.006    FALSE 1 0.999  1875
## b         11.536 5.505  3.192 10.625 24.401    FALSE 1 1.000  1875
## theta[1]   0.038 0.033  0.001  0.030  0.122    FALSE 1 1.001  1875
## theta[2]   0.090 0.046  0.021  0.082  0.199    FALSE 1 1.000  1875
## theta[3]   0.050 0.042  0.001  0.040  0.151    FALSE 1 1.000  1875
## theta[4]   0.251 0.068  0.131  0.247  0.395    FALSE 1 1.000  1875
## theta[5]   0.202 0.062  0.097  0.195  0.338    FALSE 1 1.000  1875
## theta[6]   0.171 0.059  0.073  0.166  0.297    FALSE 1 1.000  1875
## theta[7]   0.161 0.059  0.060  0.157  0.287    FALSE 1 1.001  1733
## theta[8]   0.092 0.074  0.002  0.075  0.273    FALSE 1 1.000  1875
## theta[9]   0.105 0.083  0.003  0.088  0.317    FALSE 1 1.000  1875
## theta[10]  0.093 0.048  0.023  0.087  0.201    FALSE 1 1.003   823
## mean_pobl  0.132 0.039  0.067  0.129  0.214    FALSE 1 1.001  1875
## deviance  30.524 4.676 23.194 30.012 41.082    FALSE 1 1.001  1875
## 
## Successful convergence based on Rhat values (all < 1.1). 
## Rhat is the potential scale reduction factor (at convergence, Rhat=1). 
## For each parameter, n.eff is a crude measure of effective sample size. 
## 
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
## 
## DIC info: (pD = var(deviance)/2) 
## pD = 10.9 and DIC = 41.465 
## DIC is an estimate of expected predictive error (lower is better).

Ahora podemos ver la posterior de la tasa de mortalidad a nivel de meta-población y compararla con el valor verdadero que usamos en las simulaciones (en rojo).

op <- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), 
     main = "", 
     xlab = "Tasa de  mortalidad", 
     lwd = 3)
abline(v = a_s/(a_s + b_s), 
       col = 2, 
       lwd = 3, 
       lty = 2)

par(op)

Finalmente podemos comparar gráficamente los tres enfoques:

op <- par(mfrow = c(1, 2), cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")

plot(density(cp.sim$sims.list$theta), 
     type = "l", 
     lwd = 3, 
     ylab = "Density", 
     xlab = "Tasa de mortalidad", 
     xlim = c(0, 1), 
     ylim = c(0, 25), 
     main = "") #Complete pooling

for (i in 1: 10){
 lines(density(np.sim$sims.list$theta[, i]), 
       col = "gray", 
       lwd = 2)
} #No pooling

curve(dbeta(x, hier.sim$mean$a, hier.sim$mean$b), 
      lwd = 3, 
      col = 2, 
      xlab = "Tasa de mortalidad", 
      ylab = "", 
      ylim = c(0, 25)) #Tasa de mortalidad meta-poblacional estimada con partial pooling

for (i in 1: 10){
  lines(density(hier.sim$sims.list$theta[, i]), 
        col = "blue", 
        lwd = 2, 
        main = "") #Tasa de mortalidad de cada subpoblación estimadas con partial pooling
 } 

par(op)  

En el gráfico la linea roja representa la posterior para la tasa de mortalidad de la meta-población y cada linea azul representa la posterior para la tasa de mortalidad de cada sub-población.

¿Cómo comparan las posteriores para las distintas sub-poblaciones con no-pooling y partial pooling?


predicción para una nueva subpoblación

predh <- numeric(1500)
for(i in 1:1500){
  theta <- rbeta(1, hier.sim$sims.list$a[i], hier.sim$sims.list$b[i])
  predh[i] <- rbinom(1, size = 30, prob = theta)
}

hist(predh,100)


Opcional: Análisis con Stan y brms

Veamos cómo hacer el análisis jerárquico con Stan

cat(file = "hier.stan",
    "
data { 
  int<lower=1> n;
  int<lower=0> y[n];
  int<lower=0> m[n];
}

parameters{
  real<lower=0> a;
  real<lower=0> b;
  real<lower=0,upper=1> theta[n];
}

model{
  
  a ~ normal(0,10);
  b ~ normal(0,10);
  theta ~ beta(a,b);
  for(i in 1: n ) {
    y[i] ~ binomial(m[i], theta[i]);    
    }
}

generated quantities {
  real mean_pobl;
  mean_pobl = a/(a + b);
}
")

Ahora llamamos a Stan

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

hier_dat <- list(n = n,
               m = m,
               y = y)

inits  <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))

fit <- stan(file = 'hier.stan', 
            data = hier_dat, 
            init = inits,
            iter = 1000, 
            thin = 1, 
            chains = 3)

Podemos comparar con el ajuste de JAGS y con el valore verdadero a nivel poblacional

pos <- as.data.frame(fit)

op <- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de  mortalidad", lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)

par(op)

Ahora veamos como hacer el análisis jerárquico con brms. La gran diferencia es que con brms formulamos en modelo es escala logit

library(brms)

priors = c(prior(normal(0, 1), 
                 class = Intercept),
                prior(cauchy(0, 1), 
                      class = sd))

fit_brms = brm(y | trials(m) ~ 1 + (1|site), 
          data = data.frame(y = y, m=m, site = 1:n), 
          family = binomial(),
          prior = priors)
## Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
## clang -mmacosx-version-min=10.13 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG   -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/Rcpp/include/"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/unsupported"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/src/"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppParallel/include/"  -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/rstan/include" -DEIGEN_NO_DEBUG  -DBOOST_DISABLE_ASSERTS  -DBOOST_PENDING_INTEGER_LOG2_HPP  -DSTAN_THREADS  -DBOOST_NO_AUTO_PTR  -include '/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp'  -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1   -I/usr/local/include   -fPIC  -Wall -g -O2  -c foo.c -o foo.o
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Dense:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Core:88:
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:1: error: unknown type name 'namespace'
## namespace Eigen {
## ^
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:16: error: expected ';' after top level declarator
## namespace Eigen {
##                ^
##                ;
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Dense:1:
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Core:96:10: fatal error: 'complex' file not found
## #include <complex>
##          ^~~~~~~~~
## 3 errors generated.
## make: *** [foo.o] Error 1

Para ver las estimaciones para cada sub-población podemos hacer

plogis(coef(fit_brms, robust = TRUE)$site[, , ])
##      Estimate Est.Error        Q2.5     Q97.5
## 1  0.04338521 0.7086898 0.003211218 0.1426554
## 2  0.08691117 0.6351344 0.023128348 0.1949758
## 3  0.05230336 0.7041928 0.003740329 0.1624295
## 4  0.24991188 0.6051472 0.125363113 0.4287400
## 5  0.20029325 0.6039299 0.098586273 0.3545137
## 6  0.16755063 0.6118494 0.071442585 0.3193849
## 7  0.15422768 0.6107950 0.065622956 0.2948793
## 8  0.08756556 0.7060376 0.006536834 0.2976353
## 9  0.10091701 0.6978283 0.009495886 0.3515520
## 10 0.08860125 0.6391461 0.023023138 0.2067523

Podemos graficar la estimación a nivel poblacional y compararla con las estimaciones de JAGS y Stan

pos <- as.data.frame(fit)
pos_brms <- posterior_samples(fit_brms)

op <- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de  mortalidad", lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
lines(density(plogis(pos_brms$b_Intercept)), lty = 3, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)
Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.

Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.

par(op)

Juan Manuel Morales. 6 de Septiembre del 2015. última actualización: 2021-12-07