DID Machine Learning Approaches

Author

Dor Leventer

Published

March 8, 2023

This document is mainly code oriented, and its purpose is to code from scratch, how to estimate a residualized DID regression.

Setup

# general setup
rm(list = ls()) # del all objects and functions
gc() #cleans memory
options(scipen = 999) # tell R not to use Scientific notation
options(digits = 5) # controls how many digits are printed by default

# libraries
pac.vec <-
  c(
    "did",
    "broom",
    "data.table",
    "tidyverse",
    "grf",
    "patchwork"
  )
lapply(pac.vec, require, character.only = T)

# setup for ggplot graph layout
theme_set(theme_bw())
theme_update(
  # change axis text size
  axis.text = element_text(size = 14),
  axis.title = element_text(size = 16),
  # change title alignment and size
  plot.title = element_text(hjust = 0.5, size = 20),
  plot.subtitle = element_text(hjust = 0.5, size = 18),
  # what lines in background of plot
  panel.grid.major.y = element_line(color = "gray"),
  panel.grid.major.x = element_line(color = "gray"),
  # legend text size and outline
  legend.text = element_text(size = 14),
  legend.background = element_rect(
    fill = "white",
    colour = "black",
    linewidth = 0.2,
    linetype = "solid"
  ),
  # facet title size
  strip.text.x = element_text(size = 16),
  strip.text.y = element_text(size = 16)
)

## create custom theme for ggplot
theme_nice <- function() {
  theme_bw(base_size = 16) +
    theme(
      panel.grid.minor = element_blank(),
      axis.text = element_text(size = 12),
      plot.title = element_text(hjust = 0.5),
      plot.subtitle = element_text(hjust = 0.5),
    )
}

# setup for graph colors
orange <- "#f46d43"
lightorange <- "#fdae61"
blue <- "#0072B2"
lightblue <- "#74add1"
  
# setup for graph saving for slides
height = 4
width = 8

# parameters
set.seed(1)

the dgp

make_data <- function(nobs = 1000,
                      years = c(1, 10), 
                      time_treat = 5, 
                      rho = 0.2, 
                      seed = 1) {
  set.seed(seed)
  # create panel data
  dat <- CJ(id = 1:nobs, time = years[1]:years[2])
  # add constant Xs
  X = gendata::genmvnorm(cor = c(rho, rho, rho, rho, rho, rho),
                         k = 4,
                         n = nobs) %>% as.data.table()
  X[, id := 1:nobs]
  dat = dat[X, on = "id"]
  # treatment propensity and treatment group
  dat[, logit_error := rnorm(1), by = "id"][
    , treat_prop := (1/(1+exp(1)^(-(-0.5 + (X3 > 0) + 0.5*X3*X4 + logit_error))))][
    , treat := 1*(treat_prop > .5)][
    , cohort := fifelse(treat == 1, time_treat, 99)]
  # did variables
  dat[, time_fe := rnorm(n=1, mean = (time)/5, sd = 1), by = "time"][
      , unit_fe := rnorm(1, mean = (treat), sd = 1), by = "id"]
  # treatment effect and outcome
  dat[, tau := fifelse(time >= cohort & treat == 1, 1*(X1 > 0) + 1*(X2 > 0), 0)][
    , cumtau := cumsum(tau), by = "id"][
      , error  := rnorm(.N, 0, .5)][
        , Y0 := unit_fe + time_fe + 0.2*(X4 > 0) + 0.1*X3^2 + 0.1*X2*X1][
          , Y := Y0 + cumtau + error]
  return(dat)
}

look at the data

data = make_data(seed = 1) |> tibble()
data |> filter(time == 1) |> count(treat)
# A tibble: 2 × 2
  treat     n
  <dbl> <int>
1     0   491
2     1   509
data = make_data(seed = 1) |> tibble()

data <- data |> 
  group_by(cohort, time) |> 
  summarise(Y_mean = mean(Y), .groups = "keep") |> 
  right_join(data, by = c("cohort", "time"))

