Exploring relationship between surprisal and SPRT

initial EDA

surps.RTs %>% 
  pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=RT)) +
  # geom_density(aes(x=surprisal,y=1200+750*..density..)) + 
  geom_hex() + scale_fill_gradient(low="#DDDDDD",high="black") +
  # geom_point(alpha=.1,color="gray") + 
  geom_smooth() +
  xlab("surprisal") +
  facet_wrap(~model)+
  ggtitle("surprisal vs self-paced RT on naturalstories data")
`geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

# with arithmetic statistics
surps.RTs.perword %>% 
  pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=meanItemRT)) +
  # geom_density(data=surps.RTs %>% 
      # pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
      # drop_na(), 
    # aes(x=surprisal,y=-1250*..density..)) + 
  geom_errorbar(aes(ymin=meanItemRT-sdItemRT, ymax=meanItemRT+sdItemRT),alpha=.1, color="black") +
  geom_point(alpha=.1,color="blue") +
  xlab("surprisal") +
  facet_wrap(~model)+
  ggtitle("surprisal vs mean RT (blue) +-1sd (black) per item")


# with geometric statistics
# surps.RTs.perword %>% 
#   pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#   drop_na() %>% 
#   ggplot(aes(x=surprisal,y=gmeanItemRT)) +
#   geom_density(aes(x=surprisal,y=1200+750*..density..)) + 
#   geom_errorbar(aes(ymin=gmeanItemRT/gsdItemRT, ymax=gmeanItemRT*gsdItemRT),alpha=.1, color="black") +
#   geom_point(alpha=.1,color="blue") +
#   xlab("surprisal") +
#   facet_wrap(~model)+
#   ggtitle("surprisal vs geometric mean RT (blue) * gsd^{+-1} (black) per item")

Almost all the data is at low surprisal. The variance increases a bit with higher surprisal.

surps.RTs %>%
pivot_longer(starts_with("surp_") &!contains("boyce"), 
             names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
drop_na() %>% 
ggplot(aes(x=surprisal,y=RT)) +
stat_boxplot(aes(x=cut_width(surprisal,boundary=0, width=1)),notch=T,varwidth=F,position="identity",alpha=.75,
             outlier.color="gray", outlier.alpha = .5, outlier.size= .5,outlier.shape=20, color="blue"
             ) +
# stat_summary_bin(fun.data="mean_sdl", geom = "crossbar", binwidth = .5, alpha=0.2, fill="blue", size=.1)+
# stat_summary_bin(fun="median", binwidth = 1, alpha=0.2, size=.5, shape=3, color="red")+
xlab("surprisal (binned)") + theme(axis.text.x = element_text(angle=90, vjust=.5, hjust=1)) +
facet_wrap(~model,nrow=1) +
ggtitle("surprisal vs RT, binned by surprisal") 


surps.RTs %>% 
  pivot_longer(starts_with("surp_boyce"), 
               names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=RT)) +
  stat_boxplot(aes(x=cut_width(surprisal,boundary=0, width=1)),notch=T,varwidth=F,position="identity",alpha=.75,
               outlier.color="gray", outlier.alpha = .5, outlier.size= .5,outlier.shape=20, color="blue"
               ) +
  # stat_summary_bin(fun.data="mean_sdl", geom = "crossbar", binwidth = .5, alpha=0.2, fill="blue", size=.1)+
  # stat_summary_bin(fun="median", binwidth = 1, alpha=0.2, size=.5, shape=3, color="red")+
  xlab("surprisal (binned)") + theme(axis.text.x = element_text(angle=90, vjust=.5, hjust=1)) +
  facet_wrap(~model) +
  ggtitle("surprisal vs RT, binned by surprisal")

qiuck GAM Comparison

plot_compare_models_smooth<-function(data, what_to_predict, surp_threshold = 100, ...){
  data %>%
  pivot_longer(starts_with("surp_") & !ends_with("_c"), 
               names_to = "model", values_to = "surprisal", 
               names_prefix = "surp_") %>% 
    filter(!is.na(get(what_to_predict)),!is.na(surprisal)) %>% filter(surprisal<=surp_threshold) %>% 
  mutate(model=factor(model)) %>% 
  mutate(model=fct_relevel(model,"GPT2","GPT2-large","GPT-Neo","GPT-J","GPT3", after = 3)) %>% 
    ggplot(aes(color=model)) +
    geom_smooth(aes(x=surprisal,y=get(what_to_predict)), ...) +
    labs(x = "surprisal", y=what_to_predict) 
}

What to consider in choosing our basis?

  • cubic regression spline bs = 'cr' places knots by quantile (probably not what we want, since it will make things wiggly in the low-surprisal area where all the data are, artificially inflating edf, I think)
  • B-splines family (bs = 'bs', bs = 'ps', bs = 'ad') places knots evenly
  • thin plate bs = 'tp' places knots randomly max.knots (see smooth.construct.tp.smooth.spec)
#quick and dirty comparison
(
surps.RTs.perword %>% 
  plot_compare_models_smooth("meanItemRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "tp")) + ggtitle("tp")
)+(
surps.RTs.perword %>% 
  plot_compare_models_smooth("meanItemRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "cr")) + ggtitle("cr")
)+ plot_layout(guides = 'collect') + plot_annotation("SPRT vs surprisal GAMs quick model comparison")

Fitting gams

surps.RTs_lag_c <- surps.RTs %>% 
  inner_join(surps.RTs_lag_c.perword) %>% 
  mutate(Word_ID=as_factor(str_c(story_num, word_num_in_story, sep="_"))) %>% 
  relocate(Word_ID,word)
Joining, by = c("word_num_in_sentence", "sentence_num", "sentence", "word", "offset", "story_num", "surp_GPT2", "surp_GPT2-large", "surp_GPT-Neo", "surp_GPT-J", "surp_GPT3", "nth_occurence_in_story", "word_num_in_story", "nItem", "meanItemRT", "sdItemRT", "gmeanItemRT", "gsdItemRT", "word_human", "wordp_whole", "wordp_word", "wordp_1", "is_in_common_vocab", "wordlength", "Word", "surp_boyce_txl", "tokencount_boyce_txl", "surp_boyce_ngram", "tokencount_boyce_ngram", "surp_boyce_grnn", "tokencount_boyce_grnn", "freq", "length")

GAMs for SPRT

Fit a GAM for the GPT data

job::job({
  gam_gpt3_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT3")
  gam_gpt2_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT2")
  gam_gpt2l_RT  <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT2-large")
  gam_gptj_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT-J")
  gam_gptneo_RT <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT-Neo")
  
  write_rds(gam_gpt3_RT  , file="scratch/gam_gpt3_RT.rds")
  write_rds(gam_gpt2_RT  , file="scratch/gam_gpt2_RT.rds")
  write_rds(gam_gpt2l_RT , file="scratch/gam_gpt2l_RT.rds")
  write_rds(gam_gptj_RT  , file="scratch/gam_gptj_RT.rds")
  write_rds(gam_gptneo_RT, file="scratch/gam_gptneo_RT.rds")
}, 
title="fitting GAMs RT v GPT surprisals")

surprisal vs mazeRT

Aligning with AMAZE TASK from Boyce

# Data processing as in Boyce and Levy, in data_processing_mazeRTs.Rmd
maze_pre_error<-read_csv("natural-stories-surprisals/maze_data/maze_pre_error.csv")

(surps.RTs_lag_c.maze_pre_error %>% 
  plot_compare_models_smooth("mazeRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "tp")) + ggtitle("thin plate")
)+(
surps.RTs_lag_c.maze_pre_error %>% 
  plot_compare_models_smooth("mazeRT",# surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "cr")) + ggtitle("cubic regression")
) + plot_layout(guides = 'collect') + plot_annotation("Maze RT vs surprisal GAMs quick model comparison")

GAMS for mazeRT

# TODO: Morgan suggested adding randomeffects like this
job::job({  
  gam_gpt3_maze_withrandomeffects <- gam(
    rt ~
      s(surprisal, subject, bs="fs") + s(subject, bs='re') + s(word, bs='re') +
      s(surprisal, bs="cr", k=20) +
      te(freq, len, bs="cr") +
      s(prev_surp, bs="cr", k=20) +
      te(prev_freq, prev_len, bs="cr"),
    data=gpt3_maze,
    method="REML")
},
title = "gam_gpt3_maze_withrandomeffects")
plot.gam_surpRTmaze
plot.gam_surpRTmaze_boyce
draw_gam_both<-function(gam_model){
  print(gam.check(gam_model))
  draw(gam_model, seWithMean=T, shift=coef(gam_model)["(Intercept)"], select=c(1,3))
  }
---
title: "surprisal and reading time on Natural Stories"
author: "Jacob Louis Hoover"
date: "Fall 2021"
output:
  html_notebook: default
  html_document:
    df_print: paged
  tufte::tufte_html: default
editor_options:
  chunk_output_type: inline
---
```{r "setup", include=FALSE}
knitr::opts_knit$set(root.dir = "~/McGill/projects/EVAL2-processing-surprisal/")
```

```{r, include=FALSE}
library(tidyverse)
library(patchwork)
library(docstring)
# library(readr)
library(brms)
library(lme4)
library(rstan)
# library(tidybayes)
# library(knitr)
library(mgcv)
library(gratia)
# library(mgcViz)
# library(tidymv)
# library(rsample) 
# library(cowplot)
# library(scales)
library(curl)
theme_set(theme_bw())
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
```


```{r, include=F}
# Load data from data_processing_surps.RTs.Rmd

