Set up environment

#####STEP 0-1: Reset environment #####
rm(list=ls())
knitr::opts_chunk$set(echo = TRUE)
options(repos = structure(c(CRAN = "http://cran.rstudio.com/")))

#####STEP 0-2: Install packages #####
list.of.packages <- c( "grf", "metafor", "splitstackshape", "dplyr", "tidyverse", "foreach", "cowplot",
                       "reshape2", "doParallel", "survival", "readstata13", "ggplot2", "rsample", "DiagrammeR",
                       "e1071", "pscl", "pROC", "caret", "ModelMetrics", "MatchIt", "Hmisc", "scales",
                       "lmtest", "sandwich","haven", "rpms", "randomForest",  "fabricatr", "gridExtra", 
                       "VIM", "mice", "missForest", "lmtest", "ivreg", "kableExtra", "policytree")
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
if(length(new.packages)) install.packages(new.packages)

lapply(list.of.packages, library, character.only = TRUE)
## Warning: package 'grf' was built under R version 4.3.3
## Warning: package 'metafor' was built under R version 4.3.3
## Warning: package 'metadat' was built under R version 4.3.3
## Warning: package 'numDeriv' was built under R version 4.3.1
## Warning: package 'splitstackshape' was built under R version 4.3.3
## Warning: package 'dplyr' was built under R version 4.3.3
## Warning: package 'tidyverse' was built under R version 4.3.3
## Warning: package 'ggplot2' was built under R version 4.3.3
## Warning: package 'tibble' was built under R version 4.3.3
## Warning: package 'tidyr' was built under R version 4.3.3
## Warning: package 'readr' was built under R version 4.3.3
## Warning: package 'purrr' was built under R version 4.3.3
## Warning: package 'stringr' was built under R version 4.3.3
## Warning: package 'forcats' was built under R version 4.3.3
## Warning: package 'lubridate' was built under R version 4.3.3
## Warning: package 'foreach' was built under R version 4.3.3
## Warning: package 'cowplot' was built under R version 4.3.3
## Warning: package 'reshape2' was built under R version 4.3.3
## Warning: package 'doParallel' was built under R version 4.3.3
## Warning: package 'iterators' was built under R version 4.3.3
## Warning: package 'readstata13' was built under R version 4.3.3
## Warning: package 'rsample' was built under R version 4.3.3
## Warning: package 'DiagrammeR' was built under R version 4.3.3
## Warning: package 'e1071' was built under R version 4.3.3
## Warning: package 'pscl' was built under R version 4.3.3
## Warning: package 'pROC' was built under R version 4.3.3
## Warning: package 'caret' was built under R version 4.3.3
## Warning: package 'ModelMetrics' was built under R version 4.3.3
## Warning: package 'MatchIt' was built under R version 4.3.3
## Warning: package 'Hmisc' was built under R version 4.3.3
## Warning: package 'scales' was built under R version 4.3.3
## Warning: package 'lmtest' was built under R version 4.3.3
## Warning: package 'zoo' was built under R version 4.3.3
## Warning: package 'sandwich' was built under R version 4.3.3
## Warning: package 'haven' was built under R version 4.3.3
## Warning: package 'rpms' was built under R version 4.3.3
## Warning: package 'randomForest' was built under R version 4.3.3
## Warning: package 'fabricatr' was built under R version 4.3.3
## Warning: package 'gridExtra' was built under R version 4.3.3
## Warning: package 'VIM' was built under R version 4.3.3
## Warning: package 'colorspace' was built under R version 4.3.3
## Warning: package 'mice' was built under R version 4.3.3
## Warning in check_dep_version(): ABI version mismatch: 
## lme4 was built with Matrix ABI version 1
## Current Matrix ABI version is 0
## Please re-install lme4 from source or restore original 'Matrix' package
## Warning: package 'missForest' was built under R version 4.3.3
## Warning: package 'ivreg' was built under R version 4.3.3
## Warning: package 'kableExtra' was built under R version 4.3.3
## Warning: package 'policytree' was built under R version 4.3.3
print(paste("Version of grf package:", packageVersion("grf")))

#####STEP 0-2: Basic information #####
Sys.time()
# Get detailed R session and system information
session_info <- sessionInfo()
system_info <- Sys.info()
# Combine the output
list(session_info = session_info, system_info = system_info)

