library(dplyr)
## Warning: package 'dplyr' was built under R version 4.4.3
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(knitr)
## Warning: package 'knitr' was built under R version 4.4.3
library(kableExtra)
## 
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
## 
##     group_rows
library(lmtest)
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
library(urca)
library(vars)
## Loading required package: MASS
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
## Loading required package: strucchange
## Loading required package: sandwich
library(tibble)
## Warning: package 'tibble' was built under R version 4.4.3
library(webshot2)
knitr::opts_chunk$set(echo = TRUE, warning = FALSE, message = FALSE)

source("./build_objects.R")
# =========================================================
# GLOBAL TABLE RENDERER FOR R MARKDOWN
# =========================================================



make_pretty_table <- function(df, title_text = NULL, digits = 4, file_prefix = "table") {
  df <- df %>%
    as.data.frame() %>%
    mutate(across(where(is.numeric), ~ round(.x, digits)))
  
  tbl <- df %>%
    knitr::kable(
      format = "html",
      caption = title_text,
      align = "l"
    ) %>%
    kableExtra::kable_styling(full_width = FALSE)
  
  file_name <- paste0(file_prefix, "_", gsub("[^A-Za-z0-9]", "_", title_text), ".png")
  
  kableExtra::save_kable(tbl, file = file_name)
  
  knitr::include_graphics(file_name)
}

rbind_fill <- function(...) {
  dfs <- list(...)
  all_names <- unique(unlist(lapply(dfs, names)))
  
  dfs_aligned <- lapply(dfs, function(df) {
    missing <- setdiff(all_names, names(df))
    if (length(missing) > 0) {
      for (m in missing) df[[m]] <- NA
    }
    df[, all_names, drop = FALSE]
  })
  
  out <- do.call(rbind, dfs_aligned)
  rownames(out) <- NULL
  out
}

safe_whitefit_p <- function(model) {
  aux <- data.frame(
    u2 = residuals(model)^2,
    fitvals = fitted(model)
  )
  white_aux <- lm(u2 ~ fitvals + I(fitvals^2), data = aux)
  bptest(white_aux)$p.value
}

collect_diag_tests <- function(model, model_name) {
  data.frame(
    model = model_name,
    BG_p = bgtest(model, order = 4)$p.value,
    LjungBox_p = Box.test(resid(model), lag = 12, type = "Ljung-Box")$p.value,
    BP_p = bptest(model)$p.value,
    WhiteFit_p = safe_whitefit_p(model),
    RESET2_p = resettest(model, power = 2, type = "fitted")$p.value,
    RESET23_p = resettest(model, power = 2:3, type = "fitted")$p.value
  )
}

# ## ADF Tables
make_best_lag_table <- function(adf_table) {
  best_AIC <- do.call(
    rbind,
    lapply(split(adf_table, adf_table$series), function(x) x[which.min(x$AIC), ])
  )
  
  best_BIC <- do.call(
    rbind,
    lapply(split(adf_table, adf_table$series), function(x) x[which.min(x$BIC), ])
  )
  
  best_summary <- merge(
    best_AIC[, c("series", "type", "lag", "tau_stat", "AIC")],
    best_BIC[, c("series", "lag", "tau_stat", "BIC")],
    by = "series",
    suffixes = c("_AIC", "_BIC")
  )
  
  names(best_summary) <- c(
    "series",
    "type",
    "best_lag_AIC",
    "tau_stat_AIC",
    "AIC",
    "best_lag_BIC",
    "tau_stat_BIC",
    "BIC"
  )
  
  rownames(best_AIC) <- NULL
  rownames(best_BIC) <- NULL
  rownames(best_summary) <- NULL
  
  list(
    full = adf_table,
    best_AIC = best_AIC,
    best_BIC = best_BIC,
    summary = best_summary
  )
}

adf_tables_levels <- make_best_lag_table(adf_lag_table)
adf_tables_logs   <- make_best_lag_table(adf_lag_table_log)
adf_tables_dlog   <- make_best_lag_table(adf_lag_table_dlog)

make_pretty_table(adf_tables_levels$full,     "ADF LEVEL VARIABLES - FULL LAG TABLE")

