library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.1      ✔ purrr   1.0.1 
## ✔ tibble  3.1.8      ✔ dplyr   1.0.10
## ✔ tidyr   1.3.0      ✔ stringr 1.5.0 
## ✔ readr   2.1.3      ✔ forcats 0.5.2 
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(here)
## here() starts at /Users/caoanjie/Desktop/projects/looking_time_models/03_pyGRANCH_multi/behavioral_fit
library(ggthemes)
surprirsal_df <- tibble(
  t = seq(1, 30), 
  surprisal = .5**t,
)  
surprirsal_df %>% 
  ggplot(aes(x = t, y = surprisal)) + 
  geom_point() 

1 Constant world EIG - Regular Luce Choice Rule

world_eigs <- c(0.000001, 0.001, 0.01, 0.05, 2, 5)


luce_choice_rule <- function(t, s, base){
  t = t
  prob = (s / (s + base))
  return (tibble(
    "time" = t, 
    "weig" = base,
    "prob" = prob
  ))
}

crossing(world_eigs, surprirsal_df) %>% 
  group_by(world_eigs) %>% 
  mutate(group_id = cur_group_id()) %>% 
  ungroup() %>% 
  group_by(group_id) %>% 
  nest() %>% 
  mutate(
    prob = nest(map_df(data, ~ luce_choice_rule(.x$t, .x$surprisal, .x$world_eigs)))
  ) %>%
  select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
  ggplot(aes(x = time, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  theme_few() 

2 changing EIG, regular luce choice rule

2.1 Linear

Here we are assuming the world is getting interestingly linearly

intercept <- c(0.0001, 0.02, 0.04, 0.08)
slope <- c(0.005, 0.01, 0.05, 1)
t <- seq(1, 30)

crossing(intercept, slope, surprirsal_df) %>% 
  mutate(
    world_eig = intercept + slope * t,
    world_eig_formula = paste("y ~ ", intercept, " + ", slope, " * t")) %>% 
 mutate(prob = surprisal / (surprisal + world_eig)) %>% 
  ggplot(aes(x = t, y = prob, color = world_eig_formula)) + 
  geom_point() +
  geom_line()

2.2 Exponential

a <- c(0.000000001, 0.2, 0.4, 0.5, 0.9, 0.99)

crossing(a, surprirsal_df) %>% 
  mutate(
    world_eig = a ** t,
    world_eig_formula = paste("y ~ ", a, " ** t")) %>% 
 mutate(prob = surprisal / (surprisal + world_eig)) %>% 
  ggplot(aes(x = t, y = prob, color = as.factor(a), group = a)) + 
  geom_point() +
  geom_line()

2.3 Log

world getting interesting but slowly

a <- c(0.1, 0.2, 0.4, 0.5, 0.9, 0.99)

crossing(a, surprirsal_df) %>% 
  mutate(
    world_eig = -log(t, base = a),
    world_eig_formula = paste("y ~ ", a, " ** t")) %>% 
 mutate(prob = surprisal / (surprisal + world_eig)) %>% 
  ggplot(aes(x = t, y = prob, color = as.factor(a), group = a)) + 
  geom_point() +
  geom_line()

2.4 Quadratic

unclear what’s the best way to spin it also the parameter space is a little bit ????

a <- c(-2, -0.5, -0.1, 0.1, 0.5, 2)
b <- c(-5, -0.1, 0.1, 20)
c <- seq(-2, 2, 1.5)


#a <- c(2)
#b <- c(-5)
#c <- c(5)

crossing(a, b, c, surprirsal_df) %>% 
  mutate(world_eig = (a* (t**2) + b*t + c) / 10, 
         formula = paste0(a, "* t **2 + ", b, " * t + ", c)) %>% 
  filter(a > 0) %>% 
  filter(world_eig > 0) %>% 
  mutate(prob = surprisal / (surprisal + world_eig)) %>% 
  ggplot(aes(x = t, y = prob, group = formula)) + 
  geom_point() +
  geom_line()+ 
  facet_wrap(~formula)

3 Constant EIG, with luce choice parameter

lt ~ e^f(surprisal) / e^f(surprisal) + c

3.1 Linear function on surprisal

3.1.1 constant base

lt ~ (intercept + slope * surprisal) / ((intercept + slope * surprisal) + weig)

world_eigs <- c(0.000001, 0.001, 0.01, 0.05, 2, 5)

s_intercept <- c(0.0001, 0.02, 0.04, 0.08)
s_slope <- c(0.005, 0.01, 0.05, 1, 3)


linear_choice <- function(t, s_slope, s_intercept, s, base){
  t = t
  s = s_intercept + s_slope * s
  
  prob = (s / (s + base))

  
  
  return (tibble(
    "time" = t, 
    "weig" = base,
    "prob" = prob,
    "s_intercept" = s_intercept, 
    "s_slope" = s_slope
  ))
}

crossing(world_eigs, s_intercept, s_slope, surprirsal_df) %>% 
   group_by(world_eigs, s_intercept, s_slope) %>% 
  mutate(group_id = cur_group_id()) %>% 
  ungroup() %>% 
  group_by(group_id) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ linear_choice(.x$t, .x$s_slope, .x$s_intercept, .x$surprisal, .x$world_eigs)))
  ) %>% 
  select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
  mutate(s_intercept_print = paste("intercept: ", s_intercept), 
         s_slope_prnt = paste("slope: ", s_slope)) %>% 
  ggplot(aes(x = time, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  theme_few()  + 
  facet_grid(s_intercept_print ~ s_slope_prnt)

3.1.2 varying base

lt ~ base ** (intercept + slope * surprisal) / (base ** (intercept + slope * surprisal) + weig)

world_eigs <- c(0.0001)
#world_eigs <- c(0.000001, 0.001, 0.01, 0.05, 2, 5)


s_intercept <- c(0.0001, 0.02, 0.04, 0.08)
#s_slope <- c(0.005, 0.01, 0.05, 1, 3, 5, 10)
s_slope <- c(2, 3, 5, 10)

s_base <- c(0.5,  0.8, 1, 1.2)

linear_choice_base_change <- function(t, s_base, s_slope, s_intercept, s, base){
  t = t
  s = s_base ** (s_intercept + s_slope * s)
  
  prob = (s / (s + base))

  
  
  return (tibble(
    "time" = t, 
    "weig" = base,
    "prob" = prob,
    "s_base" = s_base, 
    "s_intercept" = s_intercept, 
    "s_slope" = s_slope
  ))
}

crossing(world_eigs, s_base, s_intercept, s_slope, surprirsal_df) %>% 
   group_by(world_eigs, s_base, s_intercept, s_slope) %>% 
  mutate(group_id = cur_group_id()) %>% 
  ungroup() %>% 
  group_by(group_id) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ linear_choice_base_change(.x$t, .x$s_base, .x$s_slope, .x$s_intercept, .x$surprisal, .x$world_eigs)))
  ) %>% 
  select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
  mutate(s_intercept_print = paste("intercept: ", s_intercept), 
         s_slope_prnt = paste("slope: ", s_slope)) %>% 
  ggplot(aes(x = time, y = prob, color = as.factor(s_base))) + 
  geom_point(position = position_dodge(width = .8)) + 
  theme_few()  + 
  facet_grid(s_intercept_print ~ s_slope_prnt, scales = "free")