Set non-run-specific parameters

#####STEP 0-3: Set non-run-specific parameters #####
seedset <- 777
seed.ivforest <- 123456789
seed.cf       <- 987654321
numthreadsset <- min(6, parallel::detectCores()) 
if (numthreadsset!= 6) {
  print("the results of grf vary by num.thread (publication paper used num.thread=6)")
} 

cat("number of threads (affects grf results):", numthreadsset,"\n")
## number of threads (affects grf results): 6
# Set printing options
options(digits = 4)

Set run-specific parameters

#####STEP 0-4: Set run-specific parameters #####
published_paper_run <- 0
if (published_paper_run == 1) {
  print("save intermediate files into both intdat and intdatAsPublished folders")
  warning("Changing this setting to 1 overwrites the input files required for replicating on different platforms.")
} else {
  print("save intermediate files into only intdat folder")
}
## [1] "save intermediate files into only intdat folder"
forest_files = c("cate_results_sim_sbp_neg_alpha_5_presentation_cw0.RData") # ANALYST FORM

cw_run_name <- "med" # ANALYST FORM
cw_grid <- 5 # ANALYST FORM

# Define cw_name
cw_name <- paste0("cw_", cw_run_name, "_", cw_grid)

Set file paths

#####STEP 0-6: Set file paths #####
# Set the processed path based on published_paper_run
processedpath <- if (published_paper_run == 1) {
  paste0("PP_Full_Analysis/Intermediate_data/As_published/empirical/", cw_name)
} else {
  paste0("PP_Full_Analysis/Intermediate_data/Testing/empirical/", cw_name)
}

# Set the original data path based on published_paper_run
originaldatapath <- if (published_paper_run == 1) {
  paste0("PP_Full_Analysis/Cleaned_input_data/As_published/empirical/", cw_name)
} else {
  paste0("PP_Full_Analysis/Cleaned_input_data/Testing/empirical/", cw_name)
}

# Set the results_dir based on published_paper_run and cw_name
results_dir <- if (published_paper_run == 1) {
  paste0("PP_Full_Analysis/Saved_tables_figures/As_published/", cw_name)
} else {
  paste0("PP_Full_Analysis/Saved_tables_figures/Testing/", cw_name)
}

# If results directory does not exist, create it
if (!dir.exists(results_dir)) {
  dir.create(results_dir, recursive = TRUE)
}

# Print processed path and results directory
print(paste("Processed path:", processedpath))
## [1] "Processed path: PP_Full_Analysis/Intermediate_data/Testing/empirical/cw_med_5"
print(paste("Results directory:", results_dir))
## [1] "Results directory: PP_Full_Analysis/Saved_tables_figures/Testing/cw_med_5"

Create functions

Policytree on training set function

#####STEP 8-1: Create policytree function #####
generate_policy_from_training_data <- function(forest_input, output_name) {
  print(paste("Currently processing policy tree for ", output_name))

  # Step 1: Extract gamma from the forest_input
  gamma <- double_robust_scores(forest_input)

  # Step 2: Hold out two folds as training data
  train_folds <- forest_input$clusters %in% 1:4  # Select rows where fold is 1-4
  X_train <- forest_input$X[train_folds, ]
  gamma_train <- gamma[train_folds, ]

  # Step 3: Train policy tree on training data
  tree <- policy_tree(X_train, gamma_train, depth=2)

  return(tree)
}

Policy assignment on test set function