make_pretty_table(adf_tables_levels$best_AIC, "ADF LEVEL VARIABLES - BEST LAG BY AIC")

make_pretty_table(adf_tables_levels$best_BIC, "ADF LEVEL VARIABLES - BEST LAG BY BIC")

make_pretty_table(adf_tables_levels$summary,  "ADF LEVEL VARIABLES - BEST LAG SUMMARY")

make_pretty_table(adf_tables_logs$full,     "ADF LOG VARIABLES - FULL LAG TABLE")

make_pretty_table(adf_tables_logs$best_AIC, "ADF LOG VARIABLES - BEST LAG BY AIC")

make_pretty_table(adf_tables_logs$best_BIC, "ADF LOG VARIABLES - BEST LAG BY BIC")

make_pretty_table(adf_tables_logs$summary,  "ADF LOG VARIABLES - BEST LAG SUMMARY")

make_pretty_table(adf_tables_dlog$full,     "ADF DIFFERENCED VARIABLES - FULL LAG TABLE")

make_pretty_table(adf_tables_dlog$best_AIC, "ADF DIFFERENCED VARIABLES - BEST LAG BY AIC")

make_pretty_table(adf_tables_dlog$best_BIC, "ADF DIFFERENCED VARIABLES - BEST LAG BY BIC")

make_pretty_table(adf_tables_dlog$summary,  "ADF DIFFERENCED VARIABLES - BEST LAG SUMMARY")

# ## Engle-Granger Tables
collect_eg_results <- function(eg_obj, model_label) {
  data.frame(
    model = model_label,
    best_lag_AIC = eg_obj$best_AIC$lag,
    tau_stat_AIC = eg_obj$best_AIC$tau_stat,
    AIC = eg_obj$best_AIC$AIC,
    best_lag_BIC = eg_obj$best_BIC$lag,
    tau_stat_BIC = eg_obj$best_BIC$tau_stat,
    BIC = eg_obj$best_BIC$BIC
  )
}

eg_table_no_break <- rbind(
  collect_eg_results(eg_baseline, "baseline"),
  collect_eg_results(eg_pathway1, "pathway1_imports"),
  collect_eg_results(eg_pathway2, "pathway2_longrate"),
  collect_eg_results(eg_optional, "optional_both_pathways")
)
rownames(eg_table_no_break) <- NULL

eg_table_break <- rbind(
  collect_eg_results(eg_baseline_break, "baseline_break"),
  collect_eg_results(eg_pathway1_break, "pathway1_break"),
  collect_eg_results(eg_pathway2_break, "pathway2_break"),
  collect_eg_results(eg_fdi_break, "fdi_break"),
  collect_eg_results(eg_optional_break, "optional_break"),
  collect_eg_results(eg_optional_breakFDI, "optional_breakFDI")
)
rownames(eg_table_break) <- NULL

make_pretty_table(eg_table_no_break, "ENGLE-GRANGER RESULTS - NO BREAK")

make_pretty_table(eg_table_break, "ENGLE-GRANGER RESULTS - BREAK MODELS")

# ## Long-Run Regression Fit Tables
collect_lm_fit_stats <- function(model, model_label) {
  s <- summary(model)
  data.frame(
    model = model_label,
    n = length(resid(model)),
    R2 = s$r.squared,
    Adj_R2 = s$adj.r.squared,
    Resid_SE = s$sigma,
    F_stat = unname(s$fstatistic[1]),
    F_df1 = unname(s$fstatistic[2]),
    F_df2 = unname(s$fstatistic[3]),
    F_p = pf(
      unname(s$fstatistic[1]),
      unname(s$fstatistic[2]),
      unname(s$fstatistic[3]),
      lower.tail = FALSE
    )
  )
}

longrun_fit_no_break <- rbind(
  collect_lm_fit_stats(eg_baseline$long_run, "baseline"),
  collect_lm_fit_stats(eg_pathway1$long_run, "pathway1_imports"),
  collect_lm_fit_stats(eg_pathway2$long_run, "pathway2_longrate"),
  collect_lm_fit_stats(eg_optional$long_run, "optional_both_pathways")
)
rownames(longrun_fit_no_break) <- NULL

