library(tidyverse)
library(rethinking)
theme_set(theme_bw())

Updating the posterior

Homework

We will cover the first 2 exercises from week’s 2 problem set.

Drawing the owl steps:

  1. Research question/Estimand
  2. Scientific model
  3. Statistical model (estimator)
  4. Validate model
  5. Analyse data (i.e. getting the estimates?)

Exercise 1

From the Howell1 dataset, consider only the people younger than 13 years old. Estimate the causal association between age and weight. Assume that age influences weight through two paths. First, age influences height, and height influences weight. Second, age directly influences weight through age related changes in muscle growth and body proportions.

Draw the DAG that represents these causal relationships.

We are given the (1) RQ - association between age and weight and asked to draw the scientific model and then write a generative simulation that takes age as an input and simulates height and weight, obeying the relationships in the DAG. (2).

library(dagitty)
library(ggdag)
dagify(
  weight ~ height,
  height ~ age,
  weight ~ age
  ) %>%
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_point() + 
  geom_dag_edges() +
  geom_dag_text() +
  theme_dag()

# assume no intercept for simplicity
sim_data <- function(age, b_ah = 10, b_aw = 1, b_hw = 0.1, sd_height = 4, sd_weight = 2) {
  N <- length(age)
  height <- b_ah * age +  rnorm(N, 0, sd_height) # or rnorm(30 + b_ah * age, sd_height) 
  weight <- b_aw * age + b_hw * height + rnorm(N, 0 , sd_weight)
  
  return(tibble(age, height, weight))
}

age <- runif(500, 0, 12)

d_sim <- sim_data(age)  

qplot(age, height, data = d_sim)

qplot(age, weight, data = d_sim)

qplot(height, weight, data = d_sim)

Exercise 2

Use a linear regression to estimate the total causal effect of each year of growth on weight.

While I guess we shouldn’t do it, let’s load the data to get some info the the figures.

data(Howell1)
# select only kids
kids <- Howell1 |> filter(age < 13)

Since age as a direct effect on weight and an indirect one through height, to get the total effect we do not need to model height since it’s only a mediator, not a confounder.

First, let’s set up our prior for the intercept and slope and explore how they look like.

N = 20

a = rnorm(N, 2.5, 1) # new-born are about this heavy on average?
b = runif(N, 2, 8) # on average people kids put on 2-8 kg very year? 


kids |> 
  ggplot(
    aes(x = age, y = weight)
  ) + 
  geom_point(alpha = 0.5) +
  coord_cartesian(
    xlim = range(kids$age),
    ylim = range(kids$weight)
  ) +
  geom_abline(
    data = tibble(a, b), aes(intercept = a, slope = b), alpha = 0.5
  ) 

  1. Statistical model/Estimator and (4) Validation. We’ll use quadratic approximation. Anyone can explain? From what I understand, it’s similar to a Taylor expansion but quadratic not linear. Is the posterior going to be normal by construction?

We’ll apply the quap to simulated date with the priors we set above.

prior <-  alist(
  weight ~ dnorm(mu, sigma),
  mu <- a + b * age,
  a ~ dunif(-5, 10),
  b ~ dunif(0, 5),
  sigma ~ dunif(0, 10)
)

# Let's test it on our simulated data
m <- quap(prior, data = d_sim)
precis(m, prob = 0.95)
##          mean     sd   2.5% 97.5%
## a     -0.0839 0.1835 -0.444 0.276
## b      2.0048 0.0265  1.953 2.057
## sigma  2.0553 0.0650  1.928 2.183
# compare it to a linear regression
lm(
  weight ~ age, data = d_sim
  ) |>  
  broom::tidy() |> 
  select(term, mean = estimate, sd = std.error) |> 
  mutate(
    `2.5%` = mean - 1.96 * sd,
    `97.5` = mean + 1.96 * sd
  )
## # A tibble: 2 × 5
##   term           mean     sd `2.5%` `97.5`
##   <chr>         <dbl>  <dbl>  <dbl>  <dbl>
## 1 (Intercept) -0.0836 0.184  -0.444  0.277
## 2 age          2.00   0.0266  1.95   2.06
# playing with the sample size

# small sample
m <- quap(prior, data = d_sim |> slice_sample(n = 10, replace = T)) 
precis(m) # performs well because the two are extremely correlated
##        mean    sd  5.5% 94.5%
## a     -1.32 1.721 -4.07  1.43
## b      2.01 0.291  1.54  2.47
## sigma  2.91 0.651  1.87  3.95
# very large sample
m <- quap(prior, data = d_sim |> slice_sample(n = 5000, replace = T))
precis(m)
##         mean      sd   5.5% 94.5%
## a     -0.061 0.05815 -0.154 0.032
## b      2.008 0.00845  1.994 2.021
## sigma  2.056 0.02056  2.023 2.089

Question - is this a good way to validate a model? In my experience, you would do it on many samples and check bias and MSE.

  1. Analyze data. We basically do the same thing as in the validation but on the real data.
m <- quap(prior, data = kids)
precis(m)
##       mean     sd 5.5% 94.5%
## a     7.45 0.3623 6.87  8.03
## b     1.34 0.0548 1.25  1.43
## sigma 2.52 0.1477 2.29  2.76

First law of statistical interpretation: the parameters are not independent from each other. What does it mean??? from the book “In the practice problems at the end of the chapter, you’ll see that the lack of covariance among the parameters results from centering.”

vcov(m) |> round(2)
##           a     b sigma
## a      0.13 -0.02  0.00
## b     -0.02  0.00  0.00
## sigma  0.00  0.00  0.02
cov2cor(vcov(m))  |> round(2)
##           a     b sigma
## a      1.00 -0.82     0
## b     -0.82  1.00     0
## sigma  0.00  0.00     1
post <- extract.samples(m)


kids |> 
  ggplot(
    aes(x = age, y = weight)
  ) + 
  geom_point(alpha = 0.5) +
  coord_cartesian(
    xlim = range(kids$age),
    ylim = range(kids$weight)
  ) +
  geom_abline(
    data = post |> slice_sample(n = 20), aes(intercept = a, slope = b), alpha = 0.5
  ) 

# what if we would have had a small sample (quick funciton)
plot_post <- function(data) {
  m <- quap(prior, data = data)
  post <- extract.samples(m)
  data |> 
    ggplot(
      aes(x = age, y = weight)
    ) + 
    geom_point(alpha = 0.5) +
    coord_cartesian(
      xlim = range(kids$age),
      ylim = range(kids$weight)
    ) +
    geom_abline(
      data = post |> slice_sample(n = 20), aes(intercept = a, slope = b), alpha = 0.5
    ) 
}

kids |> slice_sample(n = 5) |> plot_post()

kids |> slice_sample(n = 2000, replace = T) |> plot_post()

age_seq <- seq(0,12)

weight_postpred <- rethinking::sim(m, data = list(age = age_seq))

# compute percentile intervals (89% by default)
weight_PI <- apply(weight_postpred, 2, PI)

data_pi <- tibble(
  lb = weight_PI[1, ],
  ub = weight_PI[2, ],
  age = age_seq
)

kids |> plot_post()  + 
  geom_line(data = data_pi, aes(x = age, y = lb), linetype = "dashed") + 
  geom_line(data = data_pi, aes(x = age, y = ub), linetype = "dashed")

mu_at_10  <- post$a + post$b * 10
dens(mu_at_10)