Case study X4: Deep learning (Multilayers Feedfoward Neural nets)

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

1.Encourage using Machine Learning techniques in medical research in Vietnam 2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Background

In this case study X4, we would like to introduce a new algorithm called “Deep learning”, supported by “h2o”, a new framework for machine learning.

The dataset from Chapter 13 - Machine Learning in Medicine-Cookbook Two by T. J. Cleophas and A. H. Zwinderman, Springer, 2014 will be reused. It contains the laboratory data and survival outcome of 200 patients hospitalized for sepsis. The main objective is to estimate the mortality risk of these patients in terms of 10 markers: gamma-glutamyl transferase (gammagt, U/l),aspartate aminotransferase and Alanine transaminase (U/l), bilirubine (μmol/l), urea (mmol/l), creatinine (μmol/l), creatinine clearance (ml/min), erythrocyte sedimentation rate (ESR, mm), c-reactive protein (CRP, mg/l) and leucocyte count (×10^9/l).

This binary classification problem was previously treated with SVM, XGBT and Random Forest, now we will try the Deep learning approach, which is also considered as one of the most powerful algorithms for machine learning.

Materials and method

The original dataset was provided in Chaper 13, Machine Learning in Medicine-Cookbook Three (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in csv format svm.csv) from their website: extras.springer.com.

library(tidyverse)

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 12),
      panel.grid.major = element_line(color = "grey"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#faefff"),
      strip.background = element_rect(fill = "#400156", color = "#400156", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
    )
}
theme_set(my_theme())


data<-read.csv("Case57.csv",sep=";")%>%as_tibble()%>%dplyr::rename(.,Death=VAR00001,GammaGT=VAR00002,ASAT=VAR00003,ALAT=VAR00004,Bilirubin=VAR00005,Urea=VAR00006,Creatinine=VAR00007,CreClearance=VAR00008,ESR=VAR00009,CRP=VAR00010,Leucocytes=VAR00011)

data$Outcome=recode_factor(data$Death,`0` = "Survive", `1` = "Death")

Initialising h2o

As the h2o framework would be used through out this study, you must install this package in R.

How to Install h2o

How to Install h2o

Then we load the package into R:

library(h2o)

To this point, the h2o cluster is not yet ready, we must activate it by h2o.init( ) function. If you see something like this, your installation would be OK

Default configuration

Default configuration

By defaut, h2o will only use two cores on our PC and 25% of system memory, to unleash the full power of our system, we will add an argument of nthreads=-1 to init function as follows:

h2o.init(nthreads = -1)
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         7 hours 31 minutes 
##     H2O cluster version:        3.10.3.6 
##     H2O cluster version age:    1 month and 22 days  
##     H2O cluster name:           H2O_started_from_R_Admin_lou010 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   0.61 GB 
##     H2O cluster total cores:    4 
##     H2O cluster allowed cores:  4 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     R Version:                  R version 3.3.1 (2016-06-21)

Now we can begin our experiment.

Importing data frame to h2o

There is only one inconvenience when using h2o, that this package sets its own environment, class and functions. For example, h2o does not accept neither dataframe nor tibble, we must convert our original R dataframe into h2o frame using as.h2o( ) function. Once converted, we can no more manipulate the h2o frame by tidyverse’s function (some basic functions still work), so it’s recommended that data manipulation and cleaning should be carefully done before converting into h2o frame.

The machine learning experiment often requires splitting the original data into training, validation and testing subsets. We can easily do this in h2o using the h2o.split() function In this example, we want to split our dataset into 2 parts: 80% for training and 20% for testing.

wdata=as.h2o(data)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
splits=h2o.splitFrame(wdata, ratios=0.7,seed=123)

wtrain=splits[[1]]
wtest=splits[[2]]

Unlike the caret package, h2o does not handle well the balance of target’s class among the splitted subsets. Fortunately that our original dataset is well balanced.

p1=wtrain%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Train_set")

p2=wtest%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Test_set")

p3=wdata%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Pooled_data")

library(gridExtra)

grid.arrange(p1,p2,p3,ncol=3)

H2o package currently supports 4 types of algorithms: GLM, GBM, Random Forest and Deep learning. Our tutorial will only focus on Deep learning.

Deep learning, also known as Multilayer neural networks, is considered one among the most powerful machine learning algorithms. Deep learning can solve any classification/regression problem if we know how to tune it adequately and give it enough time and/or data for learning. The drawbacks of deep learning include a slow training process, blackbox model and difficulties with categorical data.

For the demonstration purpose, the present tutorial only treats the Deep learning with fixed configuration setting. In the next tutorial we will explore the grid-based tuning and more advance settings.

As we are all medical doctors, we will not waste our time to describe the structure and physiological activity of the neuron, as presented below (Knowledge in physiology would be a great advantage for us to understand the mechanism of neural network)

Figure: Structure of a neuron

Figure: Structure of a neuron

