# Analyze simulated demographic data 

rm(list=ls())


# Load R environment ---------

renv::activate()


# Load packages ---------

library(here)
## here() starts at /users/akhann16/code/cadre/data-analysis-plotting/Simulated-Data-Analysis/r/agent-log-analysis
library(data.table)
library(yaml)
library(ggplot2)


# Read RDS file ------------

agent_log_env <- readRDS("/users/akhann16/code/cadre/data-analysis-plotting/Simulated-Data-Analysis/r/agent-log-analysis/rds-outs/agent_log_env.RDS")


# Load data ------------

agent_dt <- agent_log_env[["agent_dt"]]
input_params <- agent_log_env[["input_params"]] # THE INPUT PARAMS NEED TO BE ADDED FOR THE CODE TO WORK


# View data --------

head(agent_dt)
##    tick id age  race female alc_use_status smoking_status
## 1:    1  0  53 White      1              0         Former
## 2:    1  1  63 White      0              0          Never
## 3:    1  2  30 White      1              0          Never
## 4:    1  3  27 White      1              1          Never
## 5:    1  4  60 White      0              1         Former
## 6:    1  5  18 White      0              1          Never
##    last_incarceration_tick last_release_tick current_incarceration_status
## 1:                      -1                -1                            0
## 2:                      -1                -1                            0
## 3:                      -1                -1                            0
## 4:                      -1                -1                            0
## 5:                      -1                -1                            0
## 6:                      -1                -1                            0
##    entry_at_tick exit_at_tick n_incarcerations n_releases n_smkg_stat_trans
## 1:             0           -1                0          0                 0
## 2:             0           -1                0          0                 0
## 3:             0           -1                0          0                 0
## 4:             0           -1                0          0                 0
## 5:             0           -1                0          0                 0
## 6:             0           -1                0          0                 0
##    n_alc_use_stat_trans
## 1:                    0
## 2:                    0
## 3:                    0
## 4:                    0
## 5:                    0
## 6:                    0
# Summary ------------

last_tick <- max(agent_dt$tick)
selected_ticks <- c(seq(1, last_tick, by = 10), last_tick)


length(unique(agent_dt$id)) # n unique agents in dataset
## [1] 15849
range(unique(agent_dt$id))
## [1]     0 15848
# Race distribution ------------

## compute
agent_dt_by_race <- agent_dt[tick == last_tick, .N, by=race][,
                                                             "per_cent":=round(N/sum(N)*100)][order(race)]
target_values <- data.frame(race = names(input_params$RACE_DISTRIBUTION), 
                            target_pct = unlist(input_params$RACE_DISTRIBUTION))
agent_dt_by_race[]
##        race    N per_cent
## 1:    Asian  357        4
## 2:    Black  879        9
## 3: Hispanic 1626       16
## 4:    White 7138       71
time_ticks <- c(seq(1, last_tick, by = 10), last_tick)

# calculate the proportion of each race at the specified time ticks
race_proportions_by_tick <- agent_dt[tick %in% time_ticks, .(
  White = sum(race == "White") / .N,
  Black = sum(race == "Black") / .N,
  Hispanic = sum(race == "Hispanic") / .N,
  Asian = sum(race == "Asian") / .N
), by = tick]

# convert the data from wide to long format
race_proportions_long <- melt(race_proportions_by_tick, id.vars = "tick", variable.name = "race", value.name = "proportion")

# create a time series plot of race proportions
race_plot <-
  ggplot(race_proportions_long, aes(x = tick, y = proportion, color = race, group = race)) +
  geom_line(linewidth=1.2) +
  labs(title = "",
       x = "Time (Days)",
       y = "Proportion") +
  scale_color_manual(values = c("#377eb8", "#ff7f00", "#4daf4a", "#e41a1c")) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    axis.text.x = element_text(size = 14, face = "bold"),  
    axis.text.y = element_text(size = 14, face = "bold"),  
    legend.text = element_text(size = 14),
    axis.title.x = element_text(size = 16, face = "bold"),
    axis.title.y = element_text(size = 16, face = "bold")
  ) +
  theme(legend.title = element_blank())+
  scale_y_continuous(limits = c(0, 0.8), breaks = seq(0, 0.8, by = 0.1)) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "White"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "White"])), color = "#377eb8", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Black"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Black"])), color = "#ff7f00", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Hispanic"] + 0.03, 
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Hispanic"])), color = "#4daf4a", check_overlap = TRUE, size=5) +
  
  geom_text(aes(x = max(tick)/2, y = target_values$target_pct[target_values$race == "Asian"] + 0.03,
                label = sprintf("Target: %.3f", target_values$target_pct[target_values$race == "Asian"])), color = "#e41a1c", check_overlap = TRUE, size=5)

race_plot

# Sex distribution ------------

