pilot_data_analysis

Scan-WM-DG Pilot Analysis

Setup (Load packages, helper functions, and data)

Load packages

library(tidyverse)
library(ggplot2)
library(here)
library(conflicted)
library(lme4)
library(lmerTest)
library(mgcv)
library(patchwork)
library(knitr)
library(RColorBrewer)
library(Hmisc)
library(slider)

conflict_prefer("filter", "dplyr")
conflict_prefer("select", "dplyr")
conflict_prefer("summarise", "dplyr")
conflict_prefer("summarize", "dplyr")
conflict_prefer("lmer", "lmerTest")

set.seed(42)

root_dir <- here::here()
analysis_dir <- here("pilot_experiment_analysis")

SAVE_PLOTS = TRUE

if (SAVE_PLOTS){
  writeLines(capture.output(sessionInfo()), paste0(analysis_dir,"/sessionInfo.txt"))
}

Define helper functions

# Bin a variable by version and summarise accuracy or RT
bin_by_version <- function(data, x_var, y_var, n_bins) {
  x_sym <- sym(x_var)
  y_sym <- sym(y_var)
  data %>%
    group_by(version) %>%
    mutate(sim_bin = cut_interval(!!x_sym, n = n_bins)) %>%
    group_by(version, sim_bin) %>%
    summarise(
      mean_y   = mean(!!y_sym, na.rm = TRUE),
      mean_x   = mean(!!x_sym, na.rm = TRUE),
      .groups  = "drop"
    )
}

# Plot accuracy or RT as a function of a similarity metric
plot_sim <- function(data, binned_data, x_var, y_var,
                     x_lab, y_lab, family = binomial) {
  x_sym <- sym(x_var)
  ggplot(data, aes(x = !!x_sym, y = !!sym(y_var), color = version)) +
    geom_point(data = binned_data,
               aes(x = mean_x, y = mean_y, color = version),
               alpha = 0.6) +
    geom_smooth(method = "gam",
                method.args = list(family = family),
                formula = y ~ s(x)) +
    labs(x = x_lab, y = y_lab) +
    theme_light()
}

moving_avg_by_version <- function(data, x_var, y_var, window = 25) {
  x_sym <- sym(x_var)
  y_sym <- sym(y_var)
  data %>%
    group_by(version) %>%
    arrange(!!x_sym, .by_group = TRUE) %>%
    mutate(
      ma_y = slide_dbl(!!y_sym, mean, .before = window %/% 2, .after = window %/% 2,
                       .complete = FALSE),
      ma_x = !!x_sym
    ) %>%
    ungroup()
}

# Plot accuracy or RT as a function of a similarity metric
plot_sim_mov_avg <- function(data, ma_data, x_var, y_var,
                        x_lab, y_lab, family = binomial) {
  x_sym <- sym(x_var)
  ggplot(data, aes(x = !!x_sym, y = !!sym(y_var), color = version)) +
    geom_line(data = ma_data,
              aes(x = ma_x, y = ma_y, color = version),
              linewidth = 0.8, alpha = 0.8) +
    geom_smooth(method = "gam",
                method.args = list(family = family),
                formula = y ~ s(x),
                linetype = "dashed", alpha = 0.2) +
    labs(x = x_lab, y = y_lab) +
    theme_light()
}

plot_sim_mov_avg_simplify <- function(data, ma_data, x_var, y_var,
                        x_lab, y_lab, family = binomial) {
  x_sym <- sym(x_var)
  ggplot(data, aes(x = !!x_sym, y = !!sym(y_var), color = version)) +
    geom_smooth(method = "gam",
                method.args = list(family = family),
                formula = y ~ s(x),
                linetype = "dashed", alpha = 0.2) +
    labs(x = x_lab, y = y_lab) +
    theme_light()
}