The functional unit of a neural network is a neuron, which is a function that receive the numeric inputs and returns a numeric output. A neural network consists of multiple layers of many neurons, the output from an entire layer will be treated as the inputs for each neuron in the next layer. The first layer is our data frame, the last layer is the model’s outcome. If our model consists of a binary classification learner, the final layer will contain two neurons, one returns the probability output for each label. Each numerical variable in the data will become a neuron in the first layer, the categorical variables will be handled by several neurons, one for each factor level. The layers between the input and output layers are called hidden layers. The number and size of hidden layers should be preset in training function.

Each neuron is characterized by a weight for inputs, bias and an activation function. H2O supports 3 activation functions: Rectifier, hyperbolic tangent (Tanh) and Maxout.

Figure: Structure of the Deep neural net

Figure: Structure of the Deep neural net

response="Outcome"
features=setdiff(colnames(wtrain),c(response,"Death"))

dlmod=h2o.deeplearning  (x = features,
                         y = response,
                             model_id = "Deep_learning",
                             training_frame = wtrain,                             
                             nfolds = 10,
                             hidden = c(200,200,200,200), 
                             stopping_metric = "misclassification",
                             replicate_training_data = TRUE,
                             stopping_tolerance = 0.001,
                             stopping_rounds = 5,
                             overwrite_with_best_model=TRUE,
                             fold_assignment = "Stratified",
                             epochs=1000,
                             activation = "TanhWithDropout",
                             keep_cross_validation_fold_assignment = TRUE,
                             keep_cross_validation_predictions=FALSE,
                             score_each_iteration = TRUE,
                             variable_importances = TRUE,
                             reproducible = TRUE,seed=123)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================                                                |  26%
  |                                                                       
  |======================                                           |  34%
  |                                                                       
  |=============================                                    |  45%
  |                                                                       
  |==============================                                   |  46%
  |                                                                       
  |===============================                                  |  47%
  |                                                                       
  |===============================                                  |  48%
  |                                                                       
  |================================                                 |  49%
  |                                                                       
  |================================                                 |  50%
  |                                                                       
  |=================================                                |  50%
  |                                                                       
  |=================================                                |  51%
  |                                                                       
  |==================================                               |  52%
  |                                                                       
  |==================================                               |  53%
  |                                                                       
  |===================================                              |  53%
  |                                                                       
  |===================================                              |  54%
  |                                                                       
  |====================================                             |  55%
  |                                                                       
  |====================================                             |  56%
  |                                                                       
  |=====================================                            |  56%
  |                                                                       
  |=====================================                            |  57%
  |                                                                       
  |======================================                           |  58%
  |                                                                       
  |======================================                           |  59%
  |                                                                       
  |=======================================                          |  59%
  |                                                                       
  |=======================================                          |  60%
  |                                                                       
  |=======================================                          |  61%
  |                                                                       
  |========================================                         |  61%
  |                                                                       
  |========================================                         |  62%
  |                                                                       
  |=========================================                        |  63%
  |                                                                       
  |=========================================                        |  64%
  |                                                                       
  |==========================================                       |  64%
  |                                                                       
  |==========================================                       |  65%
  |                                                                       
  |===========================================                      |  66%
  |                                                                       
  |============================================                     |  67%
  |                                                                       
  |============================================                     |  68%
  |                                                                       
  |=============================================                    |  69%
  |                                                                       
  |=============================================                    |  70%
  |                                                                       
  |==============================================                   |  70%
  |                                                                       
  |==============================================                   |  71%
  |                                                                       
  |===============================================                  |  72%
  |                                                                       
  |===============================================                  |  73%
  |                                                                       
  |================================================                 |  73%
  |                                                                       
  |=================================================                |  75%
  |                                                                       
  |======================================================           |  84%
  |                                                                       
  |=======================================================          |  84%
  |                                                                       
  |=================================================================| 100%

All machine learning function in h2o begins by introduction of features list (x) and the response variable (y). X is a vector that contains only name of features, not a real frame, y is the name of response variable.

The model_id is facultative

Epochs is the amount of training cycles a deep learning algorithm should do.

During the model training, h2o regularly score the temporal models on a validation subset by each iteration. We can set a fixed validation frame, or let the cross-validation test the model against random folds. If neither validation frame nor folds is set, validation will be performed on the entire training frame.

The “hidden “argument allows to set the network structure: by default is 200,200, indicating two hidden layers, each one contains 200 neurons. In our case, we can double the number of hidden layers as c(200,200,200,200)

Early stopping argument allows to optimize the training time. The training process should be stopped when the model’s performance already achieves a steady state, no more time should be wasted for training.

stopping_metric indicate the metric that will be used to decide whether the model’s performance is improving or not. By defaut, h2o uses logloss, here we use “misclassification” stopping_tolerance sets a threshold for the metric’s improvement, below which the training will stop, otherwise the model keeps learning

stopping_rounds allows to extend the flexibility of two above criteria, by giving the learner some chances to gets worse before it can get better. H2O has early stopping on by default for deep learning, so explicitly set stopping_rounds to 0 if you don’t want it (not recommended). overwrite_with_best_model means that the returned model will always the best model found during training.

nfolds=Number of folds for cross validation

