library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.3.3
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.2     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.5.2     ✔ tibble    3.2.1
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(here)
## here() starts at /Users/brialong/Documents/GitHub/object-detection
library(langcog)
## 
## Attaching package: 'langcog'
## 
## The following object is masked from 'package:base':
## 
##     scale

Low precision categories from Jane

exclude_low_precision = read_csv('low_precision.txt', col_names=FALSE)
## Rows: 67 Columns: 1
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (1): X1
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Load embeddings

embeddings_things_dino <- read_csv(here::here('data/embeddings/things_dino_embeddings.csv'))  %>%
  rename(category = 1) %>%
  rename_with(
    .cols = where(is.numeric),
    .fn   = ~ paste0("dim_", seq_along(.x))
  )  %>%
  mutate(
    across(
      .cols = where(is.numeric) & !matches("category"),
      .fns  = ~ as.numeric(scale(.x))
    )
  ) %>%
  filter(!category %in% exclude_low_precision$X1)
## Rows: 205 Columns: 769
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr   (1): label
## dbl (768): 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
embeddings_things_clip <- read_csv(here::here('data/embeddings/things_clip_embeddings.csv'))  %>%
  rename(category = 1) %>%
  rename_with(
    .cols = where(is.numeric),
    .fn   = ~ paste0("dim_", seq_along(.x))
  )  %>%
  mutate(
    across(
      .cols = where(is.numeric) & !matches("category"),
      .fns  = ~ as.numeric(scale(.x))
    )
  ) %>%
  filter(!category %in% exclude_low_precision$X1)  %>%
  mutate(category = as.factor(category)) %>%
  arrange(-desc(category))
## Rows: 205 Columns: 513
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr   (1): text
## dbl (512): 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
filtered_categories = unique(embeddings_things_clip$category)
embeddings_babyview_clip <- read_csv(here::here('data/embeddings/babyview_clip_filtered26_category_average_embeddings.csv'))  %>%
  rename(category = 1) %>%
  rename_with(
    .cols = where(is.numeric),
    .fn   = ~ paste0("dim_", seq_along(.x))
  )  %>%
  mutate(
    across(
      .cols = where(is.numeric) & !matches("category"),
      .fns  = ~ as.numeric(scale(.x))
    )
  ) %>%
  filter(category %in%filtered_categories) %>%
  filter(!category %in% exclude_low_precision$X1) %>%
   mutate(category = as.factor(category)) %>%
  arrange(-desc(category))
## New names:
## Rows: 291 Columns: 513
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (1): ...1 dbl (512): 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
## 17, 18,...
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
embeddings_babyview_dino <- read_csv(here::here('data/embeddings/babyview_dinov3_filtered26_category_average_embeddings.csv'))  %>%
  rename(category = 1) %>%
  rename_with(
    .cols = where(is.numeric),
    .fn   = ~ paste0("dim_", seq_along(.x))
  )  %>%
  mutate(
    across(
      .cols = where(is.numeric) & !matches("category"),
      .fns  = ~ as.numeric(scale(.x))
    )
  ) %>%
  filter(category %in%filtered_categories) %>%
  filter(!category %in% exclude_low_precision$X1) %>%
   mutate(category = as.factor(category)) %>%
  arrange(-desc(category))
