library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(data.table)
## 
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
## 
##     between, first, last
set.seed(123)

df <- data.frame(id = 1:400)

# define case definitions and coefficients for simulation
case_defs <- c(A = 1, B = 2, C = 3, D = 4)

# randomly sample from case definition components to make combined case definitions
df$alt <- sapply(1:nrow(df), function(i) {
  paste(sort(sample(x = names(case_defs), size = sample(1:4))), collapse = '')
})

df$ref <- sapply(1:nrow(df), function(i) {
  paste(sort(sample(x = names(case_defs), size = sample(1:4))), collapse = '')
})

# make half of the reference case definitions 'r' to signify the gold standard
df$ref <- ifelse(df$id %% 2 == 0, "r", df$ref)

# remove duplicates
df2 <- df %>%
  filter(ref != alt)


df2$y <- 0

# subtract coefficient if component is in the reference def (denominator)
for (i in names(case_defs)) df2$y <- df2$y - (sapply(i, grepl, df2$ref) * case_defs[i])
# add coefficient if component is in the alternative def (numerator)
for (i in names(case_defs)) df2$y <- df2$y + (sapply(i, grepl, df2$alt) * case_defs[i])

# add noise to outcome variable from N(0,1)
df2$y <- df2$y + rnorm(n = nrow(df2), mean = 0, sd = 1)
df2$y_se <- 0.5

# encode network
for (i in names(case_defs)) df2[, i] <- 0
for (i in names(case_defs)) df2[, i] <- df2[, i] - sapply(i, grepl, df2$ref)
for (i in names(case_defs)) df2[, i] <- df2[, i] + sapply(i, grepl, df2$alt)

head(as.data.table(df2))
##    id  alt ref         y y_se A  B C D
## 1:  1   AD   B  2.421825  0.5 1 -1 0 1
## 2:  2  BCD   r  8.461829  0.5 0  1 1 1
## 3:  3  ABD  AD  2.563249  0.5 0  1 0 0
## 4:  4 ABCD   r 12.751104  0.5 1  1 1 1
## 5:  5   CD   C  3.285564  0.5 0  0 0 1
## 6:  6    B   r  1.598300  0.5 0  1 0 0
# run MR-BRT model
mr_brt_dir <- "/home/j/temp/reed/prog/projects/run_mr_brt/"
source(paste0(mr_brt_dir, "cov_info_function.R"))
source(paste0(mr_brt_dir, "run_mr_brt_function.R"))
source(paste0(mr_brt_dir, "check_for_outputs_function.R"))
source(paste0(mr_brt_dir, "load_mr_brt_outputs_function.R"))


cov_list <- lapply(names(case_defs), function(x) cov_info(x, "X"))

fit1 <- run_mr_brt(
  output_dir = "/home/j/temp/reed/prog/data/",
  model_label = "crosswalk_network_simulation_v1",
  data = df2,
  mean_var = "y",
  se_var = "y_se",
  covs = cov_list,
  remove_x_intercept = TRUE,
  method = "remL",
  overwrite_previous = TRUE,
  project = "proj_custom_models"
)
## Model metadata 
## 
##   covariate design_matrix prior_mean prior_var type
## 1 intercept             Z          0       inf base
## 2         A             X          0       inf     
## 3         B             X          0       inf     
## 4         C             X          0       inf     
## 5         D             X          0       inf     
## 
## If outputs do not appear in '/home/j/temp/reed/prog/data/crosswalk_network_simulation_v1/' after the job finishes (or if your job is stuck in qwait), submit the following command in a qlogin session: 
## 
## sh /home/j/temp/reed/prog/utility_scripts/run_mr_brt.sh /home/j/temp/reed/prog/data/crosswalk_network_simulation_v1/ y y_se obs_id FALSE reml fit 0 0
check_for_outputs(fit1)
## Waiting up to 30 seconds for outputs...
## Outputs are now available in '/home/j/temp/reed/prog/data/crosswalk_network_simulation_v1/'
## [1] TRUE
results1 <- load_mr_brt_outputs(fit1)
results1$model_coefs
##   x_cov beta_soln    beta_var beta_prior_mean beta_prior_var     z_cov
## 1     A  1.024828 0.005625361               0            Inf intercept
## 2     B  1.967761 0.005014438               0            Inf          
## 3     C  2.988982 0.005463565               0            Inf          
## 4     D  4.030223 0.004999711               0            Inf          
##   gamma_soln gamma_prior_mean gamma_prior_var
## 1  0.7801801                0             Inf
## 2         NA               NA              NA
## 3         NA               NA              NA
## 4         NA               NA              NA