# Fit linear, quadratic, cubic glmer and return AIC/BIC tables
compare_poly_models <- function(data, x_var, random = "(1 | unique_id) + (1 | triplet_id)") {
  make_formula <- function(degree) {
    pred <- if (degree == 1) x_var else paste0("poly(", x_var, ", ", degree, ")")
    as.formula(paste("correct ~", pred, "+", random))
  }
  m1 <- glmer(make_formula(1), data = data, family = binomial)
  m2 <- glmer(make_formula(2), data = data, family = binomial)
  m3 <- glmer(make_formula(3), data = data, family = binomial)
  
  list(
    models = list(linear = m1, quadratic = m2, cubic = m3),
    aic    = AIC(m1, m2, m3),
    bic    = BIC(m1, m2, m3)
  )
}

# Wrapper around ggsave with consistent defaults
save_plot <- function(plot, filename, plot_dir, w = 5, h = 4) {
  ggsave(
    file  = here(plot_dir, filename),
    plot  = plot,
    width = w, height = h, units = "in", dpi = 300
  )
}

Load and combine data across versions

data_dir_v1 <- here("pilot_experiment_analysis","data_v1")
data_dir_v2 <- here("pilot_experiment_analysis","data_v2")
data_dir_v3 <- here("pilot_experiment_analysis","data_v3")
data_dir_v4 <- here("pilot_experiment_analysis","data_v4")
  
exp_string_v1 <- "scan_wm_dg_pilot_v1"
exp_string_v2 <- "scan_wm_dg_pilot_v2"
exp_string_v3 <- "scan_wm_dg_pilot_v3"
exp_string_v4 <- "scan_wm_dg_pilot_v4"

plot_dir_all <- here("pilot_experiment_analysis","plots_combined")

# load all wm data
v1_wm <- read.csv(paste0(data_dir_v1, "/", exp_string_v1, "_wm_data_proc.csv")) %>%
  mutate(version = "v1")
v2_wm <- read.csv(paste0(data_dir_v2, "/", exp_string_v2, "_wm_data_proc.csv")) %>%
  mutate(version = "v2")
v3_wm <- read.csv(paste0(data_dir_v3, "/", exp_string_v3, "_wm_data_proc.csv")) %>%
  mutate(version = "v3")
v4_wm <- read.csv(paste0(data_dir_v4, "/", exp_string_v4, "_wm_data_proc.csv")) %>%
  mutate(version = "v4")

# load all ltm data
v1_ltm <- read.csv(paste0(data_dir_v1, "/", exp_string_v1, "_ltm_data_proc.csv")) %>%
  mutate(version = "v1", )
v2_ltm <- read.csv(paste0(data_dir_v2, "/", exp_string_v2, "_ltm_data_proc.csv")) %>%
  mutate(version = "v2", )
v3_ltm <- read.csv(paste0(data_dir_v3, "/", exp_string_v3, "_ltm_data_proc.csv")) %>%
  mutate(version = "v3", )
v4_ltm <- read.csv(paste0(data_dir_v4, "/", exp_string_v4, "_ltm_data_proc.csv")) %>%
  mutate(version = "v4", )

# combine
wm_data <- rbind(v1_wm, v2_wm, v3_wm, v4_wm) %>%
  unite("unique_id", version, participant_id, sep = "_", remove=FALSE)
ltm_data <- rbind(v1_ltm, v2_ltm, v3_ltm, v4_ltm) %>%
  unite("unique_id", version, participant_id, sep = "_", remove=FALSE)

n_subj <- n_distinct(wm_data$unique_id)

n_subj_by_version <- wm_data %>%
  group_by(version) %>%
  dplyr::summarise(subj = n_distinct(unique_id))
  
kable(n_subj_by_version)
version subj
v1 26
v2 24
v3 21
v4 47
rm(v1_wm, v1_ltm, v2_wm, v2_ltm, v3_wm, v3_ltm, v4_wm, v4_ltm)
rm(data_dir_v1, data_dir_v2, data_dir_v3, data_dir_v4)
rm(exp_string_v1, exp_string_v2, exp_string_v3, exp_string_v4)

