CELL TYPE PRIORITIZATION

Author

Akschya Sivacoumar

Load libraries

library(Seurat)
library(randomForest)
library(caret)
library(pROC)
library(dplyr)
library(sparseMatrixStats)
library(magrittr)
library(stats)
library(lmtest)
library(Matrix)
library(rsample)
library(ggplot2)
library(reshape2)
library(tibble)
library(knitr)
library(gt)

DATA LOADING

Import RDS

Load the RDS object from task 1

Code
# Load RDS and extract data
data <- readRDS("~/EPFL/Task 2/data/CNS.rds")
expr <- Seurat::GetAssayData(data)
meta <- data@meta.data %>% 
  droplevels()
cell_types <- meta[["cell_type"]]
labels <- meta[["label"]]

Examine the data

Code
unique(labels)
[1] "Disease" "Control"
Code
unique(cell_types)
 [1] "Peripheral immune cells"               
 [2] "Microglia"                             
 [3] "Myeloid, dividing"                     
 [4] "NK, T cells"                           
 [5] "B cells"                               
 [6] "Newly formed oligodendrocytes (NFOL)"  
 [7] "Mature oligodendrocytes (MOL)"         
 [8] "Myelin forming oligodendrocytes (MFOL)"
 [9] "Oligodendrocytes precursor cells (OPC)"
[10] "Vascular endothelial cells"            
[11] "Vascular cells"                        
[12] "Astrocytes"                            
[13] "Ependymal cells"                       
[14] "Dorsal"                                
[15] "Ventral"                               

Check if the data has any NA values

which(is.na(expr))
integer(0)

Check is the ‘labels’ are factors and convert to factors if not

Code
if (is.factor(labels) == "FALSE") {
  labels %<>% as.factor()                                               
  print(paste0("Labels are converted to factors: ", is.factor(labels)))
} else{
  print("Labels are factors!")
}
1
The labels need to be factors, only then Random Forrest will recognize them as distinct values and use classification mode instead of regression
[1] "Labels are converted to factors: TRUE"

SELECT HIGH-VARIANCE GENES - FUNCTION

Function to select highly variable genes to classify cells.

Calculate the Coefficient of variation (CV)

CV = mean/sds

Model Fitting:

Fit models to the data and choose the best model for predicting CV based on statistical tests.

  • Negative means -> fit a loess (local regression) model using CV vs. mean.
  • Non-negative means -> fit two models:
    • using CV vs. mean
    • using CV vs. log(mean).
  • Model selection is based on p-values by doing a comparative coxtest.

Need for Variability-Based Filtering:

  1. To Focus on High-Variance Genes
  2. To Reduce Noise
  3. To Improve Model Performance
Code
select_var <- function(mat) {
  sds <- rowSds(mat)
  sds[is.na(sds)] <- 0
  mat <- mat[sds > 0, ]
  
  means <- rowMeans(mat)
  sds <- sds[sds > 0]
  cvs <- means / sds
 
  lower <- quantile(cvs, 0.01)
  upper <- quantile(cvs, 0.99)
  keep <- cvs > lower & cvs < upper
  cv0 <- cvs[keep]
  mean0 <- means[keep]
  
  if (any(mean0 < 0)) {
    model = loess(cv0 ~ mean0)
  } else {
    fit1 = loess(cv0 ~ mean0)
    fit2 = loess(cv0 ~ log(mean0))
    cox = coxtest(fit1, fit2)
    probs = cox$`Pr(>|z|)`
    if (probs[1] < probs[2]) {
      model = fit1
    } else {
      model = fit2
    }
  }
  genes = rownames(mat)[keep]
  residuals = setNames(model$residuals, genes)
  
  genes = names(residuals)[residuals > quantile(residuals, 0.5)]
  mat %<>% extract(genes, )
  return(mat)
}

SPLIT DATA

  • The data is split based on cells
  • Although it might be more useful to split by replicates, as it would mimic more real-world scenarios while training and testing along with avoiding over-fitting, it cannot be applied here as the number of replicates is minimal
  • Stratifies split is performed to make sure all the test and train datasets include all the labels and replicates
  • Split ratio:
    • Train data - 70%
    • Test data - 30%
  • Maximum number of cells/cell type in a split = 1500
  • This subsampling was done to reduce the run time, as local machine couldn’t handle the load.
Code
set.seed(123)

meta$strata <- paste(meta$replicate, meta$label, sep = "_")

train_indices <- c()
test_indices <- c()

strata_list <- split(seq_len(nrow(meta)), meta$strata)

