COVID - Study 1 - Modeling - Cases

Import libraries, define helper functions, set options

knitr::opts_chunk$set(echo = TRUE)
library(brms)
## Loading required package: Rcpp

## Loading 'brms' package (version 2.10.0). Useful instructions
## can be found by typing help('brms'). A more detailed introduction
## to the package is available through vignette('brms_overview').
library(bayesplot)
## This is bayesplot version 1.7.0

## - Online documentation and vignettes at mc-stan.org/bayesplot

## - bayesplot theme set to bayesplot::theme_default()

##    * Does _not_ affect other ggplot2 plots

##    * See ?bayesplot_theme_set for details on theme setting
library(tidybayes)
library(rstanarm)
## rstanarm (Version 2.19.2, packaged: 2019-10-01 20:20:33 UTC)

## - Do not expect the default priors to remain the same in future rstanarm versions.

## Thus, R scripts should specify priors explicitly, even if they are just the defaults.

## - For execution on a local, multicore CPU with excess RAM we recommend calling

## options(mc.cores = parallel::detectCores())

## - bayesplot theme set to bayesplot::theme_default()

##    * Does _not_ affect other ggplot2 plots

##    * See ?bayesplot_theme_set for details on theme setting

## 
## Attaching package: 'rstanarm'

## The following objects are masked from 'package:brms':
## 
##     dirichlet, exponential, get_y, lasso, loo_R2, ngrps
library(rstan)
## Loading required package: StanHeaders

## Loading required package: ggplot2

## rstan (Version 2.19.2, GitRev: 2e1f913d3ca3)

## For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores()).
## To avoid recompilation of unchanged Stan programs, we recommend calling
## rstan_options(auto_write = TRUE)
library(MASS)
library(dplyr)
## 
## Attaching package: 'dplyr'

## The following object is masked from 'package:MASS':
## 
##     select

## The following objects are masked from 'package:stats':
## 
##     filter, lag

## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(sjstats)
## 
## Attaching package: 'sjstats'

## The following object is masked from 'package:rstanarm':
## 
##     se
library(tidyr)
## 
## Attaching package: 'tidyr'

## The following object is masked from 'package:rstan':
## 
##     extract
library(ggthemes)
library(ggplot2)
library(emmeans)
library(xtable)

run_model <- function(expr, path, reuse = TRUE) {
    path <- paste0(path, ".Rds")
    if (reuse) {
        fit <- suppressWarnings(try(readRDS(path), silent = TRUE))
    } else {
        fit <- eval(expr)
        saveRDS(fit, file = path)
    }
    fit
}
rescale.center = function(x,m2,s2) {
    # rescales x to have mean of m2 and sd of s2
    # http://www.stat.columbia.edu/~gelman/research/unpublished/standardizing.pdf
    # http://www.stat.columbia.edu/~gelman/research/published/priors11.pdf
    m1 = mean(x,na.rm=T)
    s1 = sd(x,na.rm=T)
    y = m2 + (x - m1) * (s2/s1)
    return(y)
}

