library(Seurat)
## Loading required package: SeuratObject
## Loading required package: sp
## 
## Attaching package: 'SeuratObject'
## The following objects are masked from 'package:base':
## 
##     intersect, t
library(pheatmap)
library(tibble)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(purrr)
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
library(gtable)
library(grid)
library(ggplot2)
library(fgsea)
library(viridis)
## Loading required package: viridisLite
library(RColorBrewer)
library(this.path)

setwd(this.path::here())

Calculation of heuristic regulatory scores

The purpose of this code is to integrate perturb-seq and Taiji data to calculate a score representing TF-gene regulatory strength. This is taken to be a simple multiplication of the TF-gene edge weight, as outputted by Taiji, with the -log2FC of expression for the TF knockout vs. the gScramble control. The underlying hypothesis here is that regulated genes are likely to (1) have evidence of physical binding (the edge weight) and (2) exhibit changes in expression due to the knockout of the regulating TF.

ProcessRegulatees:

inputs: - edge_weights_fname: file name for Taiji edge weights - obj_fname: file name for perturb-seq Seurat object - cell_type_taiji: query cell type from the edge weights file - cell_type_seurat query cell type from the Seurat object (optional) - gRNAs: gRNAs to query for visualization.

outputs: - heuristic_scores: -log2FC*edge_weight; calculated for each TF-gene pair - log2FC_scores: log2FC of TF KO vs. gScramble control for each gene - edge_weight_scores: edge weights from edge_weights_fname for each TF-gene pair - high_conf_regulatees: regulatees/genes that had edge weights in the top 25% quantile for a given TF with a |logFC| > 0.58

ProcessRegulatees = function(edge_weights_fname, sobj_fname, cell_type_taiji, cell_type_seurat, gRNAs) {
  tmp = read.csv('kmeans_cluster_k_16_log2FC_btw_TEX_and_TRM_mean_edge_weight_top500_subset_TFs_v2_v2_ordered_by_group.csv')[-1]
  query_genes = unique(tmp$gene) # 17023 genes, uppercase
  
  edge_weights = read.csv(edge_weights_fname)
  rownames(edge_weights) = edge_weights$X # also uppercase
  edge_weights = edge_weights[-1]
  
  sobj = readRDS(sobj_fname)
  sobj = sobj[, sobj$guide_ID %in% gRNAs]
  perturbed_genes = unique(sobj$perturbed_gene)
  perturbed_genes = perturbed_genes[perturbed_genes != 'gScramble']
  
  # remove gRNAs w/very low counts from query
  for (perturbed_gene in perturbed_genes) {
    N = length(Cells(sobj[, sobj$perturbed_gene == perturbed_gene]))
    if (N < 3) {
      perturbed_genes = perturbed_genes[perturbed_genes != perturbed_gene]
    }
  }
  
  query_tfs = sapply(perturbed_genes, function(x) toupper(substring(x, 2)), USE.NAMES = FALSE)
  edge_weights = edge_weights[rownames(edge_weights) %in% query_genes, colnames(edge_weights) %in% query_tfs] # 14466 genes x 16 TFs
  
  Idents(sobj) = sobj@meta.data$perturbed_gene
  de_results = list()
  
  for(gRNA in perturbed_genes) {
    tmp = FindMarkers(sobj, ident.1 = gRNA, ident.2 = "gScramble") # 19-1377 cells (gRNAs) vs. 750 cells (gScramble)
    de_results[[gRNA]] = tmp[tmp$p_val_adj < 1e5, ] # No filter, but can decrease to enforce a significance threshold
  }
  
  log2FC_res = lapply(de_results, function(x) {
    return(data.frame(avg_log2FC = x$avg_log2FC, row.names = toupper(rownames(x))))
  })
  
  # convert names from gRNAs to uppercase TFs
  names(log2FC_res) = lapply(names(log2FC_res), function(x) toupper(substring(x, 2)))
  
  tmp_colnames = colnames(edge_weights)
  
  edge_weights = lapply(colnames(edge_weights), function(colname) {
    tmp_df = data.frame(edge_weights[[colname]], row.names = rownames(edge_weights))
    names(tmp_df) = colname
    return(tmp_df)
  })
  
  names(edge_weights) = tmp_colnames
  
  log2FC_scores = list()
  heuristic_scores = list()
  edge_weight_scores = list()
  high_conf_regulatees = list()
  
  for (name in names(edge_weights)) {
    shared_genes = intersect(rownames(log2FC_res[[name]]), rownames(edge_weights[[name]]))
    edge_weights[[name]] = edge_weights[[name]][shared_genes, , drop = FALSE]
    log2FC_res[[name]] = log2FC_res[[name]][shared_genes, , drop = FALSE]
    
    top_quantile_threshold = quantile(edge_weights[[name]][[1]], 0.75)
    mask = edge_weights[[name]][[1]] > top_quantile_threshold
    
    # define high-confidence regulatees as having edge weights in the top 25% quantile
    # of query genes with abs(log2FC) > 0.58, corresponding to 1.5 or 1/1.5 FC
    pos_regulatees = rownames(edge_weights[[name]])[which(mask & (log2FC_res[[name]][[1]] > 0.58))]
    neg_regulatees = rownames(edge_weights[[name]])[which(mask & (log2FC_res[[name]][[1]] < -0.58))]
    
    log2FC_scores[[name]] = log2FC_res[[name]]
    edge_weight_scores[[name]] = edge_weights[[name]]
    heuristic_scores[[name]] = -edge_weights[[name]]*log2FC_res[[name]] # flip sign since positive log2FC implies repression
    
    high_conf_regulatees[['repression']][[name]] = pos_regulatees # positive log2FC vs. gScramble --> repression
    high_conf_regulatees[['activation']][[name]] = neg_regulatees # negative log2FC vs. gScramble --> activation
  }
  
  return(list(heuristic_scores = heuristic_scores, 
              log2FC_scores = log2FC_scores, 
              edge_weight_scores = edge_weight_scores,
              high_conf_regulatees = high_conf_regulatees))
}

