ML Tutorial 5: Multiclass Imputation using h2o Random Forest

Foreword: About the Machine Learning in Medicine (MLM) project

Introduction

In this 5th tutorial, we attempt to perform a (greedy) imputation for missing values of multiclass-categorical variables using the Random Forest algorithm. Through this data experiment, you will learn how to:

  1. Detect missing values in your dataset
  2. Develop your own imputation function that implies h2o Random Forest algorithms in loop.

Our case study will the UCI’s Mammographic dataset. This dataset has been used in the Case study X7 of MLM project. In brief, the dataset contains patient’s age, BI_RADS score and three BI-RADS attributes together with the Target classification (severity) for 516 benign and 445 malignant masses that have been identified on digital mammograms. All features have Missing-values, and some of them are multiclass variables.

First, we will prepare the ggplot theme for our experiment

library(tidyverse)

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 12),
      panel.grid.major = element_line(color = "grey"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#ffffef"),
      strip.background = element_rect(fill = "#ffbb00", color = "#ffbb00", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "black"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
    )
}
theme_set(my_theme())

Now we load the dataset from UCI website and make some modifications

df=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/mammographic-masses/mammographic_masses.data",sep=",",na.strings ="?")%>%as_tibble()

names(df)=c("BIRAD","Age","Shape","Margin","Density","Severity")

df$Shape=recode_factor(df$Shape,`1` = "Round", `2` = "Oval",`3` = "Lobular", `4` = "Irregular" )
df$Margin=recode_factor(df$Margin,`1` = "Circumscribed", `2` = "Microlobulated",`3` = "Obscured", `4` = "Illdefined",`5` = "Spiculated" )
df$Density=recode_factor(df$Density,`1` = "High", `2` = "Iso",`3` = "Low" , `4` = "Fatcontaining")
df$Severity=recode_factor(df$Severity,`0` = "Benign", `1` = "Malignant")
df$Age=as.numeric(df$Age)


df
## # A tibble: 961 Ă— 6
##    BIRAD   Age     Shape        Margin Density  Severity
##    <int> <dbl>    <fctr>        <fctr>  <fctr>    <fctr>
## 1      5    67   Lobular    Spiculated     Low Malignant
## 2      4    43     Round Circumscribed      NA Malignant
## 3      5    58 Irregular    Spiculated     Low Malignant
## 4      4    28     Round Circumscribed     Low    Benign
## 5      5    74     Round    Spiculated      NA Malignant
## 6      4    65     Round            NA     Low    Benign
## 7      4    70        NA            NA     Low    Benign
## 8      5    42     Round            NA     Low    Benign
## 9      5    57     Round    Spiculated     Low Malignant
## 10     5    60        NA    Spiculated    High Malignant
## # ... with 951 more rows
summary(df)
##      BIRAD             Age              Shape                Margin   
##  Min.   : 0.000   Min.   :18.00   Round    :224   Circumscribed :357  
##  1st Qu.: 4.000   1st Qu.:45.00   Oval     :211   Microlobulated: 24  
##  Median : 4.000   Median :57.00   Lobular  : 95   Obscured      :116  
##  Mean   : 4.348   Mean   :55.49   Irregular:400   Illdefined    :280  
##  3rd Qu.: 5.000   3rd Qu.:66.00   NA's     : 31   Spiculated    :136  
##  Max.   :55.000   Max.   :96.00                   NA's          : 48  
##  NA's   :2        NA's   :5                                           
##           Density         Severity  
##  High         : 16   Benign   :516  
##  Iso          : 59   Malignant:445  
##  Low          :798                  
##  Fatcontaining: 12                  
##  NA's         : 76                  
##                                     
## 
bar_missing <- function(x){
  library(dplyr)
  library(reshape2)
  library(ggplot2)
  x %>%
    is.na %>%
    melt %>%
    ggplot(data = .,
           aes(x = Var2)) +
    geom_bar(aes(y=(..count..),fill=value),alpha=0.7,color="black")+scale_fill_manual(values=c("gold","red3"),name = "",
                                                                        labels = c("Available","Missing"))+
    theme_minimal()+
    theme(axis.text.x = element_text(angle=45, vjust=0.5)) +
    labs(x = "Variables in Dataset",
         y = "Observations")+coord_flip()
}


