rm(list=ls())
set.seed(1234)
getwd()
## [1] "D:/Projects/Live-PK/0.DataCleaning/0.Code"
library("tidyverse")
library("fixest")
library("arrow")
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(ggplot2)
library(ranger)
library(xgboost)

# suppress messages during fitting
lgr::get_logger("mlr3")$set_threshold("warn") 

df <- read_parquet("../../0.DataCleaning/1.Input/synthetic_data.parquet") %>%
  mutate(
    p_date = as_date(p_date),              # Convert to Date format
    day = format(p_date, "%Y-%m-%d"),      # Extract day in "YYYY-MM-DD" format
    month = format(p_date, "%Y-%m"),       # Extract month in "YYYY-MM" format
    year = format(p_date, "%Y"),           # Extract year in "YYYY" format
    quarter = paste0(year(p_date), "-Q", quarter(p_date))  # Extract quarter in "YYYY-QN" format
  ) %>%
  mutate(
    reference_date = as_date("2022-12-31"),
    relative_day = as.integer(difftime(as_date(day), reference_date, units = "days"))  # Difference in days
  ) %>%
  mutate(
    gender = as.factor(gender),
    author_type = as.factor(author_type),
    author_income_range = as.factor(author_income_range),
    age_range = as.factor(age_range),
    fre_country_region = as.factor(fre_country_region),
    fre_city_level = as.factor(fre_city_level),
    is_big_v = as.factor(is_big_v),
    relative_day = as.factor(relative_day)
  )

# gc()

Model-free

