#
# 20220603_rebecca_xwalk_check.R
#


# library(crosswalk002, lib.loc = path_to_r_version4_packages)
library(vctrs, lib.loc="/ihme/singularity-images/rstudio/lib/4.1.3.4") #for dplyr in MRBRT library (only need sometimes)
library(crosswalk002, lib.loc = "/ihme/code/mscm/Rv4/packages/")
## Loading required package: reticulate
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following object is masked from 'package:vctrs':
## 
##     data_frame
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
set.seed(123)
beta0_true <- -3
beta1_true <- 1
beta2_true <- 2

gamma_tmp <- 0.6
sigma_tmp <- 1

df_matched <- data.frame(
  x1 = runif(n = 200, min = 0, max = 4),
  x2 = rbinom(n = 200, prob = 0.5, size = 1) ) %>%
  mutate(
    logit_diff = beta0_true + x1*beta1_true + x2*beta2_true + 
      rnorm(n = nrow(.), mean = 0, ) + rnorm(n = nrow(.), mean = 0, sd = sqrt(gamma_tmp)),
    logit_diff_se = runif(200, min = 0.9, max = 1.1),
    altvar = "selfreported",
    refvar = "measured",
    group_id = rep(1:20, each = 10)
  )
head(df_matched)
##         x1 x2 logit_diff logit_diff_se       altvar   refvar group_id
## 1 1.150310  0  0.2921442     1.0717183 selfreported measured        1
## 2 3.153221  1  2.5604000     1.0774770 selfreported measured        1
## 3 1.635908  1 -0.1209113     0.9978183 selfreported measured        1
## 4 3.532070  1  3.0529231     1.0436184 selfreported measured        1
## 5 3.761869  0  0.8670481     0.9973411 selfreported measured        1
## 6 0.182226  1 -2.5725288     1.0977418 selfreported measured        1
df_orig <- data.frame(stringsAsFactors = FALSE,
                      meanvar = runif(400, min = 0.2, max = 0.8), # original prevalence values; between 0 and 1
                      sdvar = runif(400, min = 0.1, max = 0.5), # standard errors of the original prevalence values; >0
                      x1 = runif(400, min = 0, max = 4),
                      x2 = rbinom(400, prob = 0.5, size = 1),
                      obs_method = sample(c("selfreported", "measured"), size = 400, replace = TRUE)
)
df_orig$row_id <- paste0("row", 1:nrow(df_orig))
head(df_orig)
##     meanvar     sdvar        x1 x2   obs_method row_id
## 1 0.3399438 0.1621057 3.2806579  0     measured   row1
## 2 0.3383922 0.4383404 2.8668819  1 selfreported   row2
## 3 0.2370358 0.1857522 3.1532994  0     measured   row3
## 4 0.4982712 0.3679493 0.0732640  1     measured   row4
## 5 0.3464751 0.3471026 0.6979606  0 selfreported   row5
## 6 0.6549113 0.1199999 2.3945959  0     measured   row6
df1 <- CWData(
  df = df_matched,          # dataset for metaregression
  obs = "logit_diff",       # column name for the observation mean
  obs_se = "logit_diff_se", # column name for the observation standard error
  alt_dorms = "altvar",     # column name of the variable indicating the alternative method
  ref_dorms = "refvar",     # column name of the variable indicating the reference method
  covs = list("x1", "x2"),     # names of columns to be used as covariates later
  study_id = "group_id",    # name of the column indicating group membership, usually the matching groups
  add_intercept = TRUE      # adds a column called "intercept" that may be used in CWModel()
)
fit1 <- CWModel(
  cwdata = df1,            # object returned by `CWData()`
  obs_type = "diff_logit", # "diff_log" or "diff_logit" depending on whether bounds are [0, Inf) or [0, 1]
  cov_models = list(       # specifying predictors in the model; see help(CovModel)
    CovModel(cov_name = "intercept"),
    CovModel(cov_name = "x1"),
    # CovModel(cov_name = "x1", spline = XSpline(knots = c(0,1,2,3,4), degree = 3L, l_linear = TRUE, r_linear = TRUE), spline_monotonicity = "increasing"),
    CovModel(cov_name = "x2") ),
  gold_dorm = "measured"   # the level of `alt_dorms` that indicates it's the gold standard
  # this will be useful when we can have multiple "reference" groups in NMA
)
print(data.frame( # checking that the model recovers the true betas
  beta_true = c(beta0_true, beta1_true, beta2_true), 
  beta_mean = fit1$beta[4:6],
  beta_se = fit1$beta_sd[4:6]
))
##   beta_true beta_mean    beta_se
## 1        -3 -3.061389 0.17137642
## 2         1  1.084463 0.06535806
## 3         2  1.927910 0.14241179
preds1 <- adjust_orig_vals(
  fit_object = fit1, # object returned by `CWModel()`
  df = df_orig,
  orig_dorms = "obs_method",
  orig_vals_mean = "meanvar",
  orig_vals_se = "sdvar",
  data_id = "row_id"   # optional argument to add a user-defined ID to the predictions;
  # name of the column with the IDs
)