for (stratum_indices in strata_list) {
  
  shuffled_indices <- sample(stratum_indices)
  train_size <- floor(length(shuffled_indices) * 0.7)
  train_indices <- c(train_indices, shuffled_indices[1:train_size])
  test_indices <- c(test_indices, shuffled_indices[(train_size + 1):length(shuffled_indices)])
  
}

# Create train and test datasets
train_meta <- meta[train_indices, ]
test_meta <- meta[test_indices, ]
train_expr <- expr[, rownames(train_meta)]
test_expr <- expr[, rownames(test_meta)]

# Subsample cell types in the training set if they exceed a maximum threshold
max_cells_per_type <- 1500  
train_meta_subsampled <- train_meta %>%
  tibble::rownames_to_column(var = "original_rowname") %>%
  group_by(cell_type) %>%
  group_modify(~ {
    if (nrow(.x) > max_cells_per_type) {
      sampled_indices <- sample(nrow(.x), max_cells_per_type)
      .x <- .x[sampled_indices, ]
    }
    return(.x)
  }) %>%
  ungroup() %>%
  tibble::column_to_rownames(var = "original_rowname")

train_meta <- train_meta_subsampled
train_expr <- train_expr[, rownames(train_meta)]

if (!all(rownames(train_meta) %in% colnames(train_expr))) {
  stop("Mismatch between rownames of train_meta and columns of train_expr")
}
#Check metrics
table(train_meta$label)

Control Disease 
   6498    6635 
table(test_meta$label)

Control Disease 
   6894    4789 
table(train_meta$replicate)

Control_10  Control_2  Control_3  Disease_4  Disease_9 
      1598       2885       2015       3844       2791 
table(test_meta$replicate)

Control_10  Control_2  Control_3  Disease_4  Disease_9 
      1832       2950       2112       2815       1974 

CELL TYPE-SPECIFIC CLASSIFICATION

  • Random forest classifier was trained and tested for each cell type

  • Metrics calculated:

    • Accuracy
    • Area under the ROC
    • Confusion matrix
Code
results <- list()
cell_types_list <- unique(cell_types)
total_cell_types <- length(cell_types_list)
auc_results <- list()
feature_importance <- list()


for (i in seq_along(cell_types_list)) {
  cell_type <- cell_types_list[i]
  #cat(sprintf("Processing cell type %d/%d: %s\n", i, total_cell_types, cell_type))
  
  # Subset data for current cell type
  train_cells <- train_meta$cell_type == cell_type
  test_cells <- test_meta$cell_type == cell_type
  
  if (sum(train_cells) > 0 & sum(test_cells) > 0) {
    train_expr_cell <- train_expr[, train_cells]
    test_expr_cell <- test_expr[, test_cells]
      
    # Select highly variable genes for this cell type
    train_expr_cell_var <- select_var(train_expr_cell)
    
    # Prepare training data
    train_data <- as.data.frame(t(train_expr_cell_var))
    train_labels <- train_meta$label[train_cells]
    train_labels <- as.factor(train_labels)
    
    # Fit Random Forest
    classifier_RF <- randomForest(x = train_data, y = train_labels, ntree = 500)
    
    # Prepare test data
    test_data <- as.data.frame(t(test_expr_cell))
    test_labels <- test_meta$label[test_cells]
    
    # Predict test labels
    predictions <- predict(classifier_RF, newdata = test_data)
    # Predict probabilities for the positive class
    prob_predictions <- predict(classifier_RF, newdata = test_data, type = "prob")[, 2]
    
    
    # Compute metrics
    confusion_mtx <- table(test_labels, predictions)
    accuracy <- sum(diag(confusion_mtx)) / sum(confusion_mtx)
    roc_obj <- roc(test_labels, as.numeric(prob_predictions))
    auc_value <- auc(roc_obj)
    
    # Store AUC result
    auc_results[[cell_type]] <- data.frame(cell_type = cell_type, auc = auc_value)
    
    # Extract feature importance
    importance_df <- as.data.frame(importance(classifier_RF))
    importance_df <- rownames_to_column(importance_df, var = "gene")
    importance_df$cell_type <- cell_type
    
    # Store feature importance
    feature_importance[[cell_type]] <- importance_df
    
    # Store results   
    results[[cell_type]] <- list(
      confusion_matrix = confusion_mtx,
      accuracy = accuracy,
      auc = auc_value,
      predictions = predictions,
      roc = roc_obj
    )
  }
}
1
Note that predictions are used to calculate accuracy and confusion matrix, but prediction probabilities are used for ROC and AUC calculation.

Results for each cell type

Code
results_summary <- data.frame(
  Cell_Type = character(),
  Accuracy = numeric(),
  AUC = numeric(),
  stringsAsFactors = FALSE
)

