## For other simulations, just use a different file here.
## However, do keep the name sim.
suppressPackageStartupMessages(library(tidyverse))
options(dplyr.summarise.inform = FALSE)
## Set path to model run you want to get diagnostics for:
# sim <- readRDS("../../../Data/EpiModelSims/ergm3_ca_nohiv_sim_with15yo_boost.rds")
if (params$sim == "None") {
  sim <- readRDS("../../../EpiModel/AE/sim_epimodel3/sim_at2033_dec28.rds")
}else{
  sim <- params$sim
}
## Registered S3 method overwritten by 'tergm':
##   method                   from
##   simulate_formula.network ergm
## Creating the netest object (make sure it is the same as what was 
## used to run the simulation).
## No HIV
# netest_mc <- readRDS("../../../Data/stergmFits/netest_mc_nohiv_ca_whamp3.rds")
# load("../../../Data/stergmFits/netest_inst_nohiv_whamp3.rda")
## HIV
if (params$netest == "None") {
  netest_mc <- readRDS("../../../Data/stergmFits/netest_mc_hiv_caxr_whamp3.rds")
  load("../../../Data/stergmFits/netest_inst_hiv_caxr_whamp3.rda")
  netest_obj <- list("main" = netest_mc$main,
                     "casl" = netest_mc$casl,
                     "inst" = netest_inst)
}else{
  netest_obj <- params$netest
}

make_num_plot <- function(sim) {
  nums <- sim$epi[[1]]$num
  nums <- nums[!is.na(nums)]
  start_year <- sim$control$year_start
  year_vals <- start_year + seq_along(nums) / 52
  fin_df <- data.frame("Network_Size" = nums, 
                       "Year" = year_vals)
  fin_df %>% ggplot(aes(x = Year, y = Network_Size)) + 
    geom_line() +
    geom_vline(xintercept = max(sim$param$demog_match_arrival_df$entry_year)) + 
    ylab("Network size")
}

make_exit_plot <- function(sim) {
  ad <- sim$epi[[1]]$age_dist %>% as.data.frame()
  ad$year <- round(sim$control$year_start +
                     (ad$at - sim$control$start) * 7 / 365)
  ad$obs_mort <- NA
  for (yrs in sort(unique(ad$year))[-1]) {
    this_yr <- ad %>% filter(year == yrs) %>%
      mutate(cur_cnt = count) %>% filter(at == min(at))
    last_yr <- ad %>% filter(year == yrs - 1) %>% 
      mutate(age = age + 1, lst_cnt = count) %>%
      filter(at == min(at))
    cmbr <- left_join(this_yr, last_yr, 
                      by = c("age", "race")) %>%
      mutate(obs_mort = (lst_cnt - cur_cnt) / lst_cnt)
    ad[ad$year == yrs, "obs_mort"] <- cmbr[, "obs_mort"]
  }
  mean_obs_morts <- ad %>% group_by(race, age) %>%
    summarise(mort = mean(obs_mort, na.rm = TRUE)) %>%
    select(race, age, mort)
  trg_morts <- WApopdata::asmr_by_race %>%
    filter(age > 15) %>%
    tidyr::pivot_longer(cols = 2:4, names_to = "race",
                        values_to = "mort") %>% 
    mutate(mort = 1 - (1 - mort) ** 52) %>%
    mutate(race = gsub("vec.asmr.", "", race))
  all_morts <- bind_rows(
    bind_cols(mean_obs_morts, type = "observed"),
    bind_cols(trg_morts, type = "target")
  )
  
  young <- all_morts %>% filter(age < 40) %>% mutate(mort = pmax(mort, 0.001)) %>%
    ggplot(aes(x = age, y = mort, shape = type, 
               lty = type, color = race)) +
    geom_point(size = 2) + geom_smooth(se = FALSE, method = "lm") + 
    coord_cartesian(ylim = c(0.001, 0.01)) + scale_y_log10() + 
    facet_wrap(~ race) + theme(legend.position = "none")
  old <- all_morts %>% filter(age >= 40) %>% 
    mutate(mort = pmax(mort, 0.001)) %>%
    ggplot(aes(x = age, y = mort, shape = type, 
               lty = type, color = race)) +
    geom_point(size = 2) + geom_smooth(se = FALSE, method = "lm") + 
    coord_cartesian(ylim = c(0.001, 0.125)) + scale_y_log10() + 
    facet_wrap(~ race) + theme(legend.position = "bottom")
  cowplot::plot_grid(
    young, old,
    labels = c("Younger", "Older"), 
    rel_heights = c(1, 1.3),
    ncol = 1
  )
}