surps.RTs.perword<-read_csv("natural-stories-surprisals/surprisals_and_RTs_data/surps.RTs.perword.csv")
surps.RTs<-read_csv(        "natural-stories-surprisals/surprisals_and_RTs_data/surps.RTs.csv"        )

# Note, all models in one dataset. you can get just surps from one model with  somethinglike
# surps.RTs %>% select( (contains("GPT2")&!contains("-large")),!starts_with("surp"))
```


# Exploring relationship between surprisal and SPRT

## initial EDA

```{r}
surps.RTs %>% 
  pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=RT)) +
  # geom_density(aes(x=surprisal,y=1200+750*..density..)) + 
  geom_hex() + scale_fill_gradient(low="#DDDDDD",high="black") +
  # geom_point(alpha=.1,color="gray") + 
  geom_smooth() +
  xlab("surprisal") +
  facet_wrap(~model)+
  ggtitle("surprisal vs self-paced RT on naturalstories data")
```




```{r}
# with arithmetic statistics
surps.RTs.perword %>% 
  pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=meanItemRT)) +
  # geom_density(data=surps.RTs %>% 
      # pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
      # drop_na(), 
    # aes(x=surprisal,y=-1250*..density..)) + 
  geom_errorbar(aes(ymin=meanItemRT-sdItemRT, ymax=meanItemRT+sdItemRT),alpha=.1, color="black") +
  geom_point(alpha=.1,color="blue") +
  xlab("surprisal") +
  facet_wrap(~model)+
  ggtitle("surprisal vs mean RT (blue) +-1sd (black) per item")

