# Analyze simulated demographic data 

rm(list=ls())


# Load R environment ---------

renv::activate()


# Load packages ---------

library(here)
## here() starts at /users/akhann16/code/cadre/data-analysis-plotting/Simulated-Data-Analysis/r/agent-log-analysis
library(data.table)
library(yaml)
library(ggplot2)


# Read RDS file ------------

agent_log_env <- readRDS("/users/akhann16/code/cadre/data-analysis-plotting/Simulated-Data-Analysis/r/agent-log-analysis/rds-outs/agent_log_env.RDS")


# Load data ------------

agent_dt <- agent_log_env[["agent_dt"]]
input_params <- agent_log_env[["input_params"]] # THE INPUT PARAMS NEED TO BE ADDED FOR THE CODE TO WORK


# View data --------

head(agent_dt)
##    tick   id age     race female alc_use_status smoking_status
## 1:    1 2610  84    Asian      1              0          Never
## 2:    1 7868  84    White      1              3         Former
## 3:    1    0  83 Hispanic      1              1         Former
## 4:    1    1  83 Hispanic      0              3         Former
## 5:    1    2  41    White      1              0          Never
## 6:    1    3  31 Hispanic      0              0          Never
##    last_incarceration_tick last_release_tick current_incarceration_status
## 1:                      -1                -1                            0
## 2:                      -1                -1                            0
## 3:                      -1                -1                            0
## 4:                      -1                -1                            0
## 5:                      -1                -1                            0
## 6:                      -1                -1                            0
##    entry_at_tick exit_at_tick n_incarcerations n_releases n_smkg_stat_trans
## 1:             0            1                0          0                 0
## 2:             0            1                0          0                 0
## 3:             0           -1                0          0                 0
## 4:             0           -1                0          0                 0
## 5:             0           -1                0          0                 0
## 6:             0           -1                0          0                 0
##    n_alc_use_stat_trans
## 1:                    0
## 2:                    0
## 3:                    0
## 4:                    0
## 5:                    0
## 6:                    0
# Summary ------------

last_tick <- max(agent_dt$tick)
selected_ticks <- c(seq(1, last_tick, by = 10), last_tick)


length(unique(agent_dt$id)) # n unique agents in dataset
## [1] 15735
range(unique(agent_dt$id))
## [1]     0 15734
# Race distribution ------------

## compute
agent_dt_by_race <- agent_dt[tick == last_tick, .N, by=race][,
                                                             "per_cent":=round(N/sum(N)*100)][order(race)]
target_values <- data.frame(race = names(input_params$RACE_DISTRIBUTION), 
                            target_pct = unlist(input_params$RACE_DISTRIBUTION))
agent_dt_by_race[]
##        race    N per_cent
## 1:    Asian  394        4
## 2:    Black  856        9
## 3: Hispanic 1619       16
## 4:    White 7132       71
time_ticks <- c(seq(1, last_tick, by = 10), last_tick)

# calculate the proportion of each race at the specified time ticks
race_proportions_by_tick <- agent_dt[tick %in% time_ticks, .(
  White = sum(race == "White") / .N,
  Black = sum(race == "Black") / .N,
  Hispanic = sum(race == "Hispanic") / .N,
  Asian = sum(race == "Asian") / .N
), by = tick]

# convert the data from wide to long format
race_proportions_long <- melt(race_proportions_by_tick, id.vars = "tick", variable.name = "race", value.name = "proportion")

# create a time series plot of race proportions
race_plot <-
  ggplot(race_proportions_long, aes(x = tick, y = proportion, color = race, group = race)) +
  geom_line(linewidth=1.2) +
  labs(title = "",
       x = "Time (Days)",
       y = "Proportion") +
  scale_color_manual(values = c("#377eb8", "#ff7f00", "#4daf4a", "#e41a1c")) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    axis.text.x = element_text(size = 14, face = "bold"),  
    axis.text.y = element_text(size = 14, face = "bold"),  
    legend.text = element_text(size = 14),
    axis.title.x = element_text(size = 16, face = "bold"),
    axis.title.y = element_text(size = 16, face = "bold")
  ) +
  theme(legend.title = element_blank())+
  scale_y_continuous(limits = c(0, 0.8), breaks = seq(0, 0.8, by = 0.1)) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "White"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "White"])), color = "#377eb8", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Black"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Black"])), color = "#ff7f00", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Hispanic"] + 0.03, 
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Hispanic"])), color = "#4daf4a", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Asian"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Asian"])), color = "#e41a1c", check_overlap = TRUE, size=5)

race_plot

# Sex distribution ------------

agent_dt[tick==last_tick, .N, by=female][,"%":=N/sum(N)][order(female)]
##    female    N         %
## 1:      0 4834 0.4833517
## 2:      1 5167 0.5166483
female_target <- input_params$FEMALE_PROP