fold_assignment will define how to split the training frame for cross-validation. It could be “Random,” “Modulo,”or “Stratified,” . Stratified is the preferred method for unbalanced dataset, since it tries to attribute the same proportion of each target class into each fold, but using this makes the training slower.

keep_cross_validation_fold_assignment= TRUE allows to find out which rows were in which folds. keep_cross_validation_predictions=TRUE allows to keep the prediction outcome of each step

Output Control: variable_importances is set as TRUE, in order to extract the relative importance of each variable. It can slow down the learning a bit so it was set off by default, but it worth a try in our case

Activation is a parameter with 6 possible values or a combination between 3 activation functions and dropout mode (with or without).

“reproductible” argument allows training the model in a reproducible way so others can get the same model results each time they run your code. By setting it true we must also set a seed. seed is an integer to control random number generation, which allows you to get exactly the same model if you run the algorithm again.

Activation functions in h2o

Activation functions in h2o

Once the training is completed, we can explore the model as other algorithms:

Summarising the model

summary(dlmod)
## Model Details:
## ==============
## 
## H2OBinomialModel: deeplearning
## Model Key:  Deep_learning 
## Status of Neuron Layers: predicting Outcome, 2-class classification, bernoulli distribution, CrossEntropy loss, 123 202 weights/biases, 1,4 MB, 61 064 training samples, mini-batch size 1
##   layer units        type dropout       l1       l2 mean_rate rate_rms
## 1     1    10       Input  0.00 %                                     
## 2     2   200 TanhDropout 50.00 % 0.000000 0.000000  0.000555 0.000877
## 3     3   200 TanhDropout 50.00 % 0.000000 0.000000  0.010516 0.012201
## 4     4   200 TanhDropout 50.00 % 0.000000 0.000000  0.017772 0.016285
## 5     5   200 TanhDropout 50.00 % 0.000000 0.000000  0.086915 0.189088
## 6     6     2     Softmax         0.000000 0.000000  0.003735 0.001186
##   momentum mean_weight weight_rms mean_bias bias_rms
## 1                                                   
## 2 0.000000   -0.001559   0.141423 -0.004069 0.077562
## 3 0.000000    0.000309   0.093140 -0.003771 0.247713
## 4 0.000000   -0.000453   0.081865  0.000362 0.228266
## 5 0.000000   -0.000092   0.071455  0.001291 0.103524
## 6 0.000000   -0.036175   0.353770 -0.000000 0.114421
## 
## H2OBinomialMetrics: deeplearning
## ** Reported on training data. **
## ** Metrics reported on full training frame **
## 
## MSE:  0.01270721
## RMSE:  0.1127263
## LogLoss:  0.04027583
## Mean Per-Class Error:  0.01482127
## AUC:  0.9986922
## Gini:  0.9973845
## 
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
##         Death Survive    Error    Rate
## Death      61       1 0.016129   =1/62
## Survive     1      73 0.013514   =1/74
## Totals     62      74 0.014706  =2/136
## 
## Maximum Metrics: Maximum metrics at their respective thresholds
##                         metric threshold    value idx
## 1                       max f1  0.381416 0.986486  73
## 2                       max f2  0.133083 0.986667  78
## 3                 max f0point5  0.663120 0.994475  71
## 4                 max accuracy  0.663120 0.985294  71
## 5                max precision  0.999988 1.000000   0
## 6                   max recall  0.133083 1.000000  78
## 7              max specificity  0.999988 1.000000   0
## 8             max absolute_mcc  0.663120 0.970859  71
## 9   max min_per_class_accuracy  0.381416 0.983871  73
## 10 max mean_per_class_accuracy  0.663120 0.986486  71
## 
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## 
## H2OBinomialMetrics: deeplearning
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
## 
## MSE:  0.05036137
## RMSE:  0.2244134
## LogLoss:  0.3000668
## Mean Per-Class Error:  0.03378378
## AUC:  0.9803836
## Gini:  0.9607672
## 
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
##         Death Survive    Error    Rate
## Death      62       0 0.000000   =0/62
## Survive     5      69 0.067568   =5/74
## Totals     67      69 0.036765  =5/136
## 
## Maximum Metrics: Maximum metrics at their respective thresholds
##                         metric threshold    value idx
## 1                       max f1  0.979448 0.965035  57
## 2                       max f2  0.137449 0.957447  68
## 3                 max f0point5  0.979448 0.985714  57
## 4                 max accuracy  0.979448 0.963235  57
## 5                max precision  0.999999 1.000000   0
## 6                   max recall  0.000000 1.000000 122
## 7              max specificity  0.999999 1.000000   0
## 8             max absolute_mcc  0.979448 0.928896  57
## 9   max min_per_class_accuracy  0.966180 0.945946  60
## 10 max mean_per_class_accuracy  0.979448 0.966216  57
## 
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## Cross-Validation Metrics Summary: 
##                                mean          sd   cv_1_valid cv_2_valid
## accuracy                 0.98888886 0.023570227          1.0  0.8888889
## auc                         0.99125 0.018561553          1.0     0.9125
## err                     0.011111111 0.023570227          0.0 0.11111111
## err_count                       0.2  0.42426407          0.0        2.0
## f0point5                       0.99 0.021213204          1.0        0.9
## f1                             0.99 0.021213204          1.0        0.9
## f2                             0.99 0.021213204          1.0        0.9
## lift_top_group            1.9132631   0.2444304    1.7777778        1.8
## logloss                  0.28878036  0.26993978 0.0057941065  1.1592535
## max_per_class_error          0.0125 0.026516505          0.0      0.125
## mcc                          0.9775  0.04772971          1.0      0.775
## mean_per_class_accuracy     0.98875 0.023864854          1.0     0.8875
## mean_per_class_error        0.01125 0.023864854          0.0     0.1125
## mse                     0.050650533 0.038761824 4.8375732E-4 0.10015181
## precision                      0.99 0.021213204          1.0        0.9
## r2                        0.7952921  0.15612668   0.99803424  0.5943852
## recall                         0.99 0.021213204          1.0        0.9
## rmse                     0.16316295  0.10960927  0.021994483  0.3164677
## specificity                  0.9875 0.026516505          1.0      0.875
##                          cv_3_valid cv_4_valid cv_5_valid   cv_6_valid
## accuracy                        1.0        1.0        1.0          1.0
## auc                             1.0        1.0        1.0          1.0
## err                             0.0        0.0        0.0          0.0
## err_count                       0.0        0.0        0.0          0.0
## f0point5                        1.0        1.0        1.0          1.0
## f1                              1.0        1.0        1.0          1.0
## f2                              1.0        1.0        1.0          1.0
## lift_top_group                  2.0  2.1666667        2.0        1.375
## logloss                  0.76311564 0.51032287 0.14848118  0.013675269
## max_per_class_error             0.0        0.0        0.0          0.0
## mcc                             1.0        1.0        1.0          1.0
## mean_per_class_accuracy         1.0        1.0        1.0          1.0
## mean_per_class_error            0.0        0.0        0.0          0.0
## mse                     0.124564655 0.14242762 0.05323253 0.0017704639
## precision                       1.0        1.0        1.0          1.0
## r2                       0.50174135 0.42689836 0.78706986    0.9910739
## recall                          1.0        1.0        1.0          1.0
## rmse                      0.3529372 0.37739584 0.23072176   0.04207688
## specificity                     1.0        1.0        1.0          1.0
##                         cv_7_valid    cv_8_valid   cv_9_valid
## accuracy                       1.0           1.0          1.0
## auc                            1.0           1.0          1.0
## err                            0.0           0.0          0.0
## err_count                      0.0           0.0          0.0
## f0point5                       1.0           1.0          1.0
## f1                             1.0           1.0          1.0
## f2                             1.0           1.0          1.0
## lift_top_group           2.4285715           2.4    1.3846154
## logloss                  0.2870191  2.6050602E-5  9.480391E-5
## max_per_class_error            0.0           0.0          0.0
## mcc                            1.0           1.0          1.0
## mean_per_class_accuracy        1.0           1.0          1.0
## mean_per_class_error           0.0           0.0          0.0
## mse                     0.08387434 7.7921025E-10 1.3558078E-7
## precision                      1.0           1.0          1.0
## r2                       0.6537188           1.0   0.99999934
## recall                         1.0           1.0          1.0
## rmse                    0.28961065  2.7914339E-5 3.6821293E-4
## specificity                    1.0           1.0          1.0
##                           cv_10_valid
## accuracy                          1.0
## auc                               1.0
## err                               0.0
## err_count                         0.0
## f0point5                          1.0
## f1                                1.0
## f2                                1.0
## lift_top_group                    1.8
## logloss                  2.1142827E-5
## max_per_class_error               0.0
## mcc                               1.0
## mean_per_class_accuracy           1.0
## mean_per_class_error              0.0
## mse                     8.3267787E-10
## precision                         1.0
## r2                                1.0
## recall                            1.0
## rmse                     2.8856159E-5
## specificity                       1.0
## 
## Scoring History: 
##             timestamp          duration training_speed  epochs iterations
## 1 2017-04-12 23:44:25         0.000 sec                0.00000          0
## 2 2017-04-12 23:44:26  8 min 48.827 sec    501 obs/sec 1.00000          1
## 3 2017-04-12 23:44:26  8 min 49.138 sec    501 obs/sec 2.00000          2
## 4 2017-04-12 23:44:26  8 min 49.451 sec    496 obs/sec 3.00000          3
## 5 2017-04-12 23:44:27  8 min 49.778 sec    490 obs/sec 4.00000          4
##      samples training_rmse training_logloss training_auc training_lift
## 1   0.000000                                                          
## 2 136.000000       0.20348          0.18537      0.98670       1.83784
## 3 272.000000       0.17691          0.17185      0.98779       1.83784
## 4 408.000000       0.19761          0.16548      0.98932       1.83784
## 5 544.000000       0.16414          0.13956      0.98997       1.83784
##   training_classification_error
## 1                              
## 2                       0.02941
## 3                       0.02206
## 4                       0.02941
## 5                       0.02206
## 
## ---
##               timestamp          duration training_speed    epochs
## 445 2017-04-12 23:46:47 11 min  9.858 sec    484 obs/sec 444.00000
## 446 2017-04-12 23:46:47 11 min 10.179 sec    484 obs/sec 445.00000
## 447 2017-04-12 23:46:47 11 min 10.489 sec    484 obs/sec 446.00000
## 448 2017-04-12 23:46:48 11 min 10.799 sec    484 obs/sec 447.00000
## 449 2017-04-12 23:46:48 11 min 11.119 sec    484 obs/sec 448.00000
## 450 2017-04-12 23:46:48 11 min 11.439 sec    484 obs/sec 449.00000
##     iterations      samples training_rmse training_logloss training_auc
## 445        444 60384.000000       0.13933          0.07535      0.99826
## 446        445 60520.000000       0.11239          0.04095      0.99847
## 447        446 60656.000000       0.12762          0.04726      0.99847
## 448        447 60792.000000       0.11958          0.04317      0.99847
## 449        448 60928.000000       0.12391          0.04558      0.99847
## 450        449 61064.000000       0.11273          0.04028      0.99869
##     training_lift training_classification_error
## 445       1.83784                       0.01471
## 446       1.83784                       0.01471
## 447       1.83784                       0.01471
## 448       1.83784                       0.01471
## 449       1.83784                       0.01471
## 450       1.83784                       0.01471
## 
## Variable Importances: (Extract with `h2o.varimp`) 
## =================================================
## 
## Variable Importances: 
##        variable relative_importance scaled_importance percentage
## 1    Leucocytes            1.000000          1.000000   0.172288
## 2          ALAT            0.736160          0.736160   0.126832
## 3           CRP            0.567268          0.567268   0.097734
## 4          ASAT            0.565372          0.565372   0.097407
## 5          Urea            0.557244          0.557244   0.096007
## 6           ESR            0.519068          0.519068   0.089429
## 7  CreClearance            0.498226          0.498226   0.085839
## 8     Bilirubin            0.493901          0.493901   0.085093
## 9       GammaGT            0.470984          0.470984   0.081145
## 10   Creatinine            0.396006          0.396006   0.068227