# with geometric statistics
# surps.RTs.perword %>% 
#   pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#   drop_na() %>% 
#   ggplot(aes(x=surprisal,y=gmeanItemRT)) +
#   geom_density(aes(x=surprisal,y=1200+750*..density..)) + 
#   geom_errorbar(aes(ymin=gmeanItemRT/gsdItemRT, ymax=gmeanItemRT*gsdItemRT),alpha=.1, color="black") +
#   geom_point(alpha=.1,color="blue") +
#   xlab("surprisal") +
#   facet_wrap(~model)+
#   ggtitle("surprisal vs geometric mean RT (blue) * gsd^{+-1} (black) per item")
```

Almost all the data is at low surprisal. The variance increases a bit with higher surprisal.

```{r}
surps.RTs %>% 
  pivot_longer(starts_with("surp_") & !ends_with("_c"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  group_by(surprisalbin=cut_width(surprisal, boundary = 0, width = 1), model) %>% 
  summarise(n=n(), n_words=length(unique(word)), binned_surprisal_min=min(surprisal), var=var(RT), .groups = "drop") %>% 
  filter(n_words>1) %>% # remove points which come from just one word in the corpus
  ggplot(aes(x=binned_surprisal_min,y=var,color=log(n))) + 
  geom_point() +
  geom_smooth(method="lm", formula=y~x, color="red", linetype="solid", size=.25, alpha=.25) +
  facet_wrap(~model, scales = "free_y") +
  labs(x="binned surprisal", y="variance in RT") +
  scale_color_continuous(low="#DDDDDD", high="black") +
  ggtitle("variance in RT vs binned surprisal")

```


```{r}
surps.RTs %>%
pivot_longer(starts_with("surp_") &!contains("boyce"), 
             names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
drop_na() %>% 
ggplot(aes(x=surprisal,y=RT)) +
stat_boxplot(aes(x=cut_width(surprisal,boundary=0, width=1)),notch=T,varwidth=F,position="identity",alpha=.75,
             outlier.color="gray", outlier.alpha = .5, outlier.size= .5,outlier.shape=20, color="blue"
             ) +
# stat_summary_bin(fun.data="mean_sdl", geom = "crossbar", binwidth = .5, alpha=0.2, fill="blue", size=.1)+
# stat_summary_bin(fun="median", binwidth = 1, alpha=0.2, size=.5, shape=3, color="red")+
xlab("surprisal (binned)") + theme(axis.text.x = element_text(angle=90, vjust=.5, hjust=1)) +
facet_wrap(~model,nrow=1) +
ggtitle("surprisal vs RT, binned by surprisal") 

surps.RTs %>% 
  pivot_longer(starts_with("surp_boyce"), 
               names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  ggplot(aes(x=surprisal,y=RT)) +
  stat_boxplot(aes(x=cut_width(surprisal,boundary=0, width=1)),notch=T,varwidth=F,position="identity",alpha=.75,
               outlier.color="gray", outlier.alpha = .5, outlier.size= .5,outlier.shape=20, color="blue"
               ) +
  # stat_summary_bin(fun.data="mean_sdl", geom = "crossbar", binwidth = .5, alpha=0.2, fill="blue", size=.1)+
  # stat_summary_bin(fun="median", binwidth = 1, alpha=0.2, size=.5, shape=3, color="red")+
  xlab("surprisal (binned)") + theme(axis.text.x = element_text(angle=90, vjust=.5, hjust=1)) +
  facet_wrap(~model) +
  ggtitle("surprisal vs RT, binned by surprisal")
```


```{r eval=FALSE, include=FALSE}

# Fitting preliminary GAMs to the data doesn't make much difference if they're fit to all 
# RT datapoints or to the means per surprisal value
# surps.RTs.perword %>%
#   pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#   drop_na() %>%
#     ggplot() +
#     geom_line(aes(x=surprisal, y=meanItemRT),
#               stat="smooth", method = "gam", formula = y ~ s(x, bs = "cs"), color="red",alpha=0.3) +
#     geom_line(data=gpts.RTs %>% 
#                 pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#                 drop_na(),
#               aes(x=surprisal,y=RT),
#               stat="smooth", method = "gam", formula = y ~ s(x, bs = "cs"), color="blue", alpha=0.3) +
#     xlab("surprisal") +
#     facet_wrap(~factor(model,levels = c("GPT2","GPT2-large","GPT-Neo","GPT-J","GPT3")))

# surps.RTs.perword %>%
#   pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#   drop_na() %>%
#     ggplot() +
#     geom_smooth(aes(x=surprisal, y=meanItemRT) ,method = "gam", color="red") +
#     xlab("surprisal") +
#     facet_wrap(~factor(model,levels = c("GPT2","GPT2-large","GPT-Neo","GPT-J","GPT3")))
# 
# surps.RTs %>% 
#   pivot_longer(starts_with("surp_"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
#     filter(!is.na(meanItemRT),!is.na(surprisal)) %>%
#     ggplot() +
#     geom_smooth(aes(x=surprisal,y=RT), method = "gam", formula = y ~ s(x, bs = "cs")) +
#     xlab("surprisal") +
#     facet_wrap(~factor(model,levels = c("GPT2","GPT2-large","GPT-Neo","GPT-J","GPT3")))
```

## qiuck GAM Comparison

```{r}
plot_compare_models_smooth<-function(data, what_to_predict, surp_threshold = 100, ...){
  data %>%
  pivot_longer(starts_with("surp_") & !ends_with("_c"), 
               names_to = "model", values_to = "surprisal", 
               names_prefix = "surp_") %>% 
    filter(!is.na(get(what_to_predict)),!is.na(surprisal)) %>% filter(surprisal<=surp_threshold) %>% 
  mutate(model=factor(model)) %>% 
  mutate(model=fct_relevel(model,"GPT2","GPT2-large","GPT-Neo","GPT-J","GPT3", after = 3)) %>% 
    ggplot(aes(color=model)) +
    geom_smooth(aes(x=surprisal,y=get(what_to_predict)), ...) +
    labs(x = "surprisal", y=what_to_predict) 
}
```

What to consider in choosing our basis?

- cubic regression spline `bs = 'cr'` places knots by quantile (probably **not** what we want, since it will make things wiggly in the low-surprisal area where all the data are, artificially inflating edf, I think)
- B-splines family (`bs = 'bs'`, `bs = 'ps'`, `bs = 'ad'`) places knots evenly
- thin plate `bs = 'tp'` places knots randomly `max.knots` (see `smooth.construct.tp.smooth.spec`)

```{r}
#quick and dirty comparison
(
surps.RTs.perword %>% 
  plot_compare_models_smooth("meanItemRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "tp")) + ggtitle("tp")
)+(
surps.RTs.perword %>% 
  plot_compare_models_smooth("meanItemRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "cr")) + ggtitle("cr")
)+ plot_layout(guides = 'collect') + plot_annotation("SPRT vs surprisal GAMs quick model comparison")
```
# Fitting gams

```{r}
# remove all NAs and then 
# - add centred versions of predictors, with suffix "_c"
# - add in predictors for past word, with prefix "prev_"
# using dplyr's mutate(across()), which is as new version of mutate_at()
surps.RTs_lag_c.perword <- surps.RTs.perword %>% 
  drop_na(contains("surp_"),meanItemRT,freq,length) %>% 
  mutate(across(
    c(contains("surp_"),freq,length),
    .fns=list("c"=function(x){x - mean(x, na.rm=T)}),
    .names="{.col}_{.fn}")) %>% 
  mutate(across(
    c(contains("surp_"),freq,freq_c,length,length_c),
    .fns=list("prev"=lag), 
    .names="{.fn}_{.col}"))


surps.RTs_lag_c <- surps.RTs %>% 
  inner_join(surps.RTs_lag_c.perword) %>% 
  mutate(Word_ID=as_factor(str_c(story_num, word_num_in_story, sep="_"))) %>% 
  relocate(Word_ID,word)
```



## GAMs for SPRT

Fit a GAM for the GPT data

```{r gam_GPTs, results = "hide"}
# Boyce and Levy's formula
formula1 <- response ~
  s(surprisal, bs="cr", k=20) + te(freq, length, bs="cr") +
  s(prev_surprisal, bs="cr", k=20) + te(prev_freq, prev_length, bs="cr")

# With defaults (bs='tp', k=10) ... should not change things much, but it does have evenly spaced knots which may be better.
formula1.default <- response ~
  s(surprisal) + te(freq, length) +
  s(prev_surprisal) + te(prev_freq, prev_length)

prepare_data_GAM_SP <- function(surps.SPRTs_df, modelname, surp_colname, prev_surp_colname) {
  surps.SPRTs_df %>%
    select(RT, WorkerId, Word_ID,
           surprisal=all_of(surp_colname), prev_surprisal=all_of(prev_surp_colname),
           freq=freq_c, length=length_c, prev_freq=prev_freq_c, prev_length=prev_length_c) %>%
    mutate(model=modelname, 
           subject=factor(WorkerId), 
           response=RT)
}

fit_bam_SPRT <- function(formula, data, modelname){
  #' fit a GAM of SPRT as a function of surprisal
  #' @param formula mcgv formula
  #' @param modelname name of model for surprisal
  surp_colname <- paste0("surp_", modelname)
  prev_surp_colname <- paste0("prev_surp_", modelname)
  data_prepared <- data %>%
    select(response=RT, WorkerId, Word_ID,
           surprisal=all_of(surp_colname), prev_surprisal=all_of(prev_surp_colname),
           freq=freq_c, length=length_c, prev_freq=prev_freq_c, prev_length=prev_length_c) %>%
    mutate(model=modelname, 
           subject=factor(WorkerId))
  bam(formula, method="REML", data=data_prepared)
}


job::job({
  gam_gpt3_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT3")
  gam_gpt2_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT2")
  gam_gpt2l_RT  <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT2-large")
  gam_gptj_RT   <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT-J")
  gam_gptneo_RT <- fit_bam_SPRT(formula1, surps.RTs_lag_c, "GPT-Neo")
  
  write_rds(gam_gpt3_RT  , file="scratch/gam_gpt3_RT.rds")
  write_rds(gam_gpt2_RT  , file="scratch/gam_gpt2_RT.rds")
  write_rds(gam_gpt2l_RT , file="scratch/gam_gpt2l_RT.rds")
  write_rds(gam_gptj_RT  , file="scratch/gam_gptj_RT.rds")
  write_rds(gam_gptneo_RT, file="scratch/gam_gptneo_RT.rds")
}, 
title="fitting GAMs RT v GPT surprisals")


gam_gpt3_RT   <-read_rds("scratch/gam_gpt3_RT.rds")
gam_gpt2_RT   <-read_rds("scratch/gam_gpt2_RT.rds")
gam_gpt2l_RT  <-read_rds("scratch/gam_gpt2l_RT.rds")
gam_gptj_RT   <-read_rds("scratch/gam_gptj_RT.rds")
gam_gptneo_RT <-read_rds("scratch/gam_gptneo_RT.rds")
```

```{r, results = "hide"}
get_gam_predictions.prev_and_curr <- function(gam_model, surp_model_name, series_length=100) {
  curr <- tidymv::get_gam_predictions(model=gam_model, series=surprisal, series_length=series_length) %>%
    select(surprisal, response, CI_upper, CI_lower) %>% 
    unique() %>% 
    mutate(model=surp_model_name, s="Current")
  
  prev <- tidymv::get_gam_predictions(model=gam_model, series=prev_surprisal, series_length=series_length) %>%
   select(surprisal=prev_surprisal, response, CI_upper, CI_lower) %>%
   unique() %>%
   mutate(model=surp_model_name, s="Previous")
  
  both <- curr %>% union(prev)
}

all_gam_surpRT_predictions <- get_gam_predictions.prev_and_curr(gam_gpt3_RT, "GPT3") %>% 
  union(get_gam_predictions.prev_and_curr(gam_gpt2_RT, "GPT2")) %>%
  union(get_gam_predictions.prev_and_curr(gam_gpt2l_RT, "GPT2-large")) %>%
  union(get_gam_predictions.prev_and_curr(gam_gptj_RT, "GPT-J")) %>% 
  union(get_gam_predictions.prev_and_curr(gam_gptneo_RT, "GPT-Neo"))
  
plot.gam_surpRT <- 
  ggplot(all_gam_surpRT_predictions, 
         aes(x=surprisal, y=response, ymin=CI_lower, ymax=CI_upper))+
  geom_line() +
  geom_ribbon(alpha=.3,fill="blue")+
  # geom_density(data=surps.RTs_lag_c.perword, aes(x=surp, y=1000*..density..), fill="gray", inherit.aes = F)+
  facet_grid(s~model) +
  # coord_cartesian(ylim=c(330,580), xlim=c(0,15))+
  labs(x="Surprisal", y="Reading Time (ms)") +
  ggtitle("GAM fits for RTvsurp_formula1")

plot.gam_surpRT

# gptdens2 <- ggplot(surps.RTs_lag_c.perword, aes(x=surp_GPT3))+
#   geom_density(fill="gray",)+
#   labs(x="Surprisal", y="")+
#   theme(axis.text.y = element_blank(), 
#         strip.text=element_blank(), 
#         axis.ticks.y =element_blank(), 
#         axis.title.y=element_blank(), 
#         plot.margin = unit(c(0, 0, 0, 0), "cm"))
# gptbot <- plot_grid(NA, gptdens2, NA, nrow=1, rel_widths = c(.075, 1, .035))
# p<-plot_grid(plot.gam_surpRT, gptbot, nrow=2, rel_heights = c(1, .3))


```

# surprisal vs mazeRT

## Aligning with AMAZE TASK from Boyce

```{r, results = "hide"}
# get surprisal (and RT) dataset
# surps.RTs_lag_c.perword 
# from above

# joindata_maze_with_surps <- function(maze_task_data, surps_data) {
#   maze_task_data %>% rename(mazeRT=rt) %>% 
#     inner_join(surps_data, by=c("word_num_in_sentence", "word", "sentence")) %>%
#     filter(word_num_in_sentence>1) %>% 
#     select(mazeRT, subject, word_num_in_story, story_num, word,
#            freq,length, meanItemRT, sdItemRT,
#            contains("surp_") | starts_with("prev") | ends_with("_c"),
#            ) %>% 
#     mutate(Word_ID=as_factor(str_c(story_num, word_num_in_story, sep="_"))) %>% relocate(Word_ID, word)
#     # select(-word_num_in_story, -story_num) %>% write_rds("scratch/processed_amaze/pre_error.rds")
# }
joindata_maze_with_surps <- function(maze_task_data, surps_data) {
  maze_task_data %>% rename(mazeRT=rt,word_human=word) %>% 
    full_join(surps_data, by=c("word_num_in_sentence", "word_human", "sentence")) %>%
    filter(word_num_in_sentence>1) %>% 
    select(mazeRT, subject, word_num_in_story, story_num, word,
           freq,length, meanItemRT, sdItemRT,
           contains("surp_") | starts_with("prev") | ends_with("_c"),
           ) %>% 
    mutate(Word_ID=as_factor(str_c(story_num, word_num_in_story, sep="_"))) %>% relocate(Word_ID, word)
    # select(-word_num_in_story, -story_num) %>% write_rds("scratch/processed_amaze/pre_error.rds")
}
```


```{r results='hide'}
# Data processing as in Boyce and Levy, in data_processing_mazeRTs.Rmd
maze_pre_error<-read_csv("natural-stories-surprisals/maze_data/maze_pre_error.csv")


# pre_error
surps.mazeRTs_lag_c <- joindata_maze_with_surps(maze_pre_error, surps.RTs_lag_c.perword) 
write_csv(surps.mazeRTs_lag_c,"natural-stories-surprisals/surprisals_and_RTs_data/surps.mazeRTs_lag_c.csv")

# # post_error (see data_processing_mazeRTs.Rmd)
# surps.RTs_lag_c.maze_post_error <- joindata_maze_with_surps(maze_ready, surps.RTs_lag_c.perword)
# 
# # post_only  (see data_processing_mazeRTs.Rmd)
# surps.RTs_lag_c.maze_post_only <- joindata_maze_with_surps(maze_post_only, surps.RTs_lag_c.perword)

```

```{r}
surps.mazeRTs_lag_c %>% 
  pivot_longer(starts_with("surp_") & !ends_with("_c"), names_to = "model", values_to = "surprisal", names_prefix = "surp_") %>%
  drop_na() %>% 
  group_by(surprisalbin=cut_width(surprisal, boundary = 0, width = 1), model) %>% 
  summarise(n=n(), n_Word_IDs=length(unique(Word_ID)), binned_surprisal_min=min(surprisal), var=var(mazeRT), .groups = "drop") %>% 
  filter(n_Word_IDs>1) %>%  # remove points which come from just one word in the corpus
  ggplot(aes(x=binned_surprisal_min,y=var,color=log(n))) + 
  geom_point() +
  geom_smooth(method="lm", formula=y~x, color="red", linetype="solid", size=.25, alpha=.25) +
  facet_wrap(~model, scales = "free_y") +
  labs(x="binned surprisal", y="variance in mazeRT") +
  scale_color_continuous(low="#DDDDDD", high="black") +
  ggtitle("variance in mazeRT vs binned surprisal")
```



```{r}
(surps.RTs_lag_c.maze_pre_error %>% 
  plot_compare_models_smooth("mazeRT", #surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "tp")) + ggtitle("thin plate")
)+(
surps.RTs_lag_c.maze_pre_error %>% 
  plot_compare_models_smooth("mazeRT",# surp_threshold = 30,
                             method = "gam", formula = y ~ s(x, bs = "cr")) + ggtitle("cubic regression")
) + plot_layout(guides = 'collect') + plot_annotation("Maze RT vs surprisal GAMs quick model comparison")
```


## GAMS for mazeRT

```{r gam_tibbles_permodel,results = "hide"}


prepare_data_GAM_mazeRT <- function(surps.mazeRTs_df, modelname, surp_colname, prev_surp_colname) {
  surps.mazeRTs_df %>%
    select(mazeRT, subject, Word_ID,
           surprisal=all_of(surp_colname), prev_surprisal=all_of(prev_surp_colname),
           freq=freq_c, length=length_c, prev_freq=prev_freq_c, prev_length=prev_length_c) %>%
    mutate(model=modelname, subject=factor(subject))
}


Bngram_maze  <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "boyce_surp_ngram", "prev_boyce_surp_ngram", "B_5-gram")
Bgrnn_maze   <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "boyce_surp_grnn", "prev_boyce_surp_grnn",   "B_GRNN"  )
Btxl_maze    <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "boyce_surp_txl", "prev_boyce_surp_txl",     "B_TXL"   )

gpt2_maze    <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "surp_GPT2", "prev_surp_GPT2", "GPT2")
gpt2l_maze   <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "surp_GPT2-large", "prev_surp_GPT2-large", "GPT2-large")
gptj_maze    <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "surp_GPT-J", "prev_surp_GPT-J", "GPT-J")
gptneo_maze  <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "surp_GPT-Neo", "prev_surp_GPT-Neo", "GPT-Neo")
gpt3_maze    <- prepare_data_GAM_mazeRT(surps.RTs_lag_c.maze_pre_error, "surp_GPT3", "prev_surp_GPT3", "GPT3")


