rm(list=ls())
set.seed(1234)
getwd()
## [1] "D:/Projects/Live-PK/0.DataCleaning/0.Code"
library("tidyverse")
library("fixest")
library("arrow")
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(ggplot2)
library(ranger)
library(xgboost)
library(plotly)
library(dplyr)

# suppress messages during fitting
lgr::get_logger("mlr3")$set_threshold("warn") 

df <- read_parquet("../../0.DataCleaning/1.Input/synthetic_data2.parquet") %>%
#   # mutate(
#   #   p_date = as_date(p_date),              # Convert to Date format
#   #   day = format(p_date, "%Y-%m-%d"),      # Extract day in "YYYY-MM-DD" format
#   #   month = format(p_date, "%Y-%m"),       # Extract month in "YYYY-MM" format
#   #   year = format(p_date, "%Y"),           # Extract year in "YYYY" format
#   #   quarter = paste0(year(p_date), "-Q", quarter(p_date))  # Extract quarter in "YYYY-QN" format
#   # ) %>%
#   # mutate(
#   #   reference_date = as_date("2022-12-31"),
#   #   relative_day = as.integer(difftime(as_date(day), reference_date, units = "days"))  # Difference in days
#   # ) %>%
#   # mutate(
#   #   gender = as.factor(gender),
#   #   author_type = as.factor(author_type),
#   #   author_income_range = as.factor(author_income_range),
#   #   age_range = as.factor(age_range),
#   #   fre_country_region = as.factor(fre_country_region),
#   #   fre_city_level = as.factor(fre_city_level),
#   #   is_big_v = as.factor(is_big_v),
#   #   relative_day = as.factor(relative_day)
#   # ) %>%
  mutate(A_fan_count = fans_user_num,
         B_fan_count = other_fans_user_num,
         log_A_fan_count = log(fans_user_num),
         log_B_fan_count = log(other_fans_user_num))

# gc()

Model-free

# Step 1: Group the data by A_fan_count and B_fan_count, and calculate the average money for A
agg_df <- df %>%
  group_by(A_fan_count, B_fan_count) %>%
  summarize(avg_money_A = mean(avg_fan_total_cost_amt, na.rm = TRUE))
## `summarise()` has grouped output by 'A_fan_count'. You can override using the
## `.groups` argument.
# Step 2: Create the 3D surface plot
fig <- plot_ly(agg_df, 
               x = ~log(A_fan_count), 
               y = ~log(B_fan_count), 
               z = ~avg_money_A, 
               type = 'surface', 
               colorscale = 'Viridis')

# Step 3: Customize the layout for better visualization
fig <- fig %>%
  layout(title = '3D Surface Plot of Avg Money for A',
         scene = list(
           xaxis = list(title = "Log(A's Fan Count)"),
           yaxis = list(title = "Log(B's Fan Count)"),
           zaxis = list(title = "Avg Money for A")
         ))

# Display the plot
fig
fig <- plot_ly(df, 
               x = ~log(A_fan_count), 
               y = ~log(B_fan_count), 
               z = ~avg_fan_total_cost_amt, 
               type = 'scatter3d', 
               mode = 'markers', 
               marker = list(size = 2, color = ~avg_fan_total_cost_amt, colorscale = 'Viridis', showscale = TRUE))

fig <- fig %>% layout(title = '3D Scatter Plot of Avg Money for A',
                      scene = list(
                        xaxis = list(title = "Log(A's Fan Count)"),
                        yaxis = list(title = "Log(B's Fan Count)"),
                        zaxis = list(title = "Avg Money for A")
                      ))

fig
# Step 1: Set the correct order for the categories
df <- df %>%
  mutate(fans_range = factor(fans_range, levels = c("0", "0-100", "100-1k", "1k-1w", "1w-10w", "10w+")),
         other_fans_range = factor(other_fans_range, levels = c("0", "0-100", "100-1k", "1k-1w", "1w-10w", "10w+")))

# Step 2: Aggregate the data to calculate the mean and confidence intervals
agg_df <- df %>%
  group_by(fans_range, other_fans_range) %>%
  summarize(
    avg_fan_total_cost_amt = mean(avg_fan_total_cost_amt, na.rm = TRUE),
    ci_lower = avg_fan_total_cost_amt - qt(0.975, n()) * sd(avg_fan_total_cost_amt, na.rm = TRUE) / sqrt(n()),
    ci_upper = avg_fan_total_cost_amt + qt(0.975, n()) * sd(avg_fan_total_cost_amt, na.rm = TRUE) / sqrt(n())
  )
