MSI Status Prediction Modeling

Author

Sehyun Oh

Published

January 10, 2025

Abstract
Handling class imbalance

Required packages

suppressPackageStartupMessages({
    library(caret)      # For model training and evaluation
    library(pROC)       # For ROC curves
    library(ROSE)       # For handling imbalanced data
    library(glmnet)     # For regularized logistic regression
    library(randomForest)
    library(xgboost)
    library(GenomicSuperSignature)
    library(tidyverse)
})
RAVmodel <- getModel('C2', load=TRUE)
[1] "downloading"

1. Data Preprocessing

## Samples with MSI status info
combined_data <- readRDS("data/combined_data.rds") # combined_TCGA_cancer
target_attr <- "patient.microsatellite_instability_test_results.microsatellite_instability_test_result.mononucleotide_and_dinucleotide_marker_panel_analysis_status"
combined_data_msi <- combined_data[,combined_data[[target_attr]] %in% c("msi-h", "mss")]

## Predictor RAVs
RAVs_combinedTCGA <- c(
    517, 220, 2109, 1303, 324, 438, 868, # RAVs that have statistically significant pairwise wilcoxon p-values of mss vs msi-h
    834, 190, 1166, # RAVs with significant KW test statistic (p-value < 0.05) for COAD
    2344, # significant KW test value for STAD, includes 324, 868, 517 above
    357) # UCEC KW test value (p-value = 0.056)
RAVmodel_sub <- RAVmodel[,RAVs_combinedTCGA]

## Calculate sample scores
sampleScore <- calculateScore(assay(combined_data_msi), RAVmodel_sub)

The data object with the predictors (i.e., sample scores from 12 RAVs) and response (i.e., MSI status) variable.

data <- as.data.frame(sampleScore)
data$status <- colData(combined_data_msi)[[target_attr]]
data$status <- ifelse(data$status == "msi-h", "MSI", "MSS")
data$status <- factor(data$status, levels = c("MSS", "MSI")) # Convert outcome to factor

Split data into training and testing sets

Cross-validation is stratified to maintain class proportions.

set.seed(123)
index <- createDataPartition(data$status, p = 0.7, list = FALSE)
train_data <- data[index, ]
test_data <- data[-index, ]

2. Handle class imbalance using ROSE

ROSE is used for handling class imbalance (alternatives include SMOTE from the DMwR package)

balanced_train <- ROSE(status ~ ., data = train_data)$data

3. Basic Logistic Regression

Both logistic regression and random forest are implemented for comparison

## With cross-validation
ctrl <- trainControl(method = "cv", 
                     number = 5,
                     classProbs = TRUE,
                     summaryFunction = twoClassSummary)

## Train logistic regression
log_model <- train(status ~ .,
                   data = balanced_train,
                   method = "glm",
                   family = "binomial",
                   trControl = ctrl,
                   metric = "ROC")

4. Random Forest with class weights

Feature importance is assessed through random forest

rf_model <- randomForest(status ~ .,
                         data = train_data,
                         ntree = 500,
                         classwt = c("MSS" = 1, "MSI" = 3))

5. Model Evaluation

Performance evaluation includes ROC curves and confusion matrices

## Predictions on test set
log_pred <- predict(log_model, test_data, type = "prob")
rf_pred <- predict(rf_model, test_data, type = "prob")

## ROC curves
roc_log <- roc(test_data$status, log_pred[, "MSI"])
Setting levels: control = MSS, case = MSI
Setting direction: controls < cases
roc_rf <- roc(test_data$status, rf_pred[, "MSI"])
Setting levels: control = MSS, case = MSI
Setting direction: controls < cases
## Plot ROC curves
plot(roc_log, col = "blue")
lines(roc_rf, col = "red")
legend("bottomright", legend = c("Logistic", "Random Forest"),
       col = c("blue", "red"), lwd = 2)

6. Feature Importance (for Random Forest)

importance(rf_model)
        MeanDecreaseGini
RAV517         30.836451
RAV220         10.174085
RAV2109        13.854494
RAV1303        10.232871
RAV324         12.474643
RAV438         11.774886
RAV868         13.411812
RAV834         40.756693
RAV190         10.324844
RAV1166        16.851458
RAV2344        15.116884
RAV357          9.889015
varImpPlot(rf_model)

7. Confusion Matrix

confusionMatrix(predict(rf_model, test_data), test_data$status)
Confusion Matrix and Statistics

          Reference
Prediction MSS MSI
       MSS 159  28
       MSI  10  25
                                          
               Accuracy : 0.8288          
                 95% CI : (0.7727, 0.8759)
    No Information Rate : 0.7613          
    P-Value [Acc > NIR] : 0.009366        
                                          
                  Kappa : 0.467           
                                          
 Mcnemar's Test P-Value : 0.005820        
                                          
            Sensitivity : 0.9408          
            Specificity : 0.4717          
         Pos Pred Value : 0.8503          
         Neg Pred Value : 0.7143          
             Prevalence : 0.7613          
         Detection Rate : 0.7162          
   Detection Prevalence : 0.8423          
      Balanced Accuracy : 0.7063          
                                          
       'Positive' Class : MSS