longrun_fit_break <- rbind(
  collect_lm_fit_stats(eg_baseline_break$long_run, "baseline_break"),
  collect_lm_fit_stats(eg_pathway1_break$long_run, "pathway1_break"),
  collect_lm_fit_stats(eg_pathway2_break$long_run, "pathway2_break"),
  collect_lm_fit_stats(eg_fdi_break$long_run, "fdi_break"),
  collect_lm_fit_stats(eg_optional_break$long_run, "optional_break"),
  collect_lm_fit_stats(eg_optional_breakFDI$long_run, "optional_breakFDI")
)
rownames(longrun_fit_break) <- NULL

make_pretty_table(longrun_fit_no_break, "LONG-RUN REGRESSION FIT - NO BREAK")

make_pretty_table(longrun_fit_break, "LONG-RUN REGRESSION FIT - BREAK MODELS")

# ## Short-Run Regression Tables
collect_short_run_stats <- function(model, model_label, key_vars = NULL) {
  s <- summary(model)
  coefs <- s$coefficients
  
  out <- data.frame(
    model = model_label,
    n = length(resid(model)),
    R2 = s$r.squared,
    Adj_R2 = s$adj.r.squared,
    Resid_SE = s$sigma,
    F_stat = unname(s$fstatistic[1]),
    F_p = pf(
      unname(s$fstatistic[1]),
      unname(s$fstatistic[2]),
      unname(s$fstatistic[3]),
      lower.tail = FALSE
    )
  )
  
  if (!is.null(key_vars)) {
    for (v in key_vars) {
      out[[paste0(v, "_coef")]] <- if (v %in% rownames(coefs)) coefs[v, 1] else NA_real_
      out[[paste0(v, "_p")]]    <- if (v %in% rownames(coefs)) coefs[v, 4] else NA_real_
    }
  }
  
  out
}

short_run_table_basic <- rbind_fill(
  collect_short_run_stats(
    mod_dlog_baseline, "baseline",
    key_vars = c("d_log_treasury_gdp", "d_log_trade_balance", "d_log_min_wage_real")
  ),
  collect_short_run_stats(
    mod_dlog_fdi, "pathway_fdi",
    key_vars = c("d_log_fdi_outward_gdp")
  ),
  collect_short_run_stats(
    mod_dlog_imports, "pathway_imports",
    key_vars = c("d_log_imports_ind_real_gdp")
  ),
  collect_short_run_stats(
    mod_dlog_longrate, "pathway_longrate",
    key_vars = c("d_long_rate")
  ),
  collect_short_run_stats(
    mod_dlog_full_with_fdi, "full_with_fdi",
    key_vars = c("d_log_fdi_outward_gdp", "d_log_imports_ind_real_gdp", "d_long_rate")
  ),
  collect_short_run_stats(
    mod_dlog_full_no_fdi, "full_no_fdi",
    key_vars = c("d_log_imports_ind_real_gdp", "d_long_rate")
  )
)

short_run_table_pandemic <- rbind_fill(
  collect_short_run_stats(
    mod_dlog_fdi_pandemic,
    "fdi_pandemic",
    key_vars = c("d_log_fdi_outward_gdp", "pandemic_q1")
  ),
  collect_short_run_stats(
    mod_dlog_full_with_fdi_pandemic,
    "full_with_fdi_pandemic",
    key_vars = c("d_log_fdi_outward_gdp", "d_log_imports_ind_real_gdp", "d_long_rate", "pandemic_q1")
  ),
  collect_short_run_stats(
    mod_dlog_fdi_break,
    "fdi_break",
    key_vars = c("d_log_fdi_outward_gdp", "pandemic_q1", "D2000", "DT2000")
  )
)

make_pretty_table(short_run_table_basic, "SHORT-RUN DIFFERENCED MODELS - FIT AND KEY COEFFICIENTS")

make_pretty_table(short_run_table_pandemic, "SHORT-RUN PANDEMIC / BREAK MODELS - FIT AND KEY COEFFICIENTS")