Important variable

h2o.varimp_plot(dlmod)

h2o.varimp(dlmod)
## Variable Importances: 
##        variable relative_importance scaled_importance percentage
## 1    Leucocytes            1.000000          1.000000   0.172288
## 2          ALAT            0.736160          0.736160   0.126832
## 3           CRP            0.567268          0.567268   0.097734
## 4          ASAT            0.565372          0.565372   0.097407
## 5          Urea            0.557244          0.557244   0.096007
## 6           ESR            0.519068          0.519068   0.089429
## 7  CreClearance            0.498226          0.498226   0.085839
## 8     Bilirubin            0.493901          0.493901   0.085093
## 9       GammaGT            0.470984          0.470984   0.081145
## 10   Creatinine            0.396006          0.396006   0.068227

Confusion matrix

h2o.confusionMatrix(dlmod)
## Confusion Matrix (vertical: actual; across: predicted)  for max f1 @ threshold = 0.381415907908741:
##         Death Survive    Error    Rate
## Death      61       1 0.016129   =1/62
## Survive     1      73 0.013514   =1/74
## Totals     62      74 0.014706  =2/136

Marginalised partial plots

h2o.partialPlot(dlmod,data=wdata)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%

## [[1]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'GammaGT'
##        GammaGT mean_response
## 1     2.000000      0.518582
## 2   122.947368      0.522329
## 3   243.894737      0.527045
## 4   364.842105      0.533433
## 5   485.789474      0.541715
## 6   606.736842      0.550912
## 7   727.684211      0.559677
## 8   848.631579      0.567565
## 9   969.578947      0.574678
## 10 1090.526316      0.581435
## 11 1211.473684      0.588337
## 12 1332.421053      0.595648
## 13 1453.368421      0.603215
## 14 1574.315789      0.610600
## 15 1695.263158      0.617396
## 16 1816.210526      0.623429
## 17 1937.157895      0.628751
## 18 2058.105263      0.633501
## 19 2179.052632      0.637792
## 20 2300.000000      0.641734
## 
## [[2]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ASAT'
##           ASAT mean_response
## 1     3.000000      0.534792
## 2   107.052632      0.512723
## 3   211.105263      0.494028
## 4   315.157895      0.463911
## 5   419.210526      0.414536
## 6   523.263158      0.313878
## 7   627.315789      0.182829
## 8   731.368421      0.106746
## 9   835.421053      0.041596
## 10  939.473684      0.011937
## 11 1043.526316      0.000261
## 12 1147.578947      0.000097
## 13 1251.631579      0.000049
## 14 1355.684211      0.000037
## 15 1459.736842      0.000032
## 16 1563.789474      0.000029
## 17 1667.842105      0.000027
## 18 1771.894737      0.000026
## 19 1875.947368      0.000026
## 20 1980.000000      0.000025
## 
## [[3]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ALAT'
##           ALAT mean_response
## 1     2.000000      0.554716
## 2    80.842105      0.514758
## 3   159.684211      0.481577
## 4   238.526316      0.419949
## 5   317.368421      0.309494
## 6   396.210526      0.158624
## 7   475.052632      0.059209
## 8   553.894737      0.014893
## 9   632.736842      0.000562
## 10  711.578947      0.000124
## 11  790.421053      0.000050
## 12  869.263158      0.000037
## 13  948.105263      0.000032
## 14 1026.947368      0.000029
## 15 1105.789474      0.000028
## 16 1184.631579      0.000027
## 17 1263.473684      0.000026
## 18 1342.315789      0.000026
## 19 1421.157895      0.000025
## 20 1500.000000      0.000025
## 
## [[4]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Bilirubin'
##     Bilirubin mean_response
## 1    1.000000      0.542578
## 2   22.000000      0.528720
## 3   43.000000      0.516002
## 4   64.000000      0.507999
## 5   85.000000      0.496483
## 6  106.000000      0.476670
## 7  127.000000      0.462963
## 8  148.000000      0.431021
## 9  169.000000      0.409299
## 10 190.000000      0.362632
## 11 211.000000      0.287991
## 12 232.000000      0.220285
## 13 253.000000      0.173899
## 14 274.000000      0.136822
## 15 295.000000      0.089550
## 16 316.000000      0.054679
## 17 337.000000      0.037184
## 18 358.000000      0.022033
## 19 379.000000      0.005819
## 20 400.000000      0.000963
## 
## [[5]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Urea'
##         Urea mean_response
## 1   2.000000      0.504329
## 2   7.052632      0.510908
## 3  12.105263      0.526168
## 4  17.157895      0.545514
## 5  22.210526      0.574097
## 6  27.263158      0.602961
## 7  32.315789      0.621378
## 8  37.368421      0.641056
## 9  42.421053      0.656762
## 10 47.473684      0.670351
## 11 52.526316      0.687223
## 12 57.578947      0.708152
## 13 62.631579      0.731078
## 14 67.684211      0.749708
## 15 72.736842      0.764716
## 16 77.789474      0.776137
## 17 82.842105      0.783861
## 18 87.894737      0.789156
## 19 92.947368      0.792336
## 20 98.000000      0.799021
## 
## [[6]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Creatinine'
##    Creatinine mean_response
## 1   47.000000      0.531442
## 2   90.052632      0.527676
## 3  133.105263      0.524740
## 4  176.157895      0.522338
## 5  219.210526      0.520213
## 6  262.263158      0.518206
## 7  305.315789      0.516270
## 8  348.368421      0.514378
## 9  391.421053      0.512306
## 10 434.473684      0.509201
## 11 477.526316      0.504059
## 12 520.578947      0.496567
## 13 563.631579      0.488162
## 14 606.684211      0.481104
## 15 649.736842      0.475492
## 16 692.789474      0.468133
## 17 735.842105      0.456342
## 18 778.894737      0.438736
## 19 821.947368      0.424749
## 20 865.000000      0.414454
## 
## [[7]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'CreClearance'
##    CreClearance mean_response
## 1   -132.000000      0.498861
## 2   -125.263158      0.499751
## 3   -118.526316      0.500432
## 4   -111.789474      0.501157
## 5   -105.052632      0.502177
## 6    -98.315789      0.503744
## 7    -91.578947      0.506112
## 8    -84.842105      0.509487
## 9    -78.105263      0.513964
## 10   -71.368421      0.519487
## 11   -64.631579      0.525911
## 12   -57.894737      0.533115
## 13   -51.157895      0.541039
## 14   -44.421053      0.549601
## 15   -37.684211      0.558600
## 16   -30.947368      0.567752
## 17   -24.210526      0.576779
## 18   -17.473684      0.585486
## 19   -10.736842      0.593791
## 20    -4.000000      0.601701
## 
## [[8]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ESR'
##           ESR mean_response
## 1    2.000000      0.547599
## 2   11.368421      0.535137
## 3   20.736842      0.526186
## 4   30.105263      0.520307
## 5   39.473684      0.516673
## 6   48.842105      0.511003
## 7   58.210526      0.501052
## 8   67.578947      0.484122
## 9   76.947368      0.472520
## 10  86.315789      0.455520
## 11  95.684211      0.431987
## 12 105.052632      0.415811
## 13 114.421053      0.383739
## 14 123.789474      0.335590
## 15 133.157895      0.270888
## 16 142.526316      0.228295
## 17 151.894737      0.198115
## 18 161.263158      0.172016
## 19 170.631579      0.135098
## 20 180.000000      0.101062
## 
## [[9]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'CRP'
##           CRP mean_response
## 1    2.000000      0.507214
## 2   14.684211      0.512134
## 3   27.368421      0.520614
## 4   40.052632      0.536730
## 5   52.736842      0.554463
## 6   65.421053      0.568694
## 7   78.105263      0.587310
## 8   90.789474      0.613578
## 9  103.473684      0.637329
## 10 116.157895      0.653810
## 11 128.842105      0.669877
## 12 141.526316      0.687090
## 13 154.210526      0.703554
## 14 166.894737      0.718720
## 15 179.578947      0.731751
## 16 192.263158      0.741415
## 17 204.947368      0.749021
## 18 217.631579      0.758300
## 19 230.315789      0.770243
## 20 243.000000      0.782689
## 
## [[10]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Leucocytes'
##    Leucocytes mean_response
## 1    2.000000      0.850415
## 2    3.473684      0.825656
## 3    4.947368      0.800200
## 4    6.421053      0.762767
## 5    7.894737      0.723762
## 6    9.368421      0.677537
## 7   10.842105      0.591172
## 8   12.315789      0.112328
## 9   13.789474      0.054820
## 10  15.263158      0.033871
## 11  16.736842      0.026521
## 12  18.210526      0.019641
## 13  19.684211      0.015134
## 14  21.157895      0.012435
## 15  22.631579      0.010745
## 16  24.105263      0.009628
## 17  25.578947      0.009168
## 18  27.052632      0.009241
## 19  28.526316      0.008198
## 20  30.000000      0.006028