all_maze <- 
  Bngram_maze %>%
  union(Bgrnn_maze) %>% 
  union(Btxl_maze) %>% 
  union(gpt2_maze) %>%  
  union(gpt2l_maze) %>%
  union(gptj_maze) %>%  
  union(gptneo_maze) %>%
  union(gpt3_maze) %>% 
  select(surprisal,prev_surprisal,model) %>% 
  pivot_longer(cols = `surprisal`:`prev_surprisal`, names_to = "surprisal_or_prev_surprisal") %>% 
  mutate(s=ifelse(surprisal_or_prev_surprisal=="surprisal", "Current", "Previous"))
```

```{r gams, results = "hide", include=F}
mazeRTvsurp_formula1 <- mazeRT ~
  s(surprisal, bs="cr", k=20) + te(freq, length, bs="cr") +
  s(prev_surprisal, bs="cr", k=20) + te(prev_freq, prev_length, bs="cr")

# job::job({
#   gam_Bngram_maze <- bam(mazeRTvsurp_formula1, method="REML", data=Bngram_maze)
#   }, title = "gam_ngram_maze")
# job::job({
#   gam_Bgrnn_maze <- bam(mazeRTvsurp_formula1, method="REML", data=Bgrnn_maze)
#   }, title = "gam_grnn_maze")
# job::job({
#   gam_Btxl_maze <- bam(mazeRTvsurp_formula1, method="REML", data=Btxl_maze)
#   }, title = "gam_txl_maze")
# 
# job::job({
#   gam_gpt2_maze <- bam(mazeRTvsurp_formula1, method="REML", data=gpt2_maze)
#   }, title = "gam_GPT2_maze")
# job::job({
#   gam_gpt2l_maze <- bam(mazeRTvsurp_formula1, method="REML", data=gpt2l_maze)
#   }, title = "gam_GPT2-large_maze")
# job::job({
#   gam_gptj_maze <- bam(mazeRTvsurp_formula1, method="REML", data=gptj_maze)
#   }, title = "gam_GPT2-J_maze")
# job::job({
#   gam_gptneo_maze <- bam(mazeRTvsurp_formula1, method="REML", data=gptneo_maze)
#   }, title = "gam_GPT2-Neo_maze")
# job::job({
#   gam_gpt3_maze <- bam(mazeRTvsurp_formula1, method="REML", data=gpt3_maze)
#   }, title = "gam_GPT3_maze")