make_mean_deg_plot <- function(sim, var_type = "race", 
                               nest_obj = NULL, simnum = 1){
  edge_counts <- sim$epi[[simnum]]$demog_edge_counts
  edge_counts$at[edge_counts$at == 1] <- 
    min(edge_counts$at[edge_counts$at != 1]) - 1
  edge_counts[, "col_var"] <- edge_counts[, var_type]
  f_edg_cnts <- edge_counts %>% filter(ptype == 0) %>%
    mutate(rep = 3) %>% uncount(rep, .id = "ptype")
  edge_counts <- bind_rows(
    edge_counts %>% filter(ptype != 0),
    f_edg_cnts
  )
  
  cmpr_edges <- pivot_wider(
    edge_counts, 
     values_from = "count", names_from = "type",
    id_cols = colnames(edge_counts)
    ) %>% filter(age.grp != 6) %>%
    group_by(ptype, col_var, at) %>%
    summarise(partner = sum(partner, na.rm = TRUE),
              total = sum(total, na.rm = TRUE)) %>%
    mutate(
      "mean_deg" = partner / total, 
      "year"  = round(
        (at - sim$control$start) / 52 + sim$control$year_start, 4
        )
      )
  cmpr_edges$ptype <- factor(cmpr_edges$ptype, levels = 1:3,
                             labels = c("Main",  "Casl", "Inst"))
  if (is.null(nest_obj)) {
    cmpr_edges %>% ungroup() %>% 
     ggplot(aes(x = year, y = mean_deg, 
                color = as.character(col_var))) + 
     geom_point(alpha = 0.5) + geom_smooth() + 
     facet_wrap(~ ptype, scales = "free_y" ) + 
     geom_vline(xintercept = 2019, color = "black") +
     scale_colour_discrete(var_type) + ylab("Mean Degree") +
     xlab("Year") + theme(legend.position = "bottom")
  }else{
    degs <- lapply(nest_obj, "[[", "fit") %>% lapply("[[", "newnetwork") %>%
      sapply(sna::degree)
    ego_dat <- ergm.ego::as.egodata(nest_obj$main$fit$newnetwork)$egos
    ego_dat <- data.frame(ego_dat, degs) %>%
      filter(age.grp != 6) %>%
      select(age.grp, race, region, 
             "Main" = main, "Casl" = casl, "Inst" = inst)
    ego_dat[, "col_var"] <- ego_dat[, var_type]
    long_dat <- pivot_longer(
      ego_dat, names_to = "ptype", values_to = "deg",
      # id_cols = c("age.grp", "race", "region", "col_var"),
      cols = c("Main", "Casl", "Inst")
        )
    mean_degs <- long_dat %>% group_by(ptype, col_var) %>%
      summarise(Mean_deg = mean(deg) / 2)
    mean_degs$ptype <- factor(mean_degs$ptype,
                              levels = c("Main",  "Casl", "Inst"),
                              labels = c("Main",  "Casl", "Inst"))
    
    cmpr_edges %>% ungroup() %>% 
     ggplot(aes(x = year, y = mean_deg, 
                color = as.character(col_var))) + 
     geom_point(alpha = 0.5) + geom_smooth() +
     geom_hline(data = mean_degs, aes(yintercept = Mean_deg,
                                      color = as.factor(col_var))) +
     facet_wrap(~ ptype, scales = "free_y" ) + 
     geom_vline(xintercept = 2019, color = "black") +
     scale_colour_discrete(var_type) + ylab("Mean Degree") +
     xlab("Year") + theme(legend.position = "bottom")
  }
}

