Training of Random forest model on synthesized relative abundance data and meta data indicating either absence of presence of Disease in each individual

# Load Packages

library(tidyverse)
## Warning: package 'tidyverse' was built under R version 4.5.1
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.5.2     ✔ tibble    3.2.1
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.0.4     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
## Warning: package 'caret' was built under R version 4.5.1
## Loading required package: lattice
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:purrr':
## 
##     lift
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.5.1
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## 
## The following object is masked from 'package:dplyr':
## 
##     combine
## 
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(pROC)
## Warning: package 'pROC' was built under R version 4.5.1
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## 
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
## 1. Load data
# Feature table (taxa as columns), metadata (with "Label")

X <- read.csv("microbiome_rel_abundance_1500x150.csv", row.names = 1, check.names = FALSE)
meta <- read.csv("microbiome_metadata_1500x150.csv", row.names = 1)


# Matching of samples to the featuretable (relative abundance table)

common <- intersect(rownames(X), rownames(meta))
X <- X[common, ]
meta <- meta[common, ]
y <- factor(meta$Label)

## 2. Train/Test split of data  80/20% - The model is trained on 80% of the data, and tested with the remaining 20%. This can be changed to 70% train / 30% split based on data size. 

set.seed(123)
idx <- createDataPartition(y, p = 0.8, list = FALSE)
Xtr <- X[idx, ]; Xte <- X[-idx, ]
ytr <- y[idx]; yte <- y[-idx]

## 3. Cross-validation setup - The model is cross-validated (method = Repeatedcv) 5 times, and repeated twice prior to running the model. 
ctrl <- trainControl(method = "repeatedcv",
                     number = 5,        
                     repeats = 2,     
                     classProbs = TRUE,
                     summaryFunction = twoClassSummary)

## 4. Train Random Forest (rf) - Receiver Operating Characteristic (ROC)
rf <- train(x = Xtr, y = ytr,
            method = "rf",
            metric = "ROC",             
            trControl = ctrl,
            tuneGrid = data.frame(mtry = floor(sqrt(ncol(Xtr))) + (-2:2)),
            ntree = 1000,
            importance = TRUE)

print(rf)
## Random Forest 
## 
## 1200 samples
##  150 predictor
##    2 classes: 'Disease', 'Healthy' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 2 times) 
## Summary of sample sizes: 959, 960, 961, 960, 960, 960, ... 
## Resampling results across tuning parameters:
## 
##   mtry  ROC        Sens       Spec     
##   10    0.9585725  0.7133663  0.9712693
##   11    0.9571860  0.7123663  0.9755704
##   12    0.9577426  0.7152871  0.9755755
##   13    0.9571574  0.7282376  0.9705498
##   14    0.9551395  0.7252277  0.9719887
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 10.
# ROC vs mtry

# caret results
rf_res <- rf$results

# ggplot of ROC vs mtry
ROCMT <- ggplot(rf_res, aes(x = mtry, y = ROC)) +
  geom_line(color = "darkgreen") +
  geom_point(size = 3, color = "darkgreen") +
  labs(title = "ROC vs mtry (Random Forest)",
       x = "Number of variables tried at each split (mtry)",
       y = "ROC (cross-validated)") +
  theme_minimal()  

ggsave("ROCMTplot.jpg", plot = ROCMT, width = 8, height = 6, dpi = 300)

ROCMT

## 5. Evaluate on test set - Instead of hard class predictions, this gives probabilities for each class.

pred_class <- predict(rf, Xte)
pred_prob  <- predict(rf, Xte, type = "prob")

confusionMatrix(pred_class, yte)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Disease Healthy
##    Disease      88       4
##    Healthy      38     170
##                                           
##                Accuracy : 0.86            
##                  95% CI : (0.8155, 0.8972)
##     No Information Rate : 0.58            
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.7015          
##                                           
##  Mcnemar's Test P-Value : 3.543e-07       
##                                           
##             Sensitivity : 0.6984          
##             Specificity : 0.9770          
##          Pos Pred Value : 0.9565          
##          Neg Pred Value : 0.8173          
##              Prevalence : 0.4200          
##          Detection Rate : 0.2933          
##    Detection Prevalence : 0.3067          
##       Balanced Accuracy : 0.8377          
##                                           
##        'Positive' Class : Disease         
## 
# ROC & AUC - The ROC curve visualizes how well the Random Forest distinguishes Disease from Healthy across different probability thresholds. The closer the curve follows the top-left corner, the better the classifier. The AUC value summarizes this performance: an AUC of 0.5 means random guessing, while values closer to 1.0 indicate strong predictive ability.