# the result of adjust_orig_vals() is a five-element list,
# vectors containing: 
# -- the adjusted mean and SE of the adjusted mean in linear space
# -- the adjustment factor and SE of the adjustment factor in transformed space;
#    note that the adjustment factor is the alt-ref prediction,
#    so we *subtract* this value to make the adjustment
# -- an identifier for the row of the prediction frame the corresponds to the prediction
lapply(preds1, head)
## $ref_vals_mean
## [1] 0.3399438 0.0662351 0.2370358 0.4982712 0.8415635 0.6549113
## 
## $ref_vals_sd
## [1] 0.1621057 0.1230136 0.1857522 0.3679493 0.2073851 0.1199999
## 
## $pred_diff_mean
## [1]  0.000000  1.975547  0.000000  0.000000 -2.304477  0.000000
## 
## $pred_diff_sd
## [1] 0.0000000 0.2911356 0.0000000 0.0000000 0.1773438 0.0000000
## 
## $data_id
## [1] "row1" "row2" "row3" "row4" "row5" "row6"
df_orig[, 
        c("meanvar_adjusted", "sdvar_adjusted", 
          "pred_logit", "pred_se_logit", "data_id")] <- preds1
# note that the gold standard observations remain untouched
head(df_orig)
##     meanvar     sdvar        x1 x2   obs_method row_id meanvar_adjusted
## 1 0.3399438 0.1621057 3.2806579  0     measured   row1        0.3399438
## 2 0.3383922 0.4383404 2.8668819  1 selfreported   row2        0.0662351
## 3 0.2370358 0.1857522 3.1532994  0     measured   row3        0.2370358
## 4 0.4982712 0.3679493 0.0732640  1     measured   row4        0.4982712
## 5 0.3464751 0.3471026 0.6979606  0 selfreported   row5        0.8415635
## 6 0.6549113 0.1199999 2.3945959  0     measured   row6        0.6549113
##   sdvar_adjusted pred_logit pred_se_logit data_id
## 1      0.1621057   0.000000     0.0000000    row1
## 2      0.1230136   1.975547     0.2911356    row2
## 3      0.1857522   0.000000     0.0000000    row3
## 4      0.3679493   0.000000     0.0000000    row4
## 5      0.2073851  -2.304477     0.1773438    row5
## 6      0.1199999   0.000000     0.0000000    row6
print(df_orig[2,]) # row 2 is a self-reported observation with prevalence 0.338 (or 33.8%)
##     meanvar     sdvar       x1 x2   obs_method row_id meanvar_adjusted
## 2 0.3383922 0.4383404 2.866882  1 selfreported   row2        0.0662351
##   sdvar_adjusted pred_logit pred_se_logit data_id
## 2      0.1230136   1.975547     0.2911356    row2
fit1$fixed_vars # estimated betas
## $measured
## [1] 0 0 0
## 
## $selfreported
## [1] -3.061389  1.084463  1.927910
#> $measured
#> [1] 0 0 0
#> 
#> $selfreported
#> [1] -3.164455  1.081816  2.076966
# the predicted adjustment for an observations with x1=0.8575217 and x2=1 should be...



(pred <- -3.061389 + 1.084463*2.866882 + 1.927910*1) # b0 + b1*x1 + b2*x2
## [1] 1.975548
# the prediction is defined as logit(alt) - logit(ref), so the final adjusted value should be
# logit(mean_adjusted) = logit(mean_alt) - prediction
logit <- function(p) log(p/(1-p))
inv_logit <- function(x) exp(x)/(1+exp(x))
logit_mean_adjusted <- logit(df_orig[2, "meanvar"]) - pred
inv_logit(logit_mean_adjusted)
## [1] 0.06623499
round(inv_logit(logit_mean_adjusted), digits = 4) == round(df_orig[2, "meanvar_adjusted"], digits = 4)
## [1] TRUE
#> [1] TRUE
# SE of the adjusted data point is calculated as sqrt(a^2 + b^2 + c^2), where
# a is the (log or logit) standard error of the original data point,
# b is the standard error of the predicted adjustment
# c is the standard deviation of between-group heterogeneity, a.k.a. sqrt(gamma)
#
# note that a, b, and c are all in transformed (log or logit) space
# 
# this method increases an adjusted observation's uncertainty and effectively downweights it in subsequent analyses, like an ST-GPR model to estimate prevalence globally
#

fit1$gamma
## [1] 0.03785378
orig_se_logit <- crosswalk002::delta_transform(
  mean = df_orig[2, "meanvar"], 
  sd = df_orig[2, "sdvar"], 
  transformation = "linear_to_logit"
)[1, "sd_logit"]

total_se_logit <- sqrt(orig_se_logit^2 + df_orig[2, "pred_se_logit"]^2 + fit1$gamma)

# manually calculated linear-space SE
crosswalk002::delta_transform(mean = logit_mean_adjusted, sd = total_se_logit, transformation = "logit_to_linear")
##      mean_linear sd_linear
## [1,]  0.06623499 0.1230134
# package calculated linear-space SE
df_orig[2, "sdvar_adjusted"]
## [1] 0.1230136