agent_dt[tick==last_tick, .N, by=female][,"%":=N/sum(N)][order(female)]
##    female    N     %
## 1:      0 4840 0.484
## 2:      1 5160 0.516
female_target <- input_params$FEMALE_PROP

gender_pct <- agent_dt[tick == last_tick, .N, by = female][, "%":= N/sum(N)][order(female)]

female_target <- input_params$FEMALE_PROP

# calculate target percentages for each gender
female_target_pct <- female_target
male_target_pct <- 1 - female_target

# calculate actual percentages for each gender
female_actual_pct <- gender_pct[female == 1, `%`]
male_actual_pct <- gender_pct[female == 0, `%`]


# create plot for last tick
ggplot(gender_pct, aes(x = ifelse(female == 0, "Male", "Female"), y = `%`, fill = ifelse(female == 0, "Male", "Female"))) +
  geom_bar(stat = "identity", color = "black") +
  scale_fill_manual(values = c("#1b9e77", "#d95f02")) +
  annotate("text", x = 1, y = female_actual_pct, 
           label = paste0("Female Target = ", round(female_actual_pct*100, 1), "%"), 
           color = "#1b9e77", size = 4, vjust = 0) +
  annotate("text", x = 2, y = male_actual_pct, 
           label = paste0("Male Target = ", round(male_actual_pct*100, 1), "%"), 
           color = "#d95f02", size = 4, vjust = 0) +
  labs(title = "",
       x = "Time (Days)",
       y = "Sex Distributions",
       fill = "") +
  theme_minimal()

# plot over time:
  
# calculate the proportion of males and females at the specified time ticks
gender_proportions_by_tick <- agent_dt[tick %in% time_ticks, .(
  male = sum(female == 0) / .N,
  female = sum(female == 1) / .N
), by = tick]

# convert the data from wide to long format
gender_proportions_long <- melt(gender_proportions_by_tick, id.vars = "tick", variable.name = "gender", value.name = "proportion")

ggplot(gender_proportions_long, aes(x = tick, y = proportion, color = gender, group = gender)) +
  geom_line(linewidth=1.5) +
  labs(title = "",
       x = "Time (Days)",
       y = "Proportion") +
  scale_color_manual(values = c("#1b9e77", "#d95f02"), labels = c("Male", "Female")) +
  theme_minimal() +
  theme(
    legend.title = element_blank(),
    axis.text.x = element_text(size = 16, face = "bold"),  
    axis.text.y = element_text(size = 16, face = "bold"),  
    legend.text = element_text(size = 16),
    axis.title.x = element_text(size = 16, face = "bold"),
    axis.title.y = element_text(size = 16, face = "bold")
  ) +
  scale_y_continuous(limits = c(0, 1), breaks = seq(0.4, 0.6, by = 0.1)) +
  geom_text(aes(x = max(tick)/2, y = (female_target+0.03), label = sprintf("Target: %.3f", female_target)), color = "#d95f02", check_overlap = TRUE, size = 5) + 
  geom_text(aes(x = max(tick)/2, y = (1 - female_target - 0.03), label = sprintf("Target: %.3f", 1 - female_target)), color = "#1b9e77", check_overlap = TRUE, size = 5)

# Age distribution ------------

agebreaks <- c(18, 25, 35, 45, 55, 65)
agelabels <- c("18-24", "25-34", "35-44", "45-54", "55-64")

agent_dt[tick == last_tick, .N, ]
## [1] 10000
setDT(agent_dt)[ , age_groups := cut(age, 
                                     breaks = agebreaks, 
                                     include.lowest = TRUE,
                                     right = FALSE, 
                                     labels = agelabels)]