summary(df$relative_day)
##     402     284     607     307     565     618     395     393     398     415 
##     317     315     315     313     312     311     310     306     305     305 
##     626     400     270     279     381     545     547     278     343     387 
##     304     302     300     300     300     300     300     299     299     299 
##     523     532     625     273     439     591     339     403     582     593 
##     299     299     299     298     298     298     297     297     297     297 
##     594     479     552     296     346     446     287     291     390     391 
##     297     296     296     295     295     295     294     294     294     294 
##     312     321     338     358     445     472     521     551     614     309 
##     293     293     293     293     293     293     293     293     293     292 
##     328     394     460     474     515     617     375     453     539     463 
##     292     292     292     292     292     292     291     291     291     290 
##     490     497     574     486     504     601     294     396     414     456 
##     290     290     290     289     289     289     288     288     288     288 
##     522     530     542     599     633     527     563     612     351     372 
##     288     288     288     288     288     287     287     287     285     285 
##     404     464     503     511     592     621     363     412     417     491 
##     285     285     285     285     285     285     284     284     284     284 
##     581     597     604     306     313     336     360     382     451 (Other) 
##     284     284     284     283     283     283     283     283     283   70959
summary(df$is_pk_live)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0000  0.0000  0.0000  0.2981  1.0000  1.0000
str(df)
## tibble [100,000 × 48] (S3: tbl_df/tbl/data.frame)
##  $ author_id               : int [1:100000] 10162 10538 10074 10436 10334 10293 10007 10180 10501 10504 ...
##  $ live_id                 : int [1:100000] 52748 56505 54926 55204 55998 51037 59513 59822 54103 51242 ...
##  $ p_date                  : Date[1:100000], format: "2024-04-11" "2023-10-02" ...
##  $ is_pk_live              : int [1:100000] 0 0 0 0 0 0 0 0 0 1 ...
##  $ valid_play_duration     : int [1:100000] 491 990 557 601 770 330 330 738 424 56 ...
##  $ valid_play_user_num     : int [1:100000] 851 635 24 673 589 43 452 290 739 331 ...
##  $ avg_valid_play_duration : num [1:100000] 5.56 9.19 2.82 6.2 8.05 ...
##  $ total_cost_amt          : num [1:100000] 608 393 915 658 568 ...
##  $ total_cost_user_num     : int [1:100000] 387 357 381 240 203 138 13 2 63 325 ...
##  $ avg_total_cost_amt      : num [1:100000] 44 28.9 19.1 21.4 41 ...
##  $ comment_cnt             : int [1:100000] 367 219 432 9 254 399 347 260 400 46 ...
##  $ comment_user_num        : int [1:100000] 86 17 255 65 299 54 244 252 218 28 ...
##  $ avg_comment_cnt         : num [1:100000] 2.29 4.38 7.31 7.06 9.11 ...
##  $ like_cnt                : int [1:100000] 4130 7452 976 1332 5915 5497 4863 2703 4290 661 ...
##  $ like_user_num           : int [1:100000] 3829 2541 798 4875 3264 3723 1208 4333 1481 3516 ...
##  $ avg_like_cnt            : num [1:100000] 2.72 8.77 12.18 1.23 7.93 ...
##  $ share_success_cnt       : int [1:100000] 376 342 836 259 772 595 881 18 630 333 ...
##  $ share_success_user_num  : int [1:100000] 37 350 220 433 338 78 230 295 225 126 ...
##  $ avg_share_success_cnt   : num [1:100000] 0.0863 2.0528 1.9958 4.3156 4.4424 ...
##  $ follow_author_cnt       : int [1:100000] 559 849 688 512 155 520 992 662 735 171 ...
##  $ cancel_follow_author_cnt: int [1:100000] 64 95 30 67 74 94 60 72 36 41 ...
##  $ follow_user_cnt         : int [1:100000] 230 413 639 209 205 880 852 610 79 960 ...
##  $ cancel_follow_user_cnt  : int [1:100000] 21 76 1 86 96 93 85 3 85 69 ...
##  $ join_fans_group_cnt     : int [1:100000] 201 195 174 75 342 206 129 458 491 181 ...
##  $ live_new_user_num       : int [1:100000] 85 564 91 127 80 797 583 636 919 930 ...
##  $ report_live_cnt         : int [1:100000] 12 82 22 1 11 37 3 65 92 33 ...
##  $ report_user_cnt         : int [1:100000] 28 29 44 49 49 35 25 0 7 3 ...
##  $ fans_user_num           : int [1:100000] 20969 3975 40820 48247 45508 94049 92317 13586 56647 82865 ...
##  $ live_duration_7d        : int [1:100000] 4484 3401 3650 5361 6599 4663 9115 8842 7508 2739 ...
##  $ play_duration_7d        : int [1:100000] 6142 3661 1981 3639 7640 1566 8609 3151 1616 7810 ...
##  $ follow_user_num         : int [1:100000] 558 722 378 694 224 749 706 605 423 436 ...
##  $ reg_day_cnt             : int [1:100000] 2812 892 3094 261 1503 2272 246 2011 696 1099 ...
##  $ pk_id                   : num [1:100000] NA NA NA NA NA ...
##  $ gender                  : Factor w/ 2 levels "Female","Male": 2 1 2 1 1 1 1 2 2 1 ...
##  $ fans_range              : chr [1:100000] "1k-10k" "10k-50k" "1k-10k" "50k-100k" ...
##  $ fans_group_fans_num     : int [1:100000] 5311 4040 6103 1270 4316 4160 7940 5110 4237 1670 ...
##  $ author_type             : Factor w/ 3 levels "游戏","电商",..: 1 1 3 2 2 2 3 3 2 3 ...
##  $ author_income_range     : Factor w/ 5 levels "1","2","3","4",..: 2 4 3 2 5 3 4 5 3 4 ...
##  $ age_range               : Factor w/ 4 levels "18-24","25-34",..: 3 2 2 3 2 4 1 2 2 4 ...
##  $ fre_country_region      : Factor w/ 4 levels "CN","JP","KR",..: 1 1 1 4 1 2 4 3 1 2 ...
##  $ fre_city_level          : Factor w/ 4 levels "1","2","3","4": 4 4 1 1 2 1 1 1 4 1 ...
##  $ is_big_v                : Factor w/ 2 levels "0","1": 1 2 2 1 2 2 1 1 2 2 ...
##  $ day                     : chr [1:100000] "2024-04-11" "2023-10-02" "2024-02-06" "2024-07-04" ...
##  $ month                   : chr [1:100000] "2024-04" "2023-10" "2024-02" "2024-07" ...
##  $ year                    : chr [1:100000] "2024" "2023" "2024" "2024" ...
##  $ quarter                 : chr [1:100000] "2024-Q2" "2023-Q4" "2024-Q1" "2024-Q3" ...
##  $ reference_date          : Date[1:100000], format: "2022-12-31" "2022-12-31" ...
##  $ relative_day            : Factor w/ 366 levels "270","271","272",..: 198 6 133 282 310 141 345 46 163 226 ...
ggplot(df, aes(x = factor(is_pk_live), y = avg_total_cost_amt, fill = factor(is_pk_live))) +
  stat_summary(fun = mean, geom = "bar", color = "black", size = 0.5, alpha = 0.7) +
  stat_summary(fun.data = mean_cl_normal, geom = "errorbar", width = 0.2, size = 1) +
  theme_minimal() +
  ggtitle("Mean and Confidence Intervals of Total Cost by PK Status") +
  labs(x = "PK Live (0 = No, 1 = Yes)", y = "Mean Total Cost Amount") +
  theme(legend.position = "none", 
        plot.title = element_text(hjust = 0.5),
        text = element_text(size = 10))
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

dummy_vars <- model.matrix(~ gender + author_type + author_income_range + age_range + fre_country_region + fre_city_level,  # + relative_day
                                  data = df)[, -1]
dummy_vars <- as.data.table(dummy_vars)

# Replace non-ASCII characters with valid ASCII names
clean_column_names <- function(names) {
  # Replace specific Chinese characters with their English equivalents
  names <- gsub("电商", "Commerce", names)  # Replace '电商' with 'Commerce'
  names <- gsub("秀场", "Showcase", names)  # Replace '秀场' with 'Showcase'
  
  # Optionally, if there are other Chinese characters, you can continue to add more replacements here
  
  # Additionally, ensure names comply with general naming conventions
  names <- make.names(names)  # This replaces any other special characters with valid R names
  
  return(names)
}