#####STEP 8-2: Create predicts on testing data function #####
get_test_predictions <- function(forest_input, output_name) {
    print(paste("Currently processing test predictions for ", output_name))

  # Step 1: Generate policy tree from training data using function
  policy <- generate_policy_from_training_data(forest_input, output_name)
  
  print("Policy Tree: ")
  print(policy)

  # Evaluate tree on testing data
  test_folds <- forest_input$clusters %in% 5:10  # Select rows where fold is 5-10

  # Initialize pi.hat with NA for all test folds
  pi.hat <- rep(NA, length(test_folds))

  print("Dimensions of pi.hat: ")
  print(length(pi.hat))

  # Predict policy and adjust
  predicted_indices <- which(test_folds)  # Indices of test folds
  predictions <- predict(policy, forest_input$X[test_folds, ]) - 1

  print("Dimensions of predictions: ")
  print(length(predictions))

  # Assign predictions to the correct indices
  pi.hat[predicted_indices] <- predictions

  print("Dimensions of pi.hat: ")
  print(length(pi.hat))

  # Extract Y and W for the same test folds
  Y <- forest_input$Y.orig
  W <- forest_input$W.orig

  print("Dimensions of Y: ")
  print(length(Y))
  print("Dimensions of W: ")
  print(length(W))

  # Combine pi.hat with Y and W
  combined_df <- data.frame(pi.hat, Y, W)

  # Subset to non-missing pi.hat
  combined_df <- combined_df[!is.na(combined_df$pi.hat), ]

  print("Information on combined_df:")
  glimpse(combined_df)
  
  # # Predicting leaves 
  # leaf <- predict(policy, forest_input$X[test_folds, ], type = "node.id")
  # num.leaves <- length(unique(leaf))

  # print("Number of leaves: ")
  # print(num.leaves)
  # print("Table of leaves: ")
  # print(table(leaf))
  
  # X <- forest_input$X.orig[test_folds, ]  # Features matrix
  # Y <- forest_input$Y.orig[test_folds]  # Outcome vector
  # W <- forest_input$W.orig[test_folds] # Treatment indicator vector
  # combined_matrix <- cbind(X, Y, W)
  # combined_df <- data.frame(combined_matrix)
  # combined_df <- na.omit(combined_df)  # Removes rows with NA values

  # print("Dimensions of data: ")
  # print(dim(combined_df))
  # print("Length of pi.hat: ")
  # print(length(pi.hat))

  return(combined_df)
}

Policy evaluation function

