rm(list=ls())
set.seed(1234)
getwd()
## [1] "D:/Projects/Live-PK/0.DataCleaning/0.Code"
library("WeightIt")
library("MatchIt")
library("marginaleffects")
library("cobalt")
library("tidyverse")
library("fixest")
library("arrow")
library(PanelMatch)
library(fect)
library(panelView)

df <- read_parquet("../../0.DataCleaning/1.Input/synthetic_data.parquet") %>%
  mutate(
    p_date = as_date(p_date),              # Convert to Date format
    day = format(p_date, "%Y-%m-%d"),      # Extract day in "YYYY-MM-DD" format
    month = format(p_date, "%Y-%m"),       # Extract month in "YYYY-MM" format
    year = format(p_date, "%Y"),           # Extract year in "YYYY" format
    quarter = paste0(year(p_date), "-Q", quarter(p_date))  # Extract quarter in "YYYY-QN" format
  ) %>%
  mutate(
    gender = as.factor(gender),
    author_type = as.factor(author_type),
    author_income_range = as.factor(author_income_range),
    age_range = as.factor(age_range),
    fre_country_region = as.factor(fre_country_region),
    fre_city_level = as.factor(fre_city_level),
    is_big_v = as.factor(is_big_v)
  )

# Step 1: Keep only the first live session per author_id per day
df_first_live_per_day <- df %>%
  group_by(author_id, day) %>%
  slice_min(p_date, with_ties = FALSE) %>% # Keep the first row by p_date for each author_id/day
  ungroup()

# Step 2: Calculate the relative integer difference to "2021-12-31"
df_first_live_per_day <- df_first_live_per_day %>%
  mutate(
    reference_date = as_date("2022-12-31"),
    relative_day = as.integer(difftime(as_date(day), reference_date, units = "days"))  # Difference in days
  )

df_first_live_per_day <- as.data.frame(df_first_live_per_day)
# gc()
summary(df_first_live_per_day$relative_day)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   270.0   361.0   453.0   452.8   545.0   635.0
summary(df_first_live_per_day$is_pk_live)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0000  0.0000  0.0000  0.2983  1.0000  1.0000
# DisplayTreatment(unit.id = "author_id", time.id = "relative_day", 
#                  legend.position = "none", xlab = "relative_day", ylab = "author_id",
#                  treatment = "is_pk_live", data = df_first_live_per_day,
#                  hide.x.tick.label = TRUE, hide.y.tick.label = TRUE)
# DisplayTreatment(unit.id = "wbcode2",
# time.id = "year", legend.position = "none",
# xlab = "year", ylab = "Country Code",
# treatment = "dem", data = dem)

panelview(total_cost_amt ~ is_pk_live, data = df_first_live_per_day, index = c("author_id","relative_day"), 
  axis.lab = "time", xlab = "Time", ylab = "Unit", 
  gridOff = TRUE, by.timing = TRUE,
  background = "white", main = "Simulated Data: Treatment Status")
## If the number of units is more than 500, we randomly select 500 units to present.
##         You can set "display.all = TRUE" to show all units.

PM.results.none <- PanelMatch(lag = 5, time.id = "relative_day", unit.id = "author_id",
                    treatment = "is_pk_live", refinement.method = "none",
                    data = df_first_live_per_day, match.missing = TRUE,
                    size.match = 5, qoi = "att", outcome.var = "avg_total_cost_amt",
                    lead = 0, forbid.treatment.reversal = FALSE,
                    use.diagonal.variance.matrix = TRUE)
# In real data, lag can be larger, like 10; size.match can be larger, like 10.

mahalanobis distance matching

PM.results.maha <- PanelMatch(lag = 5, time.id = "relative_day", unit.id = "author_id",
                    treatment = "is_pk_live", refinement.method = "mahalanobis",
                    data = df_first_live_per_day, match.missing = TRUE,
                    covs.formula = ~ gender + author_type + author_income_range + age_range +
                      I(lag(play_duration_7d, 1:5)) + I(lag(live_duration_7d, 1:5)),
                    size.match = 5, qoi = "att", outcome.var = "avg_total_cost_amt",
                    lead = 0, forbid.treatment.reversal = FALSE,
                    use.diagonal.variance.matrix = TRUE)