data |> 
  ggplot() + 
  geom_line(data = data |> filter(treat == 1), 
            mapping = aes(x = time, y = Y, group = id), 
            color = "#FC9272", alpha = .2, linewidth = .1) +
  geom_line(data = data |> filter(treat == 0), 
            mapping = aes(x = time, y = Y, group = id), 
            color = "#9ECAE1", alpha = .2, linewidth = .1) +
  geom_line(data = data |> filter(treat == 1), 
            mapping = aes(x = time, y = Y_mean,  color = "Treated"), 
            alpha = .8, linewidth = 2) + 
  geom_line(data = data |> filter(treat == 0), 
            mapping = aes(x = time, y = Y_mean, color = "Control"), 
            alpha = .8, linewidth = 2) + 
  theme_nice() +
  labs(x = "Time Period") +
  geom_vline(xintercept = 4.5, linetype = "dashed") +
  scale_x_continuous(breaks = seq(1, 10, 1)) + 
  scale_color_manual(name = NULL, breaks = c("Treated", "Control"), values = c("#CB181D", "#2171B5"))

ggsave(
    filename = glue::glue("grf_desc.pdf"),
    width = 8,
    height = 4
  )

into the did r-learner

data and estimand

data <- make_data()
# the att(5, 5) is
data |> filter(cohort == 5) |> group_by(time) |> summarise(att = mean(cumtau))
# A tibble: 10 × 2
    time   att
   <int> <dbl>
 1     1  0   
 2     2  0   
 3     3  0   
 4     4  0   
 5     5  1.10
 6     6  2.20
 7     7  3.29
 8     8  4.39
 9     9  5.49
10    10  6.59
# testing the cs estimator
did::att_gt(yname = "Y", tname = "time", idname = "id", gname = "cohort", data = data)

Call:
did::att_gt(yname = "Y", tname = "time", idname = "id", gname = "cohort", 
    data = data)

Reference: Callaway, Brantly and Pedro H.C. Sant'Anna.  "Difference-in-Differences with Multiple Time Periods." Journal of Econometrics, Vol. 225, No. 2, pp. 200-230, 2021. <https://doi.org/10.1016/j.jeconom.2020.12.001>, <https://arxiv.org/abs/1803.09015> 

Group-Time Average Treatment Effects:
 Group Time ATT(g,t) Std. Error [95% Simult.  Conf. Band]  
     5    2  -0.0141     0.0444       -0.1248      0.0966  
     5    3  -0.0439     0.0426       -0.1503      0.0624  
     5    4   0.0260     0.0467       -0.0904      0.1425  
     5    5   1.1239     0.0578        0.9798      1.2679 *
     5    6   2.2386     0.0816        2.0351      2.4422 *
     5    7   3.3127     0.1132        3.0304      3.5951 *
     5    8   4.4067     0.1488        4.0355      4.7778 *
     5    9   5.5607     0.1827        5.1049      6.0165 *
     5   10   6.5845     0.2153        6.0476      7.1215 *