# 
# write_rds(gam_Bngram_maze, file="scratch/gam_Bngram_maze.rds")
# write_rds(gam_Bgrnn_maze, file="scratch/gam_Bgrnn_maze.rds")
# write_rds(gam_Btxl_maze, file="scratch/gam_Btxl_maze.rds")
# write_rds(gam_gpt2_maze, file="scratch/gam_gpt2_maze.rds")
# write_rds(gam_gpt2l_maze, file="scratch/gam_gpt2l_maze.rds")
# write_rds(gam_gptj_maze, file="scratch/gam_gptj_maze.rds")
# write_rds(gam_gptneo_maze, file="scratch/gam_gptneo_maze.rds")
# write_rds(gam_gpt3_maze, file="scratch/gam_gpt3_maze.rds")
# 
gam_Bngram_maze <- read_rds("scratch/gam_Bngram_maze.rds")
gam_Bgrnn_maze  <- read_rds("scratch/gam_Bgrnn_maze.rds")
gam_Btxl_maze   <- read_rds("scratch/gam_Btxl_maze.rds")
gam_gpt2_maze   <- read_rds("scratch/gam_gpt2_maze.rds")
gam_gpt2l_maze  <- read_rds("scratch/gam_gpt2l_maze.rds")
gam_gptj_maze   <- read_rds("scratch/gam_gptj_maze.rds")
gam_gptneo_maze <- read_rds("scratch/gam_gptneo_maze.rds")
gam_gpt3_maze   <- read_rds("scratch/gam_gpt3_maze.rds")

