rm(list=ls())
set.seed(1234)
getwd()
## [1] "D:/Projects/Live-PK/0.DataCleaning/0.Code"
library("tidyverse")
library("fixest")
library("arrow")
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(ggplot2)
library(ranger)
library(xgboost)
# suppress messages during fitting
lgr::get_logger("mlr3")$set_threshold("warn")
df <- read_parquet("../../0.DataCleaning/1.Input/synthetic_data.parquet") %>%
mutate(
p_date = as_date(p_date), # Convert to Date format
day = format(p_date, "%Y-%m-%d"), # Extract day in "YYYY-MM-DD" format
month = format(p_date, "%Y-%m"), # Extract month in "YYYY-MM" format
year = format(p_date, "%Y"), # Extract year in "YYYY" format
quarter = paste0(year(p_date), "-Q", quarter(p_date)) # Extract quarter in "YYYY-QN" format
) %>%
mutate(
reference_date = as_date("2022-12-31"),
relative_day = as.integer(difftime(as_date(day), reference_date, units = "days")) # Difference in days
) %>%
mutate(
gender = as.factor(gender),
author_type = as.factor(author_type),
author_income_range = as.factor(author_income_range),
age_range = as.factor(age_range),
fre_country_region = as.factor(fre_country_region),
fre_city_level = as.factor(fre_city_level),
is_big_v = as.factor(is_big_v),
relative_day = as.factor(relative_day)
)
# gc()
summary(df$relative_day)
## 402 284 607 307 565 618 395 393 398 415
## 317 315 315 313 312 311 310 306 305 305
## 626 400 270 279 381 545 547 278 343 387
## 304 302 300 300 300 300 300 299 299 299
## 523 532 625 273 439 591 339 403 582 593
## 299 299 299 298 298 298 297 297 297 297
## 594 479 552 296 346 446 287 291 390 391
## 297 296 296 295 295 295 294 294 294 294
## 312 321 338 358 445 472 521 551 614 309
## 293 293 293 293 293 293 293 293 293 292
## 328 394 460 474 515 617 375 453 539 463
## 292 292 292 292 292 292 291 291 291 290
## 490 497 574 486 504 601 294 396 414 456
## 290 290 290 289 289 289 288 288 288 288
## 522 530 542 599 633 527 563 612 351 372
## 288 288 288 288 288 287 287 287 285 285
## 404 464 503 511 592 621 363 412 417 491
## 285 285 285 285 285 285 284 284 284 284
## 581 597 604 306 313 336 360 382 451 (Other)
## 284 284 284 283 283 283 283 283 283 70959
summary(df$is_pk_live)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0000 0.0000 0.0000 0.2981 1.0000 1.0000
str(df)
## tibble [100,000 × 48] (S3: tbl_df/tbl/data.frame)
## $ author_id : int [1:100000] 10162 10538 10074 10436 10334 10293 10007 10180 10501 10504 ...
## $ live_id : int [1:100000] 52748 56505 54926 55204 55998 51037 59513 59822 54103 51242 ...
## $ p_date : Date[1:100000], format: "2024-04-11" "2023-10-02" ...
## $ is_pk_live : int [1:100000] 0 0 0 0 0 0 0 0 0 1 ...
## $ valid_play_duration : int [1:100000] 491 990 557 601 770 330 330 738 424 56 ...
## $ valid_play_user_num : int [1:100000] 851 635 24 673 589 43 452 290 739 331 ...
## $ avg_valid_play_duration : num [1:100000] 5.56 9.19 2.82 6.2 8.05 ...
## $ total_cost_amt : num [1:100000] 608 393 915 658 568 ...
## $ total_cost_user_num : int [1:100000] 387 357 381 240 203 138 13 2 63 325 ...
## $ avg_total_cost_amt : num [1:100000] 44 28.9 19.1 21.4 41 ...
## $ comment_cnt : int [1:100000] 367 219 432 9 254 399 347 260 400 46 ...
## $ comment_user_num : int [1:100000] 86 17 255 65 299 54 244 252 218 28 ...
## $ avg_comment_cnt : num [1:100000] 2.29 4.38 7.31 7.06 9.11 ...
## $ like_cnt : int [1:100000] 4130 7452 976 1332 5915 5497 4863 2703 4290 661 ...
## $ like_user_num : int [1:100000] 3829 2541 798 4875 3264 3723 1208 4333 1481 3516 ...
## $ avg_like_cnt : num [1:100000] 2.72 8.77 12.18 1.23 7.93 ...
## $ share_success_cnt : int [1:100000] 376 342 836 259 772 595 881 18 630 333 ...
## $ share_success_user_num : int [1:100000] 37 350 220 433 338 78 230 295 225 126 ...
## $ avg_share_success_cnt : num [1:100000] 0.0863 2.0528 1.9958 4.3156 4.4424 ...
## $ follow_author_cnt : int [1:100000] 559 849 688 512 155 520 992 662 735 171 ...
## $ cancel_follow_author_cnt: int [1:100000] 64 95 30 67 74 94 60 72 36 41 ...
## $ follow_user_cnt : int [1:100000] 230 413 639 209 205 880 852 610 79 960 ...
## $ cancel_follow_user_cnt : int [1:100000] 21 76 1 86 96 93 85 3 85 69 ...
## $ join_fans_group_cnt : int [1:100000] 201 195 174 75 342 206 129 458 491 181 ...
## $ live_new_user_num : int [1:100000] 85 564 91 127 80 797 583 636 919 930 ...
## $ report_live_cnt : int [1:100000] 12 82 22 1 11 37 3 65 92 33 ...
## $ report_user_cnt : int [1:100000] 28 29 44 49 49 35 25 0 7 3 ...
## $ fans_user_num : int [1:100000] 20969 3975 40820 48247 45508 94049 92317 13586 56647 82865 ...
## $ live_duration_7d : int [1:100000] 4484 3401 3650 5361 6599 4663 9115 8842 7508 2739 ...
## $ play_duration_7d : int [1:100000] 6142 3661 1981 3639 7640 1566 8609 3151 1616 7810 ...
## $ follow_user_num : int [1:100000] 558 722 378 694 224 749 706 605 423 436 ...
## $ reg_day_cnt : int [1:100000] 2812 892 3094 261 1503 2272 246 2011 696 1099 ...
## $ pk_id : num [1:100000] NA NA NA NA NA ...
## $ gender : Factor w/ 2 levels "Female","Male": 2 1 2 1 1 1 1 2 2 1 ...
## $ fans_range : chr [1:100000] "1k-10k" "10k-50k" "1k-10k" "50k-100k" ...
## $ fans_group_fans_num : int [1:100000] 5311 4040 6103 1270 4316 4160 7940 5110 4237 1670 ...
## $ author_type : Factor w/ 3 levels "游戏","电商",..: 1 1 3 2 2 2 3 3 2 3 ...
## $ author_income_range : Factor w/ 5 levels "1","2","3","4",..: 2 4 3 2 5 3 4 5 3 4 ...
## $ age_range : Factor w/ 4 levels "18-24","25-34",..: 3 2 2 3 2 4 1 2 2 4 ...
## $ fre_country_region : Factor w/ 4 levels "CN","JP","KR",..: 1 1 1 4 1 2 4 3 1 2 ...
## $ fre_city_level : Factor w/ 4 levels "1","2","3","4": 4 4 1 1 2 1 1 1 4 1 ...
## $ is_big_v : Factor w/ 2 levels "0","1": 1 2 2 1 2 2 1 1 2 2 ...
## $ day : chr [1:100000] "2024-04-11" "2023-10-02" "2024-02-06" "2024-07-04" ...
## $ month : chr [1:100000] "2024-04" "2023-10" "2024-02" "2024-07" ...
## $ year : chr [1:100000] "2024" "2023" "2024" "2024" ...
## $ quarter : chr [1:100000] "2024-Q2" "2023-Q4" "2024-Q1" "2024-Q3" ...
## $ reference_date : Date[1:100000], format: "2022-12-31" "2022-12-31" ...
## $ relative_day : Factor w/ 366 levels "270","271","272",..: 198 6 133 282 310 141 345 46 163 226 ...
ggplot(df, aes(x = factor(is_pk_live), y = avg_total_cost_amt, fill = factor(is_pk_live))) +
stat_summary(fun = mean, geom = "bar", color = "black", size = 0.5, alpha = 0.7) +
stat_summary(fun.data = mean_cl_normal, geom = "errorbar", width = 0.2, size = 1) +
theme_minimal() +
ggtitle("Mean and Confidence Intervals of Total Cost by PK Status") +
labs(x = "PK Live (0 = No, 1 = Yes)", y = "Mean Total Cost Amount") +
theme(legend.position = "none",
plot.title = element_text(hjust = 0.5),
text = element_text(size = 10))
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
dummy_vars <- model.matrix(~ gender + author_type + author_income_range + age_range + fre_country_region + fre_city_level, # + relative_day
data = df)[, -1]
dummy_vars <- as.data.table(dummy_vars)
# Replace non-ASCII characters with valid ASCII names
clean_column_names <- function(names) {
# Replace specific Chinese characters with their English equivalents
names <- gsub("电商", "Commerce", names) # Replace '电商' with 'Commerce'
names <- gsub("秀场", "Showcase", names) # Replace '秀场' with 'Showcase'
# Optionally, if there are other Chinese characters, you can continue to add more replacements here
# Additionally, ensure names comply with general naming conventions
names <- make.names(names) # This replaces any other special characters with valid R names
return(names)
}
# Apply the function to clean the column names in dummy_vars
setnames(dummy_vars, old = colnames(dummy_vars), new = clean_column_names(colnames(dummy_vars)))
data_with_dummies <- cbind(df, dummy_vars)
features_base <- c("live_duration_7d", "play_duration_7d", colnames(dummy_vars))
data_with_dummies <- data_with_dummies %>% mutate(across(where(is.Date), as.numeric))
data_dml_base <- DoubleMLData$new(data_with_dummies,
y_col = "avg_total_cost_amt", # Outcome variable
d_cols = "is_pk_live", # Treatment variable
x_cols = features_base) # Feature columns (including dummies and continuous variables)
Here, PLR is the most suitable one for causal analysis.
# Initialize learners
# set.seed(123)
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min")
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min")
# Initialize DoubleMLPLR model
dml_plr_lasso = DoubleMLPLR$new(data_dml_base,
ml_l = lasso,
ml_m = lasso_class,
n_folds = 3)
dml_plr_lasso$fit()
dml_plr_lasso$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## is_pk_live -0.04455 0.09991 -0.446 0.656
randomForest = lrn("regr.ranger", max.depth = 7,
mtry = 3, min.node.size = 3)
randomForest_class = lrn("classif.ranger", max.depth = 5,
mtry = 4, min.node.size = 7)
# set.seed(123)
dml_plr_forest = DoubleMLPLR$new(data_dml_base,
ml_l = randomForest,
ml_m = randomForest_class,
n_folds = 3)
dml_plr_forest$fit()
dml_plr_forest$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## is_pk_live -0.04356 0.09991 -0.436 0.663
trees = lrn("regr.rpart", cp = 0.0047, minsplit = 203)
trees_class = lrn("classif.rpart", cp = 0.0042, minsplit = 104)
# set.seed(123)
dml_plr_tree = DoubleMLPLR$new(data_dml_base,
ml_l = trees,
ml_m = trees_class,
n_folds = 3)
dml_plr_tree$fit()
dml_plr_tree$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## is_pk_live -0.04298 0.09990 -0.43 0.667
boost = lrn("regr.xgboost",
objective = "reg:squarederror",
eta = 0.1, nrounds = 35)
boost_class = lrn("classif.xgboost",
objective = "binary:logistic", eval_metric = "logloss",
eta = 0.1, nrounds = 34)
# set.seed(123)
dml_plr_boost = DoubleMLPLR$new(data_dml_base,
ml_l = boost,
ml_m = boost_class,
n_folds = 3)
dml_plr_boost$fit()
dml_plr_boost$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## is_pk_live -0.06315 0.10006 -0.631 0.528
confints = rbind(dml_plr_lasso$confint(), dml_plr_forest$confint(),
dml_plr_tree$confint(), dml_plr_boost$confint())
estimates = c(dml_plr_lasso$coef, dml_plr_forest$coef,
dml_plr_tree$coef, dml_plr_boost$coef)
result_plr = data.table("model" = "PLR",
"ML" = c("glmnet", "ranger", "rpart", "xgboost"),
"Estimate" = estimates,
"lower" = confints[,1],
"upper" = confints[,2])
result_plr
## model ML Estimate lower upper
## <char> <char> <num> <num> <num>
## 1: PLR glmnet -0.04455103 -0.2403674 0.1512654
## 2: PLR ranger -0.04356096 -0.2393837 0.1522618
## 3: PLR rpart -0.04298217 -0.2387867 0.1528224
## 4: PLR xgboost -0.06315042 -0.2592666 0.1329658
g_ci = ggplot(result_plr, aes(x = ML, y = Estimate, color = ML)) +
geom_point() +
geom_errorbar(aes(ymin = lower, ymax = upper, color = ML)) +
geom_hline(yintercept = 0, color = "grey") +
theme_minimal() + ylab("Coefficients and 0.95- confidence interval") +
xlab("") +
theme(axis.text.x = element_text(angle = 90), legend.position = "none",
text = element_text(size = 20))
g_ci