Evaluating model’s performance on test subset

h2o.performance(dlmod,wtest)
## H2OBinomialMetrics: deeplearning
## 
## MSE:  0.062223
## RMSE:  0.2494454
## LogLoss:  0.2237813
## Mean Per-Class Error:  0.06451613
## AUC:  0.9902248
## Gini:  0.9804497
## 
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
##         Death Survive    Error   Rate
## Death      27       4 0.129032  =4/31
## Survive     0      33 0.000000  =0/33
## Totals     27      37 0.062500  =4/64
## 
## Maximum Metrics: Maximum metrics at their respective thresholds
##                         metric threshold    value idx
## 1                       max f1  0.010546 0.942857  36
## 2                       max f2  0.010546 0.976331  36
## 3                 max f0point5  0.964964 0.973154  28
## 4                 max accuracy  0.964964 0.937500  28
## 5                max precision  0.999987 1.000000   0
## 6                   max recall  0.010546 1.000000  36
## 7              max specificity  0.999987 1.000000   0
## 8             max absolute_mcc  0.964964 0.882244  28
## 9   max min_per_class_accuracy  0.071646 0.935484  32
## 10 max mean_per_class_accuracy  0.964964 0.939394  28
## 
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
h2o.performance(dlmod,newdata=wtest)%>%plot()