for (cell_type in names(results)) {
  results_summary <- rbind(results_summary, data.frame(
    Cell_Type = cell_type,
    Accuracy = results[[cell_type]]$accuracy,
    AUC = results[[cell_type]]$auc
  ))
}

# Alternatively, you can use the gt package for a more stylized table
gt(results_summary) %>%
   tab_header(title = "Summary of Results for Each Cell Type")
Summary of Results for Each Cell Type
Cell_Type Accuracy AUC
Peripheral immune cells 0.9759036 0.9957562
Microglia 0.9545455 0.9997586
Myeloid, dividing 0.9696970 1.0000000
NK, T cells 0.6000000 0.8133333
B cells 0.8823529 0.8000000
Newly formed oligodendrocytes (NFOL) 0.8825397 0.9948366
Mature oligodendrocytes (MOL) 0.9342438 0.9942582
Myelin forming oligodendrocytes (MFOL) 0.9452771 0.9989944
Oligodendrocytes precursor cells (OPC) 0.9714286 0.9893410
Vascular endothelial cells 0.7977528 0.9635057
Vascular cells 0.8732782 0.9428398
Astrocytes 0.9896641 0.9992138
Ependymal cells 0.8387097 0.8618421
Dorsal 0.9750727 0.9997634
Ventral 0.9882904 0.9999573

Confusion Matrices

           predictions
test_labels Control Disease
    Control       0       8
    Disease       0     324
           predictions
test_labels Control Disease
    Control     145      31
    Disease       0     506
           predictions
test_labels Control Disease
    Control       0       1
    Disease       0      32
           predictions
test_labels Control Disease
    Control       9       1
    Disease       9       6
           predictions
test_labels Control Disease
    Control      15       0
    Disease       2       0
           predictions
test_labels Control Disease
    Control      80      74
    Disease       0     476
           predictions
test_labels Control Disease
    Control    1364       8
    Disease     122     483
           predictions
test_labels Control Disease
    Control    1928       4
    Disease     153     784
           predictions
test_labels Control Disease
    Control     240      15
    Disease       1     304
           predictions
test_labels Control Disease
    Control      60       0
    Disease      18      11
           predictions
test_labels Control Disease
    Control      94      30
    Disease      16     223
           predictions
test_labels Control Disease
    Control     226       2
    Disease       2     157
           predictions
test_labels Control Disease
    Control      17       2
    Disease       3       9
           predictions
test_labels Control Disease
    Control    1611       2
    Disease      58     736
           predictions
test_labels Control Disease
    Control     927       0
    Disease      15     339
Code
auc_results <- lapply(auc_results, function(df) {
  df <- df %>%
    rename(cell_type = `cell_type`, auc = `auc`) %>% 
    mutate(cell_type = as.character(cell_type), auc = as.numeric(auc)) 
  return(df)
})

# Combine AUC results into a single data frame
auc_summary <- bind_rows(auc_results) %>%
  group_by(cell_type) %>%       
  summarise(mean_auc = mean(auc, na.rm = TRUE)) %>%  
  arrange(desc(mean_auc))
auc_summary <- as.data.frame(auc_summary)

# Combine feature importance results into a single data frame
importance_summary <- bind_rows(feature_importance) %>%
  dplyr::select(cell_type, gene, MeanDecreaseGini) %>%
  rename(importance = MeanDecreaseGini)
importance_summary <- as.data.frame(importance_summary)

AUC Summary per Cell Type

Code
print(auc_summary)
                                cell_type  mean_auc
1                       Myeloid, dividing 1.0000000
2                                 Ventral 0.9999573
3                                  Dorsal 0.9997634
4                               Microglia 0.9997586
5                              Astrocytes 0.9992138
6  Myelin forming oligodendrocytes (MFOL) 0.9989944
7                 Peripheral immune cells 0.9957562
8    Newly formed oligodendrocytes (NFOL) 0.9948366
9           Mature oligodendrocytes (MOL) 0.9942582
10 Oligodendrocytes precursor cells (OPC) 0.9893410
11             Vascular endothelial cells 0.9635057
12                         Vascular cells 0.9428398
13                        Ependymal cells 0.8618421
14                            NK, T cells 0.8133333
15                                B cells 0.8000000

Feature Importance Summary