matrix_missing <- function(x){
  library(dplyr)
  library(reshape2)
  library(ggplot2)
  x %>%
    is.na %>%
    melt %>%
    ggplot(data = .,
           aes(x = Var1,
               y = Var2)) +
    geom_tile(aes(fill = value),alpha=0.6) +
    scale_fill_manual(values=c("gold","red3"),name = "",
                      labels = c("Available","Missing")) +
    theme_minimal()+
    theme(axis.text.x = element_text(angle=45, vjust=0.5)) +
    labs(y = "Variables in Dataset",
         x = "Total observations")+coord_flip()
}

df%>%bar_missing()

df%>%matrix_missing()

As you can see, there are missing values in our data:

  1. BIRAD score is a discrete numerical variable, it has 2 missing values
  2. Age is a numerical and continous variable, it has 5 missing values
  3. Shape is a multiclass categorical variable, it has 31 missing values
  4. Margin is a multiclass categorical variable, it has 48 missing values
  5. Density is a multiclass categorical variable, it has 76 missing values

Each one of those variables could be considered as a Target outcome that could be predicted from other variables by a Regression model (for numerical variables such as Age or BIRAD) or by a Classification rule (for categorical variables like Shape, Margin or Density).

KNN algorithm has been used for imputation (caret package). We previously showed that other classification or regression algorithm could also be adopted for the imputation purpose. However we have never confronted missing value problem in multiclass variables.

Our experipent aims to explore the utility of h2o based Random Forest algorithm as an imputation method. This algorithm has been choosen as a potential solution for our problem because:

  1. Random Forest can support multiclass classification tasks
  2. It’s a powerful learner, as verified in other studies
  3. h2o Random Forest consists of a scalable and quick learner, that could handle missing values in training data.

First, we will intialise our h2o package in R:

library(h2o)
h2o.init(nthreads = -1,max_mem_size ="4g")
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         1 hours 20 minutes 
##     H2O cluster version:        3.10.3.6 
##     H2O cluster version age:    2 months and 1 day  
##     H2O cluster name:           H2O_started_from_R_Admin_dua659 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.44 GB 
##     H2O cluster total cores:    4 
##     H2O cluster allowed cores:  4 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     R Version:                  R version 3.3.1 (2016-06-21)

Then we will develop a loop for our imputation. The loop consists of 5 rounds, corresponding to each one among 5 target variables (BIRAD, Age, Shape, Margin and Density). By each round, a Random Forest model will be trained on a temporary data frame in which our target variable will be considered as response variable, while other remaining variables will be considered as predictors. An advantage of h2o is that it could handle automatically the missing values within predictors. Once the model is ready, it will be implied on the original dataset (n=961) to predict the target variable’s values. The accuracy of prediction for each target variable will be checked and its result will be stored in a specific variable (Check). The loop will continue to run until there is no more target variable to impute.

imputationDF=df[,-6] ## Cloning a data frame to receive the output
max=ncol(imputationDF) ## Set the end-point for loop

imputationDF=imputationDF%>%mutate(CheckBIRAD=NA,CheckAge=NA,CheckShape=NA,CheckMargin=NA,CheckDensity=NA) ## Add Dummies variables for accuracy check

target=colnames(df[,-6]) ## A list of target variables (to be considered as Response variable)

wdata=as.h2o(df)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## Begin loop here

n=1

for (i in 1:max) {
  
  ## Creating temporary dataframes for model training
  
  IMP=df[,-6]%>%subset(.,get(target[[i]])!="NA") 
  ty=IMP[,i] 
  id=caret::createDataPartition(y=as.matrix(ty),p=0.75,list=FALSE)
  trainset=IMP[id,]
  validset=IMP[-id,]
  
  ## Tranforming temp. data to h2o frames
  wtrain=as.h2o(trainset)
  wvalid=as.h2o(validset)
  
  ## Training a RF model
  response=target[[i]]
  features=setdiff(colnames(trainset),response)
  
  rfmod=h2o.randomForest(x = features,
                         y = response,
                         training_frame = wtrain,validation_frame = wvalid, nfolds=10,
                         fold_assignment = "Stratified",
                         balance_classes = TRUE,
                         ntrees = 500, max_depth = 100,
                         stopping_metric = "AUTO",
                         stopping_tolerance = 0.01,
                         stopping_rounds = 3,
                         score_each_iteration = TRUE,seed=123)
  
  ## Prediction and Accuracy check
  
  truth=df[,i]
  pred=predict(rfmod,newdata=wdata)%>%as_tibble()%>%mutate(.,truth=as.matrix(truth))
  pred=pred%>%mutate(check=ifelse(.$predict ==.$truth, "TRUE", "FALSE"))
  
  imputationDF[i]=pred$predict
  imputationDF[i+5]=pred$check
  
  n= n+1 ## Go to next round, until i reaches end-point (=max)
}
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |===========================================================      |  91%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |==============================                                   |  46%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |==============================                                   |  46%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=                                                                |   1%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |======                                                           |   9%
  |                                                                       
  |=====================================================            |  82%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