Wrangling

Target/foil similarity metrics

wm_data <- wm_data %>%
  # get some metrics in terms of target and foil (not img1 and img2)
  mutate(
    correct = as.numeric(correct),
    V2_targ_root = ifelse(is_neutral, NA, 
                           ifelse(target_index=="img1", V2_root_im1, V2_root_im2)),
    V2_foil_root = ifelse(is_neutral, NA, 
                           ifelse(target_index=="img1", V2_root_im2, V2_root_im1)),
    IT_targ_root = ifelse(is_neutral, NA, 
                           ifelse(target_index=="img1", IT_root_im1, IT_root_im2)),
    IT_foil_root = ifelse(is_neutral, NA, 
                           ifelse(target_index=="img1", IT_root_im2, IT_root_im1)),
    V2_targ_minus_foil = ifelse(is_neutral, NA,
                         ifelse(target_index == "img1", signed_V2_diff, -signed_V2_diff)),
    IT_targ_minus_foil = ifelse(is_neutral, NA,
                         ifelse(target_index == "img1", signed_IT_diff, -signed_IT_diff))
  ) %>%
  # now get the response in terms of the target and foil
  mutate(choice_label_targ_foil = case_when(
    choice_label == target_index ~ "target",
    choice_label == ifelse(target_index == "img1", "img2", "img1") ~ "foil",
    choice_label %in% c("dist0", "dist1") & target_index == "img1" ~ "dist_match",
    choice_label %in% c("dist2", "dist3") & target_index == "img2" ~ "dist_match",
    choice_label %in% c("dist0", "dist1") & target_index == "img2" ~ "dist_nonmatch",
    choice_label %in% c("dist2", "dist3") & target_index == "img1" ~ "dist_nonmatch",
  )) %>%
  mutate(
    within_category_error = case_when(
      correct == TRUE ~ NA,
      choice_label_targ_foil %in% c("foil", "dist_nonmatch") ~ FALSE,
      choice_label_targ_foil == "dist_match" ~ TRUE
    )
  )

LTM image role information from WM task

wm_long <- wm_data %>%
  mutate(
    foil_index = ifelse(target_index == "img1", "img2", "img1")
  ) %>%
  select(version, unique_id, triplet_id, cue_valid, correct, is_neutral, target_index, foil_index,
         img1, img2, root) %>%
  rename(wm_correct = correct) %>%
  pivot_longer(cols = c(img1, img2, root),
               names_to = "image_role_raw",
               values_to = "stimulus_wm") %>%
  mutate(
    image_role = case_when(
      image_role_raw == target_index ~ "target",
      image_role_raw == foil_index   ~ "foil",
      image_role_raw == "root"       ~ "root"
    ),
    stimulus = str_replace(stimulus_wm, "stimuli_wm", "stimuli_ltm")
  ) %>%
  select(version, unique_id, triplet_id, cue_valid, wm_correct, is_neutral, image_role, stimulus)

ltm_with_wm_info <- ltm_data %>%
  left_join(wm_long,
            by = c("version","unique_id", "stimulus")) %>%
  mutate(wm_response_missing = seen & is.na(cue_valid))

ltm_stim_dir <- here("scripts_for_study_design")
ltm_stim_list <- read.csv(paste0(ltm_stim_dir, "/", "jsPsych_full_ltm_images.csv")) %>%
  rename(
    stimulus = image
  ) %>%
  select(c(stimulus, seen_unseen_type))

ltm_with_wm_info <- left_join(ltm_with_wm_info, ltm_stim_list, by = "stimulus")

ltm_data <- ltm_with_wm_info
rm(ltm_with_wm_info)
rm(ltm_stim_list)
rm(wm_long)

Are prioritized images better remembered?

WM task: Yes!

Note: there were a number of people who were basically dropping the uncued item…