gender_pct <- agent_dt[tick == last_tick, .N, by = female][, "%":= N/sum(N)][order(female)]

female_target <- input_params$FEMALE_PROP

# calculate target percentages for each gender
female_target_pct <- female_target
male_target_pct <- 1 - female_target

# calculate actual percentages for each gender
female_actual_pct <- gender_pct[female == 1, `%`]
male_actual_pct <- gender_pct[female == 0, `%`]


# create plot for last tick
ggplot(gender_pct, aes(x = ifelse(female == 0, "Male", "Female"), y = `%`, fill = ifelse(female == 0, "Male", "Female"))) +
  geom_bar(stat = "identity", color = "black") +
  scale_fill_manual(values = c("#1b9e77", "#d95f02")) +
  annotate("text", x = 1, y = female_actual_pct, 
           label = paste0("Female Target = ", round(female_actual_pct*100, 1), "%"), 
           color = "#1b9e77", size = 4, vjust = 0) +
  annotate("text", x = 2, y = male_actual_pct, 
           label = paste0("Male Target = ", round(male_actual_pct*100, 1), "%"), 
           color = "#d95f02", size = 4, vjust = 0) +
  labs(title = "",
       x = "Time (Days)",
       y = "Sex Distributions",
       fill = "") +
  theme_minimal()

# plot over time:
  
# calculate the proportion of males and females at the specified time ticks
gender_proportions_by_tick <- agent_dt[tick %in% time_ticks, .(
  male = sum(female == 0) / .N,
  female = sum(female == 1) / .N
), by = tick]

# convert the data from wide to long format
gender_proportions_long <- melt(gender_proportions_by_tick, id.vars = "tick", variable.name = "gender", value.name = "proportion")

ggplot(gender_proportions_long, aes(x = tick, y = proportion, color = gender, group = gender)) +
  geom_line(linewidth=1.5) +
  labs(title = "",
       x = "Time (Days)",
       y = "Proportion") +
  scale_color_manual(values = c("#1b9e77", "#d95f02"), labels = c("Male", "Female")) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    axis.text.x = element_text(size = 16, face = "bold"),  
    axis.text.y = element_text(size = 16, face = "bold"),  
    legend.text = element_text(size = 16),
    axis.title.x = element_text(size = 16, face = "bold"),
    axis.title.y = element_text(size = 16, face = "bold")
  ) +
  scale_y_continuous(limits = c(0, 1), breaks = seq(0.4, 0.6, by = 0.1)) +
  geom_text(aes(x = max(tick)/2, y = (female_target+0.03), label = sprintf("Target: %.3f", female_target)), color = "#d95f02", check_overlap = TRUE, size = 5) + 
  geom_text(aes(x = max(tick)/2, y = (1 - female_target - 0.03), label = sprintf("Target: %.3f", 1 - female_target)), color = "#1b9e77", check_overlap = TRUE, size = 5)

# Age distribution ------------

agebreaks <- c(18, 25, 35, 45, 55, 65)
agelabels <- c("18-24", "25-34", "35-44", "45-54", "55-64")

agent_dt[tick == last_tick, .N, ]
## [1] 10001
setDT(agent_dt)[ , age_groups := cut(age, 
                                     breaks = agebreaks, 
                                     include.lowest = TRUE,
                                     right = FALSE, 
                                     labels = agelabels)]