Now that the function is defined, calculate heuristic regulatory scores for the CD8+ terminally exhausted (LCMV Clone 13) and TRM (LCMC Armstrong) cell types.

# higher dynamic range KOs
query_gRNAs = c("gScramble-2", "gScramble-4", "gHinfp-4", "gEtv5-4", "gArid3a-2", "gZbtb49-4",
                "gTfdp1-2", "gZfp410-2", "gFoxd2-2", "gZscan20-4", "gJdp2-2", "gPrdm4-2",
                "gNfil3-2", "gNfatc1-4", "gGfi1-2", "gZfp143-4", "gNr4a2-2", "gIrf8-4",
                "gZfp324-4", "gPrdm1-2", "gStat3-4", "gHic1-4", "gIkzf3-2", "gPrdm1-4")

heuristic_scores = list()
edge_weight_scores = list()
log2FC_scores = list()
high_conf_regulatees = list()

res = ProcessRegulatees(edge_weights_fname = 'mean_TexTerm_edge_weight.csv', 
                        #sobj_fname = 'Cl13_23TP04_20231010.rds',
                        sobj_fname = 'Cl13_integration_4clusters_22TP09_23TP03_23TP04_23TP16_20231010.rds', 
                        cell_type_taiji = 'TexTerm', cell_type_seurat = NULL,
                        gRNAs = query_gRNAs)
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## For a (much!) faster implementation of the Wilcoxon Rank Sum Test,
## (default method for FindMarkers) please install the presto package
## --------------------------------------------
## install.packages('devtools')
## devtools::install_github('immunogenomics/presto')
## --------------------------------------------
## After installation of presto, Seurat will automatically use the more 
## efficient implementation (no further action necessary).
## This message will be shown once per session
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.1, drop = FALSE]): NaNs produced
## Warning in mean.fxn(object[features, cells.2, drop = FALSE]): NaNs produced
heuristic_scores[['TexTerm']] = res$heuristic_scores
log2FC_scores[['TexTerm']] = res$log2FC_scores
edge_weight_scores[['TexTerm']] = res$edge_weight_scores
high_conf_regulatees[['TexTerm']] = res$high_conf_regulatees


res = ProcessRegulatees(edge_weights_fname = 'mean_TRM_edge_weight.csv', 
                        sobj_fname = 'TP16_Arm_clean_simpler4clustersn_20231020.rds',
                        cell_type_taiji = 'TRM', cell_type_seurat = NULL,
                        gRNAs = query_gRNAs)

heuristic_scores[['TRM']] = res$heuristic_scores
log2FC_scores[['TRM']] = res$log2FC_scores
edge_weight_scores[['TRM']] = res$edge_weight_scores
high_conf_regulatees[['TRM']] = res$high_conf_regulatees

