En este práctico vamos a escribir y ajustar modelos jerárquicos en JAGS
. Para entender de qué se trata esto de las estructuras jerárquicas vamos a empezar con datos simulados.
JAGS
Suponiendo que estamos estudiando \(10\) sub-poblaciones de Chinchillones en la estepa Patagónica, y queremos saber cuál es la probabilidad de que estos bichos sobrevivan al invierno. Vamos a simular datos provenientes de individuos de cada una de las sub-poblaciones. Además, vamos a considerar que todas las sub-poblaciones son parte de una misma “meta-población”. Esto hace que esperemos que los parámetros de supervivencia sean parecidos entre las sub-poblaciones, pero que sin embargo tengan cierto grado de diferencias. Para modelar la variabilidad entre sub-poblaciones podemos usar una distribución Beta, por ejemplo con parámetros \(a_s = 2\) y \(b_s = 10\). Si repasamos el Bestiario de distribuciones vemos que esta combinación de parámetros resulta en un valor esperado de \(\frac{a_s}{(a_s+b_s)}\) = 0.167 y una varianza de \(\frac{(a_s \times b_s)}{(a_s + b_s)^2 \times (a_s + b_s + 1)}\) = 0.011. Entonces, cada sub-población tendrá su propia tasa de supervivencia. Por otro lado, podemos imaginar que logramos seguir un número variable de individuos en cada población para ver si sobreviven o no. Para simular los datos hacemos:
set.seed(1234)
<- 10 # sub-poblaciones
n <- c(30, 28, 20, 30, 30, 26, 29, 5, 3, 27) # número de individuos muestraeados por sub-poblacion
m
<- 2
a_s <- 10
b_s
<- rbeta(n, a_s, b_s) # generamos n tasas de mortalidad, una por cada sub-poblacion
theta <- rbinom(n, size = m, prob = theta) # simulamos número de muertes por grupo
y
<- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op
plot(table(y), xlab = "Muertes por sub-población", ylab = "Frecuencia")
par(op)
Estos datos nos muestran el número de muertes por sub-población. Hay cuatro sub-poblaciones que no tuvieron muertes, dos que tuvieron dos muertes, etc. Además, tenemos que recordar que cuando simulamos estos datos cada sub-población tenia su propia tasa de mortalidad.
Vamos a modelar estos datos suponiendo que:
todas las sub-poblaciones tienen la misma mortalidad (complete pooling), es decir que ignoramos variabilidad entre sub-poblaciones
cada sub-población tiene una tasa de mortalidad independiente de las otras (no pooling), es decir que ignoramos que todas las sub-poblaciones son parte de una misma “meta-población”, la misma región, la misma especie…
cada sub-población es diferente pero parecida a las otras (partial pooling)
Vamos por partes:
En este caso el modelo de datos (cuántos individuos mueren) es una Binomial:
cat(file = "completepooling.bug",
"
model{
for( i in 1 : n ) {
y[i] ~ dbin(theta, m[i]) #Likelihood
}
theta ~ dbeta(1, 1) # previa no-informativa para la tasa de mortalidad
}
")
Como de costumbre, definimos los datos que le vamos a pasar a JAGS
, una función para generar valores iniciales para las cadenas Markovianas, los parámetros que queremos guardar, cuántas iteraciones vamos a correr, etc.
<- list(y = y,
data m = m,
n = n)
<- function() list(theta = runif(1, 0, 1))
inits <- c("theta", "pred")
params
<- 1000
ni <- 3
nc <- 1
nt <- 500 nb
Ahora llamamos a JAGS
library(jagsUI)
<- jags(data,
cp.sim
inits,
params, model.file = "completepooling.bug",
n.chains = nc,
n.iter = ni,
n.burnin = nb,
n.thin = nt)
##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 1
## Total graph size: 24
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 500 iterations x 3 chains
##
##
## Sampling from joint posterior, 500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(cp.sim)
## JAGS output for model 'completepooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior.
## MCMC ran for 0 minutes at time 2021-12-07 19:40:14.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## theta 0.135 0.023 0.095 0.134 0.183 FALSE 1 1.003 1323
## deviance 47.417 1.425 46.403 46.853 51.326 FALSE 1 1.013 491
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 1 and DIC = 48.43
## DIC is an estimate of expected predictive error (lower is better).
1- ¿Cuál es el valor esperado de la posterior de la tasa de mortalidad bajo el supuesto de “complete pooling”?
2- Comparar el valor esperado de la tasa de mortalidad con el valor promedio de las tasas de mortalidad usadas para simular los datos.
De nuevo usando previas no informativas:
cat(file="nopooling.bug",
"
model {
for( i in 1 : n ) {
y[i] ~ dbin(theta[i], m[i]) # Likelihood
theta[i] ~ dbeta(1, 1) # previas no-informativas para la tasa de mortalidad
}
}
")
library(jagsUI)
<- list(y = y,
data m = m,
n = n)
<- function() list(theta = runif(n, 0, 1))
inits <- c("theta")
params
<- 1000
ni <- 3
nc <- 1
nt <- 500
nb <- jags(data,
np.sim
inits,
params, model.file = "nopooling.bug",
n.chains = nc,
n.iter = ni,
n.burnin = nb,
n.thin = nt)
##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 10
## Total graph size: 51
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 500 iterations x 3 chains
##
##
## Sampling from joint posterior, 500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(np.sim)
## JAGS output for model 'nopooling.bug', generated by jagsUI.
## Estimates based on 3 chains of 1000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 500 iterations and thin rate = 1,
## yielding 1500 total samples from the joint posterior.
## MCMC ran for 0 minutes at time 2021-12-07 19:40:14.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## theta[1] 0.032 0.032 0.001 0.022 0.115 FALSE 1 1.006 813
## theta[2] 0.099 0.054 0.021 0.089 0.228 FALSE 1 1.000 1500
## theta[3] 0.046 0.044 0.001 0.032 0.163 FALSE 1 1.002 1061
## theta[4] 0.314 0.080 0.169 0.312 0.479 FALSE 1 1.000 1500
## theta[5] 0.248 0.076 0.119 0.245 0.411 FALSE 1 1.004 473
## theta[6] 0.213 0.074 0.088 0.206 0.382 FALSE 1 0.999 1500
## theta[7] 0.195 0.070 0.079 0.188 0.355 FALSE 1 1.001 1500
## theta[8] 0.148 0.128 0.005 0.110 0.468 FALSE 1 1.000 1500
## theta[9] 0.205 0.166 0.008 0.166 0.621 FALSE 1 1.001 1500
## theta[10] 0.103 0.055 0.022 0.094 0.228 FALSE 1 1.001 1500
## deviance 31.919 5.126 24.251 31.250 43.971 FALSE 1 0.999 1500
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 13.2 and DIC = 45.077
## DIC is an estimate of expected predictive error (lower is better).
Veamos un gráfico para comprar los dos análisis (complete-pooling y no-pooling)
<- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op
plot(density(cp.sim$sims.list$theta),
type = "l",
lwd = 3,
ylab = "Density",
xlab = "Tasa de mortalidad",
xlim = c(0, 0.5),
ylim = c(0, 30), main = "")
for (i in 1: 10){
lines(density(np.sim$sims.list$theta[,i ]),
col = "gray",
lwd = 3)
}
par(op)
1- ¿Qué se puede decir de las distintas posteriores?
2- ¿Alguno de estos dos enfoques les parece más adecuado? ¿Cuáles serían los pro y contras de cada uno?
Finalmente, podemos hacer un análisis jerárquico o “multi-nivel” (partial pooling), reconociendo explícitamente que cada sub-población es potencialmente distinta de las otras, pero que todas son parte de la misma meta-población, la misma región, la misma especie, etc. Esto último implica que las observaciones en las sub-poblaciones no son del todo independientes.
cat(file = "hier.bug",
"
model{
#Likelihood
for(i in 1: n ) {
y[i] ~ dbin(theta[i], m[i])
theta[i] ~ dbeta(a, b)
}
#Previas
a ~ dnorm(0,0.01)T(0,)
b ~ dnorm(0,0.01)T(0,)
mean_pobl <- a/(a + b)
}
")
Noten que en la última línea del modelo BUGS
, registramos el valor esperado de la tasa de mortalidad para la meta-población. Este es un ejemplo de cómo podemos calcular valores de interés (en la jerga estadística serían “derived quantities”) dentro del modelo BUGS
. En este caso, usamos esta fórmula que corresponde a la media de la distribución Beta (ver Bestiario de distribuciones). Entonces, vamos a obtener una posterior para esta “derived quantity”.
library(jagsUI)
<- list(y = y,
data m = m,
n = n)
<- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
inits <- c("a", "b", "theta", "mean_pobl")
params <- 5000
ni <- 3
nc <- 4
nt <- 2500
nb
<- jags(data,
hier.sim
inits,
params, model.file = "hier.bug",
n.chains = nc,
n.iter = ni,
n.burnin = nb,
n.thin = nt)
##
## Processing function input.......
##
## Done.
##
## Compiling model graph
## Resolving undeclared variables
## Allocating nodes
## Graph information:
## Observed stochastic nodes: 10
## Unobserved stochastic nodes: 12
## Total graph size: 41
##
## Initializing model
##
## Adaptive phase.....
## Adaptive phase complete
##
##
## Burn-in phase, 2500 iterations x 3 chains
##
##
## Sampling from joint posterior, 2500 iterations x 3 chains
##
##
## Calculating statistics.......
##
## Done.
print(hier.sim)
## JAGS output for model 'hier.bug', generated by jagsUI.
## Estimates based on 3 chains of 5000 iterations,
## adaptation = 100 iterations (sufficient),
## burn-in = 2500 iterations and thin rate = 4,
## yielding 1875 total samples from the joint posterior.
## MCMC ran for 0.024 minutes at time 2021-12-07 19:40:15.
##
## mean sd 2.5% 50% 97.5% overlap0 f Rhat n.eff
## a 1.731 0.945 0.451 1.537 4.006 FALSE 1 0.999 1875
## b 11.536 5.505 3.192 10.625 24.401 FALSE 1 1.000 1875
## theta[1] 0.038 0.033 0.001 0.030 0.122 FALSE 1 1.001 1875
## theta[2] 0.090 0.046 0.021 0.082 0.199 FALSE 1 1.000 1875
## theta[3] 0.050 0.042 0.001 0.040 0.151 FALSE 1 1.000 1875
## theta[4] 0.251 0.068 0.131 0.247 0.395 FALSE 1 1.000 1875
## theta[5] 0.202 0.062 0.097 0.195 0.338 FALSE 1 1.000 1875
## theta[6] 0.171 0.059 0.073 0.166 0.297 FALSE 1 1.000 1875
## theta[7] 0.161 0.059 0.060 0.157 0.287 FALSE 1 1.001 1733
## theta[8] 0.092 0.074 0.002 0.075 0.273 FALSE 1 1.000 1875
## theta[9] 0.105 0.083 0.003 0.088 0.317 FALSE 1 1.000 1875
## theta[10] 0.093 0.048 0.023 0.087 0.201 FALSE 1 1.003 823
## mean_pobl 0.132 0.039 0.067 0.129 0.214 FALSE 1 1.001 1875
## deviance 30.524 4.676 23.194 30.012 41.082 FALSE 1 1.001 1875
##
## Successful convergence based on Rhat values (all < 1.1).
## Rhat is the potential scale reduction factor (at convergence, Rhat=1).
## For each parameter, n.eff is a crude measure of effective sample size.
##
## overlap0 checks if 0 falls in the parameter's 95% credible interval.
## f is the proportion of the posterior with the same sign as the mean;
## i.e., our confidence that the parameter is positive or negative.
##
## DIC info: (pD = var(deviance)/2)
## pD = 10.9 and DIC = 41.465
## DIC is an estimate of expected predictive error (lower is better).
Ahora podemos ver la posterior de la tasa de mortalidad a nivel de meta-población y compararla con el valor verdadero que usamos en las simulaciones (en rojo).
<- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op plot(density(hier.sim$sims.list$mean_pobl),
main = "",
xlab = "Tasa de mortalidad",
lwd = 3)
abline(v = a_s/(a_s + b_s),
col = 2,
lwd = 3,
lty = 2)
par(op)
Finalmente podemos comparar gráficamente los tres enfoques:
<- par(mfrow = c(1, 2), cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op
plot(density(cp.sim$sims.list$theta),
type = "l",
lwd = 3,
ylab = "Density",
xlab = "Tasa de mortalidad",
xlim = c(0, 1),
ylim = c(0, 25),
main = "") #Complete pooling
for (i in 1: 10){
lines(density(np.sim$sims.list$theta[, i]),
col = "gray",
lwd = 2)
#No pooling
}
curve(dbeta(x, hier.sim$mean$a, hier.sim$mean$b),
lwd = 3,
col = 2,
xlab = "Tasa de mortalidad",
ylab = "",
ylim = c(0, 25)) #Tasa de mortalidad meta-poblacional estimada con partial pooling
for (i in 1: 10){
lines(density(hier.sim$sims.list$theta[, i]),
col = "blue",
lwd = 2,
main = "") #Tasa de mortalidad de cada subpoblación estimadas con partial pooling
}
par(op)
En el gráfico la linea roja representa la posterior para la tasa de mortalidad de la meta-población y cada linea azul representa la posterior para la tasa de mortalidad de cada sub-población.
¿Cómo comparan las posteriores para las distintas sub-poblaciones con no-pooling y partial pooling?
predicción para una nueva subpoblación
<- numeric(1500)
predh for(i in 1:1500){
<- rbeta(1, hier.sim$sims.list$a[i], hier.sim$sims.list$b[i])
theta <- rbinom(1, size = 30, prob = theta)
predh[i]
}
hist(predh,100)
Stan
y brms
Veamos cómo hacer el análisis jerárquico con Stan
cat(file = "hier.stan",
"
data {
int<lower=1> n;
int<lower=0> y[n];
int<lower=0> m[n];
}
parameters{
real<lower=0> a;
real<lower=0> b;
real<lower=0,upper=1> theta[n];
}
model{
a ~ normal(0,10);
b ~ normal(0,10);
theta ~ beta(a,b);
for(i in 1: n ) {
y[i] ~ binomial(m[i], theta[i]);
}
}
generated quantities {
real mean_pobl;
mean_pobl = a/(a + b);
}
")
Ahora llamamos a Stan
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
<- list(n = n,
hier_dat m = m,
y = y)
<- function() list(a = runif(1, 1, 5), b = runif(1, 5, 20))
inits
<- stan(file = 'hier.stan',
fit data = hier_dat,
init = inits,
iter = 1000,
thin = 1,
chains = 3)
Podemos comparar con el ajuste de JAGS
y con el valore verdadero a nivel poblacional
<- as.data.frame(fit)
pos
<- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de mortalidad", lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)
par(op)
Ahora veamos como hacer el análisis jerárquico con brms
. La gran diferencia es que con brms
formulamos en modelo es escala logit
library(brms)
= c(prior(normal(0, 1),
priors class = Intercept),
prior(cauchy(0, 1),
class = sd))
= brm(y | trials(m) ~ 1 + (1|site),
fit_brms data = data.frame(y = y, m=m, site = 1:n),
family = binomial(),
prior = priors)
## Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
## clang -mmacosx-version-min=10.13 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/Rcpp/include/" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/unsupported" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/src/" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppParallel/include/" -I"/Library/Frameworks/R.framework/Versions/4.1/Resources/library/rstan/include" -DEIGEN_NO_DEBUG -DBOOST_DISABLE_ASSERTS -DBOOST_PENDING_INTEGER_LOG2_HPP -DSTAN_THREADS -DBOOST_NO_AUTO_PTR -include '/Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp' -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1 -I/usr/local/include -fPIC -Wall -g -O2 -c foo.c -o foo.o
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Dense:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Core:88:
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:1: error: unknown type name 'namespace'
## namespace Eigen {
## ^
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:16: error: expected ';' after top level declarator
## namespace Eigen {
## ^
## ;
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Dense:1:
## /Library/Frameworks/R.framework/Versions/4.1/Resources/library/RcppEigen/include/Eigen/Core:96:10: fatal error: 'complex' file not found
## #include <complex>
## ^~~~~~~~~
## 3 errors generated.
## make: *** [foo.o] Error 1
Para ver las estimaciones para cada sub-población podemos hacer
plogis(coef(fit_brms, robust = TRUE)$site[, , ])
## Estimate Est.Error Q2.5 Q97.5
## 1 0.04338521 0.7086898 0.003211218 0.1426554
## 2 0.08691117 0.6351344 0.023128348 0.1949758
## 3 0.05230336 0.7041928 0.003740329 0.1624295
## 4 0.24991188 0.6051472 0.125363113 0.4287400
## 5 0.20029325 0.6039299 0.098586273 0.3545137
## 6 0.16755063 0.6118494 0.071442585 0.3193849
## 7 0.15422768 0.6107950 0.065622956 0.2948793
## 8 0.08756556 0.7060376 0.006536834 0.2976353
## 9 0.10091701 0.6978283 0.009495886 0.3515520
## 10 0.08860125 0.6391461 0.023023138 0.2067523
Podemos graficar la estimación a nivel poblacional y compararla con las estimaciones de JAGS
y Stan
<- as.data.frame(fit)
pos <- posterior_samples(fit_brms)
pos_brms
<- par(cex.lab = 1.5 , font.lab = 2, cex.axis = 1.3, las = 1, bty = "n")
op plot(density(hier.sim$sims.list$mean_pobl), main = "", xlab = "Tasa de mortalidad", lwd = 3)
lines(density(pos$mean_pobl), lty = 2, lwd = 3)
lines(density(plogis(pos_brms$b_Intercept)), lty = 3, lwd = 3)
abline(v = a_s/(a_s + b_s), col = 2, lwd = 3, lty = 2)
Estimaciones de la posterior a nivel poblacional de la probabilidad de mortalidad. La línea sólida corresponde a JAGS, la barreada a Stan y la punteada a brms. En rojo se muestra el valor verdadero del parámetro.
par(op)
Juan Manuel Morales. 6 de Septiembre del 2015. última actualización: 2021-12-07