Calculating all performance’s metric on test subset

h2o.performance(dlmod,newdata=wtest)%>%h2o.metric()
## Metrics for Thresholds: Binomial metrics as a function of classification thresholds
##   threshold       f1       f2 f0point5 accuracy precision   recall
## 1  0.999987 0.058824 0.037594 0.135135 0.500000  1.000000 0.030303
## 2  0.999987 0.114286 0.074627 0.243902 0.515625  1.000000 0.060606
## 3  0.999986 0.166667 0.111111 0.333333 0.531250  1.000000 0.090909
## 4  0.999986 0.216216 0.147059 0.408163 0.546875  1.000000 0.121212
## 5  0.999986 0.263158 0.182482 0.471698 0.562500  1.000000 0.151515
##   specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 1    1.000000     0.122111               0.030303                0.515152
## 2    1.000000     0.174078               0.060606                0.530303
## 3    1.000000     0.214941               0.090909                0.545455
## 4    1.000000     0.250252               0.121212                0.560606
## 5    1.000000     0.282152               0.151515                0.575758
##   tns fns fps tps      tnr      fnr      fpr      tpr idx
## 1  31  32   0   1 1.000000 0.969697 0.000000 0.030303   0
## 2  31  31   0   2 1.000000 0.939394 0.000000 0.060606   1
## 3  31  30   0   3 1.000000 0.909091 0.000000 0.090909   2
## 4  31  29   0   4 1.000000 0.878788 0.000000 0.121212   3
## 5  31  28   0   5 1.000000 0.848485 0.000000 0.151515   4
## 
## ---
##    threshold       f1       f2 f0point5 accuracy precision   recall
## 59  0.000027 0.717391 0.863874 0.613383 0.593750  0.559322 1.000000
## 60  0.000026 0.709677 0.859375 0.604396 0.578125  0.550000 1.000000
## 61  0.000025 0.702128 0.854922 0.595668 0.562500  0.540984 1.000000
## 62  0.000025 0.694737 0.850515 0.587189 0.546875  0.532258 1.000000
## 63  0.000024 0.687500 0.846154 0.578947 0.531250  0.523810 1.000000
## 64  0.000024 0.680412 0.841837 0.570934 0.515625  0.515625 1.000000
##    specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 59    0.161290     0.300355               0.161290                0.580645
## 60    0.129032     0.266398               0.129032                0.564516
## 61    0.096774     0.228808               0.096774                0.548387
## 62    0.064516     0.185308               0.064516                0.532258
## 63    0.032258     0.129989               0.032258                0.516129
## 64    0.000000     0.000000               0.000000                0.500000
##    tns fns fps tps      tnr      fnr      fpr      tpr idx
## 59   5   0  26  33 0.161290 0.000000 0.838710 1.000000  58
## 60   4   0  27  33 0.129032 0.000000 0.870968 1.000000  59
## 61   3   0  28  33 0.096774 0.000000 0.903226 1.000000  60
## 62   2   0  29  33 0.064516 0.000000 0.935484 1.000000  61
## 63   1   0  30  33 0.032258 0.000000 0.967742 1.000000  62
## 64   0   0  31  33 0.000000 0.000000 1.000000 1.000000  63