my_theme = theme_tufte(base_size = 10,base_family = "sans") +
    theme(
        axis.line.y = element_line(size=.5),
        axis.line.x = element_line(size=.5),
        legend.title = element_text(size=12,family="sans",face="bold"),
        axis.text = element_text(face = "plain"),
        axis.title = element_text(size = 12,family="sans",face="bold"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        plot.title = element_text(hjust = .5,size=13,family="sans", face="bold"))

options(mc.cores = parallel::detectCores())
mcmc_controls = list(adapt_delta = 0.9, max_treedepth=15, stepsize=0.2)
rstan_options(auto_write = TRUE)
options(contrasts = c("contr.sdif","contr.sdif"))
color_scheme_set("mix-blue-red")

setwd("/Users/adkinsty/Box/side_projects/covid/modeling/")

Get data

loc_data = read.csv("../data/extra_data/loc/uszips.csv")
data = read.csv("../data/mar28_covid_data_long.csv")  %>%
    merge(loc_data,by="zip",all.x = TRUE) %>% 
    mutate(state = state_id)

# model data
m_data = data %>% 
    filter(outcome != "deaths") %>%
    filter(est < 1e6 & est > 25000) %>%    # remove extreme values
    mutate(
        lag = factor(delay,levels = c(1,2,3)),
        cond = factor(group, levels=c("g","t","ta")),
        y_est = rescale.center(est,0,1),
        pop = rescale.center(population,0,1),
        gen_anx = rescale.center(gen_anxiety,0,1),
        risk = rescale.center(risk_mu,0,1),
        cons = rescale.center(conserv_mu,0,1),
        age = rescale.center(age,0,1),
        covid_anx = rescale.center(corona_anxious_1,0,1),
        covid_news = rescale.center(corona_news_1,0,1),
        health = rescale.center(gen_health,0,1)) %>%
    select(y_est,est,pop,delay,lag,cond,personal,state_id,id,outcome,
           age,est,risk,covid_anx,covid_news,health,cons,gen_anx) %>%
    tidyr::drop_na()

Priors and formulae

priors = c(set_prior("normal(0,.5)",class="b"),
           set_prior("normal(0,1)",class="sd"),
           set_prior("normal(0,1)",class="b",dpar = "sigma"),
           set_prior("normal(0,5)",class="b",dpar = "alpha"))
m_formula =  bf(
    y_est ~ 0 +  + (0 + intercept | state_id) + intercept + lag + outcome + cond + personal + cons + 
            risk + pop + age + health + gen_anx + covid_anx + covid_news,
    sigma ~ 0 + intercept + outcome + lag,
    alpha ~ 0 + intercept + outcome + lag)
# get_prior(m_formula,m_data,family=skew_normal())

Prior-only model

This model ignores the data likelihood and merely samples from the priors.

prior_m = run_model(
    path = "fitted_models/cases_model_prior_Apr3", reuse = TRUE,
    expr= brm(
        sample_prior = "only",
        formula = m_formula,
        data = m_data,
        family = skew_normal(),
        prior = priors,
        control = mcmc_controls,
        warmup = 5000, iter = 10000, chains = 4,seed = 420))
prior_m
##  Family: skew_normal 
##   Links: mu = identity; sigma = log; alpha = identity 
## Formula: y_est ~ 0 + intercept + lag + outcome + cond + personal + cons + risk + pop + age + health + gen_anx + covid_anx + covid_news + (0 + intercept | state_id) 
##          sigma ~ 0 + intercept + outcome + lag
##          alpha ~ 0 + intercept + outcome + lag
##    Data: m_data (Number of observations: 4764) 
## Samples: 4 chains, each with iter = 10000; warmup = 5000; thin = 1;
##          total post-warmup samples = 20000
## 
## Group-Level Effects: 
## ~state_id (Number of levels: 50) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(intercept)     0.80      0.61     0.03     2.26 1.00    18622    10005
## 
## Population-Level Effects: 
##                            Estimate Est.Error l-95% CI u-95% CI Rhat
## intercept                     -0.00      0.50    -0.99     0.98 1.00
## lag2M1                         0.00      0.50    -0.99     0.99 1.00
## lag3M2                         0.00      0.50    -0.97     0.97 1.00
## outcomecCasesMaCases           0.00      0.50    -0.97     0.97 1.00
## condtMg                       -0.00      0.50    -0.97     0.97 1.00
## condtaMt                      -0.00      0.50    -0.99     0.98 1.00
## personal                       0.00      0.50    -0.97     0.98 1.00
## cons                           0.00      0.49    -0.96     0.96 1.00
## risk                          -0.00      0.50    -0.97     0.99 1.00
## pop                           -0.00      0.50    -0.97     0.96 1.00
## age                           -0.00      0.49    -0.97     0.96 1.00
## health                         0.00      0.50    -0.97     0.99 1.00
## gen_anx                       -0.00      0.50    -0.98     0.98 1.00
## covid_anx                     -0.00      0.49    -0.97     0.97 1.00
## covid_news                    -0.00      0.50    -1.00     1.01 1.00
## sigma_intercept                0.00      1.01    -1.98     1.98 1.00
## sigma_outcomecCasesMaCases    -0.00      0.99    -1.96     1.96 1.00
## sigma_lag2M1                  -0.00      0.99    -1.95     1.95 1.00
## sigma_lag3M2                  -0.00      0.99    -1.93     1.94 1.00
## alpha_intercept               -0.01      5.02    -9.83     9.76 1.00
## alpha_outcomecCasesMaCases    -0.04      5.03    -9.96     9.86 1.00
## alpha_lag2M1                   0.02      4.93    -9.70     9.62 1.00
## alpha_lag3M2                  -0.03      4.99    -9.84     9.68 1.00
##                            Bulk_ESS Tail_ESS
## intercept                     46311    14507
## lag2M1                        47445    13804
## lag3M2                        50269    14352
## outcomecCasesMaCases          47606    13689
## condtMg                       50632    14098
## condtaMt                      48681    13787
## personal                      47287    13804
## cons                          47996    13501
## risk                          49838    14482
## pop                           50112    12843
## age                           46694    13628
## health                        44571    14785
## gen_anx                       46953    12615
## covid_anx                     46409    14499
## covid_news                    43265    13040
## sigma_intercept               46908    13647
## sigma_outcomecCasesMaCases    49963    13858
## sigma_lag2M1                  45910    14749
## sigma_lag3M2                  47287    14360
## alpha_intercept               49534    14236
## alpha_outcomecCasesMaCases    47839    13699
## alpha_lag2M1                  51446    13665
## alpha_lag3M2                  48804    13768
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).
# launch_shinystan(prior_m) # for diagnostics

