En este trabajo práctico vamos a aprender a escribir y ajustar modelos jerárquicos en JAGS. Para entender de qué se trata esto de las estructuras jerárquicas vamos a empezar con datos simulados.
JAGSSuponiendo que estamos estudiando \(10\) sub-poblaciones de Chinchillones en la estepa Patagónica, y queremos saber cuál es la probabilidad de que estos bichos sobrevivan al invierno. Vamos a simular datos provenientes de individuos de cada una de las sub-poblaciones. Además, vamos a considerar que todas las sub-poblaciones son parte de una misma “meta-población”. Esto hace que esperemos que los parámetros de supervivencia sean parecidos entre las sub-poblaciones, pero que sin embargo tengan cierto grado de diferencias. Para modelar la variabilidad entre sub-poblaciones podemos usar una distribución Beta, por ejemplo con parámetros \(a_s = 2\) y \(b_s = 10\). Si repasamos el Bestiario vemos que esta combinación de parámetros resulta en un valor esperado de \(\frac{a_s}{(a_s+b_s)}\) = 0.167 y una varianza de \(\frac{(a_s \times b_s)}{(a_s + b_s)^2 \times (a_s + b_s + 1)}\) = 0.011. Entonces, cada sub-población tendrá su propia tasa de supervivencia. Por otro lado, podemos imaginar que logramos seguir un número variable de individuos en cada población para ver si sobreviven o no. Para simular los datos hacemos:
set.seed(1234)
n <- 10 # sub-poblaciones
m <- c(30, 28, 20, 30, 30, 26, 29, 5, 3, 27) # número de individuos muestraeados por sub-poblacion
a_s <- 2
b_s <- 10
theta <- rbeta(n, a_s, b_s) # generamos n tasas de mortalidad, una por cada sub-poblacion
y <- rbinom(n, size = m, prob = theta) # simulamos número de muertes por grupo
op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(table(y), xlab = "Muertes por sub-población", ylab = "Frecuencia")par(op)Estos datos nos muestran el número de muertes por sub-población. Hay cuatro sub-poblaciones que no tuvieron muertes, dos que tuvieron dos muertes, etc. Además, tenemos que recordar que cuando simulamos estos datos cada sub-población tenia su propia tasa de mortalidad.
Vamos a modelar estos datos suponiendo que:
todas las sub-poblaciones tienen la misma mortalidad (complete pooling), es decir que ignoramos que cada sub-población tiene su propia tasa de mortalidad
cada sub-población tiene una tasa de mortalidad independiente de las otras (no pooling), es decir que ignoramos que todas las sub-poblaciones son parte de una misma “meta-población”.
cada sub-población es diferente pero parecida a las otras (partial pooling)
Vamos por partes:
En este caso el modelo de datos (cuántos individuos mueren) es una Binomial:
cat(file = "completepooling.bug", "
model{
for( i in 1 : n ) {
y[i] ~ dbin(theta, m[i]) #Likelihood\t
\t\t}
\t theta ~ dbeta(1, 1) # previa no-informativa para la tasa de mortalidad
}
")Como de costumbre, definimos los datos que le vamos a pasar a JAGS, una función para generar valores iniciales para las cadenas Markovianas, los parámetros que queremos guardar, cuántas iteraciones vamos a correr, etc.
data <- list("y", "m", "n")
inits <- function() list(theta = runif(1, 0, 1))
params <- c("theta")
ni <- 1000
nc <- 3
nt <- 1
nb <- 500Ahora llamamos a JAGS
library(jagsUI)
cp.sim <- jags(data, inits, params, model.file = "completepooling.bug", n.chains = nc,
n.iter = ni, n.burnin = nb, n.thin = nt)##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 1
## Total graph size: 24
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 500 iterations x 3 chains
##
##
## Sampling from joint posterior, 500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(cp.sim)## JAGS output for model 'completepooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior.
## MCMC ran for 0.004 minutes at time 2019-03-01 14:17:46.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## theta 0.135 0.023 0.093 0.134 0.183 FALSE 1 1.001 1133
## deviance 47.443 1.490 46.403 46.888 51.386 FALSE 1 1.009 1500
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 1.1 and DIC = 48.553
## DIC is an estimate of expected predictive error (lower is better).
1- ¿Cuál es el valor esperado de la posterior de la tasa de mortalidad bajo el supuesto de “complete pooling”?
2- Comparar el valor esperado de la tasa de mortalidad con el valor promedio de las tasas de mortalidad usadas para simular los datos.
De nuevo usando previas no informativas:
cat(file = "nopooling.bug", "
model {
for( i in 1 : n ) {
y[i] ~ dbin(theta[i], m[i]) # Likelihood
theta[i] ~ dbeta(1, 1) # previas no-informativas para la tasa de mortalidad
\t\t}
}
")library(jagsUI)
data <- list("y", "m", "n")
inits <- function() list(theta = runif(n, 0, 1))
params <- c("theta")
ni <- 1000
nc <- 3
nt <- 1
nb <- 500
np.sim <- jags(data, inits, params, model.file = "nopooling.bug", n.chains = nc,
n.iter = ni, n.burnin = nb, n.thin = nt)##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 10
## Total graph size: 51
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 500 iterations x 3 chains
##
##
## Sampling from joint posterior, 500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(np.sim)## JAGS output for model 'nopooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior.
## MCMC ran for 0.002 minutes at time 2019-03-01 14:17:47.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## theta[1] 0.031 0.030 0.001 0.022 0.110 FALSE 1 1.001 1500
## theta[2] 0.100 0.053 0.021 0.092 0.225 FALSE 1 1.000 1500
## theta[3] 0.043 0.042 0.001 0.030 0.158 FALSE 1 1.006 753
## theta[4] 0.313 0.080 0.164 0.311 0.479 FALSE 1 1.000 1500
## theta[5] 0.252 0.076 0.123 0.246 0.412 FALSE 1 0.999 1500
## theta[6] 0.214 0.077 0.088 0.206 0.385 FALSE 1 1.005 658
## theta[7] 0.194 0.070 0.073 0.187 0.344 FALSE 1 1.001 1095
## theta[8] 0.137 0.117 0.004 0.106 0.414 FALSE 1 1.003 934
## theta[9] 0.195 0.157 0.006 0.152 0.560 FALSE 1 0.999 1500
## theta[10] 0.103 0.056 0.023 0.094 0.232 FALSE 1 1.002 1216
## deviance 31.534 4.757 24.037 30.865 42.234 FALSE 1 1.000 1500
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 11.3 and DIC = 42.855
## DIC is an estimate of expected predictive error (lower is better).
Veamos un gráfico para comprar los dos análisis (complete-pooling y no-pooling)
op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(cp.sim$sims.list$theta), type = "l", lwd = 3, ylab = "Density",
xlab = "Tasa de mortalidad", xlim = c(0, 0.5), ylim = c(0, 30), main = "") #Posterior de theta con complete pooling
for (i in 1:10) {
lines(density(np.sim$sims.list$theta[, i]), type = "l", col = "gray", lwd = 3) #Posteriores de theta con no pooling
}par(op)1- ¿Qué se puede decir de las distintas posteriores?
2- ¿Alguno de estos dos enfoques les parece más adecuado? ¿Cuáles serían los pro y contras de cada uno?
Finalmente, podemos hacer un análisis jerárquico o “multi-nivel” (partial pooling), reconociendo explícitamente que cada sub-población es potencialmente distinta de las otras, pero que todas son parte de la misma meta-población. Esto último implica que las observaciones en las sub-poblaciones no son del todo independientes.
cat(file = "hier.bug", "
model{
#Likelihood
for(i in 1: n ) {
\t\t y[i] ~ dbin(theta[i], m[i])\t
\t\t theta[i] ~ dbeta(a, b)
}
#Previas
\t\ta ~ dnorm(0,0.01)T(0,)
\t\tb ~ dnorm(0,0.01)T(0,)
mean_pobl <- a/(a + b)
}
")Noten que en la última línea del modelo BUGS, registramos el valor esperado de la tasa de mortalidad para la meta-población. Este es un ejemplo de cómo podemos calcular valores de interés (en la jerga estadística serían “derived quantities”) dentro del modelo BUGS. En este caso, usamos esta fórmula que corresponde a la media de la distribución Beta (ver Bestiario de distribuciones). Entonces, vamos a obtener una posterior para esta “derived quantity”.
library(jagsUI)
data <- list("y", "m", "n")
inits <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
params <- c("a", "b", "theta", "mean_pobl")
ni <- 5000
nc <- 3
nt <- 4
nb <- 2500
hier.sim <- jags(data, inits, params, model.file = "hier.bug", n.chains = nc,
n.iter = ni, n.burnin = nb, n.thin = nt)##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 12
## Total graph size: 41
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 2500 iterations x 3 chains
##
##
## Sampling from joint posterior, 2500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(hier.sim)## JAGS output for model 'hier.bug', generated by jagsUI.
## Estimates based on 3 chains of 5000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 2500 iterations and thin rate = 4,
## yielding 1875 total samples from the joint posterior.
## MCMC ran for 0.028 minutes at time 2019-03-01 14:17:49.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## a 1.676 0.977 0.395 1.466 4.161 FALSE 1 1.014 313
## b 11.340 5.769 2.794 10.449 24.406 FALSE 1 1.010 315
## theta[1] 0.037 0.032 0.000 0.030 0.117 FALSE 1 1.005 568
## theta[2] 0.090 0.045 0.020 0.083 0.190 FALSE 1 1.001 1555
## theta[3] 0.048 0.042 0.001 0.038 0.155 FALSE 1 1.004 681
## theta[4] 0.254 0.069 0.138 0.249 0.405 FALSE 1 1.000 1875
## theta[5] 0.204 0.062 0.100 0.199 0.338 FALSE 1 1.002 788
## theta[6] 0.172 0.062 0.072 0.166 0.313 FALSE 1 1.001 1420
## theta[7] 0.160 0.058 0.063 0.154 0.289 FALSE 1 1.003 568
## theta[8] 0.087 0.072 0.001 0.071 0.275 FALSE 1 1.001 1875
## theta[9] 0.098 0.077 0.001 0.083 0.286 FALSE 1 1.000 1851
## theta[10] 0.091 0.048 0.020 0.084 0.203 FALSE 1 1.002 880
## mean_pobl 0.130 0.038 0.066 0.128 0.213 FALSE 1 1.001 1626
## deviance 30.207 4.513 22.803 29.582 40.411 FALSE 1 1.008 303
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 10.1 and DIC = 40.336
## DIC is an estimate of expected predictive error (lower is better).
Ahora podemos ver la posterior de la tasa de mortalidad a nivel de meta-población y compararla con el valor verdadero que usamos en las simulaciones (en rojo).
op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de mortalidad",
lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)par(op)Finalmente podemos comparar gráficamente los tres enfoques:
op <- par(mfrow = c(1, 2), cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1,
bty = "n")
plot(density(cp.sim$sims.list$theta), type = "l", lwd = 3, ylab = "Density",
xlab = "Tasa de mortalidad", xlim = c(0, 1), ylim = c(0, 25), main = "") #Complete pooling
for (i in 1:10) {
lines(density(np.sim$sims.list$theta[, i]), type = "l", col = "gray", lwd = 2)
} #No pooling
curve(dbeta(x, hier.sim$mean$a, hier.sim$mean$b), lwd = 3, col = 2, xlab = "Tasa de mortalidad",
ylab = "", ylim = c(0, 25)) #Tasa de mortalidad meta-poblacional estimada con partial pooling
for (i in 1:10) {
lines(density(hier.sim$sims.list$theta[, i]), col = "blue", lwd = 2, main = "") #Tasa de mortalidad de cada subpoblación estimadas con partial pooling
}par(op)En el gráfico la linea roja representa la posterior para la tasa de mortalidad de la meta-población y cada linea azul representa la posterior para la tasa de mortalidad de cada sub-población.
¿Cómo comparan las posteriores para las distintas sub-poblaciones con no-pooling y partial pooling?
Stan y brmsVeamos cómo hacer el análisis jerárquico con Stan
cat(file = "hier.stan", "
data {
int<lower=1> n;
int<lower=0> y[n];
int<lower=0> m[n];
}
parameters{
real<lower=0> a;
real<lower=0> b;
real<lower=0,upper=1> theta[n];
}
model{
a ~ normal(0,10);
b ~ normal(0,10);
theta ~ beta(a,b);
for(i in 1: n ) {
y[i] ~ binomial(m[i], theta[i]);\t
\t}
}
generated quantities {
real mean_pobl;
mean_pobl = a/(a + b);
}
")Ahora llamamos a Stan
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
hier_dat <- list(n = n, m = m, y = y)
inits <- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
fit <- stan(file = "hier.stan", data = hier_dat, init = inits, iter = 1000,
thin = 1, chains = 3)Podemos comparar con el ajuste de JAGS y con el valore verdadero a nivel poblacional
pos <- as.data.frame(fit)
op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de mortalidad",
lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)par(op)Ahora veamos como hacer el análisis jerárquico con brms. La gran diferencia es que con brms formulamos en modelo es escala logit
library(brms)
priors = c(prior(normal(0, 1), class = Intercept), prior(cauchy(0, 1), class = sd))
fit_brms = brm(y | trials(m) ~ 1 + (1 | site), data = data.frame(y = y, m = m,
site = 1:n), family = binomial(), prior = priors)Para ver las estimaciones para cada sub-población podemos hacer
plogis(coef(fit_brms, robust = TRUE)$site[, , ])## Estimate Est.Error Q2.5 Q97.5
## 1 0.04435602 0.7048795 0.002918079 0.1425000
## 2 0.08897024 0.6395497 0.023207112 0.1964227
## 3 0.05440197 0.7121480 0.003329889 0.1667172
## 4 0.25140673 0.6027072 0.126697361 0.4188041
## 5 0.19888246 0.6040404 0.100137723 0.3507151
## 6 0.16407096 0.6147617 0.067465115 0.3218250
## 7 0.15406484 0.6088120 0.065749125 0.2966351
## 8 0.08784399 0.7031740 0.006358927 0.2943714
## 9 0.09791512 0.7095773 0.005933390 0.3689950
## 10 0.09060340 0.6352696 0.024233502 0.2002980
Podemos graficar la estimación a nivel poblacional y compararla con las estimaciones de JAGS y Stan
pos <- as.data.frame(fit)
pos_brms <- posterior_samples(fit_brms)
op <- par(cex.lab = 1.5, font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de mortalidad",
lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
lines(density(plogis(pos_brms$b_Intercept)), lty = 3, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)Figure 1: Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.
par(op)Juan Manuel Morales. 6 de Septiembre del 2015. última actualización: 2019-03-01