---
Signif. codes: `*' confidence band does not cover 0

P-value for pre-test of parallel trends assumption:  0.60754
Control Group:  Never Treated,  Anticipation Periods:  0
Estimation Method:  Doubly Robust

functions for estimator

ml_train <- function(label, data) {
  objective = "reg:squarederror"
  if(label |> unique() |> length() == 2) {
    objective = "binary:logistic"
  }
  
  mod <- xgboost::xgboost(data = data, label = label, max_depth = 2, eta = 1, nrounds = 10,
                   objective = objective, verbose = F)
  return(mod)
}

ml_pred <- function(mod, X) {
  pred <- predict(mod, X)
  return(pred)
}

k-folds

single_fold <- function(k, folds, Y, W, X) {
  index_k = folds == k
  Y_train <- Y[!index_k]
  W_train <- W[!index_k]
  X_train <- X[!index_k,]
  Y_test <- Y[index_k]
  W_test <- W[index_k]
  X_test <- X[index_k,]
  
  mu_hat_train <- ml_train(Y_train, X_train)
  p_hat_train <- ml_train(W_train, X_train)
  
  mu_hat_test <- ml_pred(mu_hat_train, X_test)
  p_hat_test <- ml_pred(p_hat_train, X_test)
  
  tau_k <- sum((Y_test - mu_hat_test)*(W_test - p_hat_test)) / sum((W_test - p_hat_test)^2)
  return(tau_k)
}
cross_fit_est <- function(K = 2, delta_Y, W, X) {
  n <- delta_Y |> length()
  folds <- sample(1:K, n, replace = T)
  taus <- lapply(1:K, single_fold, folds, delta_Y, W, X) 
  
  tau_df <- data.frame()
  for(i in 1:K) {
    tau_df <- tau_df |> 
      rbind(data.frame(tau = taus[[i]], k = i))
  }
  
  tau_df <- tau_df |> 
    left_join(
      folds |> tibble() |> count(k = folds),
      by = "k"
    )
  
  tau_df |> 
    mutate(n = n/sum(n)) |> 
    mutate(tau_weight = tau * n) |> 
    summarise(tau = sum(tau_weight)) |> 
    pull(tau)
}

monte carlo sim

single_iteration <- function(seed, year_pre = 4, year_post = 5) {
  
  data <- make_data(seed = seed) |> 
    filter(time %in% c(year_pre, year_post)) |> 
    mutate(W = treat) |> 
    tibble()
  
  
  true_att <- data |> filter(cohort == 5) |> filter(time == 5) |> 
    summarise(att = mean(cumtau)) |> pull(att)
  # testing the cs estimator
  cs_est <- did::att_gt(yname = "Y", tname = "time", idname = "id", 
                        gname = "cohort", xformla= ~ X1 + X2 + X3 + X4, data = data) |> 
    broom::tidy() |> filter(time == 5) |> pull(estimate)
  
  # some needed variables
  Y_treat_post = data |> filter(time == year_post) |> pull(Y)
  Y_treat_pre = data |> filter(time == year_pre) |> pull(Y)
  delta_Y = Y_treat_post - Y_treat_pre
  W = data |> filter(time == year_post) |> pull(W)
  X <- data |> filter(time == year_post) |> select(X1:X4) |> as.matrix()
  
  ols_est <- lm(delta_Y ~ W)$coefficients[2]
  
  g_mod <- ml_train(delta_Y[W == 0], X[W == 0,])
  g_hat <- ml_pred(g_mod, X)
  
  delta_Y_2 <- delta_Y - g_hat
  debias1_est <- lm(delta_Y_2 ~ W - 1)$coefficients[1] |> as.numeric()
  
  mu_mod <- ml_train(delta_Y, X)
  mu_hat <- ml_pred(mu_mod, X)
  p_mod <- ml_train(W, X)
  p_hat <- ml_pred(p_mod, X)
  
  delta_Y_3 <- delta_Y - mu_hat
  W_2 <- W - p_hat
  debias2_est <- lm(delta_Y_3 ~ W_2 - 1)$coefficients[1] |> as.numeric()
  
  cf_mod <- grf::causal_forest(X = X, W = W, Y = delta_Y)
  grf_est <- grf::average_treatment_effect(cf_mod)[1]
  
  debias_cf_est <- cross_fit_est(K = 5, delta_Y, W, X)
  
  ret <- tibble(true_att, ols_est, cs_est, debias1_est, debias2_est, grf_est, debias_cf_est)
  return(ret)
}
num_it <- 1000
full <- lapply(1:num_it, single_iteration, year_pre = 4, year_post = 5)
full_long <- full |> 
  dplyr::bind_rows() |> 
  mutate(seed = row_number()) |> 
  pivot_longer(
    !c(true_att, seed),
    names_to = "method",
    values_to = "est"
  ) |> 
  mutate(bias = est - true_att)

g1 <- full_long |> 
  filter(method %in% c("ols_est", "cs_est", "debias1_est")) |> 
  ggplot() + 
  geom_density(aes(x = bias, fill = method), alpha = .3) +
  geom_vline(xintercept = 0, color = "black", linetype = "dashed") + 
  xlim(-.2, .2) + ylim(0, 10)

g2 <- full_long |> 
  filter(!method %in% c("ols_est", "cs_est", "debias1_est")) |> 
  ggplot() + 
  geom_density(aes(x = bias, fill = method), alpha = .3) +
  geom_vline(xintercept = 0, color = "black", linetype = "dashed") +
  xlim(-.2, .2) + ylim(0, 10)

g1
Warning: Removed 3 rows containing non-finite values (`stat_density()`).

g2
Warning: Removed 329 rows containing non-finite values (`stat_density()`).