yte2 <- relevel(yte, ref = "Disease")   
roc_obj <- roc(yte2, pred_prob$Disease, levels = rev(levels(yte2)))
## Setting direction: controls < cases
plot(roc_obj, main = "ROC curve (Random Forest)")

auc(roc_obj)
## Area under the curve: 0.9567
auc_val <- as.numeric(auc(roc_obj))

ROCC <- ggroc(roc_obj) +
  labs(title = sprintf("ROC curve (Random Forest) — AUC = %.3f", auc_val),
       x = "1 - Specificity",
       y = "Sensitivity") +
  theme_minimal()+
   geom_abline(slope = 1, intercept = 1, linetype = "dashed", color = "black")

ggsave("ROCplot.jpg", plot = ROCC, width = 8, height = 6, dpi = 300)

ROCC

## 6. Feature importance - The taxa that are most important for the correct predicitons. Removing the top taxa will lead to lesser prediction, as they are associcated with disease state. 

rf_raw <- rf$finalModel
imp <- importance(rf_raw, type = 1)  
imp_df <- data.frame(Taxon = rownames(imp), MeanDecreaseAccuracy = imp[,1]) %>%
  arrange(desc(MeanDecreaseAccuracy))

head(imp_df, 15)
##               Taxon MeanDecreaseAccuracy
## Taxon_26   Taxon_26             18.77855
## Taxon_17   Taxon_17             17.94785
## Taxon_57   Taxon_57             17.44648
## Taxon_49   Taxon_49             16.62050
## Taxon_114 Taxon_114             16.44814
## Taxon_124 Taxon_124             16.31766
## Taxon_24   Taxon_24             15.47878
## Taxon_82   Taxon_82             15.25134
## Taxon_128 Taxon_128             15.11317
## Taxon_62   Taxon_62             14.39993
## Taxon_84   Taxon_84             14.30990
## Taxon_81   Taxon_81             13.55007
## Taxon_92   Taxon_92             12.53425
## Taxon_18   Taxon_18             12.33721
## Taxon_25   Taxon_25             12.26327
Top <- ggplot(imp_df[1:30,], aes(x = reorder(Taxon, MeanDecreaseAccuracy),
                          y = MeanDecreaseAccuracy,
                          fill = MeanDecreaseAccuracy)) +
  geom_col() +
  coord_flip() +
  scale_fill_gradient(low = "darkgreen", high = "lightgreen") +
  labs(title = "Top important taxa (Random Forest)",
       x = "Taxon", y = "Mean Decrease Accuracy") +
  theme_minimal()+
  theme(legend.position = "none")
   


# Save plot
ggsave("TopImportantTax.png", plot = Top, width = 8, height = 6, dpi = 300)

Top

#Save the model for further use. 

saveRDS(rf, "rf_model.rds")

# Save the exact feature order the model was trained on
train_features <- colnames(Xtr)        
saveRDS(train_features, "rf_features.rds")

# Save any preprocessing choices
prep_info <- list(transform = "relative_abundance",         center_scale = FALSE,
                  notes = "Taxa columns only; samples in rows")
saveRDS(prep_info, "rf_prep_info.rds")

Deployment of trained RF model

# Inputs:
#   Trained model: rf_model.rds      (caret::train object)
#   Feature list: rf_features.rds    (character vector of column names used in training)
#   New data:     Lifelike_relative_abundance__preview_.csv  (samples x taxa, rownames = sample IDs)
#
# Output: predictions_on_lifelike_preview.csv  (class + probabilities per sample)
#
# Notes:
#   - Assumes training used relative abundances with no additional transforms.
#   - If you did log1p/CLR/etc. during training, apply the SAME transform here before prediction.



library(dplyr)
library(tidyr)
library(ggplot2)
library(forcats)
library(tibble)

#  Paths 
model_path    <- "rf_model.rds"
features_path <- "rf_features.rds"
newdata_path  <- "Lifelike_relative_abundance__preview_2.csv"
out_path      <- "predictions_on_lifelike_preview.csv"

#  Load artifacts 
if (!file.exists(model_path)) stop("Model file not found: ", model_path)
if (!file.exists(features_path)) stop("Feature file not found: ", features_path)

rf <- readRDS(model_path)
train_features <- readRDS(features_path)

#  Load new data (keep Diet separately) 
raw_in <- read.csv(newdata_path, row.names = 2, check.names = FALSE)
if (!is.data.frame(raw_in) || nrow(raw_in) == 0) {
  stop("Loaded new data is empty or malformed: ", newdata_path)
}