## New names:
## Rows: 291 Columns: 769
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (1): ...1 dbl (768): 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
## 17, 18,...
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`

Functions for making rdms

# Cosine distance helper
cosine_distance <- function(x, y) {
  1 - sum(x * y) / (sqrt(sum(x^2)) * sqrt(sum(y^2)))
}

pearson_distance <- function(x, y) {
  # x and y are numeric vectors of equal length
  r <- cor(x, y, method = "pearson", use = "pairwise.complete.obs")
  1 - r
}

makeLongRDM <- function(embeddings,distance_metric) {
  
  categories <- embeddings$category
  emb_matrix <- embeddings %>%
    select(starts_with("dim_")) %>%
    as.matrix()
  
  # Cosine RDM as an n x n matrix
  rdm_cosine_mat <- outer(
    1:nrow(emb_matrix),
    1:nrow(emb_matrix),
    Vectorize(function(i, j) distance_metric(emb_matrix[i, ], emb_matrix[j, ]))
  )

  dimnames(rdm_cosine_mat) <- list(categories, categories)

  # Tidy long version (great for ggplot / RSA)
  rdm_cosine_long <- as_tibble(rdm_cosine_mat, rownames = "category1") %>%
  pivot_longer(
    cols = -category1,
    names_to = "category2",
    values_to = "distance"
  )

  return(rdm_cosine_long)
  
}
babyview_dino <- makeLongRDM(embeddings_babyview_dino,cosine_distance)
babyview_clip <- makeLongRDM(embeddings_babyview_clip,cosine_distance)
things_clip <- makeLongRDM(embeddings_things_clip,cosine_distance)
things_dino <- makeLongRDM(embeddings_things_dino,cosine_distance)

Compare RDMs

Babyview vs THINGS

cor.test(babyview_clip$distance, things_clip$distance, method='spearman')
## Warning in cor.test.default(babyview_clip$distance, things_clip$distance, :
## Cannot compute exact p-value with ties
## 
##  Spearman's rank correlation rho
## 
## data:  babyview_clip$distance and things_clip$distance
## S = 1.9584e+12, p-value < 2.2e-16
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##       rho 
## 0.3734871

Babyview vs THINGS in DINO

cor.test(babyview_dino$distance, things_dino$distance, method='spearman')
## Warning in cor.test.default(babyview_dino$distance, things_dino$distance, :
## Cannot compute exact p-value with ties
## 
##  Spearman's rank correlation rho
## 
## data:  babyview_dino$distance and things_dino$distance
## S = 3.0836e+12, p-value = 0.02744
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##        rho 
## 0.01352903

Babyview clip vs dino

cor.test(babyview_clip$distance, babyview_dino$distance, method='spearman')
## Warning in cor.test.default(babyview_clip$distance, babyview_dino$distance, :
## Cannot compute exact p-value with ties
## 
##  Spearman's rank correlation rho
## 
## data:  babyview_clip$distance and babyview_dino$distance
## S = 3.6159e+11, p-value < 2.2e-16
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##       rho 
## 0.8843237

THINGS clip vs dino

cor.test(things_clip$distance, things_dino$distance, method='spearman')
## Warning in cor.test.default(things_clip$distance, things_dino$distance, :
## Cannot compute exact p-value with ties
## 
##  Spearman's rank correlation rho
## 
## data:  things_clip$distance and things_dino$distance
## S = 3.0922e+12, p-value = 0.0788
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##        rho 
## 0.01078353
makeLongOrderedRDM <- function(rdm_cosine_long){
# Make a wide matrix again for clustering
rdm_mat <- rdm_cosine_long %>%
  pivot_wider(
    names_from = category2,
    values_from = distance
  ) %>%
  column_to_rownames("category1") %>%
  as.matrix()

# Hierarchical clustering order
hc <- hclust(as.dist(rdm_mat))
ord <- rownames(rdm_mat)[hc$order]

rdm_cosine_long_ordered <- rdm_cosine_long %>%
  mutate(
    category1 = factor(category1, levels = ord),
    category2 = factor(category2, levels = ord)
  )
}
babyview_clip_ordered <- makeLongOrderedRDM(babyview_clip)
ordered_categories = levels(babyview_clip_ordered$category1)

Make ordered RDMs

babyview_dino_ordered <- babyview_dino %>%
  mutate(
    category1 = factor(category1, levels = ordered_categories),
    category2 = factor(category2, levels = ordered_categories)
  )

things_dino_ordered <- things_dino %>%
  mutate(
    category1 = factor(category1, levels = ordered_categories),
    category2 = factor(category2, levels = ordered_categories)
  )

things_clip_ordered <- things_clip %>%
  mutate(
    category1 = factor(category1, levels = ordered_categories),
    category2 = factor(category2, levels = ordered_categories)
  )
ggplot(babyview_clip_ordered, aes(x = category1, y = category2, fill = distance)) +
  geom_tile() +
  scale_fill_viridis_c(option = "magma", direction = -1) +
  coord_equal() +
  # theme_minimal(base_size = 10) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size=4),
    axis.text.y = element_text(hjust = 1, vjust = 1, size=4),
    panel.grid = element_blank()
  ) +
  labs(
    x = NULL,
    y = NULL,
    fill = "Cosine distance",
    title = "CLIP category RDM (cluster-ordered)"
  )

ggsave('babyview-clip-ordered.pdf', width=6, height=6, units='in')
ggplot(babyview_dino_ordered, aes(x = category1, y = category2, fill = distance)) +
  geom_tile() +
  scale_fill_viridis_c(option = "magma", direction = -1) +
  coord_equal() +
  # theme_minimal(base_size = 10) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size=4),
    axis.text.y = element_text(hjust = 1, vjust = 1, size=4),
    panel.grid = element_blank()
  ) +
  labs(
    x = NULL,
    y = NULL,
    fill = "Cosine distance",
    title = "DINO category RDM (clip cluster-ordered)"
  )

ggsave('babyview-dino-ordered.pdf', width=6, height=6, units='in')
ggplot(things_dino_ordered, aes(x = category1, y = category2, fill = distance)) +
  geom_tile() +
  scale_fill_viridis_c(option = "magma", direction = -1) +
  coord_equal() +
   # ylim(0,1.8) +
  # theme_minimal(base_size = 10) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size=4),
    axis.text.y = element_text(hjust = 1, vjust = 1, size=4),
    panel.grid = element_blank()
  ) +
  labs(
    x = NULL,
    y = NULL,
    fill = "Cosine distance",
    title = "THINGS Dino category RDM (CLIP cluster-ordered)"
  )

ggsave('things-dino-ordered.pdf', width=6, height=6, units='in')
ggplot(things_clip_ordered, aes(x = category1, y = category2, fill = distance)) +
  geom_tile() +
  scale_fill_viridis_c(option = "magma", direction = -1) +
  coord_equal() +
  # ylim(0,1.8) +
  # theme_minimal(base_size = 10) +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size=4),
    axis.text.y = element_text(hjust = 1, vjust = 1, size=4),
    panel.grid = element_blank()
  ) +
  labs(
    x = NULL,
    y = NULL,
    fill = "Cosine distance",
    title = "Things CLIP category RDM (CLIP cluster-ordered)"
  )

ggsave('things-clip-ordered.pdf', width=6, height=6, units='in')