summary_valid <- wm_data %>%
  group_by(version, unique_id, cue_valid) %>%
  dplyr::summarise(
    mean_acc = mean(correct, na.rm = TRUE),
    mean_rt = mean(trial_rt, na.rm = TRUE),
    prop_within_cat_err = mean(within_category_error, na.rm = TRUE),
    prop_swap = mean(choice_label_targ_foil == "foil", na.rm = TRUE),
    .groups = "drop"
  )

make_validity_plot <- function(data, y_var, y_lab, title) {
  ggplot(data, aes(x = cue_valid, y = !!sym(y_var), color = cue_valid)) +
    geom_jitter(width = 0.1, alpha = 0.6) +
    facet_wrap(~version) +
    labs(x = "Cue validity", y = y_lab, title = title) +
    stat_summary(fun.data = mean_cl_boot, geom = "errorbar",
                 width = 0.2, color = "black") +
    stat_summary(fun = mean, geom = "point", size = 3, color = "black") +
    guides(color = "none") +
    theme_light()
}

p_acc <- make_validity_plot(summary_valid, "mean_acc",
                            "Mean accuracy", "WM: Acc by cue validity")
p_rt  <- make_validity_plot(summary_valid, "mean_rt",
                            "Mean response time", "WM: RT by cue validity")
p_err <- make_validity_plot(summary_valid, "prop_within_cat_err",
                            "Prop. within-category errors", "WM: Errors by cue validity")
p_swap <- make_validity_plot(summary_valid, "prop_swap",
                             "Mean swap errors", "WM: Swaps by cue validity")

# JW: later, make a plot with the proportion of swap errors (reporting the foil)

p_acc

p_rt

p_err

p_swap

if (SAVE_PLOTS){
  save_plot(p_acc, "acc_validity_wm.png", plot_dir_all, w = 6)
  save_plot(p_rt, "rt_validity_wm.png",  plot_dir_all, w = 6)
  save_plot(p_err, "err_validity_wm.png", plot_dir_all, w = 6)
  save_plot(p_swap, "swap_validity_wm.png", plot_dir_all, w = 6)
}

rm(summary_valid)
rm(p_acc, p_rt, p_err, p_swap)

How does the role of the image in WM affect LTM?

We’re breaking things down by cue validity!

role_labels_seen <- c(
  "root" = "Root",
  "target" = "Target",
  "foil" = "Foil"
)

role_labels_unseen <- c(
  "unseen_new_category" = "New Cat.",
  "unseen_old_category" = "Old Cat."
)