nrow(agent_dt[tick == last_tick])
## [1] 10000
agent_dt[tick == last_tick, .N, by=c("age_groups", "race", "female")][order(age_groups, race, female)]
##     age_groups     race female    N
##  1:      18-24    Asian      0    1
##  2:      18-24    Asian      1    2
##  3:      18-24    Black      0    6
##  4:      18-24    Black      1    1
##  5:      18-24 Hispanic      0    5
##  6:      18-24 Hispanic      1    8
##  7:      18-24    White      0   34
##  8:      18-24    White      1   31
##  9:      25-34    Asian      0    3
## 10:      25-34    Asian      1    7
## 11:      25-34    Black      0   22
## 12:      25-34    Black      1   17
## 13:      25-34 Hispanic      0   28
## 14:      25-34 Hispanic      1   35
## 15:      25-34    White      0  132
## 16:      25-34    White      1  144
## 17:      35-44    Asian      0    9
## 18:      35-44    Asian      1   12
## 19:      35-44    Black      0   26
## 20:      35-44    Black      1   32
## 21:      35-44 Hispanic      0   64
## 22:      35-44 Hispanic      1   65
## 23:      35-44    White      0  240
## 24:      35-44    White      1  276
## 25:      45-54    Asian      0   31
## 26:      45-54    Asian      1   28
## 27:      45-54    Black      0   75
## 28:      45-54    Black      1   85
## 29:      45-54 Hispanic      0  157
## 30:      45-54 Hispanic      1  153
## 31:      45-54    White      0  613
## 32:      45-54    White      1  709
## 33:      55-64    Asian      0   47
## 34:      55-64    Asian      1   52
## 35:      55-64    Black      0  127
## 36:      55-64    Black      1  115
## 37:      55-64 Hispanic      0  187
## 38:      55-64 Hispanic      1  204
## 39:      55-64    White      0  912
## 40:      55-64    White      1  973
## 41:       <NA>    Asian      0   85
## 42:       <NA>    Asian      1   80
## 43:       <NA>    Black      0  184
## 44:       <NA>    Black      1  189
## 45:       <NA> Hispanic      0  342
## 46:       <NA> Hispanic      1  378
## 47:       <NA>    White      0 1510
## 48:       <NA>    White      1 1564
##     age_groups     race female    N
agent_dt[tick == last_tick, .N, by=c("race", "age_groups", "female")][order(race, age_groups)]
##         race age_groups female    N
##  1:    Asian      18-24      1    2
##  2:    Asian      18-24      0    1
##  3:    Asian      25-34      0    3
##  4:    Asian      25-34      1    7
##  5:    Asian      35-44      1   12
##  6:    Asian      35-44      0    9
##  7:    Asian      45-54      0   31
##  8:    Asian      45-54      1   28
##  9:    Asian      55-64      1   52
## 10:    Asian      55-64      0   47
## 11:    Asian       <NA>      0   85
## 12:    Asian       <NA>      1   80
## 13:    Black      18-24      1    1
## 14:    Black      18-24      0    6
## 15:    Black      25-34      1   17
## 16:    Black      25-34      0   22
## 17:    Black      35-44      1   32
## 18:    Black      35-44      0   26
## 19:    Black      45-54      0   75
## 20:    Black      45-54      1   85
## 21:    Black      55-64      1  115
## 22:    Black      55-64      0  127
## 23:    Black       <NA>      0  184
## 24:    Black       <NA>      1  189
## 25: Hispanic      18-24      0    5
## 26: Hispanic      18-24      1    8
## 27: Hispanic      25-34      0   28
## 28: Hispanic      25-34      1   35
## 29: Hispanic      35-44      1   65
## 30: Hispanic      35-44      0   64
## 31: Hispanic      45-54      1  153
## 32: Hispanic      45-54      0  157
## 33: Hispanic      55-64      1  204
## 34: Hispanic      55-64      0  187
## 35: Hispanic       <NA>      0  342
## 36: Hispanic       <NA>      1  378
## 37:    White      18-24      1   31
## 38:    White      18-24      0   34
## 39:    White      25-34      1  144
## 40:    White      25-34      0  132
## 41:    White      35-44      0  240
## 42:    White      35-44      1  276
## 43:    White      45-54      0  613
## 44:    White      45-54      1  709
## 45:    White      55-64      1  973
## 46:    White      55-64      0  912
## 47:    White       <NA>      1 1564
## 48:    White       <NA>      0 1510
##         race age_groups female    N
agent_dt[tick==last_tick, .N, by=c("race", "female")][,"%":=N/sum(N)*100][order(race)]
##        race female    N     %
## 1:    Asian      0  176  1.76
## 2:    Asian      1  181  1.81
## 3:    Black      0  440  4.40
## 4:    Black      1  439  4.39
## 5: Hispanic      0  783  7.83
## 6: Hispanic      1  843  8.43
## 7:    White      1 3697 36.97
## 8:    White      0 3441 34.41
agent_dt[tick==last_tick, .N, by=c("age_groups")][,"%":=round(N/sum(N)*100)][order(age_groups)]
##    age_groups    N  %
## 1:      18-24   88  1
## 2:      25-34  388  4
## 3:      35-44  724  7
## 4:      45-54 1851 19
## 5:      55-64 2617 26
## 6:       <NA> 4332 43
#Median age at the start: 
  
agent_dt[tick==1, median(age)]
## [1] 51
# Median age at the end:
  
nrow(agent_dt[tick==last_tick])
## [1] 10000
agent_dt[tick==last_tick, median(age)]
## [1] 63
# Create age groups
agent_dt_by_age_group <- agent_dt[tick==last_tick, .N, 
                                  by = .(age_group = cut(age, breaks = c(18, 25, 35, 45, 55, 65, 75, 80, 84)))]

# Create bar chart
ggplot(agent_dt_by_age_group, aes(x = age_group, y = N, fill = age_group)) +
  geom_bar(stat = "identity", color = "black") +
  scale_fill_manual(values = c("#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf")) +
  labs(title = "Agent Age Distribution",
       x = "Age Group",
       y = "Count",
       fill = "Age Group") +
  theme_minimal()