```


```
# TODO: Morgan suggested adding randomeffects like this
job::job({  
  gam_gpt3_maze_withrandomeffects <- gam(
    rt ~
      s(surprisal, subject, bs="fs") + s(subject, bs='re') + s(word, bs='re') +
      s(surprisal, bs="cr", k=20) +
      te(freq, len, bs="cr") +
      s(prev_surp, bs="cr", k=20) +
      te(prev_freq, prev_len, bs="cr"),
    data=gpt3_maze,
    method="REML")
},
title = "gam_gpt3_maze_withrandomeffects")
```
```{r, include=F}


all_gam_surpRTmaze_predictions <- 
  get_gam_predictions.prev_and_curr("mazeRT",gam_Bngram_maze, "B_5-gram") %>% 
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_Bgrnn_maze, "B_GRNN")) %>% 
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_Btxl_maze, "B_TXL")) %>% 
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_gpt3_maze, "GPT3")) %>% 
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_gpt2_maze, "GPT2")) %>%
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_gpt2l_maze, "GPT2-large")) %>%
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_gptj_maze, "GPT-J")) %>% 
  union(get_gam_predictions.prev_and_curr("mazeRT",gam_gptneo_maze, "GPT-Neo"))
  
plot.gam_surpRTmaze <- 
  ggplot(all_gam_surpRTmaze_predictions %>% filter(!str_detect(model, "^B_")), 
         aes(x=surprisal, y=mazeRT, ymin=CI_lower, ymax=CI_upper))+
  geom_line() +
  geom_ribbon(alpha=.3,fill="blue")+
  # geom_density(data=surps.RTs_lag_c.perword, aes(x=surp, y=1000*..density..), fill="gray", inherit.aes = F)+
  facet_grid(s~model) +
  # coord_cartesian(ylim=c(330,580), xlim=c(0,15))+
  labs(x="Surprisal", y="Maze Reaction Time (ms)") +
  ggtitle("GAM fits for mazeRTvsurp_formula1")