make_mean_agextime <- function(sim, var_type = "race",
                               yrs = NULL, targ_year = 2019,
                               nest_obj = NULL,
                               plot_dif = FALSE){
  if (length(targ_year) != 1) {stop("Must specify one target year")}
  edge_counts <- sim$epi[[1]]$demog_edge_counts
  if (var_type == "none") {
    edge_counts[, "facet_var"] <- "pooled"
  }else{
    edge_counts[, "facet_var"] <- edge_counts[, var_type]
  }
  f_edg_cnts <- edge_counts %>% filter(ptype == 0) %>%
    mutate(rep = 3) %>% uncount(rep, .id = "ptype")
  edge_counts <- bind_rows(
    edge_counts %>% filter(ptype != 0),
    f_edg_cnts
  )
  cmpr_edges <- pivot_wider(
    edge_counts, 
     values_from = "count", names_from = "type",
    id_cols = colnames(edge_counts)
    ) %>% group_by(ptype, age, at, facet_var) %>%
    summarise(partner = sum(partner, na.rm = TRUE),
              total = sum(total, na.rm = TRUE)) %>%
    mutate(
      "mean_deg" = partner / total, 
      "year"  = round(
        (at - sim$control$start) / 52 + sim$control$year_start
        )
      )
  cmpr_edges$ptype <- factor(cmpr_edges$ptype, levels = 1:3,
                             labels = c("Main",  "Casl", "Inst"))
  all_yrs <- sort(unique(cmpr_edges$year), decreasing = FALSE)[-1]
  if (is.null(yrs)) {
    keep_yrs <- c(min(all_yrs), all_yrs[all_yrs %% 10 == 0])
  }else{
    keep_yrs <- yrs
  }
  keep_yrs <- unique(c(yrs, targ_year))
  cmpr_edges <- cmpr_edges %>% filter(year %in% keep_yrs, 
                                      year < 2031)
  cmpr_edges$year <- factor(
    cmpr_edges$year, 
    levels = sort(cmpr_edges$year), 
    labels = sort(cmpr_edges$year)
    )
  if (is.null(nest_obj)) {
   cmpr_edges %>% ungroup() %>% 
     ggplot(aes(x = age, y = mean_deg, col = facet_var)) + 
     # geom_line(alpha = 0.5) +
     geom_smooth(se = FALSE, span = 4) + #, size = 2) + 
     facet_grid(year~ptype) + 
     geom_vline(xintercept = 2019, color = "black") +
     scale_colour_discrete("Partnership Type") +
     ylab("Mean Degree") +
     xlab("Age") + theme(legend.position = "bottom")
  }else{
    degs <- lapply(nest_obj, "[[", "fit") %>% lapply("[[", "newnetwork") %>%
      sapply(sna::degree)
    ego_dat <- ergm.ego::as.egodata(nest_obj$main$fit$newnetwork)$egos
    ego_dat <- data.frame(ego_dat, degs) %>%
      select(age, race, region, 
             "Main" = main, "Casl" = casl, "Inst" = inst) %>%
      mutate(age = floor(age))
    if (var_type == "none") {
      ego_dat[, "facet_var"] <- "pooled"
    }else{
      ego_dat[, "facet_var"] <- ego_dat[, var_type]
    }
    long_dat <- pivot_longer(
      ego_dat, names_to = "ptype", values_to = "deg",
      cols = c("Main", "Casl", "Inst")
        )
    mean_degs <- long_dat %>% group_by(ptype, age, facet_var) %>%
      summarise(mean_deg = mean(deg) / 2)
    mean_degs$ptype <- factor(mean_degs$ptype,
                              levels = c("Main",  "Casl", "Inst"),
                              labels = c("Main",  "Casl", "Inst"))
    r_mean_degs <- mean_degs %>% mutate(rep = length(keep_yrs)) %>%
      uncount(rep, .id = "year") %>% mutate(year = keep_yrs[year]) %>%
      filter(year < 2031)
    r_mean_degs$year <- factor(
    r_mean_degs$year,
    levels = sort(r_mean_degs$year),
    labels = sort(r_mean_degs$year)
    )
    if (var_type == "none") {alph <- 0.3}else{alph = 0}
    all_md <- bind_rows(
      bind_cols(cmpr_edges, "type" = "Simulated"),
      bind_cols(r_mean_degs, "type" = "Start/Target  (from netest$newnetwork)")
    ) 
    if (plot_dif) {
      wd <- all_md %>% ungroup() %>% select(type, facet_var, 
                              mean_deg, ptype, year, age, highlight) %>% 
        pivot_wider(names_from = type, 
                    values_from = mean_deg) %>%
        mutate(dif = Simulated - `Start/Target (from netest$newnetwork)` )
     wd %>%
     ggplot(aes(x = age, y = dif, 
                col = facet_var)) + 
     geom_hline(yintercept = 0, color = "black") +
     geom_vline(xintercept = 65, color = "gray") +
     
     geom_smooth(se = FALSE) + #, size = 2, span = 2.5) + 
     facet_wrap(~ptype + year, scales = "free_y", dir = "v", 
                ncol = 3) + 
     scale_colour_discrete(var_type) +
     coord_cartesian(ylim = c(-0.6, 0.4)) +
     ylab("Mean Degree (Observed - Target)") +
     xlab("Age") + theme(legend.position = "bottom")
    }else{
      fac_df <- expand.grid(year = unique(all_md$year), 
                            ptype = unique(all_md$ptype))
      fac_df$highl <- as.numeric(fac_df$year == targ_year)
      all_md %>% ungroup() %>% 
        ggplot() +
        # geom_point(alpha = alph) +
        geom_rect(data = fac_df %>% filter(highl == 1),
                  aes(fill = as.factor(highl)),
                  xmin = -Inf, xmax = Inf,
                  ymin = -Inf, ymax = Inf, alpha = 0.3) +
        geom_vline(xintercept = 65, color = "gray") +
        geom_smooth(aes(x = age, y = mean_deg, 
                        color = facet_var,
                   linetype = type), 
                   se = FALSE) + #, size = 2, span = 2.5) + 
        facet_grid(ptype ~ year, scales = "free_y") + 
        scale_colour_discrete(var_type) +
        scale_fill_manual(values = "black") +
        ylab("Mean Degree") + guides(fill = FALSE) +
        xlab("Age") + theme(legend.position = "bottom") + 
        coord_cartesian(ylim = c(0, NA))
    } 
  }
}

