library(rethinking)
library(cmdstanr)

library(tidyverse)
library(tidybayes)
library(haven)

for( .file in list.files('nlss2022/data/stata format') ) {
  .name <- strsplit(.file, ".", fixed=TRUE)[[1]][1]
  .path <- paste0('nlss2022/data/stata format/', .file)
  .table <- read_dta(.path) |> as_factor()
  assign( .name , .table )
}
df <- poverty |>
  left_join(
    S05 |> filter(food_code == 701)
  ) |>
  mutate(
    potato_consumption = ((q05_03 + q05_04 + q05_05) / hhsize ) |> replace_na(0),
    potato_value = ((replace_na(q05_03_b, 0) + replace_na(q05_04_b, 0) + replace_na(q05_05_b, 0)) / hhsize),
    pcep_potato = log(potato_value / pcep_food),
    spending = log(pcep),
  ) |>
  filter(is.finite(pcep_potato) & is.finite(pcep)) |>
  mutate(
    P = standardize(pcep_potato),
    S = standardize(spending)
  )
## Joining with `by = join_by(psu_number, hh_number)`
# raw values
df |>
  mutate(
    S = pcep
  ) |>
  ggplot( aes(S, weight=hhs_wt) ) +
    geom_histogram(bins=20, color="gray")

# standardized values
df |>
  mutate(
    S = pcep |> standardize()
  ) |>
  ggplot( aes(S, weight=hhs_wt) ) +
    geom_histogram(bins=20, color="gray") +
    labs(x="PCEP (standardized)")

# logscale values
df |>
  mutate(
    S = pcep |> log()
  ) |>
  filter( is.finite(S) ) |>
  ggplot( aes(S, weight=hhs_wt) ) +
    geom_histogram(bins=20, color="gray") +
    labs(x="PCEP (logscale)")

df |>
  mutate(
    S = pcep |> log() |> standardize()
  ) |>
  filter( is.finite(S) ) |>
  ggplot( aes(S, weight=hhs_wt) ) +
    geom_histogram(bins=20, color="gray") +
    labs(x="PCEP (logscale and standardized)")

# standardized and logscale densities
df |>
  ggplot(mapping = aes(x = P)) +
    geom_density() +
    labs(
      x = "Potato consumption as a proportion of overall food consumption (standardized, logscale)"
    )

df |>
  ggplot(mapping = aes(x = S)) +
    geom_density() +
    labs(
      x = "Per Capital total spending (standardized, logscale)"
    )

df |>
  ggplot(mapping = aes(x = S, y = P)) +
    geom_point(shape=1, color="blue", alpha=0.2) +
    labs(x = "Total Spending", y="Potato Consumption")

model <- quap(
  alist(
    P ~ dnorm( mu , sigma ),
    mu <- a + b * S, 
    a ~ dnorm( 0, 2 ),
    b ~ dnorm( 0 , 2 ),
    sigma ~ dunif( 0 , 5 )
  ),
  data=df
)
precis(model, corr=TRUE, digits=4)
##                mean          sd        5.5%       94.5%
## a     -0.0000178637 0.009821625 -0.01571472  0.01567899
## b     -0.3843162628 0.009822181 -0.40001400 -0.36861852
## sigma  0.9231916433 0.006945525  0.91209135  0.93429193
# add some summary statistics
tibble(
  S = seq(from = -3, to = 6.5, by = 0.1)
) |>
  mutate(
    mu = map(S, \(x) link(model, data = list(S = x))),
    mu.avg = map_dbl(mu, mean),
    mu.min = map_dbl(mu, quantile, 0.025),
    mu.max = map_dbl(mu, quantile, 0.975),
    sim = map(S, \(x) sim(model, data = list(S = x))),
    sim.min = map_dbl(sim, quantile, 0.025),
    sim.max = map_dbl(sim, quantile, 0.975),
  ) |>
  ggplot(mapping = aes(x = S)) +
    geom_ribbon(aes(ymin=sim.min, ymax=sim.max), color="lightgray", alpha=0.2) +
    geom_ribbon(aes(ymin=mu.min, ymax=mu.max), color="gray", alpha=0.2) +
    geom_line(aes(y = mu.avg)) +
    geom_point(data=df, aes(y = P), shape=1, color="blue", alpha=0.2) +
    labs(x = "Total Spending", y="Potato Consumption")

.linked <- link(model)
.simmed <- sim(model)
df |>
  mutate(
    reg.avg = .linked |> apply(2, mean),
    reg.min = .linked |> apply(2, quantile, 0.025),
    reg.max = .linked |> apply(2, quantile, 0.975),
    sim.min = .simmed |> apply(2, quantile, 0.1),
    sim.max = .simmed |> apply(2, quantile, 0.9),
  ) |>
  ggplot(mapping = aes(x = S, y = P)) +
    geom_ribbon(aes(ymin=sim.min, ymax=sim.max), color="lightgray", alpha=0.2) +
    geom_ribbon(aes(ymin=reg.min, ymax=reg.max), color="gray", alpha=0.2) +
    geom_line(aes(y = reg.avg)) +
    geom_point(shape=1, color="blue", alpha=0.2) +
    labs(x = "Total Spending", y="Potato Consumption")