# Apply the function to clean the column names in dummy_vars
setnames(dummy_vars, old = colnames(dummy_vars), new = clean_column_names(colnames(dummy_vars)))
data_with_dummies <- cbind(df, dummy_vars)
features_base <- c("live_duration_7d", "play_duration_7d", colnames(dummy_vars))
data_with_dummies <- data_with_dummies %>% mutate(across(where(is.Date), as.numeric))


data_dml_base <- DoubleMLData$new(data_with_dummies,
                                  y_col = "avg_total_cost_amt",    # Outcome variable
                                  d_cols = "is_pk_live",      # Treatment variable
                                  x_cols = features_base)  # Feature columns (including dummies and continuous variables)

Partially Linear Regression Model (PLR)

Here, PLR is the most suitable one for causal analysis.

LASSO

# Initialize learners
# set.seed(123)
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min")
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min")

# Initialize DoubleMLPLR model
dml_plr_lasso = DoubleMLPLR$new(data_dml_base, 
                                ml_l = lasso,
                                ml_m = lasso_class,
                                n_folds = 3)
dml_plr_lasso$fit()
dml_plr_lasso$summary()
## Estimates and significance testing of the effect of target variables
##            Estimate. Std. Error t value Pr(>|t|)
## is_pk_live  -0.04455    0.09991  -0.446    0.656

Random Forest

randomForest = lrn("regr.ranger", max.depth = 7,
                   mtry = 3, min.node.size = 3)
randomForest_class = lrn("classif.ranger", max.depth = 5,
                         mtry = 4, min.node.size = 7)

# set.seed(123)
dml_plr_forest = DoubleMLPLR$new(data_dml_base,
                                 ml_l = randomForest,
                                 ml_m = randomForest_class,
                                 n_folds = 3)
dml_plr_forest$fit() 
dml_plr_forest$summary()
## Estimates and significance testing of the effect of target variables
##            Estimate. Std. Error t value Pr(>|t|)
## is_pk_live  -0.04356    0.09991  -0.436    0.663

Regression tree

trees = lrn("regr.rpart", cp = 0.0047, minsplit = 203)
trees_class = lrn("classif.rpart", cp = 0.0042, minsplit = 104)

# set.seed(123)
dml_plr_tree = DoubleMLPLR$new(data_dml_base,
                               ml_l = trees,
                               ml_m = trees_class,
                               n_folds = 3)
dml_plr_tree$fit()
dml_plr_tree$summary()
## Estimates and significance testing of the effect of target variables
##            Estimate. Std. Error t value Pr(>|t|)
## is_pk_live  -0.04298    0.09990   -0.43    0.667

Boosted trees

boost = lrn("regr.xgboost",
            objective = "reg:squarederror",
            eta = 0.1, nrounds = 35)
boost_class = lrn("classif.xgboost",
                  objective = "binary:logistic", eval_metric = "logloss",
                  eta = 0.1, nrounds = 34)

# set.seed(123)
dml_plr_boost = DoubleMLPLR$new(data_dml_base,
                                ml_l = boost,
                                ml_m = boost_class,
                                n_folds = 3)
dml_plr_boost$fit()
dml_plr_boost$summary()
## Estimates and significance testing of the effect of target variables
##            Estimate. Std. Error t value Pr(>|t|)
## is_pk_live  -0.06315    0.10006  -0.631    0.528
confints = rbind(dml_plr_lasso$confint(), dml_plr_forest$confint(),
                 dml_plr_tree$confint(), dml_plr_boost$confint())
estimates = c(dml_plr_lasso$coef, dml_plr_forest$coef,
              dml_plr_tree$coef, dml_plr_boost$coef)
result_plr = data.table("model" = "PLR", 
                        "ML" = c("glmnet", "ranger", "rpart", "xgboost"), 
                        "Estimate" = estimates,
                        "lower" = confints[,1],
                        "upper" = confints[,2])
result_plr
##     model      ML    Estimate      lower     upper
##    <char>  <char>       <num>      <num>     <num>
## 1:    PLR  glmnet -0.04455103 -0.2403674 0.1512654
## 2:    PLR  ranger -0.04356096 -0.2393837 0.1522618
## 3:    PLR   rpart -0.04298217 -0.2387867 0.1528224
## 4:    PLR xgboost -0.06315042 -0.2592666 0.1329658
g_ci = ggplot(result_plr, aes(x = ML, y = Estimate, color = ML)) + 
        geom_point() +
        geom_errorbar(aes(ymin = lower, ymax = upper, color = ML))  +
        geom_hline(yintercept = 0, color = "grey") +
        theme_minimal() + ylab("Coefficients and 0.95- confidence interval") + 
        xlab("") + 
        theme(axis.text.x = element_text(angle = 90), legend.position = "none",
              text = element_text(size = 20))

g_ci