#####STEP 8-3: Create policy evaluation function #####
evaluate_policy <- function(combined_df) {
  print("Starting policy evaluation")
  # Calculate Estimated Policy Value
  optimal <- combined_df$W == combined_df$pi.hat
  optimal_data <- combined_df[optimal,]

  # Calculate optimal value as the mean of the outcome for those in optimal_treat or optimal_control
  optimal_value <- mean(optimal_data$Y)
  optimal_se <- sd(optimal_data$Y)/sqrt(nrow(optimal_data))
  
  # Calculate Uniform Control
  uniform_control <- mean(combined_df$Y[combined_df$W == 0])
  uniform_control_se <- sd(combined_df$Y[combined_df$W == 0])/
                         sqrt(sum(combined_df$W == 0))
  
  # Calculate Uniform Treatment  
  uniform_treat <- mean(combined_df$Y[combined_df$W == 1])
  uniform_treat_se <- sd(combined_df$Y[combined_df$W == 1])/
                       sqrt(sum(combined_df$W == 1))
  
  # Calculate differences
  diff_control <- optimal_value - uniform_control
  diff_control_se <- sqrt(optimal_se^2 + uniform_control_se^2)
  
  diff_treat <- optimal_value - uniform_treat
  diff_treat_se <- sqrt(optimal_se^2 + uniform_treat_se^2)
  
  # Create results table
  results <- data.frame(
    Metric = c("Estimated Policy Value",
               "Uniform Control",
               "Difference from Control",
               "Uniform Treatment", 
               "Difference from Treatment"),
    Estimate = c(optimal_value,
                uniform_control,
                diff_control,
                uniform_treat,
                diff_treat),
    SE = c(optimal_se,
           uniform_control_se,
           diff_control_se,
           uniform_treat_se,
           diff_treat_se)
  )
  
  # Format table with estimates and SEs in parentheses
  results$Value <- sprintf("%.3f (%.3f)", results$Estimate, results$SE)
  
  # Clean up final table
  final_results <- results[,c("Metric", "Value")]
  
  print("Policy Evaluation Results:")
  print(kable(final_results))
  
  return(final_results)
}
#####STEP 8-3: Create policy evaluation function #####
evaluate_policy <- function(combined_df) {
  print("Starting policy evaluation")

  # data_file <- file.path(originaldatapath, "3_Penalty_Outcomes_Wide_Dataset.csv")
  
  # # Load original dataset
  # print(paste("Loading data from:", data_file))
  # original_data <- read.csv(data_file)

  # Create dataframe with person_id and pi.hat
  test_folds <- forest_input$clusters %in% 5:10  # Select rows where fold is 5-10
  person_id <- forest_input$X[test_folds, "person_id"] 

  # Create dataframe with person_id and pi.hat
  policy_df <- data.frame(
    person_id = original_data[original_data$folds %in% 5:10, "person_id"],
    pi_hat = pi.hat
  )
  
  # Merge datasets
  print("Merging policy assignments with original data")
  combined_df <- merge(original_data, policy_df, 
                      by.x = "person_id", 
                      by.y = "person_id",
                      all.x = FALSE, 
                      all.y = FALSE)
  
  print("Dimensions of merged dataset:")
  print(dim(combined_df))
  
  return(combined_df)
}
policytree_cate_plots <- function(forest_input, output_name) {
  
  print(paste("Currently processing policy tree for ", output_name))
  policy <- generate_policy_from_training_data(forest_input, output_name)
  
  # Evaluate tree on testing data
  test_folds <- forest_input$clusters %in% 5:10  # Select rows where fold is 5-10
  pi.hat <- predict(policy, forest_input$X[test_folds, ]) - 1
  
  print("Policy Tree: ")
  print(policy)
  
  # Predicting leaves 
  leaf <- predict(policy, forest_input$X[test_folds, ], type = "node.id")
  num.leaves <- length(unique(leaf))
  
  
  X <- forest_input$X.orig[test_folds, ]  # Features matrix
  Y <- forest_input$Y.orig[test_folds]  # Outcome vector
  W <- forest_input$W.orig[test_folds] # Treatment indicator vector
  combined_matrix <- cbind(X, Y, W)
  combined_df <- data.frame(combined_matrix)
  combined_df <- na.omit(combined_df)  # Removes rows with NA values

  print("Dimensions of data: ")
  print(dim(combined_df))
  print("Length of pi.hat: ")
  print(length(pi.hat))
  
  
  # test that treatment effect is indeed different across "regions" defined by assignment under the learned policy,
  print("Test if treatment effect is different across regions defined by assignment")
  cat("$$ H_0 : \\mathbb{E} \\left[ Y_i(1) - Y_i(0) \\mid \\hat{\\pi}(X_i) = 1 \\right] = \\mathbb{E} \\left[ Y_i(1) - Y_i(0) \\mid \\hat{\\pi}(X_i) = 0 \\right] $$")
  fmla <- formula(paste0("Y ~ 0 + pi.hat + W:pi.hat"))
  ols <- lm(fmla, data=transform(combined_df[test_folds,], pi.hat=factor(pi.hat)))
  coefs <- coeftest(ols, vcov=vcovHC(ols, 'HC2'))[3:4, 1:2]
  print(coefs)

  # test if treatment effects are different across leaves
  print("Test if treatment effects are different across leaves")
  cat("$$ H_0 : \\mathbb{E} \\left[ Y_i(1) - Y_i(0) \\mid \\text{Leaf} = 1 \\right] = \\mathbb{E} \\left[ Y_i(1) - Y_i(0) \\mid \\text{Leaf} = \\ell \\right] \\text{ for } \\ell \\geq 2 $$\n")
  fmla <- paste0("Y ~ 0 + leaf +  W:leaf")
  ols <- lm(fmla, data=transform(combined_df[test_folds,], leaf=factor(leaf)))
  coefs <- coeftest(ols, vcov=vcovHC(ols, 'HC2'))[,1:2]
  interact <- grepl(":", rownames(coefs))
  coefs[interact,1] <- coefs[interact,1] - 0 # subtracting cost
  print(coefs[interact,])

  # check how covariate averages vary across subgroups
  print("check how covariate averages vary across subgroups")
  cat("$$ H_0 : \\mathbb{E} \\left[ X_{ij} \\mid \\hat{\\pi}(X_i) = 1 \\right] = \\mathbb{E} \\left[ X_{ij} \\mid \\hat{\\pi}(X_i) = 0 \\right] \\text{ for each covariate } j $$\n")
  
  df <- lapply(covariates, function(covariate) {
    fmla <- formula(paste0(covariate, " ~ 0 + factor(pi.hat)"))
    ols <- lm(fmla, data=transform(combined_df[test_folds,], pi.hat=pi.hat))
    ols.res <- coeftest(ols, vcov=vcovHC(ols, "HC2"))
      
    # Retrieve results
    avg <- ols.res[,1]
    stderr <- ols.res[,2]
    
    # Tally up results
    data.frame(
      covariate, avg, stderr, pi.hat=factor(c('control', 'treatment')),
      # Used for coloring
      scaling=pnorm((avg - mean(avg))/sd(avg)), 
      # We will order based on how much variation is 'explained' by the averages
      # relative to the total variation of the covariate in the data
      variation=sd(avg) / sd(combined_df[,covariate]),
      # String to print in each cell in heatmap below
      labels=paste0(signif(avg, 3), "\n", "(", signif(stderr, 3), ")"))
  })
  df <- do.call(rbind, df)
  
  # a small optional trick to ensure heatmap will be in decreasing order of 'variation'
  df$covariate <- reorder(df$covariate, order(df$variation))
  
    # plot heatmap
  print(ggplot(df) +
      aes(pi.hat, covariate) +
      geom_tile(aes(fill = scaling)) + 
      geom_text(aes(label = labels)) +
      scale_fill_gradient(low = "#E1BE6A", high = "#40B0A6") +
      theme_minimal() + 
      ylab("") + xlab("") +
      theme(plot.title = element_text(size = 12, face = "bold"),
            axis.text=element_text(size=11)))
  
}

