# ------------------------------------------------------------
# PLOT 1: Higher income households spend more per shopping trip
# Uses: transactions_sample + demographics  (JOIN)
# Aggregation: basket-level then income-level summaries
# ------------------------------------------------------------

basket_df <- transactions_sample %>%
  inner_join(demographics, by = "household_id") %>%
  filter(!is.na(income)) %>%
  group_by(income, basket_id) %>%
  summarise(
    basket_sales = sum(sales_value, na.rm = TRUE),
    basket_quantity = sum(quantity, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  group_by(income) %>%
  summarise(
    avg_sales = mean(basket_sales),
    avg_quantity = mean(basket_quantity),
    baskets = n(),
    .groups = "drop"
  ) %>%
  filter(baskets >= 50) %>%
  mutate(
    income = fct_relevel(
      income,
      "Under 15K", "15-24K", "25-34K", "35-49K",
      "50-74K", "75-99K", "100-124K",
      "125-149K", "150-174K", "175-199K",
      "200-249K", "250K+"
    )
  )

low  <- basket_df %>% slice_min(income, n = 1)
high <- basket_df %>% slice_max(income, n = 1)

difference <- round(high$avg_sales - low$avg_sales, 2)
percent_lift <- round((high$avg_sales - low$avg_sales) / low$avg_sales * 100, 1)

ggplot(basket_df, aes(x = income, y = avg_sales, group = 1)) +
  geom_col(fill = "steelblue") +
  geom_line(color = "black", linewidth = 1) +
  geom_point(size = 3) +
  scale_y_continuous(labels = dollar_format()) +
  labs(
    title = "Higher income households spend more per shopping trip",
    subtitle = paste0(
      "Average basket sales increase by $", difference,
      " (", percent_lift, "%) from lowest to highest income segment"
    ),
    x = "Household income segment",
    y = "Average basket sales ($)"
  ) +
  theme_minimal(base_size = 12) +
  theme(axis.text.x = element_text(angle = 25, hjust = 1))


# ------------------------------------------------------------
# Plot 2: Average spend per distinct product in basket
# ------------------------------------------------------------

basket_df <- transactions_sample %>%
  inner_join(demographics, by = "household_id") %>%
  filter(!is.na(income)) %>%
  group_by(basket_id, income) %>%
  summarise(
    total_sales = sum(sales_value, na.rm = TRUE),
    distinct_products = n_distinct(product_id),
    spend_per_product = total_sales / distinct_products,
    .groups = "drop"
  ) %>%
  mutate(
    income_collapsed = case_when(
      income %in% c("Under 15K", "15-24K", "25-34K") ~ "Under 35K",
      income %in% c("35-49K", "50-74K") ~ "35-74K",
      income %in% c("75-99K", "100-124K") ~ "75-124K",
      income %in% c("125-149K", "150-174K") ~ "125-174K",
      income %in% c("175-199K", "200-249K", "250K+") ~ "175K+",
      TRUE ~ NA_character_
    ),
    income_collapsed = factor(
      income_collapsed,
      levels = c("Under 35K", "35-74K", "75-124K", "125-174K", "175K+")
    )
  ) %>%
  filter(!is.na(income_collapsed))

plot2_df <- basket_df %>%
  group_by(income_collapsed) %>%
  summarise(
    avg_spend = mean(spend_per_product, na.rm = TRUE),
    sd_spend  = sd(spend_per_product, na.rm = TRUE),
    baskets   = n(),
    se        = sd_spend / sqrt(baskets),
    ci_lower  = avg_spend - 1.96 * se,
    ci_upper  = avg_spend + 1.96 * se,
    .groups = "drop"
  ) %>%
  filter(baskets >= 50) %>%
  mutate(label_y = ci_upper + 0.06)

min_val <- min(plot2_df$avg_spend)
max_val <- max(plot2_df$avg_spend)
diff_val <- max_val - min_val
pct_change <- diff_val / min_val

y_top <- max(plot2_df$label_y) * 1.05

ggplot(plot2_df, aes(x = income_collapsed, y = avg_spend)) +
  geom_col(fill = "#4C78A8", width = 0.7) +
  geom_errorbar(aes(ymin = ci_lower, ymax = ci_upper), width = 0.15, linewidth = 0.7) +
  geom_text(aes(y = label_y, label = paste0("n=", comma(baskets))),
            hjust = 0.5, vjust = 0, size = 3.5) +
  scale_y_continuous(labels = dollar_format(), limits = c(0, y_top)) +
  labs(
    title = "Higher income households spend more per product in each basket",
    subtitle = paste0(
      "Average spend per distinct product increases by $",
      round(diff_val, 2),
      " (", percent(pct_change, accuracy = 0.1), ") from lowest to highest income group\n",
      "Error bars represent 95% confidence intervals"
    ),
    x = "Household income segment (collapsed)",
    y = "Average spend per distinct product ($)",
    caption = "Source: completejourney (transactions_sample + demographics). Basket-level metrics; income groups with ≥ 50 baskets."
  ) +
  theme_minimal(base_size = 12) +
  theme(
    plot.title.position = "plot",
    plot.caption = element_text(hjust = 0.5),
    axis.title.x = element_text(hjust = 0.5),
    axis.title.y = element_text(hjust = 0.5)
  )


# ------------------------------------------------------------
# Plot 3: Higher income households buy a wider variety per trip
# ------------------------------------------------------------

basket_variety <- transactions_sample %>%
  inner_join(demographics, by = "household_id") %>%
  filter(!is.na(income)) %>%
  group_by(basket_id, income) %>%
  summarise(
    distinct_products = n_distinct(product_id),
    .groups = "drop"
  ) %>%
  mutate(
    income_collapsed = case_when(
      income %in% c("Under 15K", "15-24K", "25-34K") ~ "Under 35K",
      income %in% c("35-49K", "50-74K") ~ "35-74K",
      income %in% c("75-99K", "100-124K") ~ "75-124K",
      income %in% c("125-149K", "150-174K") ~ "125-174K",
      income %in% c("175-199K", "200-249K", "250K+") ~ "175K+",
      TRUE ~ NA_character_
    ),
    income_collapsed = factor(
      income_collapsed,
      levels = c("Under 35K", "35-74K", "75-124K", "125-174K", "175K+")
    )
  ) %>%
  filter(!is.na(income_collapsed))

plot3_df <- basket_variety %>%
  group_by(income_collapsed) %>%
  summarise(
    avg_variety = mean(distinct_products, na.rm = TRUE),
    baskets = n(),
    .groups = "drop"
  ) %>%
  filter(baskets >= 50)

min_val <- min(plot3_df$avg_variety)
max_val <- max(plot3_df$avg_variety)
diff_val <- max_val - min_val
pct_change <- diff_val / min_val

ggplot(plot3_df, aes(x = income_collapsed, y = avg_variety, group = 1)) +
  geom_line(linewidth = 1.2, color = "#2F4B7C") +
  geom_point(size = 3, color = "#2F4B7C") +
  labs(
    title = "Higher income households buy a wider variety per shopping trip",
    subtitle = paste0(
      "Average distinct products per basket generally increase by ",
      round(diff_val, 2),
      " (", percent(pct_change, accuracy = 0.1), ") from lowest to highest income group"
    ),
    x = "Household income segment (collapsed)",
    y = "Average distinct products per basket",
    caption = "Source: completejourney (transactions_sample + demographics). Basket breadth measured as distinct product_id count; income groups with ≥ 50 baskets."
  ) +
  theme_minimal(base_size = 12) +
  theme(
    plot.title.position = "plot",
    plot.caption = element_text(hjust = 0.5),
    axis.title.x = element_text(hjust = 0.5),
    axis.title.y = element_text(hjust = 0.5)
  )