# Grab Diet (if present) before coercions
diet_df <- tibble(
  SampleID = rownames(raw_in),
  Diet = if ("Diet" %in% colnames(raw_in)) raw_in[["Diet"]] else NA_character_
)

# Feature matrix: drop Diet if present
X_new <- raw_in
if ("Diet" %in% colnames(X_new)) X_new[["Diet"]] <- NULL

# --- Coerce to numeric safely ---
num_cols <- suppressWarnings(
  sapply(X_new, function(col) {
    if (is.numeric(col)) return(TRUE)
    suppressWarnings(!any(is.na(as.numeric(col))))
  })
)
X_new[] <- lapply(X_new, function(col) suppressWarnings(as.numeric(col)))
if (anyNA(X_new)) {
  warning("Some values became NA during numeric coercion. Replacing NA with 0.")
  X_new[is.na(X_new)] <- 0
}

#  Align columns to training feature set 
align_to_training <- function(X, train_feats) {
  missing <- setdiff(train_feats, colnames(X))
  if (length(missing) > 0) X[missing] <- 0

  extra <- setdiff(colnames(X), train_feats)
  if (length(extra) > 0) X <- X[, !(colnames(X) %in% extra), drop = FALSE]

  X[, train_feats, drop = FALSE]
}
X_new_aligned <- align_to_training(X_new, train_features)

stopifnot(ncol(X_new_aligned) == length(train_features))
stopifnot(identical(colnames(X_new_aligned), train_features))

#  Predict 
pred_class <- predict(rf, newdata = X_new_aligned)

prob_df <- tryCatch(
  as.data.frame(predict(rf, newdata = X_new_aligned, type = "prob")),
  error = function(e) { warning("Probability prediction failed: ", conditionMessage(e)); NULL }
)

#  Assemble output + add Diet back 
out <- tibble(
  SampleID = rownames(X_new_aligned),
  PredictedClass = as.character(pred_class)
)
if (!is.null(prob_df)) {
  names(prob_df) <- paste0("Prob_", names(prob_df))
  out <- bind_cols(out, as_tibble(prob_df))
}
out <- out %>% left_join(diet_df, by = "SampleID")

# Save predictions
write.csv(out, out_path, row.names = FALSE)
message("Wrote predictions to: ", normalizePath(out_path))
## Wrote predictions to: C:\Users\ROSFRB\Desktop\RFC\predictions_on_lifelike_preview.csv
#  Long format for heatmap 
out_long <- out %>%
  pivot_longer(
    cols = starts_with("Prob_"),
    names_to = "Class",
    values_to = "Probability"
  )

#  Order samples within each Diet (by PredictedClass then max prob) 
prob_cols <- grep("^Prob_", names(out), value = TRUE)
out_order <- out %>%
  mutate(max_prob = do.call(pmax, c(across(all_of(prob_cols)), na.rm = TRUE))) %>%
  arrange(Diet, PredictedClass, desc(max_prob)) %>%
  mutate(SampleID_ord = factor(SampleID, levels = unique(SampleID)))

# Join ordering back to long data
out_long <- out_long %>%
  left_join(out_order %>% select(SampleID, SampleID_ord, Diet), by = "SampleID")

out_long <- out_long %>%
  filter(Class %in% c("Prob_Disease", "Prob_Healthy"))

#  Grouped heatmap (faceted by Diet) 

heat <- ggplot(out_long, aes(x = Class, y = SampleID_ord, fill = Probability)) +
  geom_tile(color = "white", linewidth = 0.2) +
  scale_fill_gradient(low = "grey", high = "darkgreen") +
  labs(title = "Prediction probabilities grouped by diet choice",
       x = "Class",
       y = "Sample ID") +
  theme_minimal() +
  theme(
    axis.text.y = element_text(size = 6),
    axis.ticks.y = element_line(),
    panel.grid = element_blank(),
    strip.text.y = element_text(face = "bold")
  ) +
  facet_grid(rows = vars(Diet.x), scales = "free_y", space = "free_y")


# Save plot
ggsave("heat.jpg", plot = heat, width = 8, height = 10, dpi = 300)

heat

#Putting the plots together nicely with patchwork

library(patchwork)
## Warning: package 'patchwork' was built under R version 4.5.1
AllPlots <- heat | (ROCMT/ROCC/Top) 

ggsave("AllPlots.jpg", plot = AllPlots, width = 15, height = 15, dpi = 300)

AllPlots