## `summarise()` has grouped output by 'fans_range'. You can override using the
## `.groups` argument.
# Step 3: Plot the data in a heatmap-like format with CI as shading or error bars
ggplot(agg_df, aes(x = fans_range, y = other_fans_range, fill = avg_fan_total_cost_amt)) +
  geom_tile() +
  geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.2) +
  scale_fill_viridis_c() +
  labs(title = "Average Fan Total Cost Amount by Fan Ranges",
       x = "Fans Range",
       y = "Other Fans Range",
       fill = "Avg Fan Total Cost Amt") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

library(dplyr)
library(ggplot2)

# Step 1: Log-transform and bin the variables into 20 categories with new names
df <- df %>%
  mutate(log_fans_user_num = log10(fans_user_num + 1),  # Apply log10 transformation
         log_other_fans_user_num = log10(other_fans_user_num + 1),  # Log-transform
         log_fans_bins = cut(log_fans_user_num, breaks = 20, labels = FALSE),  # Bin into 20 categories
         log_other_fans_bins = cut(log_other_fans_user_num, breaks = 20, labels = FALSE))  # Bin into 20 categories

# Step 2: Aggregate the data to calculate the mean and confidence intervals
agg_df <- df %>%
  group_by(log_fans_bins, log_other_fans_bins) %>%
  summarize(
    avg_fan_total_cost_amt = mean(avg_fan_total_cost_amt, na.rm = TRUE),
    ci_lower = avg_fan_total_cost_amt - qt(0.975, n()) * sd(avg_fan_total_cost_amt, na.rm = TRUE) / sqrt(n()),
    ci_upper = avg_fan_total_cost_amt + qt(0.975, n()) * sd(avg_fan_total_cost_amt, na.rm = TRUE) / sqrt(n()),
    count = n()
  )
## `summarise()` has grouped output by 'log_fans_bins'. You can override using the
## `.groups` argument.
# Calculate the total number of data points for the percentage plot
total_data_points <- nrow(df)

# Step 3: Plot the heatmap with log-transformed bins
ggplot(agg_df, aes(x = log_fans_bins, y = log_other_fans_bins, fill = avg_fan_total_cost_amt)) +
  geom_tile() +
  scale_fill_gradient(low = "white", high = "red") +  # White to red color gradient
  labs(title = "Heatmap of Avg Fan Total Cost Amt by Log-Transformed Fan Bins",
       x = "Log(Binned Fans User Num)",
       y = "Log(Binned Other Fans User Num)",
       fill = "Avg Fan Total Cost Amt") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# Step 3: Plot the heatmap with log-transformed bins but exponential labels
ggplot(agg_df, aes(x = as.numeric(log_fans_bins), y = as.numeric(log_other_fans_bins), fill = avg_fan_total_cost_amt)) +
  geom_tile() +
  scale_fill_gradient(low = "white", high = "red") +  # White to red color gradient
  scale_x_continuous(breaks = seq(1, 20), labels = function(x) round(10^((x - 1) * (log10(max(df$fans_user_num + 1)) / 19)), 2)) +
  scale_y_continuous(breaks = seq(1, 20), labels = function(y) round(10^((y - 1) * (log10(max(df$other_fans_user_num + 1)) / 19)), 2)) +
  labs(title = "Heatmap of Avg Fan Total Cost Amt by Log-Transformed Fan Bins",
       x = "Fans User Num (Original Scale)",
       y = "Other Fans User Num (Original Scale)",
       fill = "Avg Fan Total Cost Amt") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# Step 4: Create a heatmap showing the percentage of the total data in each block
agg_df <- agg_df %>%
  mutate(pct_of_total = (count / total_data_points) * 100)

ggplot(agg_df, aes(x = as.numeric(log_fans_bins), y = as.numeric(log_other_fans_bins), fill = pct_of_total)) +
  geom_tile() +
  scale_fill_viridis_c() +
  scale_x_continuous(breaks = seq(1, 20), labels = function(x) round(10^((x - 1) * (log10(max(df$fans_user_num + 1)) / 19)), 2)) +
  scale_y_continuous(breaks = seq(1, 20), labels = function(y) round(10^((y - 1) * (log10(max(df$other_fans_user_num + 1)) / 19)), 2)) +
  labs(title = "Percentage of Data in Each Log-Transformed Bin",
       x = "Fans User Num (Original Scale)",
       y = "Other Fans User Num (Original Scale)",
       fill = "Percentage of Total Data") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))