make_agextime_gif <- function(sim, var_type = "race",
                              nest_obj = NULL){
  require(gganimate)
  edge_counts <- sim$epi[[1]]$demog_edge_counts
  if (var_type == "none") {
    edge_counts[, "facet_var"] <- "pooled"
  }else{
    edge_counts[, "facet_var"] <- edge_counts[, var_type]
  }
  f_edg_cnts <- edge_counts %>% filter(ptype == 0) %>%
    mutate(rep = 3) %>% uncount(rep, .id = "ptype")
  edge_counts <- bind_rows(
    edge_counts %>% filter(ptype != 0),
    f_edg_cnts
  )
  cmpr_edges <- pivot_wider(
    edge_counts, 
    values_from = "count", names_from = "type",
    id_cols = colnames(edge_counts)
  ) %>% group_by(ptype, age, at, facet_var) %>%
    summarise(partner = sum(partner, na.rm = TRUE),
              total = sum(total, na.rm = TRUE)) %>%
    mutate(
      "mean_deg" = partner / total, 
      "year"  = round(
        (at - sim$control$start) / 52 + sim$control$year_start
      )
    )
  cmpr_edges$ptype <- factor(cmpr_edges$ptype, levels = 1:3,
                             labels = c("Main",  "Casl", "Inst"))
  all_yrs <- unique(cmpr_edges$year)[-1]
  cmpr_edges <- cmpr_edges %>% filter(year %in% all_yrs)
  degs <- lapply(nest_obj, "[[", "fit") %>% lapply("[[", "newnetwork") %>%
    sapply(sna::degree)
  ego_dat <- ergm.ego::as.egodata(nest_obj$main$fit$newnetwork)$egos
  ego_dat <- data.frame(ego_dat, degs) %>%
    select(age, race, region, 
           "Main" = main, "Casl" = casl, "Inst" = inst) %>%
    mutate(age = floor(age))
  if (var_type == "none") {
    ego_dat[, "facet_var"] <- "pooled"
  }else{
    ego_dat[, "facet_var"] <- ego_dat[, var_type]
  }
  long_dat <- pivot_longer(
    ego_dat, names_to = "ptype", values_to = "deg",
    cols = c("Main", "Casl", "Inst")
  )
  mean_degs <- long_dat %>% group_by(ptype, age, facet_var) %>%
    summarise(mean_deg = mean(deg) / 2)
  mean_degs$ptype <- factor(mean_degs$ptype,
                            levels = c("Main",  "Casl", "Inst"),
                            labels = c("Main",  "Casl", "Inst"))
  r_mean_degs <- mean_degs %>% mutate(rep = length(all_yrs)) %>%
    uncount(rep, .id = "year") %>% mutate(year = all_yrs[year])
  all_md <- bind_rows(
    bind_cols(cmpr_edges, "type" = "Simulated"),
    bind_cols(r_mean_degs, "type" = "Start/Target (from netest$newnetwork)")
  )
  if (var_type == "none") {alph <- 0.7}else{alph = 0}
  all_md %>% ungroup() %>% 
    ggplot(aes(x = age, y = mean_deg, 
               col = facet_var, shape = type,
               linetype = type)) + 
    geom_point(alpha = alph) +
    geom_smooth(se = FALSE) + 
    facet_wrap(~ptype, scales = "free_y") + 
    scale_colour_discrete("Partnership Type") +
    ylab("Mean Degree") +
    xlab("Age") + theme(legend.position = "bottom") + 
    labs(title = 'Year: {round(frame_time)}', x = 'Age',
         y = 'Mean Degree') +
    transition_time(year) +
    ease_aes('linear')
}

cmpr_tstats_plot <- function(sim, type = "race"){
  edge_counts <- sim$epi[[1]]$demog_edge_counts
  f_edg_cnts <- edge_counts %>% filter(ptype == 0) %>%
    mutate(rep = 3) %>% uncount(rep, .id = "ptype")
  edge_counts <- bind_rows(
    edge_counts %>% filter(ptype != 0),
    f_edg_cnts
  )
  total_nodes <- sim$epi[[1]]$demog_edge_counts %>%
    filter(at == 5200, ptype == 0)
  race_tots <- total_nodes %>% group_by(race) %>%
    summarise(count = sum(count)) %>% mutate(times = 3) %>%
    uncount(times, .id = "ptype")
  age.grp_tots <- total_nodes %>% 
    mutate(age.grp = as.character(age.grp)) %>%
    group_by(age.grp) %>%
    summarise(count = sum(count)) %>% mutate(times = 3) %>%
    uncount(times, .id = "ptype") 
  
  if (type == "race") {
    race_targs <- NULL
    for (nw_idx in 1:3) {
      stats <- sim$nwparam[[nw_idx]]$target.stats
      names(stats) <- sim$nwparam[[nw_idx]]$target.stats.names
      edges_term <- stats[1]
      race_targ <- stats[grep("nodefactor.race", names(stats))]
      race_names <- c(gsub("nodefactor.race.", 
                           "", names(race_targ)), "O")
      num_edges <- bind_cols(
        race = race_names,
        edges = c(race_targ,
                  2 * edges_term - sum(race_targ)),
        ptype = nw_idx,
      )
      race_targs <- bind_rows(race_targs, num_edges)
    }
    
    race_counts <- left_join(race_targs, race_tots, 
                             by = c("ptype", "race")) %>%
      mutate("mean_deg" = edges / count)
    
    race_ec <- edge_counts %>% group_by(race, ptype, type, at) %>%
      summarise(count = sum(count))
    cmpr_race_edges <- pivot_wider(race_ec,
                                   values_from = "count", names_from = "type",
                                   id_cols = colnames(race_ec)
    ) %>% 
      mutate(
        "mean_deg" = partner / total,
        "year"  = round(
          (at - sim$control$start) / 52 + sim$control$year_start
        )
      )
    cmpr_race_edges$ptype <- factor(cmpr_race_edges$ptype, levels = 1:3,
                                    labels = c("Main",  "Casl", "Inst"))
    race_counts$ptype <- factor(race_counts$ptype, levels = 1:3,
                                    labels = c("Main",  "Casl", "Inst"))
    ggplot(cmpr_race_edges %>% filter(year != min(year)), 
           aes(x = year, y = mean_deg, color = ptype)) + 
      geom_point(alpha = 0.5, size = 2) + geom_smooth(size = 2) +
      geom_hline(data = race_counts, size = 2,
                 aes(yintercept = mean_deg, color = as.factor(ptype))) +
      facet_wrap(~race) + theme_minimal() + coord_cartesian(ylim = c(0, NA))
  }
  else{
    age.grp_targs <- NULL
    for (nw_idx in 1:3) {
      stats <- sim$nwparam[[nw_idx]]$target.stats
      names(stats) <- sim$nwparam[[nw_idx]]$target.stats.names
      edges_term <- stats[1]
      age.grp_targ <- stats[grep("nodefactor.age.grp", names(stats))]
      num_edges <- bind_cols(
        age.grp = c(gsub("nodefactor.age.grp.", "", names(age.grp_targ)), "1"),
        edges = c(age.grp_targ, 2 * edges_term - sum(age.grp_targ)),
        ptype = nw_idx,
      )
      age.grp_targs <- bind_rows(age.grp_targs, num_edges)
    }
    
    age.grp_counts <- left_join(age.grp_targs, age.grp_tots, 
                                by = c("ptype", "age.grp")) %>%
      mutate("mean_deg" = edges / count)
    
    age.grp_ec <- edge_counts %>% group_by(age.grp, ptype, type, at) %>%
      summarise(count = sum(count))
    cmpr_age.grp_edges <- pivot_wider(age.grp_ec,
                                      values_from = "count", names_from = "type",
                                      id_cols = colnames(age.grp_ec)
    ) %>% mutate("mean_deg" = partner / total)
    
    cmpr_age.grp_edges$ptype <- factor(
      cmpr_age.grp_edges$ptype, levels = 1:3,
      labels = c("Main",  "Casl", "Inst"))
    age.grp_counts$ptype <- factor(age.grp_counts$ptype, levels = 1:3,
                                    labels = c("Main",  "Casl", "Inst"))
    ggplot(cmpr_age.grp_edges %>% filter(year != min(year)), 
           aes(x = at, y = mean_deg, color = as.factor(ptype))) + 
      geom_point(alpha = 0.5) + geom_smooth() +
      geom_hline(data = age.grp_counts, size = 2,
                 aes(yintercept = mean_deg, color = as.factor(ptype))) +
      facet_wrap(~age.grp)
  }
}