# ## Short-Run Diagnostics Table
short_run_diag_table <- rbind(
  collect_diag_tests(mod_dlog_fdi, "mod_dlog_fdi"),
  collect_diag_tests(mod_dlog_full_with_fdi, "mod_dlog_full_with_fdi"),
  collect_diag_tests(mod_dlog_fdi_pandemic, "mod_dlog_fdi_pandemic"),
  collect_diag_tests(mod_dlog_full_with_fdi_pandemic, "mod_dlog_full_with_fdi_pandemic"),
  collect_diag_tests(mod_dlog_fdi_break, "mod_dlog_fdi_break")
)
rownames(short_run_diag_table) <- NULL

make_pretty_table(short_run_diag_table, "SHORT-RUN MODEL DIAGNOSTICS")

# ## ECM Fit Table
collect_ecm_stats <- function(model, model_label, ect_name, extra_vars = NULL) {
  s <- summary(model)
  coefs <- s$coefficients
  
  out <- data.frame(
    model = model_label,
    n = length(resid(model)),
    R2 = s$r.squared,
    Adj_R2 = s$adj.r.squared,
    Resid_SE = s$sigma,
    F_stat = unname(s$fstatistic[1]),
    F_p = pf(
      unname(s$fstatistic[1]),
      unname(s$fstatistic[2]),
      unname(s$fstatistic[3]),
      lower.tail = FALSE
    ),
    ECT_coef = if (ect_name %in% rownames(coefs)) coefs[ect_name, 1] else NA_real_,
    ECT_p = if (ect_name %in% rownames(coefs)) coefs[ect_name, 4] else NA_real_,
    pandemic_coef = if ("pandemic_q1_d" %in% rownames(coefs)) coefs["pandemic_q1_d", 1] else NA_real_,
    pandemic_p = if ("pandemic_q1_d" %in% rownames(coefs)) coefs["pandemic_q1_d", 4] else NA_real_
  )
  
  if (!is.null(extra_vars)) {
    for (v in extra_vars) {
      out[[paste0(v, "_coef")]] <- if (v %in% rownames(coefs)) coefs[v, 1] else NA_real_
      out[[paste0(v, "_p")]]    <- if (v %in% rownames(coefs)) coefs[v, 4] else NA_real_
    }
  }
  
  out
}

ecm_fit_table <- rbind_fill(
  collect_ecm_stats(
    ecm_pathway1_break,
    "ecm_pathway1_break",
    "ect_pathway1_break_dd",
    extra_vars = c("d_log_imports_ind_real_gdp")
  ),
  collect_ecm_stats(
    ecm_pathway2_break,
    "ecm_pathway2_break",
    "ect_pathway2_break_dd",
    extra_vars = c("d_long_rate")
  ),
  collect_ecm_stats(
    ecm_optional_break,
    "ecm_optional_break",
    "ect_optional_break_dd",
    extra_vars = c("d_log_imports_ind_real_gdp", "d_long_rate")
  ),
  collect_ecm_stats(
    ecm_optional_breakFDI,
    "ecm_optional_breakFDI",
    "ect_optional_breakFDI_dd",
    extra_vars = c("d_log_imports_ind_real_gdp", "d_long_rate", "d_log_fdi_outward_gdp")
  )
)

make_pretty_table(ecm_fit_table, "ECM MODELS - FIT AND KEY COEFFICIENTS")

# ## ECM Diagnostics Table
ecm_test_table <- rbind(
  collect_diag_tests(ecm_pathway1_break, "ecm_pathway1_break"),
  collect_diag_tests(ecm_pathway2_break, "ecm_pathway2_break"),
  collect_diag_tests(ecm_optional_break, "ecm_optional_break"),
  collect_diag_tests(ecm_optional_breakFDI, "ecm_optional_breakFDI")
)
rownames(ecm_test_table) <- NULL

make_pretty_table(ecm_test_table, "ECM MODEL DIAGNOSTICS")

# ## VAR Overview and Diagnostics Tables
collect_var_info <- function(var_obj, model_name) {
  roots_vals <- roots(var_obj)
  data.frame(
    model = model_name,
    n = var_obj$obs,
    K = var_obj$K,
    p = var_obj$p,
    max_root = max(Mod(roots_vals)),
    min_root = min(Mod(roots_vals))
  )
}