Main loop

# Main Loop
for (i in seq_along(forest_files)) {
  # Load the RData file
  load(file.path(processedpath, forest_files[i]))
  print(paste("Successfully loaded:", forest_files[i]))
  
  # Extract the forest object from final_results
  if (exists("final_results")) {
    forest_input <- final_results$results$forest
    print("Successfully extracted forest from final_results")
    
    # Extract file title for naming
    file_title <- sub("\\.RData$", "", basename(forest_files[i]))
    
    # Check if "clate" is present in file_title
    if (grepl("clate", file_title)) {
      # Do something if "clate" is found
      # Replicate cate results but with IV reg instead of LM.
      # You can replace this with any code you want to execute when "clate" is found
  } else {
      combined_df <- get_test_predictions(forest_input, file_title)
      final_results <- evaluate_policy(combined_df)

      # Extract the file title for naming without the cate_clate_results
      results_title <- sub("cate_results_sim_sbp_neg_alpha_5_presentation_cw0", "", file_title)  

      # Save the final_results to a CSV file
      write.csv(final_results, file.path(results_dir, paste0(results_title, "_final_results.csv")), row.names = FALSE)
    }
  } else {
      print("Error: 'final_results' not found in loaded data")
  }
  }
## [1] "Successfully loaded: cate_results_sim_sbp_neg_alpha_5_presentation_cw0.RData"
## [1] "Successfully extracted forest from final_results"
## [1] "Currently processing test predictions for  cate_results_sim_sbp_neg_alpha_5_presentation_cw0"
## [1] "Currently processing policy tree for  cate_results_sim_sbp_neg_alpha_5_presentation_cw0"
## [1] "Policy Tree: "
## policy_tree object 
## Tree depth:  2 
## Actions:  1: control 2: treated 
## Variable splits: 
## (1) split_variable: ed_charg_tot_pre_ed  split_value: 1107.95 
##   (2) split_variable: ed_charg_tot_pre_ed  split_value: 1052.9 
##     (4) * action: 2 
##     (5) * action: 1 
##   (3) split_variable: ed_charg_tot_pre_ed  split_value: 27182.6 
##     (6) * action: 2 
##     (7) * action: 1 
## [1] "Dimensions of pi.hat: "
## [1] 12167
## [1] "Dimensions of predictions: "
## [1] 6035
## [1] "Dimensions of pi.hat: "
## [1] 12167
## [1] "Dimensions of Y: "
## [1] 12167
## [1] "Dimensions of W: "
## [1] 12167
## [1] "Information on combined_df:"
## Rows: 6,035
## Columns: 3
## $ pi.hat <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
## $ Y      <dbl> -144.00, -84.61, -160.39, -98.00, -108.00, -104.00, -111.00, -1…
## $ W      <int> 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, …
## [1] "Starting policy evaluation"
## [1] "Policy Evaluation Results:"
## 
## 
## |Metric                    |Value            |
## |:-------------------------|:----------------|
## |Estimated Policy Value    |-119.084 (0.469) |
## |Uniform Control           |-124.528 (0.489) |
## |Difference from Control   |5.444 (0.678)    |
## |Uniform Treatment         |-119.004 (0.468) |
## |Difference from Treatment |-0.079 (0.663)   |