make_split_num_plot <- function(sim, race_spec = FALSE, ofm_proj = NULL) {
  age_cnts <- sim$epi[[1]]$age_dist %>%
    mutate(Age = c("Active (< 65)", 
                          "Inactive (>65)")[2 - as.numeric(age < 65)]) 
  if (race_spec) {
    age_per_year <- age_cnts %>% group_by(at, Age, race) %>%
      summarise(count = sum(count)) %>% 
      mutate(year = round((at - sim$control$start) / 52 + 
                            sim$control$year_start))
    if (!is.null(ofm_proj)) {
      age_per_year$type <- "Simulated"
      age_per_year <- age_per_year %>% filter(year > 2010) %>% group_by(year) %>%
        mutate(prop = count / sum(count)) %>%
        select(year, prop, race, type, Age)
      age_per_year <- bind_rows(age_per_year, ofm_proj)
      swtch_yr <- sim$param$demog_match_arrival_df$entry_year
      fin_plot <- age_per_year %>% 
        ggplot(aes(x = year, y = 100 * prop, 
                   linetype = type, color = Age)) + 
        geom_line(size = 2) +
        geom_vline(xintercept = max(swtch_yr)) + 
        facet_wrap(~ race, scales = "free_y") + 
        ylab("Percent") + xlab("Year") +
        theme(legend.position = "bottom", 
              legend.key.width = unit(3, "line"))
      return(fin_plot)
    }
    age_per_year %>%
      ggplot(aes(x = year, y = count, color = Age)) + 
      geom_line(size = 2) +
      geom_vline(xintercept = max(sim$param$demog_match_arrival_df$entry_year)) + 
      facet_wrap(~ race) + theme(legend.position = "bottom") 
  }else {
    age_per_year <- age_cnts %>% group_by(at, Age) %>%
      summarise(count = sum(count)) %>%
      mutate(year = round((at - sim$control$start) / 52 + 
                            sim$control$year_start))
    age_per_year %>% ggplot(aes(x = year, y = count, color = Age)) + 
      geom_line(size = 2) +
      geom_vline(xintercept = max(sim$param$demog_match_arrival_df$entry_year)) 
  }
}

library(ggridges)
yoi <- c(1943, 2011, 2016, 2019, 2020, 2025, 2030)
source("make_line_level_data.R")

sim_age_distr <- expand_age_dist(
  sim$epi[[1]]$age_dist,
  years_wanted = yoi,
  year_start = sim$control$year_start 
  )

ofm_data <- make_ofm_ll_dat(yoi)

pop_props <- WApopdata::msm_all_age10_race_region_2019 %>%
  group_by(age.grp = age.grp10, race) %>% 
  filter(!age.grp %in% c("65-74", "75-84", "85+")) %>% 
  summarise(prop = sum(joint.pct), 
            value = prop * 10000)
pop_props$min_age <- substr(pop_props$age.grp, start = 1, stop = 2)
pop_props$max_age <- substr(pop_props$age.grp, start = 4, stop = 5)

targ_dist <- WApopdata::wa_joint_age_race %>% 
  tidyr::pivot_longer(cols = c("B", "H", "O"), names_to = "race") %>%
  group_by(race) %>% mutate(prop = value / sum(value))


targ_dist$age.grp <- gsub("age", "", targ_dist$age.grp) %>%
  gsub(pattern = "to", replacement = "-") %>%
  gsub(pattern = "plus", replacement = "+")