Given the heuristic regulatory scores, we can quantify the genes (regulatees) that exhibit flipped directions of regulation between the LCMV Clone 13 (TExTerm) and Armstrong (TRM).

### Quantify changes in TF-gene regulation direction between TRM and TexTerm ###

regulatee_flip_counts = list()
flipped_regulatees_by_TF = list()
shared_TFs = intersect(names(high_conf_regulatees[['TexTerm']][['repression']]), names(high_conf_regulatees[['TRM']][['repression']]))

for (TF in shared_TFs) {
  flipped_regulatees_by_TF[[TF]] = c()

  # TODO: double check cell type assumption
  a = high_conf_regulatees$TexTerm$repression[[TF]] # regulatees repressed by TF in TexTerm
  b = high_conf_regulatees$TRM$repression[[TF]] # regulatees repressed by TF in TRM
  c = high_conf_regulatees$TexTerm$activation[[TF]] # regulatees activated by TF in TexTerm
  d = high_conf_regulatees$TRM$activation[[TF]] # regulatees activated by TF in TRM

  diff_regulatees1 = intersect(a, d) # regulatees repressed in TexTerm and activated in TRM
  diff_regulatees2 = intersect(b, c) # regulatees activated in TexTerm and repressed in TRM

  for (regulatee in c(diff_regulatees1, diff_regulatees2)) {

    if (regulatee %in% diff_regulatees1) {
      flipped_regulatees_by_TF[[TF]] = c(flipped_regulatees_by_TF[[TF]], paste0(regulatee, ' (TexTerm-Repressed_&_TRM-Activated)'))
    } else if (regulatee %in% diff_regulatees2) {
      flipped_regulatees_by_TF[[TF]] = c(flipped_regulatees_by_TF[[TF]], paste0(regulatee, ' (TexTerm-Activated_&_TRM-Repressed)'))
    }

    if (!is.null(regulatee_flip_counts[[regulatee]])) {
      regulatee_flip_counts[[regulatee]] = regulatee_flip_counts[[regulatee]] + 1
    } else {
      regulatee_flip_counts[[regulatee]] = 1
    }
  }
}

Finally, visualize the specific TF-regulatee heuristic scores shown in extended Figure 8c.

# All flipped TF-regulatee pairs for TexTerm vs. TRM comparison
# query_tf_regulatee_pairs = c()
# 
# for (tf in names(flipped_regulatees_by_TF)) {
#   for (regulatee in flipped_regulatees_by_TF[[tf]]) {
#     query_tf_regulatee_pairs = c(query_tf_regulatee_pairs, paste0(tf,'-',strsplit(regulatee, ' ')[[1]][1]))
#   }
# }

### Query individual heuristic scores
query_tf_regulatee_pairs = c('PRDM1-EGR2', 'PRDM1-IRF4', 'PRDM1-NR4A3', 
                             'HIC1-CREB3L2', 'HIC1-RBPJ', 'HIC1-MYC', 'HIC1-JUNB')

query_mat = matrix(NA, nrow = length(query_tf_regulatee_pairs), ncol = 2)
rownames(query_mat) = query_tf_regulatee_pairs
colnames(query_mat) = c('TRM', 'TexTerm')

for (pair in query_tf_regulatee_pairs) {
  for (condition in c('TRM', 'TexTerm')) {
    tf = strsplit(pair, '-')[[1]][1]
    regulatee = strsplit(pair, '-')[[1]][2]
    score = heuristic_scores[[condition]][[tf]][regulatee, ]
    query_mat[pair, condition] = score
  }
}

val_max = max(abs(query_mat))/3
val_min = -val_max

heatmap = pheatmap(query_mat, cluster_cols = FALSE, show_colnames = TRUE, main = 'Individual heuristic scores',
                   border_color = NA, treeheight_row = 0, treeheight_col = 0, fontsize_col = 8, angle_col = '315',
                   cluster_rows = FALSE, silent = TRUE, breaks = seq(val_min, val_max, length.out = 101),
                   color = colorRampPalette(rev(brewer.pal(11, "RdBu")))(100), margins = c(5,5,0,5), cellheight = 10, cellwidth = 14)

heatmap

write.csv(query_mat, paste0('Individual heuristic scores', ".csv"), row.names = TRUE)