nrow(agent_dt[tick == last_tick])
## [1] 10001
agent_dt[tick == last_tick, .N, by=c("age_groups", "race", "female")][order(age_groups, race, female)]
##     age_groups     race female    N
##  1:      18-24    Asian      0    1
##  2:      18-24    Asian      1    2
##  3:      18-24    Black      0    4
##  4:      18-24    Black      1    3
##  5:      18-24 Hispanic      0   10
##  6:      18-24 Hispanic      1    9
##  7:      18-24    White      0   27
##  8:      18-24    White      1   33
##  9:      25-34    Asian      0    7
## 10:      25-34    Asian      1   12
## 11:      25-34    Black      0   14
## 12:      25-34    Black      1   13
## 13:      25-34 Hispanic      0   34
## 14:      25-34 Hispanic      1   44
## 15:      25-34    White      0  134
## 16:      25-34    White      1  142
## 17:      35-44    Asian      0    6
## 18:      35-44    Asian      1   23
## 19:      35-44    Black      0   24
## 20:      35-44    Black      1   37
## 21:      35-44 Hispanic      0   43
## 22:      35-44 Hispanic      1   60
## 23:      35-44    White      0  224
## 24:      35-44    White      1  250
## 25:      45-54    Asian      0   40
## 26:      45-54    Asian      1   36
## 27:      45-54    Black      0   61
## 28:      45-54    Black      1   62
## 29:      45-54 Hispanic      0  140
## 30:      45-54 Hispanic      1  164
## 31:      45-54    White      0  608
## 32:      45-54    White      1  653
## 33:      55-64    Asian      0   60
## 34:      55-64    Asian      1   54
## 35:      55-64    Black      0  126
## 36:      55-64    Black      1  117
## 37:      55-64 Hispanic      0  222
## 38:      55-64 Hispanic      1  193
## 39:      55-64    White      0  923
## 40:      55-64    White      1 1006
## 41:       <NA>    Asian      0   76
## 42:       <NA>    Asian      1   77
## 43:       <NA>    Black      0  190
## 44:       <NA>    Black      1  205
## 45:       <NA> Hispanic      0  346
## 46:       <NA> Hispanic      1  354
## 47:       <NA>    White      0 1514
## 48:       <NA>    White      1 1618
##     age_groups     race female    N
agent_dt[tick == last_tick, .N, by=c("race", "age_groups", "female")][order(race, age_groups)]
##         race age_groups female    N
##  1:    Asian      18-24      1    2
##  2:    Asian      18-24      0    1
##  3:    Asian      25-34      1   12
##  4:    Asian      25-34      0    7
##  5:    Asian      35-44      1   23
##  6:    Asian      35-44      0    6
##  7:    Asian      45-54      1   36
##  8:    Asian      45-54      0   40
##  9:    Asian      55-64      1   54
## 10:    Asian      55-64      0   60
## 11:    Asian       <NA>      1   77
## 12:    Asian       <NA>      0   76
## 13:    Black      18-24      0    4
## 14:    Black      18-24      1    3
## 15:    Black      25-34      0   14
## 16:    Black      25-34      1   13
## 17:    Black      35-44      1   37
## 18:    Black      35-44      0   24
## 19:    Black      45-54      0   61
## 20:    Black      45-54      1   62
## 21:    Black      55-64      1  117
## 22:    Black      55-64      0  126
## 23:    Black       <NA>      1  205
## 24:    Black       <NA>      0  190
## 25: Hispanic      18-24      1    9
## 26: Hispanic      18-24      0   10
## 27: Hispanic      25-34      1   44
## 28: Hispanic      25-34      0   34
## 29: Hispanic      35-44      0   43
## 30: Hispanic      35-44      1   60
## 31: Hispanic      45-54      1  164
## 32: Hispanic      45-54      0  140
## 33: Hispanic      55-64      0  222
## 34: Hispanic      55-64      1  193
## 35: Hispanic       <NA>      0  346
## 36: Hispanic       <NA>      1  354
## 37:    White      18-24      1   33
## 38:    White      18-24      0   27
## 39:    White      25-34      1  142
## 40:    White      25-34      0  134
## 41:    White      35-44      1  250
## 42:    White      35-44      0  224
## 43:    White      45-54      1  653
## 44:    White      45-54      0  608
## 45:    White      55-64      0  923
## 46:    White      55-64      1 1006
## 47:    White       <NA>      0 1514
## 48:    White       <NA>      1 1618
##         race age_groups female    N
agent_dt[tick==last_tick, .N, by=c("race", "female")][,"%":=N/sum(N)*100][order(race)]
##        race female    N         %
## 1:    Asian      1  204  2.039796
## 2:    Asian      0  190  1.899810
## 3:    Black      1  437  4.369563
## 4:    Black      0  419  4.189581
## 5: Hispanic      0  795  7.949205
## 6: Hispanic      1  824  8.239176
## 7:    White      0 3430 34.296570
## 8:    White      1 3702 37.016298
agent_dt[tick==last_tick, .N, by=c("age_groups")][,"%":=round(N/sum(N)*100)][order(age_groups)]
##    age_groups    N  %
## 1:      18-24   89  1
## 2:      25-34  400  4
## 3:      35-44  667  7
## 4:      45-54 1764 18
## 5:      55-64 2701 27
## 6:       <NA> 4380 44
#Median age at the start: 
  
agent_dt[tick==1, median(age)]
## [1] 51
# Median age at the end:
  
nrow(agent_dt[tick==last_tick])
## [1] 10001
agent_dt[tick==last_tick, median(age)]
## [1] 63
# Create age groups
agent_dt_by_age_group <- agent_dt[tick==last_tick, .N, 
                                  by = .(age_group = cut(age, breaks = c(18, 25, 35, 45, 55, 65, 75, 80, 84)))]

# Create bar chart
ggplot(agent_dt_by_age_group, aes(x = age_group, y = N, fill = age_group)) +
  geom_bar(stat = "identity", color = "black") +
  scale_fill_manual(values = c("#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf")) +
  labs(title = "Agent Age Distribution",
       x = "Age Group",
       y = "Count",
       fill = "Age Group") +
  theme_minimal()