targ_dist$min_age <- str_split(targ_dist$age.grp, "-") %>%
  sapply("[", 1)
targ_dist$max_age <- str_split(targ_dist$age.grp, "-") %>%
  sapply("[", 2)
targ_dist <- targ_dist %>% filter(age.grp %in% c(
   "15-19", "20-24", "25-29", "30-34", "35-39", 
  "40-44", "45-49", "50-54", "55-59", "60-64")
) 

targ_dst <- bind_rows(
  bind_cols(pop_props, "year" = yoi[1]),
  bind_cols(pop_props, "year" = yoi[2]),
  bind_cols(pop_props, "year" = yoi[3]),
  bind_cols(pop_props, "year" = yoi[4]),
  bind_cols(pop_props, "year" = yoi[5])
)

acs_dat <- expnd_df(targ_dst)
acs_dat <- rbind(acs_dat, acs_dat %>% mutate(race = "all"))

all_entry <- bind_rows(
  # bind_cols(ent_15yr_distr, "type" = "all_15"),
  bind_cols(sim_age_distr, "type" = "Simulated"),
  bind_cols(ofm_data, type = "OFM_Proj"),
  bind_cols(acs_dat, type = "ACS_Target")
)

make_ridge_plot <- function(data_fr){
  suppressMessages( ggplot(data_fr %>% filter(age < 64),
       aes(x = age,
           y = as.factor(year),
           color = type,
           fill = type)) +
  geom_vline(xintercept = seq(from = 15, to = 85, by = 10),
             alpha = 0.1) +
  # geom_density_ridges(alpha = 0.1, scale = 1,
  #                     bins = length(unique(all_entry$age)),
  #                     stat = "binline") +
  geom_density_ridges(alpha = 0.2, scale = 1,
                      size = 1) +
  theme_ridges() +
  theme(legend.position = "bottom") + 
  scale_x_continuous(minor_breaks = seq(from = 15, to = 90, by = 5),
                     breaks = seq(from = 15, to = 85, by = 5)) + 
  ylab("Year") + xlab("Age") + 
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank()) + facet_wrap(~race, scales = "free_y")
)
}

1 Introduction

This report describes how we produce the target demographic profile in the simulated population.

  • The targets are based on estimates of the WA State male population by race, age and region from the American Community Survey (5 year estimate from 2011-2015)1, with the regional distribution of MSM based on estimates of the fraction of men who are MSM by county in Grey et al., 2016.
  • The goal is to reproduce this as a stable distribution in the simulated population, as much as possible.

  • The methods are implemented in the arrival module of EpiModel.

The primary challenge is to reproduce the target age distribution (and the specific age distributions by race). A uniform age distribution is easy to reproduce with constant inflow from the lower age boundary (15). As the plot below shows, these distributions are not uniform, and there are substantial differences in the age profiles by race. So we need to take another approach.

ADD AGE DISTN BY RACE PLOT HERE

Reproducing the non-uniform target age distribution can be achieved in one of two ways:

  1. In-migration into some age groups and out-migration from others at each timestep, in such a way that their relative frequency is preserved over time.

In-migration is driving much of the growth in WA state’s population, based on OFM analyses. This likely explains the increase in relative frequency in the 20-40 age range. But it is not as clear what is driving the subsequent declining frequency of older groups – it could be outmigration, or temporal variation in the rates of in-migration, or a combination of both.
Implementing migration in the simulation would require a large number of additional assumptions about the demographic attributes and HIV status of in- and out- migrants. Empirical estimates of these migration-specific attributes are not readily available, and beyond the scope of this project, so this approach is not feasible.

  1. Limiting entry to age 15, and back-calculating the size of the incoming cohort to produce the target age profile at the start year of the simulation.

This is the approach we take. It uses temporal cycling of the entering cohort size to achieve the observed age profile by race in the target start year. Note that this does not allow us to preserve the exact profile over time after the target start year, but the short simulation time frame (10 years) means that our population does not vary much from the initial target.

1.1 Targets and data sources

There are two targets the arrival modules tries to match.

1.1.1 Start Year Demographics: 2019

The first goal is to have a demographically representative MSM population at the start of the simulation. Demographically representative in this context means that the simulated network has:

  • The target size of sexually active individuals. In our simulation individuals who are 65 are retained in the population; their existing partnerships continue until they dissolve, but no new partnerships are formed in this group.

  • Joint age / race distributions that match the WA State ACS (American community survey) joint age / race distribution for

  • Regional distribution (Western Washington, King Co, Eastern Washington) based on the ACS estimates of male population size in each region adjusted by the Grey et al. estimates of the proportion of men in each region that are MSM. More information about the MSM estimates can be found in this paper, at the CAMP Website, and data can be downloaded here.

1.1.2 Demographics over time

The second goal is to have the entering cohort after the start year match the OFM projected race proportions while maintaining a constant population size of sexually active (younger than 65 years) individuals.

2 Methods

2.1 Hitting Start Year Demographics

In order to hit the target year demographics, a back calculation is carried out to determine the number of individuals who will enter the population at each time step. All individuals who enter the population do so at the age of 15.