3.2 Exponential on surprisal

lt ~ base ** (exp ** surpisal) / (base ** (exp ** surpisal) + weig))

world_eigs <- c(0.000001, 0.001, 0.01, 0.05, 2, 5)
exp <- c(0.000000001, 0.2, 0.4, 0.5, 0.9, 0.99)
s_base <- c(0.5,  0.8, 1, 1.2)



exp_choice <- function(t, exp, s, base, s_base){
  t = t
  s = s_base ** (exp ** s)
  prob = (s / (s + base))

  return (tibble(
    "time" = t, 
    "weig" = base,
    "prob" = prob,
    "exp" = exp, 
    "s_base" = s_base
  ))
}

crossing(world_eigs, exp, s_base, surprirsal_df) %>% 
   group_by(world_eigs, s_base, exp) %>% 
  mutate(group_id = cur_group_id()) %>% 
  ungroup() %>% 
  group_by(group_id) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ exp_choice(.x$t, .x$exp,  .x$surprisal, .x$world_eigs, .x$s_base)))
  ) %>% 
    select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
    mutate(s_base_print = paste("base: ", s_base), 
         s_exp_print = paste("exp: ", exp)) %>% 
  ggplot(aes(x = time, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  theme_few()  + 
  facet_grid(s_base_print ~ s_exp_print)

3.3 Log on surprisal

world_eigs <- c(0.000001, 0.001, 0.01, 0.05, 2, 5)
log_base <- c(0.1, 0.2, 0.4, 0.5, 0.9, 0.99)




log_choice <- function(t, log_base, s, base){
  t = t
  s = log(s, base = log_base)
  prob = (s / (s + base))

  return (tibble(
    "time" = t, 
    "weig" = base,
    "prob" = prob,
    "log_base" = log_base
  ))
}

crossing(world_eigs, log_base, surprirsal_df) %>% 
   group_by(world_eigs, log_base) %>% 
  mutate(group_id = cur_group_id()) %>% 
  ungroup() %>% 
  group_by(group_id) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ log_choice(.x$t, .x$log_base,  .x$surprisal,  .x$world_eigs)))
  ) %>% 
    select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
   
  ggplot(aes(x = time, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  theme_few()  +
  facet_wrap(~log_base)

3.4 quadratic

prob = (a * (surprisal ** 2) + b * surprisal + c) / ((a * (surprisal ** 2) + b * surprisal + c) + constant)

#sa <- c(-2, -0.5, -0.1, 0.1, 0.5, 2)
#sb <- c(-5, -0.1, 0.1, 20)
#sc <- seq(-2, 2, 1.5, 0)

sa <- c(-0.1,10)
sb <- c(-0.1, 2)
sc <- c(-0.5, 5)
world_eigs <-  c(0.001)

quad_choice <- function(t, s,sa, sb, sc, base){
  t = t
  s = sa * (s**2) + sb * (s * 2) + sc 
  prob = (s / (s + base))

  return (tibble(
    "time" = t, 
    "surprisal" = s, 
    "weig" = base,
    "prob" = prob,
    "sa" = sa, 
    "sb" = sb, 
    "sc" = sc
  ))
}


# surprisal by prob 
crossing(sa, sb, sc, world_eigs, surprirsal_df) %>% 
  group_by(world_eigs, sa, sb, sc) %>% 
  mutate(group_id = cur_group_id()) %>% 
  mutate(formula = paste0(sa, "* s **2 + ", sb, " * s + ", sc)) %>% 
  ungroup() %>% 
  group_by(group_id, formula) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ quad_choice(.x$t, .x$surprisal, .x$sa, .x$sb, .x$sc,  .x$world_eigs)))
  ) %>% 
    select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
  ggplot(aes(x = surprisal, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  geom_line() + 
  theme_few() + 
  facet_wrap(~formula, scales = "free")

# t by prob

crossing(sa, sb, sc, world_eigs, surprirsal_df) %>% 
  group_by(world_eigs, sa, sb, sc) %>% 
  mutate(group_id = cur_group_id()) %>% 
  mutate(formula = paste0(sa, "* s **2 + ", sb, " * s + ", sc)) %>% 
  ungroup() %>% 
  group_by(group_id, formula) %>% 
  nest() %>% 
   mutate(
    prob = nest(map_df(data, ~ quad_choice(.x$t, .x$surprisal, .x$sa, .x$sb, .x$sc,  .x$world_eigs)))
  ) %>% 
    select(-data) %>% 
  unnest(prob) %>% 
  unnest(data) %>% 
  ggplot(aes(x = time, y = prob, color = as.factor(weig))) + 
  geom_point() + 
  theme_few() + 
  facet_wrap(~formula, scales = "free")

#a <- c(2)
#b <- c(-5)
#c <- c(5)