I commented out the function header so we can interact with the objects inside.
In this version, I get 10000 draws of the intercept, age beta, tau and uncertainty around tau (later I’ll remove uncertainty around tau for comparison).
# get_draws_mixed_model <- function(
# n_draws, coefs, vcov, tau2, se_tau2) {
dev <- TRUE
if (dev) {
n_draws <- 10000
coefs <- c(intercept = -0.4238, age_mid = 0.0301) # must be named vector
vcov <- rbind(c(0.0264985, -0.0003751), c(-0.0003751, 0.0000060))
tau2 <- 0.0785
se_tau2 <- 0.0422
}
require(MASS)
## Loading required package: MASS
# take draws from the variance covariance matrix to capture intercept and age effects
draws_fe <- mvrnorm(n = n_draws, mu = coefs, Sigma = vcov)
head(draws_fe) # look at the first six rows (there are 10000)
## intercept age_mid
## [1,] -0.2987436 0.02813097
## [2,] -0.1304666 0.02666264
## [3,] -0.1977718 0.02775403
## [4,] -0.4962770 0.03246446
## [5,] -0.3287720 0.03073573
## [6,] -0.3495126 0.02807102
# in this case, se_tau2 has the value of 0.0422, so the 'else' statement is executed
if (is.na(se_tau2)) {
draws_tau2 <- rep(tau2, times = n_draws)
} else {
draws_tau2 <- rnorm(n = n_draws, mean = tau2, sd = se_tau2)
}
str(draws_tau2) # 'draws_tau2' is a n=10000 vector of tau-squared values
## num [1:10000] 0.1238 0.1114 0.0641 0.1167 0.0715 ...
hist(draws_tau2) # normally distributed around the value of tau-squared, 0.0785
draws_tau_tmp <- suppressWarnings({ sqrt(draws_tau2) })
str(draws_tau_tmp) # 'draws_tau_tmp' is a n=10000 vector of tau values, with some NaN values
## num [1:10000] 0.352 0.334 0.253 0.342 0.267 ...
table(is.nan(draws_tau_tmp))
##
## FALSE TRUE
## 9687 313
# convert NaN to zero
draws_tau_tmp2 <- ifelse(is.nan(draws_tau_tmp), 0, draws_tau_tmp)
str(draws_tau_tmp2)
## num [1:10000] 0.352 0.334 0.253 0.342 0.267 ...
table(is.nan(draws_tau_tmp2))
##
## FALSE
## 10000
# get final draws of tau
draws_tau <- rnorm(n = n_draws, mean = 0, sd = draws_tau_tmp2)
str(draws_tau)
## num [1:10000] -0.4314 -0.0439 0.3838 -0.4453 0.1489 ...
# the first draw used the first element of 'draws_tau_tmp2', the second used the second, etc.
# here's a check that the function works element-wise like I think it does
rnorm(n = 5, mean = 0, sd = c(1, 999, 1, 999, 1))
## [1] 0.2245817 -1083.0556283 -2.1399444 853.1916769 0.4104656
# the final output
out <- as.data.frame(cbind(draws_fe, draws_tau))
str(out)
## 'data.frame': 10000 obs. of 3 variables:
## $ intercept: num -0.299 -0.13 -0.198 -0.496 -0.329 ...
## $ age_mid : num 0.0281 0.0267 0.0278 0.0325 0.0307 ...
## $ draws_tau: num -0.4314 -0.0439 0.3838 -0.4453 0.1489 ...
# }
# I add the levels of 'x_age' for which I want predictions
# I'll aggregate 5000 draws for each age in this example
out$x_age <- rep(c(25, 100), each = 5000)
str(out)
## 'data.frame': 10000 obs. of 4 variables:
## $ intercept: num -0.299 -0.13 -0.198 -0.496 -0.329 ...
## $ age_mid : num 0.0281 0.0267 0.0278 0.0325 0.0307 ...
## $ draws_tau: num -0.4314 -0.0439 0.3838 -0.4453 0.1489 ...
## $ x_age : num 25 25 25 25 25 25 25 25 25 25 ...
require(dplyr)
## Loading required package: dplyr
## Warning: package 'dplyr' was built under R version 3.4.4
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:MASS':
##
## select
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
# generate the predictions
out2 <- out %>%
mutate(pred = intercept + x_age*age_mid + draws_tau)
## Warning: package 'bindrcpp' was built under R version 3.4.4
str(out2)
## 'data.frame': 10000 obs. of 5 variables:
## $ intercept: num -0.299 -0.13 -0.198 -0.496 -0.329 ...
## $ age_mid : num 0.0281 0.0267 0.0278 0.0325 0.0307 ...
## $ draws_tau: num -0.4314 -0.0439 0.3838 -0.4453 0.1489 ...
## $ x_age : num 25 25 25 25 25 25 25 25 25 25 ...
## $ pred : num -0.0268 0.4922 0.8799 -0.1299 0.5885 ...
# aggregate the draws to get point estimate and UI
out3 <- out2 %>%
group_by(x_age) %>%
summarize(
pred_mean = mean(pred),
pred_lo = quantile(pred, probs = 0.025),
pred_hi = quantile(pred, probs = 0.975))
print(out3)
## # A tibble: 2 x 4
## x_age pred_mean pred_lo pred_hi
## <dbl> <dbl> <dbl> <dbl>
## 1 25.0 0.324 -0.292 0.923
## 2 100 2.59 1.98 3.22