The back calculation is carried out for each age / race category. For a single age race category, say 25 year old Hispanics, the target number of individuals in this age race category is calculated by multiplying the proportion of the population in this category by the size of the network. Next, the proportion of individuals who will still be alive by the time they are their target age, say 55, is calculated using the age and race specific mortality rates.

With this information in hand, each year, the calculated number of individuals will be entered into the population. Note that because this method requires entering individuals at the age 15 and some individuals in our target population are 65, the simulation must start many years before the target year.

As a simple example, say that the number of 55 year old Hispanics in our target population was 2000. The proportion of hispanics who are alive at age 25 who were alive at age 15 is calculated using asmr. Using the table given below:

# knitr::kable(WApopdata::asmr_by_race[15:55, c("age", "vec.asmr.H")])
prop_alive <- WApopdata::asmr_by_race[15:55, c("age", "vec.asmr.H")] %>%
  mutate(prop_alive = cumprod((1 - vec.asmr.H) ** 52))
knitr::kable(head(prop_alive))
age vec.asmr.H prop_alive
15 1.26e-05 0.9993450
16 1.46e-05 0.9985851
17 1.67e-05 0.9977205
18 1.87e-05 0.9967514
19 2.07e-05 0.9956783
20 2.27e-05 0.9945014

We can calculate the proportion alive at each age by taking the cumulative product of \(1 - \text{Pr}(\text{death})\) for each year. A power of 52 is taken because asmr is given in probability of death for each week.

num_enter <- 2000 / prop_alive$prop_alive[nrow(prop_alive)]
num_enter
## [1] 2199.84

Thus, 2200 individuals will be entered into the population at the age of 15 so that by the time this group has reached the age of 55, there will be 2000 of them. If the target year is 2020, these individuals will enter the population in the year \(2020 - (55 - 15) = 1980\) so they are the correct age during the target year.

2.1.1 Post Target year Entries

After the target year, the number of individuals entering the population will be set in a way to maintain a constant number of individuals who are age 65 or younger. This is done by calculating the difference between the initial younger than 65 population and the size of the current younger than 65 population. The number of individuals who enter at each time step is equal to this difference. The proportion of entering individuals who are assigned to each race is determined by the OFM projected proportions for the given year in the 15 - 19 year old age group.

2.2 Simulations

2.2.1 Overall Number of individuals

Looking at the number of individuals in our population at any given time, we see that the population steadily grows, then decreases before leveling out around the target year. This is reasonable given what are targets are for population and what our starting population is. Our beginning population consist of only 15-65 year olds, but we allow individuals to survive until 90 in our simulation. Also, our target (at least after the target year) is to have a constant 15-65 population, so it is reasonable that our population will increase as we fill out the population in the 65-90 age range.

make_num_plot(sim)

2.2.2 Number of individuals broken down by Active versus Inactive

Breaking out the populations out by older than 65 and younger than 65 , we see that the active age population (15-65) is relatively constant before the target year (given by the black vertical line) and almost nearly flat after 2019. Before the target year, we are allowing our population to fluctuate somewhat so that we hit our demographic targets in the target year. After the target year the simulation aims to maintain a constant active age population.

make_split_num_plot(sim)

2.2.3 Number of individuals broken down by Active versus Inactive and race

Looking at things broken out by race, we see that the number of individuals of each race stays relatively constant until the target year at which point, the number of Hispanics increases, and the number of others decreases. This is in line with the OFM projections that the proportion of the population that is Hispanic will increase.

make_split_num_plot(sim, race_spec = TRUE)

2.2.4 Comparison with the OFM projections

Next, we make a comparison of the race distribution trajectories for our simulations with the OFM trajectories of the racial distribution. Because the OFM gives projections for all individuals in the state rather than MSM there are some differences between the overall racial distributions. Most notably, the OFM has a much smaller proportion of black individuals than our targets (based on census data).

ofm_proj <- WApopdata::wa_age_race_sex_hisp_proj %>% filter(hisp != "all", sex == "Male", age.grp != "Total")
ofm_proj$racecat <- ifelse(ofm_proj$hisp == "Hispanic", "H",
                           ifelse(ofm_proj$race == "Black", "B", "O"))
ofm_proj$simp_ag <- substr(ofm_proj$age.grp, 1,1)

ofm_df <- ofm_proj %>% filter(!(age.grp %in% c("0-4", "5-9", "Total")))
ofm_df$act <- as.numeric(as.numeric(ofm_df$simp_ag) < 6)
ofm_df$act[ofm_df$age.grp == c("60-64")] <- 1
ofm_df$Age <- c("Active (< 65)", "Inactive (>65)")[2 - ofm_df$act]

ofm_df <- ofm_df %>% group_by(Age, race = racecat, year) %>%
  summarise(n = sum(value)) %>% ungroup() %>% group_by(year) %>%
  mutate(prop = n / sum(n)) %>%
  mutate(type = "OFM")

make_split_num_plot(sim, race_spec = TRUE, ofm_df)

2.2.5 Comparison with regularized OFM projections

In this plot we have re-weighted the OFM projects to match more closely the race distribution of our targets. Thus, it easier to make comparisons in trends.

ofm_race_dst <- ofm_df %>% filter(year == 2019) %>% group_by(race) %>%
  summarise(prop = sum(prop))

adj_race_dst <- WApopdata::msm_15_65_age10_race_region_2019 %>% 
  group_by(race) %>% summarise(num = sum(num)) %>% ungroup() %>%
  mutate(c_prop = num / sum(num)) 