#replicating boyce
plot.gam_surpRTmaze_boyce <- 
  ggplot(all_gam_surpRTmaze_predictions %>% filter(str_detect(model, "^B_")), 
         aes(x=surprisal, y=mazeRT, ymin=CI_lower, ymax=CI_upper))+
  geom_line() +
  geom_ribbon(alpha=.3,fill="blue")+
  # geom_density(data=surps.RTs_lag_c.perword, aes(x=surp, y=1000*..density..), fill="gray", inherit.aes = F)+
  facet_grid(s~model) +
  # coord_cartesian(ylim=c(330,580), xlim=c(0,15))+
  labs(x="Surprisal", y="Maze Reaction Time (ms)") +
  ggtitle("GAM fits for mazeRTvsurp_formula1 (replicating Boyce)")


# dens2 <-   ggplot(all_maze, aes(x=value))+
#   geom_density(fill="gray",)+
#   facet_grid(.~model)+
#   labs(x="Surprisal (bits)", y="")+
#   theme(axis.text.y = element_blank(), strip.text=element_blank(), axis.ticks.y =element_blank(), axis.title.y=element_blank(), plot.margin = unit(c(0, 0, 0, 0), "cm"))
# 
# bot <- plot_grid(NA, dens2, NA, nrow=1, rel_widths = c(.075, 1, .035))
# 
# plot_grid(plot.gam_surpRTmaze, bot, nrow=2, rel_heights = c(1, .3))

```

```{r}
plot.gam_surpRTmaze
```

```{r}
plot.gam_surpRTmaze_boyce
```


```{r}
draw_gam_both<-function(gam_model){
  print(gam.check(gam_model))
  draw(gam_model, seWithMean=T, shift=coef(gam_model)["(Intercept)"], select=c(1,3))
  }
```