Prior predictive check

# prior predictive check
prior_y = prior_m$data$y_est
prior_yrep = posterior_predict(prior_m,nsamples=4000)
prior_dens = ppc_dens_overlay(prior_y,prior_yrep[sample(1:4000,200),],
                              size = .1,alpha=.5) + xlim(-5,10)
ggsave(filename = "../visualization/figures/bayesian/prior_pc_case_model_Apr3.pdf",
       plot = prior_dens,units="in",height=4,width=4)
## Warning: Removed 53839 rows containing non-finite values (stat_density).
prior_dens
## Warning: Removed 53839 rows containing non-finite values (stat_density).

Posterior model

m = run_model(
    path = "fitted_models/cases_model_Apr3", reuse = TRUE,
    expr= update(prior_m,sample_prior="no"))
# launch_shinystan(m) # for diagnostics
m
##  Family: skew_normal 
##   Links: mu = identity; sigma = log; alpha = identity 
## Formula: y_est ~ 0 + intercept + lag + outcome + cond + personal + cons + risk + pop + age + health + gen_anx + covid_anx + covid_news + (0 + intercept | state_id) 
##          sigma ~ 0 + intercept + outcome + lag
##          alpha ~ 0 + intercept + outcome + lag
##    Data: m_data (Number of observations: 4764) 
## Samples: 4 chains, each with iter = 10000; warmup = 5000; thin = 1;
##          total post-warmup samples = 20000
## 
## Group-Level Effects: 
## ~state_id (Number of levels: 50) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(intercept)     0.07      0.01     0.04     0.09 1.00     5971    10592
## 
## Population-Level Effects: 
##                            Estimate Est.Error l-95% CI u-95% CI Rhat
## intercept                      0.10      0.02     0.07     0.13 1.00
## lag2M1                         0.43      0.02     0.39     0.47 1.00
## lag3M2                         0.26      0.03     0.21     0.32 1.00
## outcomecCasesMaCases          -0.33      0.02    -0.36    -0.29 1.00
## condtMg                        0.02      0.02    -0.01     0.05 1.00
## condtaMt                       0.05      0.01     0.02     0.08 1.00
## personal                      -0.01      0.02    -0.05     0.03 1.00
## cons                          -0.01      0.01    -0.02     0.00 1.00
## risk                          -0.01      0.01    -0.02    -0.00 1.00
## pop                            0.01      0.01    -0.00     0.02 1.00
## age                           -0.01      0.01    -0.02     0.00 1.00
## health                        -0.02      0.01    -0.03    -0.01 1.00
## gen_anx                       -0.01      0.01    -0.03     0.00 1.00
## covid_anx                     -0.01      0.01    -0.02     0.01 1.00
## covid_news                     0.04      0.01     0.03     0.06 1.00
## sigma_intercept               -0.31      0.01    -0.33    -0.29 1.00
## sigma_outcomecCasesMaCases    -0.40      0.02    -0.44    -0.36 1.00
## sigma_lag2M1                   0.53      0.03     0.48     0.58 1.00
## sigma_lag3M2                   0.25      0.03     0.20     0.30 1.00
## alpha_intercept               10.17      0.54     9.17    11.28 1.00
## alpha_outcomecCasesMaCases    -2.61      0.90    -4.42    -0.87 1.00
## alpha_lag2M1                   2.39      0.94     0.55     4.22 1.00
## alpha_lag3M2                   1.27      1.20    -0.98     3.73 1.00
##                            Bulk_ESS Tail_ESS
## intercept                      7425    10876
## lag2M1                        10210    13843
## lag3M2                        11573    13855
## outcomecCasesMaCases          15127    16092
## condtMg                       16260    15669
## condtaMt                      15996    16116
## personal                      21013    15465
## cons                          23096    15974
## risk                          21111    15709
## pop                           19425    13938
## age                           21362    14798
## health                        20197    16205
## gen_anx                       16222    15323
## covid_anx                     15285    14563
## covid_news                    18394    15992
## sigma_intercept               16148    15727
## sigma_outcomecCasesMaCases    15528    15515
## sigma_lag2M1                  11736    14118
## sigma_lag3M2                  11163    13325
## alpha_intercept               15449    14034
## alpha_outcomecCasesMaCases    18791    15695
## alpha_lag2M1                  15988    14543
## alpha_lag3M2                  17370    14968
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).