bth_rd <- left_join(ofm_race_dst, adj_race_dst, by = "race") %>% 
  mutate(cor = c_prop / prop) %>% select(race, cor)
ofm_df <- left_join(ofm_df, bth_rd, by = "race") %>% 
  mutate(prop = prop * cor) %>% group_by(year) %>%
  mutate(prop = prop / sum(prop)) %>% ungroup()

make_split_num_plot(sim, race_spec = TRUE, ofm_df)

2.2.6 Age race distributions at certain years

Lastly, looking at how our simulated demographics line up with our targets, we see that we do hit our ACS based targets fairly well in 2019. The age 65-90 segment of the population is excluded in these plots so it is easier to see how well we are hitting our targets (which only include 15 - 65 year olds).

make_ridge_plot(all_entry)

2.3 Exiting rates

Here we plot the exit rates among each group and compare this to the target, which in this case is the ASMR. This plot is given on the log scale, with 0 values assigned to be vary small values instead (0.001).

make_exit_plot(sim)
## `geom_smooth()` using formula 'y ~ x'
## Warning: Removed 3 rows containing non-finite values (stat_smooth).
## Warning: Removed 3 rows containing missing values (geom_point).
## `geom_smooth()` using formula 'y ~ x'

2.4 Mean Degree plots

2.4.1 Within race group

make_mean_deg_plot(sim, var_type = "race", nest_obj = netest_obj)

2.4.2 Within Region

make_mean_deg_plot(sim, var_type = "region", netest_obj)

2.5 Mean Degree across Age

2.5.1 Within race group

make_mean_agextime(sim, "race", yrs = c(2005, 2015, 2019, 2025, 2030), 
                   nest_obj = netest_obj)

2.5.2 Within Region

make_mean_agextime(sim, "region", yrs = c(2005, 2015, 2019, 2025, 2030), 
                   nest_obj = netest_obj)

2.6 Session Info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS  10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] ggridges_0.5.2  forcats_0.5.0   stringr_1.4.0   dplyr_1.0.2    
##  [5] purrr_0.3.4     readr_1.3.1     tidyr_1.1.2     tibble_3.0.4   
##  [9] ggplot2_3.3.2   tidyverse_1.3.0
## 
## loaded via a namespace (and not attached):
##  [1] nlme_3.1-148             fs_1.4.2                 lubridate_1.7.9         
##  [4] doParallel_1.0.15        RColorBrewer_1.1-2       httr_1.4.2              
##  [7] rprojroot_1.3-2          tools_4.0.2              backports_1.1.10        
## [10] R6_2.5.0                 mgcv_1.8-31              DBI_1.1.0               
## [13] lazyeval_0.2.2           colorspace_1.4-1         withr_2.3.0             
## [16] tidyselect_1.1.0         compiler_4.0.2           ergm.multi_4.0-3965     
## [19] cli_2.2.0                rvest_0.3.6              xml2_1.3.2              
## [22] network_1.17.0-585       labeling_0.3             scales_1.1.1            
## [25] DEoptimR_1.0-8           robustbase_0.93-6        digest_0.6.27           
## [28] ergm.ego_0.5-471         rmarkdown_2.4            networkDynamic_0.10.2   
## [31] pkgconfig_2.0.3          htmltools_0.5.0          egonet_0.0.1.0          
## [34] highr_0.8                dbplyr_1.4.4             rlang_0.4.9             
## [37] readxl_1.3.1             rstudioapi_0.11          farver_2.0.3            
## [40] generics_0.1.0           jsonlite_1.7.1           statnet.common_4.4.0-300
## [43] magrittr_2.0.1           Matrix_1.2-18            Rcpp_1.0.5              
## [46] munsell_0.5.0            fansi_0.4.1              ape_5.4-1               
## [49] lifecycle_0.2.0          stringi_1.5.3            yaml_2.2.1              
## [52] MASS_7.3-51.6            plyr_1.8.6               grid_4.0.2              
## [55] blob_1.2.1               parallel_4.0.2           crayon_1.3.4            
## [58] lattice_0.20-41          cowplot_1.0.0            haven_2.3.1             
## [61] splines_4.0.2            hms_0.5.3                sna_2.5                 
## [64] knitr_1.30               pillar_1.4.7             EpiModel_2.0.3          
## [67] codetools_0.2-16         lpSolve_5.6.15           rle_0.9.2-223           
## [70] WApopdata_1.0.8          srvyr_0.3.10             reprex_0.3.0            
## [73] glue_1.4.2               evaluate_0.14            trust_0.1-8             
## [76] mitools_2.4              modelr_0.1.8             deSolve_1.28            
## [79] vctrs_0.3.6              foreach_1.5.0            cellranger_1.1.0        
## [82] gtable_0.3.0             assertthat_0.2.1         tergmLite_2.2.0         
## [85] tergm_4.0.0-2013         xfun_0.18                broom_0.7.1             
## [88] survey_4.0               ergm_4.0-5824            coda_0.19-4             
## [91] survival_3.2-7           iterators_1.0.12         ellipsis_0.3.1          
## [94] here_0.1

  1. We are currently working to update the ACS data, but the
    original ACS website (“American Fact Finder”) has been retired, and the replacement website does not provide the tables we need to update to 2019 5-year estimates.↩︎