collect_var_tests <- function(var_obj, model_name) {
  serial_p <- serial.test(var_obj, lags.pt = 12, type = "PT.asymptotic")$serial$p.value
  arch_p <- arch.test(var_obj, lags.multi = 4)$arch.mul$p.value
  normality_p <- normality.test(var_obj)$jb.mul$JB$p.value
  
  data.frame(
    model = model_name,
    Serial_p = serial_p,
    ARCH_p = arch_p,
    Normality_p = normality_p,
    max_root = max(Mod(roots(var_obj)))
  )
}

var_overview_table <- rbind(
  collect_var_info(var_fdi, "var_fdi"),
  collect_var_info(var_no_fdi, "var_no_fdi")
)
rownames(var_overview_table) <- NULL

var_diag_table <- rbind(
  collect_var_tests(var_fdi, "var_fdi"),
  collect_var_tests(var_no_fdi, "var_no_fdi")
)
rownames(var_diag_table) <- NULL

make_pretty_table(var_overview_table, "VAR MODELS - OVERVIEW")

make_pretty_table(var_diag_table, "VAR MODEL DIAGNOSTICS")

# ## VAR Granger Tables
collect_granger_result <- function(var_obj, cause_var, model_name) {
  g <- causality(var_obj, cause = cause_var)
  
  data.frame(
    model = model_name,
    cause = cause_var,
    Granger_p = g$Granger$p.value,
    Instant_p = g$Instant$p.value
  )
}

var_fdi_granger_table <- rbind(
  collect_granger_result(var_fdi, "d_log_fdi_outward_gdp", "var_fdi"),
  collect_granger_result(var_fdi, "d_log_wage_gap", "var_fdi")
)
rownames(var_fdi_granger_table) <- NULL

var_no_fdi_granger_table <- rbind(
  collect_granger_result(var_no_fdi, "d_log_imports_ind_real_gdp", "var_no_fdi"),
  collect_granger_result(var_no_fdi, "d_long_rate", "var_no_fdi"),
  collect_granger_result(var_no_fdi, "d_log_wage_gap", "var_no_fdi")
)
rownames(var_no_fdi_granger_table) <- NULL

make_pretty_table(var_fdi_granger_table, "VAR WITH FDI - GRANGER CAUSALITY")

make_pretty_table(var_no_fdi_granger_table, "VAR WITHOUT FDI - GRANGER CAUSALITY")

# ## IRF Tables
collect_irf_table <- function(irf_obj, model_name, impulse_name, response_name) {
  data.frame(
    model = model_name,
    impulse = impulse_name,
    response = response_name,
    horizon = seq_along(irf_obj$irf[[1]]),
    irf = as.numeric(irf_obj$irf[[1]]),
    lower = as.numeric(irf_obj$Lower[[1]]),
    upper = as.numeric(irf_obj$Upper[[1]])
  )
}

irf_table_fdi_on_wage <- collect_irf_table(
  irf_fdi_on_wage,
  "var_fdi",
  "d_log_fdi_outward_gdp",
  "d_log_wage_gap"
)

irf_table_wage_on_fdi <- collect_irf_table(
  irf_wage_on_fdi,
  "var_fdi",
  "d_log_wage_gap",
  "d_log_fdi_outward_gdp"
)

irf_table_imports_on_wage <- collect_irf_table(
  irf_imports_on_wage,
  "var_no_fdi",
  "d_log_imports_ind_real_gdp",
  "d_log_wage_gap"
)

irf_table_longrate_on_wage <- collect_irf_table(
  irf_longrate_on_wage,
  "var_no_fdi",
  "d_long_rate",
  "d_log_wage_gap"
)

make_pretty_table(irf_table_fdi_on_wage, "IRF TABLE - FDI SHOCK ON WAGE GAP")

make_pretty_table(irf_table_wage_on_fdi, "IRF TABLE - WAGE GAP SHOCK ON FDI")

make_pretty_table(irf_table_imports_on_wage, "IRF TABLE - IMPORTS SHOCK ON WAGE GAP")

make_pretty_table(irf_table_longrate_on_wage, "IRF TABLE - LONG RATE SHOCK ON WAGE GAP")