Code
print(importance_summary %>% group_by(cell_type) %>% top_n(10, importance))
# A tibble: 150 × 3
# Groups:   cell_type [15]
   cell_type               gene    importance
   <chr>                   <chr>        <dbl>
 1 Peripheral immune cells Acss2        0.397
 2 Peripheral immune cells Klhl13       0.327
 3 Peripheral immune cells Dnajc6       0.464
 4 Peripheral immune cells Syn3         0.300
 5 Peripheral immune cells Meis1        0.334
 6 Peripheral immune cells Carmil1      0.763
 7 Peripheral immune cells C3           0.351
 8 Peripheral immune cells Eif2s3y      0.575
 9 Peripheral immune cells Pip5k1b      0.676
10 Peripheral immune cells Slc1a1       0.308
# ℹ 140 more rows

Metrics of best classified cell type

Code
best_auc_cell_type <- auc_summary$cell_type[which.max(auc_summary$mean_auc)]
best_results <- results[[best_auc_cell_type]]

# Extract metrics
best_accuracy <- best_results$accuracy
best_auc <- best_results$auc
best_confusion_matrix <- best_results$confusion_matrix

# Print metrics
cat("Best Classified cell type Metrics:\n")
Best Classified cell type Metrics:
Code
cat(sprintf("Cell Type: %s\n", best_auc_cell_type))
Cell Type: Myeloid, dividing
Code
cat(sprintf("Accuracy: %.4f\n", best_accuracy))
Accuracy: 0.9697
Code
cat(sprintf("AUC: %.4f\n", best_auc))
AUC: 1.0000
Code
print(best_confusion_matrix)
           predictions
test_labels Control Disease
    Control       0       1
    Disease       0      32

VISUALIZATIONS

Plot Model Performance for Different Cell Types

Confusion Matrix Plot (Heatmap)

Code
confusion_data <- lapply(names(results), function(ct) {
  cm <- as.data.frame(results[[ct]]$confusion_matrix)
  cm$Cell_Type <- ct
  return(cm)
})
confusion_data <- do.call(rbind, confusion_data)

ggplot(confusion_data, aes(x = predictions, y = test_labels, fill = Freq)) +
  geom_tile(color = "black") +
  facet_wrap(~ Cell_Type, scales = "free") +
  scale_fill_gradient(low = "white", high = "blue") +
  labs(title = "Confusion Matrix for Each Cell Type",
       x = "Predicted", y = "Actual") +
  theme_minimal()

Accuracy Bar Plot for Each Cell Type

Code
accuracy_data <- data.frame(
  Cell_Type = names(results),
  Accuracy = sapply(results, function(x) x$accuracy)
)

ggplot(accuracy_data, aes(x = reorder(Cell_Type, -Accuracy), y = Accuracy)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  coord_flip() +
  labs(title = "Model Accuracy for Each Cell Type",
       x = "Cell Type", y = "Accuracy") +
  theme_minimal()

Identify Perturbed Cell Types

Bar Plot to Rank Cell Types Based on Perturbation

Code
perturbation_scores <- data.frame(
  Cell_Type = names(results),
  Perturbation_Score = sapply(results, function(x) mean(as.numeric(x$predictions == "Disease")))
)

ggplot(perturbation_scores, aes(x = reorder(Cell_Type, -Perturbation_Score), y = Perturbation_Score)) +
  geom_bar(stat = "identity", fill = "red") +
  coord_flip() +
  labs(title = "Perturbation Scores by Cell Type",
       x = "Cell Type", y = "Perturbation Score") +
  theme_minimal()

UMAP Visualization

Code
meta$predictions <- NA
for (cell_type in names(results)) {
  test_cells <- rownames(test_meta[test_meta$cell_type == cell_type, ])
  meta[test_cells, "predictions"] <- results[[cell_type]]$predictions
}

data@meta.data <- meta

umap_data <- as.data.frame(Embeddings(data, "umap"))
umap_data$cell_type <- data@meta.data$cell_type
umap_data$predictions <- data@meta.data$predictions

# Plot UMAP with predictions and facet by cell type
ggplot(umap_data, aes(x = umap_1, y = umap_2, color = predictions)) +
  geom_point(size = 0.5) +
  facet_wrap(~ cell_type) +
  labs(title = "UMAP Visualization of Disease vs Control Predictions") +
  theme_minimal()

ROC curve for the best classifier

Code
roc_curve <- best_results$roc
ggroc(roc_curve) +
  ggtitle(sprintf("ROC Curve for %s", best_auc_cell_type)) +
  theme_minimal()

SESSION INFO

Code
sessionInfo()
R version 4.4.1 (2024-06-14 ucrt)
Platform: x86_64-w64-mingw32/x64
Running under: Windows 11 x64 (build 22631)

Matrix products: default


locale:
[1] LC_COLLATE=English_United States.utf8 
[2] LC_CTYPE=English_United States.utf8   
[3] LC_MONETARY=English_United States.utf8
[4] LC_NUMERIC=C                          
[5] LC_TIME=English_United States.utf8    

time zone: Asia/Calcutta
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] gt_0.11.0                knitr_1.48               tibble_3.2.1            
 [4] reshape2_1.4.4           rsample_1.2.1            Matrix_1.7-0            
 [7] lmtest_0.9-40            zoo_1.8-12               magrittr_2.0.3          