Posterior predictive check

# posterior predictive check
y = m$data$y_est
yrep = posterior_predict(m,nsamples=4000)
dens = ppc_dens_overlay(y,yrep[sample(1:4000,200),],
                        size = .1,alpha=.5) + xlim(-5,10)
ggsave(filename = "../visualization/figures/bayesian/post_pc_case_model_Apr3.pdf",
       plot = dens,units="in",height=4,width=4)
dens

Posterior parameter estimates

params = brms::parnames(m)

color_scheme_set("teal")
areas = bayesplot::mcmc_areas_ridges(x = m,pars = params[5:15],prob = .89) +
    vline_0(colour="orange") +
    theme_tufte(base_size = 11,base_family = "sans") +
    theme(
        axis.line.y = element_line(size=.5),
        axis.line.x = element_line(size=.5),
        legend.title = element_text(size=12,family="sans",face="bold"),
        axis.text = element_text(face = "plain"),
        axis.title = element_text(size = 12,family="sans",face="bold"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        plot.title = element_text(hjust = .5,size=13,family="sans", face="bold"))
ggsave("../visualization/figures/bayesian/areas_ridges_model_Apr3.pdf",
       units="in",height=6,width=6)
areas

Individual effects

# Estimates for 3,6, or 9 days future
m_data %>% ggplot(aes(x=lag,y=est)) + stat_summary() + my_theme
## No summary function supplied, defaulting to `mean_se()
# outcome estimated: confirmed vs actual cases
m_data %>% ggplot(aes(x=outcome,y=est)) + stat_summary() + my_theme
## No summary function supplied, defaulting to `mean_se()
# data format group: graph, table, text
m_data %>% ggplot(aes(x=cond,y=est)) + stat_summary() + my_theme
## No summary function supplied, defaulting to `mean_se()
# know someone with covid
m_data %>% ggplot(aes(x=personal,y=est)) + stat_summary() + my_theme
## No summary function supplied, defaulting to `mean_se()
# conservatism
m_data %>% ggplot(aes(x=cons,y=est)) + geom_smooth(method="lm") + my_theme
# risk-aversion
m_data %>% ggplot(aes(x=risk,y=est)) + geom_smooth(method="lm") + my_theme
# zip-code population
m_data %>% ggplot(aes(x=pop,y=est)) + geom_smooth(method="lm") + my_theme
# age
m_data %>% ggplot(aes(x=age,y=est)) + geom_smooth(method="lm") + my_theme
# general anxiety
m_data %>% ggplot(aes(x=gen_anx,y=est)) + geom_smooth(method="lm") + my_theme
# general health
m_data %>% ggplot(aes(x=health,y=est)) + geom_smooth(method="lm") + my_theme
# anxiety about coronavirus
m_data %>% ggplot(aes(x=covid_anx,y=est)) + geom_smooth(method="lm") + my_theme
# watch news about coronavirus
m_data %>% ggplot(aes(x=covid_news,y=est)) + geom_smooth(method="lm") + my_theme