##Quality check for numerical prediction

imputationDF$BIRAD=round(imputationDF$BIRAD)
imputationDF$Age=as.numeric(imputationDF$Age)

##Checking whether predicted BIRAD score is true ?

imputationDF$CheckBIRAD=ifelse(imputationDF$BIRAD==df$BIRAD, "TRUE", "FALSE")

## Checking whether absolute error is greater than 5 for Age ?
imputationDF$CheckAge=ifelse(abs(imputationDF$Age-df$Age)<5, "TRUE", "FALSE")

The imputation loop ends here. Now we can evaluate the accuracy of predictions on graphs:

## Gathering data
impdf=imputationDF%>%mutate(.,Id=rownames(.))%>%gather(.,CheckBIRAD:CheckDensity,key="Target",value="Accuracy")


## Detect Imputations in the data:
impdf$Accuracy[is.na(impdf$Accuracy)]="Imputation"


## Plot the quality check

p1=impdf%>%ggplot(aes(x=Target,y=..count..,fill=Accuracy))+geom_bar(show.legend=F,color="black")+scale_fill_manual(values=c("#f32440","green","#ffd700"))+coord_flip()

p2=impdf%>%ggplot(aes(x=Id,y=Target,fill=Accuracy))+geom_tile(show.legend=T)+scale_fill_manual(values=c("#f32440","green","#ffd700"))+theme(axis.text.x=element_blank())+scale_y_discrete("Target")+scale_x_discrete("Patient's Id")

library(gridExtra)

grid.arrange(p1,p2,ncol=1)

As we could see: 5 distinct Random Forest based models were generated in our loop, each one handled the imputation of one specific target.

It seems that Random Forest models worked better on Classification tasks (Shape, Density, Margin) than on Regression Tasks (Age, Birad). Their performance are impressive for those 3 multiclass categorical variables. The performance was best for Desity but lower for Shape and Margin. Model’s performance was only acceptable for BIRAD and worst for Age. Another regression learner should be considerd for imputation of Age.

Now we can replace the missing value in the original data by imputed value from the Imputation dataframe.

df2=df

df2$Shape[is.na(df2$Shape)]=imputationDF$Shape[is.na(imputationDF$CheckShape)]
df2$Age[is.na(df2$Age)]=imputationDF$Age[is.na(imputationDF$CheckAge)]
df2$BIRAD[is.na(df2$BIRAD)]=imputationDF$BIRAD[is.na(imputationDF$CheckBIRAD)]
df2$Margin[is.na(df2$Margin)]=imputationDF$Margin[is.na(imputationDF$CheckMargin)]
df2$Density[is.na(df2$Density)]=imputationDF$Density[is.na(imputationDF$CheckDensity)]

df2%>%bar_missing()

mycolors=c("#f32440","#ffd700","#faa014","#c9e101","#c100e6")

df2%>%ggplot(aes(x=Severity,fill=Shape))+geom_bar(position="fill",color="black",alpha=0.7)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Shape")

df2%>%ggplot(aes(x=Severity,fill=Margin))+geom_bar(position="fill",color="black",alpha=0.7)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Margin")

df2%>%ggplot(aes(x=Severity,fill=Density))+geom_bar(position="fill",color="black",alpha=0.7)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Density")

df2%>%ggplot(aes(fill=Severity,x=Severity,y=Age))+geom_boxplot(alpha=0.7)+scale_fill_manual(values=c("#02afdb","#ce002c"))+coord_flip()+ggtitle("Age")

df2%>%ggplot(aes(fill=Severity,x=as.factor(BIRAD),y=..count..))+geom_bar(alpha=0.7)+scale_fill_manual(values=c("#02afdb","#ce002c"))+coord_flip()+ggtitle("BIRAD")

Conclusion

Through this tutorial, we have tried to perform a difficult imputation on Multiclass categorical and numerical data, using multiple Random Forest models. The idea is to consider imputation as classification or regression problems and attemp to solve these problems by machine learning techniques. Random Forest could be used as a tool for imputation, particularly when we confront the multilevel categorical variables. The imputation’s accuracy could only be expected by evaluating the model’s performance on the available data. Based on those results, we can see that Random Forest based solution works best on Categorical variables. Other regression learner should be considered for imputation of numerical variables.

Thank you for joining us and see you soon in the next tutorial.