[10] sparseMatrixStats_1.16.0 MatrixGenerics_1.16.0    matrixStats_1.4.1       
[13] dplyr_1.1.4              pROC_1.18.5              caret_6.0-94            
[16] lattice_0.22-6           ggplot2_3.5.1            randomForest_4.7-1.1    
[19] Seurat_5.1.0             SeuratObject_5.0.2       sp_2.1-4                

loaded via a namespace (and not attached):
  [1] RColorBrewer_1.1-3     rstudioapi_0.16.0      jsonlite_1.8.8        
  [4] spatstat.utils_3.1-0   farver_2.1.2           rmarkdown_2.28        
  [7] vctrs_0.6.5            ROCR_1.0-11            spatstat.explore_3.3-2
 [10] htmltools_0.5.8.1      sass_0.4.9             sctransform_0.4.1     
 [13] parallelly_1.38.0      KernSmooth_2.23-24     htmlwidgets_1.6.4     
 [16] ica_1.0-3              plyr_1.8.9             lubridate_1.9.3       
 [19] plotly_4.10.4          igraph_2.0.3           mime_0.12             
 [22] lifecycle_1.0.4        iterators_1.0.14       pkgconfig_2.0.3       
 [25] R6_2.5.1               fastmap_1.2.0          fitdistrplus_1.2-1    
 [28] future_1.34.0          shiny_1.9.1            digest_0.6.37         
 [31] colorspace_2.1-1       furrr_0.3.1            patchwork_1.2.0       
 [34] tensor_1.5             RSpectra_0.16-2        irlba_2.3.5.1         
 [37] labeling_0.4.3         progressr_0.14.0       timechange_0.3.0      
 [40] fansi_1.0.6            spatstat.sparse_3.1-0  httr_1.4.7            
 [43] polyclip_1.10-7        abind_1.4-8            compiler_4.4.1        
 [46] withr_3.0.1            fastDummies_1.7.4      lava_1.8.0            
 [49] MASS_7.3-60.2          ModelMetrics_1.2.2.2   tools_4.4.1           
 [52] httpuv_1.6.15          future.apply_1.11.2    nnet_7.3-19           
 [55] goftest_1.2-3          glue_1.7.0             nlme_3.1-164          
 [58] promises_1.3.0         grid_4.4.1             Rtsne_0.17            
 [61] cluster_2.1.6          generics_0.1.3         recipes_1.1.0         
 [64] gtable_0.3.5           spatstat.data_3.1-2    class_7.3-22          
 [67] tidyr_1.3.1            data.table_1.16.0      xml2_1.3.6            
 [70] utf8_1.2.4             spatstat.geom_3.3-2    RcppAnnoy_0.0.22      
 [73] ggrepel_0.9.6          RANN_2.6.2             foreach_1.5.2         
 [76] pillar_1.9.0           stringr_1.5.1          spam_2.10-0           
 [79] RcppHNSW_0.6.0         later_1.3.2            splines_4.4.1         
 [82] survival_3.6-4         deldir_2.0-4           tidyselect_1.2.1      
 [85] miniUI_0.1.1.1         pbapply_1.7-2          gridExtra_2.3         
 [88] scattermore_1.2        stats4_4.4.1           xfun_0.47             
 [91] hardhat_1.4.0          timeDate_4032.109      stringi_1.8.4         
 [94] lazyeval_0.2.2         yaml_2.3.10            evaluate_0.24.0       
 [97] codetools_0.2-20       cli_3.6.3              uwot_0.2.2            
[100] rpart_4.1.23           xtable_1.8-4           reticulate_1.39.0     
[103] munsell_0.5.1          Rcpp_1.0.13            globals_0.16.3        
[106] spatstat.random_3.3-1  png_0.1-8              spatstat.univar_3.0-1 
[109] parallel_4.4.1         gower_1.0.1            dotCall64_1.1-1       
[112] listenv_0.9.1          viridisLite_0.4.2      ipred_0.9-15          
[115] prodlim_2024.06.25     scales_1.3.0           ggridges_0.5.6        
[118] leiden_0.4.3.1         purrr_1.0.2            rlang_1.1.4           
[121] cowplot_1.1.3