Exploring the training evolution

plot(dlmod,
     timestep = "epochs",
     metric = "classification_error")

plot(dlmod,
     timestep = "epochs",
     metric = "auc")

plot(dlmod,
     timestep = "epochs",
     metric = "logloss")

plot(dlmod,
     timestep = "epochs",
     metric = "rmse")

Exploring the h2o’s model in mlr package

library(mlr)

train=as.data.frame(wtrain)
test=as.data.frame(wtest)

taskSepsis=mlr::makeClassifTask(id="Sepsis",data=train[,-1],target="Outcome",positive = "Death")

tasktest=mlr::makeClassifTask(id="Sepsis",data=test[,-1],target="Outcome",positive = "Death")

mets=list(auc,bac,tpr,tnr,mmce,ber,fpr,fnr,ppv,npv)

# classif.h2o.deeplearning 

learnerDL = makeLearner(id="DL","classif.h2o.deeplearning", predict.type = "prob")

mlrDL=mlr::train(learnerDL,taskSepsis)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |====================================================             |  80%
  |                                                                       
  |=================================================================| 100%
mlrDL$learner.model=dlmod

predDL=predict(mlrDL,tasktest,measures=mets)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
pdf=predDL%>%performance(.,mets)%>%as_tibble()%>%mutate(.,Metric=row.names(.))