# Extract the first matched set
mset <- PM.results.maha$att[333]
# Use the DisplayTreatment function to visualize the
# treated unit and matched controls.
DisplayTreatment(unit.id = "author_id", time.id = "relative_day", 
                legend.position = "none", xlab = "relative_day", ylab = "author_id",
                treatment = "is_pk_live", data = df_first_live_per_day,
                matched.set = mset, # this way we highlight the particular set
                show.set.only = TRUE)

get_covariate_balance(PM.results.maha$att,
data = df_first_live_per_day,
covariates = c("play_duration_7d", "live_duration_7d"),
plot = FALSE)
##     play_duration_7d live_duration_7d
## t_5      0.019386089     -0.044670937
## t_4     -0.021214794     -0.028396664
## t_3      0.046270605     -0.049592010
## t_2      0.023305289      0.013694906
## t_1      0.005190167     -0.003737995
## t_0     -0.012462018     -0.003525231

PS weighting

PM.results.ps.weight <- PanelMatch(lag = 5, time.id = "relative_day", unit.id = "author_id",
                    treatment = "is_pk_live", refinement.method = "ps.weight",
                    data = df_first_live_per_day, match.missing = TRUE,
                    covs.formula = ~ gender + author_type + author_income_range + age_range +
                      I(lag(play_duration_7d, 1:5)) + I(lag(live_duration_7d, 1:5)),
                    size.match = 5, qoi = "att", outcome.var = "avg_total_cost_amt",
                    lead = 0, forbid.treatment.reversal = FALSE,
                    use.diagonal.variance.matrix = TRUE)
get_covariate_balance(PM.results.ps.weight$att,
data = df_first_live_per_day,
covariates = c("play_duration_7d", "live_duration_7d"),
plot = FALSE)
##     play_duration_7d live_duration_7d
## t_5      0.015901890     -0.033042619
## t_4     -0.015223210     -0.009537915
## t_3      0.034324536     -0.027919140
## t_2      0.020825703      0.013666586
## t_1      0.004880736     -0.004623150
## t_0     -0.013304107     -0.001494143
get_covariate_balance(PM.results.ps.weight$att,
data = df_first_live_per_day,
use.equal.weights = TRUE,
covariates = c("play_duration_7d", "live_duration_7d"),
plot = TRUE, ylim = c(-.1, .1))

get_covariate_balance(PM.results.ps.weight$att,
data = df_first_live_per_day,
covariates = c("play_duration_7d", "live_duration_7d"),
plot = TRUE, ylim = c(-.1, .1))

balance_scatter(
  matched_set_list = list(PM.results.maha$att, PM.results.ps.weight$att),
  data = df_first_live_per_day,
  covariates = c("play_duration_7d", "live_duration_7d"))

PE.results <- PanelEstimate(sets = PM.results.ps.weight, data = df_first_live_per_day,
              se.method = "bootstrap",
              number.iterations = 1000,
              parallel = TRUE,
              num.cores = 6,
              confidence.level = .95)

# View the point estimates
PE.results[["estimates"]]
##        t+0 
## -0.3458811
# View standard errors
PE.results[["standard.error"]]
##       t+0 
## 0.4401584
summary(PE.results)
## Weighted Difference-in-Differences with Propensity Score
## Matches created with 5 lags
## 
## Standard errors computed with 1000 Weighted bootstrap samples
## 
## Estimate of Average Treatment Effect on the Treated (ATT) by Period:
## $summary
##       estimate std.error      2.5%     97.5%
## t+0 -0.3458811 0.4401584 -1.233423 0.5214371
## 
## $lag
## [1] 5
## 
## $iterations
## [1] 1000
## 
## $qoi
## [1] "att"
plot(PE.results)