# For seen images
ltm_data_role_summary1 <- ltm_data %>%
  dplyr::filter(!is.na(image_role)) %>%
  group_by(version, unique_id, image_role, cue_valid) %>%
  dplyr::summarise(
    mean_acc = mean(correct, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  ungroup()

# For unseen images
ltm_data_role_summary2 <- ltm_data %>%
  dplyr::filter(!seen) %>%
  group_by(version, unique_id, seen_unseen_type) %>%
  dplyr::summarise(
    mean_acc = mean(correct, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  ungroup()

dodge <- position_dodge(width = 0.8)  # define once, reuse everywhere

p_seen_role <- ltm_data_role_summary1 %>%
  ggplot(aes(x = cue_valid, y = mean_acc,
             group = fct_reorder(image_role, mean_acc),
             color = image_role)) + 
  geom_jitter(alpha = 0.6, position = position_jitterdodge(
    jitter.width = 0.1,
    dodge.width = 0.8        # must match dodge defined above!!!
  )) +
  facet_wrap(~version) +
  labs(x = "Cue condition", y = "Mean accuracy", color = "Image role") +
  scale_x_discrete(labels = c("FALSE" = "Invalid", "TRUE" = "Valid")) + 
  scale_color_brewer(palette = "Dark2", labels = role_labels_seen) + 
  stat_summary(fun.data = mean_cl_boot, geom = "errorbar", width = 0.2,
               color = "black", position = dodge) +
  stat_summary(fun = mean, geom = "point", size = 3,
               color = "black", position = dodge) +
  theme_light()
p_seen_role

p_unseen_role <- ltm_data_role_summary2 %>%
  ggplot(aes(x = fct_reorder(seen_unseen_type, mean_acc),
             y = mean_acc, color = seen_unseen_type)) + 
  geom_jitter(width = 0.1, alpha = 0.6) +
  facet_wrap(~version) +
  scale_x_discrete(labels=role_labels_unseen) +
  scale_color_brewer(palette = "PuRd", guide="none") +
  labs(x = "Role", y = "Mean Correct Rejections") +
  stat_summary(fun.data = mean_cl_boot, geom = "errorbar", width = 0.2, color = "black") +
  stat_summary(fun = mean, geom = "point", size = 3, color = "black") +
  theme_light()
  
p_unseen_role

if (SAVE_PLOTS){
  save_plot(p_seen_role, "acc_ltm_image_role_seen.png", plot_dir_all, w = 6)
  save_plot(p_unseen_role, "acc_ltm_image_role_unseen.png", plot_dir_all, w = 6)
}

rm(p_unseen_role, p_seen_role)

Accuracy as a function of relative similarity

Plots

mean_neutral <- wm_data %>%
  filter(is_neutral & cue_valid) %>%
  summarise(mean_correct = mean(correct, na.rm = TRUE)) %>%
  pull(mean_correct)

wm_filtered <- wm_data %>%
  filter(!is_neutral & cue_valid) %>%
  mutate(correct = as.numeric(correct),
         participant_id = as.factor(participant_id),
         triplet_id = as.factor(triplet_id))

win_size <- 100
wm_ma_V2_relative <- moving_avg_by_version(wm_filtered,
                                           "V2_targ_minus_foil",
                                           "correct",
                                           window = win_size)
wm_ma_IT_relative <- moving_avg_by_version(wm_filtered,
                                           "IT_targ_minus_foil",
                                           "correct",
                                           window = win_size)

plot_sim_mov_avg(wm_filtered, wm_ma_V2_relative,
            "V2_targ_minus_foil", "correct",
            "V2 Relative Similarity", "P(Correct)") 

plot_sim_mov_avg(wm_filtered, wm_ma_IT_relative,
            "IT_targ_minus_foil", "correct",
            "IT Relative Similarity", "P(Correct)")

plot_sim_mov_avg_simplify(wm_filtered, wm_ma_V2_relative,
            "V2_targ_minus_foil", "correct",
            "V2 Relative Similarity", "P(Correct)")+ 
  geom_hline(yintercept = mean_neutral, linetype = "dashed", color = "black")

plot_sim_mov_avg_simplify(wm_filtered, wm_ma_IT_relative,
            "IT_targ_minus_foil", "correct",
            "IT Relative Similarity", "P(Correct)") + 
  geom_hline(yintercept = mean_neutral, linetype = "dashed", color = "black")

Model comparison

set.seed(42)
curr_version <- "v4"

wm_filtered_curr <- wm_filtered %>%
  filter(version == curr_version)

gam_fit <- bam(correct ~ s(IT_targ_minus_foil, k = 10) + s(V2_targ_minus_foil, k=10) + 
                            s(participant_id, bs = 're') + 
                            s(triplet_id, bs = 're'),
               data = wm_filtered_curr, family = binomial, discrete = TRUE)
summary(gam_fit)

Family: binomial 
Link function: logit 

Formula:
correct ~ s(IT_targ_minus_foil, k = 10) + s(V2_targ_minus_foil, 
    k = 10) + s(participant_id, bs = "re") + s(triplet_id, bs = "re")

Parametric coefficients:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept)    2.029      0.139    14.6   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate significance of smooth terms:
                         edf  Ref.df  Chi.sq  p-value    
s(IT_targ_minus_foil)  5.867   6.888  24.397 0.000775 ***
s(V2_targ_minus_foil)  1.000   1.000   0.499 0.480296    
s(participant_id)     39.894  47.000 384.747  < 2e-16 ***
s(triplet_id)         30.947 128.000  43.144 0.125044    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

R-sq.(adj) =  0.115   Deviance explained = 14.2%
fREML = 1754.4  Scale est. = 1         n = 4690
# Family: binomial 
# Link function: logit 
# 
# Formula:
# correct ~ s(IT_targ_minus_foil, k = 10) + s(V2_targ_minus_foil, 
#     k = 10) + s(participant_id, bs = "re") + s(triplet_id, bs = "re")
# 
# Parametric coefficients:
#             Estimate Std. Error z value Pr(>|z|)    
# (Intercept)    2.029      0.139    14.6   <2e-16 ***
# ---
# Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
# 
# Approximate significance of smooth terms:
#                          edf  Ref.df  Chi.sq  p-value    
# s(IT_targ_minus_foil)  5.867   6.888  24.397 0.000775 ***
# s(V2_targ_minus_foil)  1.000   1.000   0.499 0.480296    
# s(participant_id)     39.894  47.000 384.747  < 2e-16 ***
# s(triplet_id)         30.947 128.000  43.144 0.125044    
# ---
# Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
# 
# R-sq.(adj) =  0.115   Deviance explained = 14.2%
# fREML = 1754.4  Scale est. = 1         n = 4690

IT: When the distractor is more similar to the target…

wm_filtered_curr <- wm_filtered %>% filter(version == curr_version & IT_targ_minus_foil > 0)

cubic_mod <- glmer(correct ~ poly(IT_targ_minus_foil,3) + (1|participant_id) + (1|triplet_id),
                   data = wm_filtered_curr, family = binomial)
summary(cubic_mod)
Generalized linear mixed model fit by maximum likelihood (Laplace
  Approximation) [glmerMod]
 Family: binomial  ( logit )
Formula: correct ~ poly(IT_targ_minus_foil, 3) + (1 | participant_id) +  
    (1 | triplet_id)
   Data: wm_filtered_curr

      AIC       BIC    logLik -2*log(L)  df.resid 
   1729.3    1763.8    -858.6    1717.3      2338 

Scaled residuals: 
    Min      1Q  Median      3Q     Max 
-4.8951  0.2310  0.2955  0.3788  1.4923 

Random effects:
 Groups         Name        Variance Std.Dev.
 triplet_id     (Intercept) 0.1218   0.3489  
 participant_id (Intercept) 0.8719   0.9337  
Number of obs: 2344, groups:  triplet_id, 128; participant_id, 47

Fixed effects:
                             Estimate Std. Error z value Pr(>|z|)    
(Intercept)                    2.1561     0.1608  13.411   <2e-16 ***
poly(IT_targ_minus_foil, 3)1   1.5627     3.3958   0.460   0.6454    
poly(IT_targ_minus_foil, 3)2  -2.6580     3.4138  -0.779   0.4362    
poly(IT_targ_minus_foil, 3)3  -8.7003     3.4102  -2.551   0.0107 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Correlation of Fixed Effects:
            (Intr) p(IT___,3)1 p(IT___,3)2
p(IT___,3)1  0.015                        
p(IT___,3)2 -0.014 -0.072                 
p(IT___,3)3 -0.053 -0.024      -0.044     
# Fixed effects:
#                              Estimate Std. Error z value Pr(>|z|)    
# (Intercept)                    2.1561     0.1608  13.411   <2e-16 ***
# poly(IT_targ_minus_foil, 3)1   1.5627     3.3958   0.460   0.6454    
# poly(IT_targ_minus_foil, 3)2  -2.6580     3.4138  -0.779   0.4362    
# poly(IT_targ_minus_foil, 3)3  -8.7003     3.4102  -2.551   0.0107 *  

IT: When the distractor is more similar to the foil…

wm_filtered_curr <- wm_filtered %>% filter(version == curr_version & IT_targ_minus_foil < 0)
cubic_mod <- glmer(correct ~ poly(IT_targ_minus_foil,3) + (1|participant_id) + (1|triplet_id),
                   data = wm_filtered_curr, family = binomial)
summary(cubic_mod)
Generalized linear mixed model fit by maximum likelihood (Laplace
  Approximation) [glmerMod]
 Family: binomial  ( logit )
Formula: correct ~ poly(IT_targ_minus_foil, 3) + (1 | participant_id) +  
    (1 | triplet_id)
   Data: wm_filtered_curr

      AIC       BIC    logLik -2*log(L)  df.resid 
   1794.7    1829.3    -891.4    1782.7      2340 

Scaled residuals: 
    Min      1Q  Median      3Q     Max 
-4.2308  0.2052  0.3024  0.4068  1.5147 

Random effects:
 Groups         Name        Variance Std.Dev.
 triplet_id     (Intercept) 0.4330   0.658   
 participant_id (Intercept) 0.7327   0.856   
Number of obs: 2346, groups:  triplet_id, 128; participant_id, 47

Fixed effects:
                             Estimate Std. Error z value Pr(>|z|)    
(Intercept)                    2.1649     0.1594  13.579   <2e-16 ***
poly(IT_targ_minus_foil, 3)1  10.5263     4.3247   2.434   0.0149 *  
poly(IT_targ_minus_foil, 3)2   7.8239     4.2845   1.826   0.0678 .  
poly(IT_targ_minus_foil, 3)3  -7.4976     4.2464  -1.766   0.0775 .  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Correlation of Fixed Effects:
            (Intr) p(IT___,3)1 p(IT___,3)2
p(IT___,3)1  0.047                        
p(IT___,3)2  0.033  0.027                 
p(IT___,3)3 -0.035  0.027       0.023     
# Fixed effects:
#                              Estimate Std. Error z value Pr(>|z|)    
# (Intercept)                    2.1649     0.1594  13.579   <2e-16 ***
# poly(IT_targ_minus_foil, 3)1  10.5263     4.3247   2.434   0.0149 *  
# poly(IT_targ_minus_foil, 3)2   7.8239     4.2845   1.826   0.0678 .  
# poly(IT_targ_minus_foil, 3)3  -7.4976     4.2464  -1.766   0.0775 .  

WM accuracy vs. LTM accuracy

wm_by_id <- wm_data %>%
  group_by(version, unique_id) %>%
  summarise(mean_wm_acc = mean(correct), .groups = "drop")

ltm_by_id <- ltm_data %>%
  group_by(version, unique_id) %>%
  summarise(mean_ltm_acc = mean(correct), .groups = "drop")

acc_by_id <- full_join(wm_by_id, ltm_by_id, by = c("version", "unique_id"))

p_wm_vs_ltm <- acc_by_id %>%
  ggplot(aes(x = mean_wm_acc, y = mean_ltm_acc, color = version)) +
  geom_point() +
  geom_smooth(method = "gam",
              method.args = list(family = gaussian),
              formula = y ~ s(x)) +
  labs(x = "Mean WM accuracy", y = "Mean LTM accuracy") +
  theme_light()

p_wm_vs_ltm

save_plot(p_wm_vs_ltm, "wm_vs_ltm.png", plot_dir_all)

rm(wm_by_id, ltm_by_id, acc_by_id, p_wm_vs_ltm)

WM competition index and LTM benefit

Fit relative similarity models

curr_version <- "v4"
wm_filtered_curr <- wm_filtered %>% filter(version == curr_version)
gam_fit <- bam(correct ~ s(IT_targ_minus_foil, k = 10) + 
                            s(participant_id, bs = 're') + 
                            s(triplet_id, bs = 're'),
               data = wm_filtered_curr, family = binomial, discrete = TRUE)
summary(gam_fit)

Family: binomial 
Link function: logit 

Formula:
correct ~ s(IT_targ_minus_foil, k = 10) + s(participant_id, bs = "re") + 
    s(triplet_id, bs = "re")

Parametric coefficients:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept)    2.028      0.139   14.59   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate significance of smooth terms:
                         edf  Ref.df Chi.sq  p-value    
s(IT_targ_minus_foil)  5.833   6.853  24.56 0.000871 ***
s(participant_id)     39.900  47.000 387.02  < 2e-16 ***
s(triplet_id)         30.664 128.000  41.67 0.042723 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

R-sq.(adj) =  0.115   Deviance explained = 14.2%
fREML = 1751.5  Scale est. = 1         n = 4690

Get competition index

# take all LTM stimuli that were WM samples AND were answered correctly
ltm_samples_from_wm <- ltm_data %>%
  dplyr::filter(!is.na(image_role)) %>%
  dplyr::filter(image_role != "root") %>%
  dplyr::filter(is_neutral == FALSE) %>%
  dplyr::filter(wm_correct == TRUE) %>%
  dplyr::filter(triplet_id != 214) # 214 had both targs and foils

# so first we'll need to re-combine some image stats...
wm_image_stats <- wm_data %>%
  select(unique_id, triplet_id, V2_targ_minus_foil, IT_targ_minus_foil)

ltm_samples_with_stats <- ltm_samples_from_wm %>%
  left_join(wm_image_stats, by = c("unique_id", "triplet_id")) %>%
  mutate(
    V2_tested_minus_untested = ifelse(image_role == "target",
                                       V2_targ_minus_foil, -V2_targ_minus_foil),
    IT_tested_minus_untested = ifelse(image_role == "target",
                                       IT_targ_minus_foil, -IT_targ_minus_foil)
  )

# Predict P(WM correct)
# And then do 1 - P(WM correct) to get the competition index
ltm_samples_with_stats <- ltm_samples_with_stats %>%
  dplyr::filter(version == "v4") %>%
  mutate(
    correct = as.numeric(correct)
  ) %>%
  mutate(
    pred_wm_correct = predict(gam_fit,
                              newdata = data.frame(
                                IT_targ_minus_foil = IT_tested_minus_untested,
                                participant_id = participant_id,
                                triplet_id = triplet_id
                              ),
                              type = "response",
                              allow.new.levels = TRUE),
    comp_index = 1-pred_wm_correct
  )

Plot competition index vs LTM accuracy (within-subjects)

This will just be harder to see with this data (I think)—part of the issue is that many subjects have a pretty narrow range for competition index

n_bins <- 3

binned_comp <- ltm_samples_with_stats %>%
  group_by(version, unique_id) %>%
  mutate(comp_bin = cut(comp_index, breaks = n_bins, labels = FALSE)) %>%
  group_by(version, unique_id, comp_bin) %>%
  summarise(
    mean_ltm  = mean(correct),
    mean_comp = mean(comp_index),
    .groups   = "drop"
  ) %>%
  mutate(comp_bin = as.factor(comp_bin))

p_comp <- binned_comp %>%
  ggplot(aes(x = comp_bin, y = mean_ltm, color = comp_bin)) +
  geom_jitter(width = 0.1, alpha = 0.6) +
  facet_wrap(~version) +
  scale_color_brewer(palette = "Spectral", guide = "none") +
  stat_summary(fun.data = mean_cl_boot, geom = "errorbar",
               width = 0.2, color = "black") +
  stat_summary(fun = mean, geom = "point", size = 3, color = "black") +
  labs(x = "WM Competition Index (tertile)", y = "Mean LTM Accuracy") +
  theme_light()

p_comp

if (SAVE_PLOTS){
  save_plot(p_comp, "comp_index_ltm.png", plot_dir_all)
}

rm(binned_comp, p_comp)

## But if we look at raw competition index versus probability correct...
raw_comp_ma <- moving_avg_by_version(ltm_samples_with_stats, "comp_index", "correct", window = 50)

plot_sim_mov_avg(ltm_samples_with_stats, raw_comp_ma, "comp_index", "correct", "competition index", "P(Correct)")