cbind(pdf$Metric,pdf$value)
##       [,1]   [,2]                
##  [1,] "auc"  "0.990224828934506" 
##  [2,] "bac"  "0.923264907135875" 
##  [3,] "tpr"  "0.967741935483871" 
##  [4,] "tnr"  "0.878787878787879" 
##  [5,] "mmce" "0.078125"          
##  [6,] "ber"  "0.0767350928641251"
##  [7,] "fpr"  "0.121212121212121" 
##  [8,] "fnr"  "0.032258064516129" 
##  [9,] "ppv"  "0.882352941176471" 
## [10,] "npv"  "0.966666666666667"

Exploring h2o’s model using caret package

library(caret)

confusionMatrix(predDL$data$response, reference=predDL$data$truth,positive ="Death",mode="everything")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Death Survive
##    Death      30       4
##    Survive     1      29
##                                          
##                Accuracy : 0.9219         
##                  95% CI : (0.827, 0.9741)
##     No Information Rate : 0.5156         
##     P-Value [Acc > NIR] : 2.373e-12      
##                                          
##                   Kappa : 0.8441         
##  Mcnemar's Test P-Value : 0.3711         
##                                          
##             Sensitivity : 0.9677         
##             Specificity : 0.8788         
##          Pos Pred Value : 0.8824         
##          Neg Pred Value : 0.9667         
##               Precision : 0.8824         
##                  Recall : 0.9677         
##                      F1 : 0.9231         
##              Prevalence : 0.4844         
##          Detection Rate : 0.4688         
##    Detection Prevalence : 0.5312         
##       Balanced Accuracy : 0.9233         
##                                          
##        'Positive' Class : Death          
## 

Accuracy of h2o’s Deep learning model on pooled data and test subset

pdf=predict(dlmod,wdata)%>%as_tibble()%>%mutate(Accuracy=ifelse(data$Outcome ==.$predict, "yes", "no"))
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
ggplot(data=pdf,aes(x=predict,fill=Accuracy))+stat_count()+scale_fill_manual(values=c("#ff003f","#0094ff"))+coord_flip()+ggtitle("Original sample")

pdf=predict(dlmod,wtest)%>%as_tibble()%>%mutate(Accuracy=ifelse(test$Outcome ==.$predict, "yes", "no"))
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
ggplot(data=pdf,aes(x=predict,fill=Accuracy))+stat_count()+scale_fill_manual(values=c("#ff003f","#0094ff"))+coord_flip()+ggtitle("Test subset")

Conclusion

Through this case study X4, we have successfully trained a Deep learning model in h2o framework. The good news is that training a Deep Learning algorithm in h2o is easiser than we though! Despite that Deep neural nets is the most complicated algorithm in h2o with a large number of parameters, the authors of h2o’s function did a good job to keep everything simplest as they could be (You can get an impressive result without any tuning). The h2o framework could also be brought up within mlr package, but it’s not recommended since the stability of mlr functions cannot be assured. Once trained, the h2o’s model can be explored in different ways, even by using caret’s confusion matrix and mlr’s metrics. The h2o package provided enough functions for the validation and interpretation of Deep learning model.

We will return to Deep learning on the next tutorial, I will show you how to tune its parameters and how to deal with unbalanced data. See you soon and thank joining us.

END