Foreword: About the Machine Learning in Medicine (MLM) project
The MLM project has been initialized in 2016 and aims to:
1.Encourage using Machine Learning techniques in medical research in Vietnam 2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.
Background
In this case study X4, we would like to introduce a new algorithm called “Deep learning”, supported by “h2o”, a new framework for machine learning.
The dataset from Chapter 13 - Machine Learning in Medicine-Cookbook Two by T. J. Cleophas and A. H. Zwinderman, Springer, 2014 will be reused. It contains the laboratory data and survival outcome of 200 patients hospitalized for sepsis. The main objective is to estimate the mortality risk of these patients in terms of 10 markers: gamma-glutamyl transferase (gammagt, U/l),aspartate aminotransferase and Alanine transaminase (U/l), bilirubine (μmol/l), urea (mmol/l), creatinine (μmol/l), creatinine clearance (ml/min), erythrocyte sedimentation rate (ESR, mm), c-reactive protein (CRP, mg/l) and leucocyte count (×10^9/l).
This binary classification problem was previously treated with SVM, XGBT and Random Forest, now we will try the Deep learning approach, which is also considered as one of the most powerful algorithms for machine learning.
Materials and method
The original dataset was provided in Chaper 13, Machine Learning in Medicine-Cookbook Three (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in csv format svm.csv) from their website: extras.springer.com.
library(tidyverse)
my_theme <- function(base_size = 10, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size = 10),
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size = 12),
panel.grid.major = element_line(color = "grey"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#faefff"),
strip.background = element_rect(fill = "#400156", color = "#400156", size =0.5),
strip.text = element_text(face = "bold", size = 10, color = "white"),
legend.position = "bottom",
legend.justification = "center",
legend.background = element_blank(),
panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
)
}
theme_set(my_theme())
data<-read.csv("Case57.csv",sep=";")%>%as_tibble()%>%dplyr::rename(.,Death=VAR00001,GammaGT=VAR00002,ASAT=VAR00003,ALAT=VAR00004,Bilirubin=VAR00005,Urea=VAR00006,Creatinine=VAR00007,CreClearance=VAR00008,ESR=VAR00009,CRP=VAR00010,Leucocytes=VAR00011)
data$Outcome=recode_factor(data$Death,`0` = "Survive", `1` = "Death")
Initialising h2o
As the h2o framework would be used through out this study, you must install this package in R.
How to Install h2o
Then we load the package into R:
library(h2o)
To this point, the h2o cluster is not yet ready, we must activate it by h2o.init( ) function. If you see something like this, your installation would be OK
Default configuration
By defaut, h2o will only use two cores on our PC and 25% of system memory, to unleash the full power of our system, we will add an argument of nthreads=-1 to init function as follows:
h2o.init(nthreads = -1)
## Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 7 hours 31 minutes
## H2O cluster version: 3.10.3.6
## H2O cluster version age: 1 month and 22 days
## H2O cluster name: H2O_started_from_R_Admin_lou010
## H2O cluster total nodes: 1
## H2O cluster total memory: 0.61 GB
## H2O cluster total cores: 4
## H2O cluster allowed cores: 4
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## R Version: R version 3.3.1 (2016-06-21)
Now we can begin our experiment.
Importing data frame to h2o
There is only one inconvenience when using h2o, that this package sets its own environment, class and functions. For example, h2o does not accept neither dataframe nor tibble, we must convert our original R dataframe into h2o frame using as.h2o( ) function. Once converted, we can no more manipulate the h2o frame by tidyverse’s function (some basic functions still work), so it’s recommended that data manipulation and cleaning should be carefully done before converting into h2o frame.
The machine learning experiment often requires splitting the original data into training, validation and testing subsets. We can easily do this in h2o using the h2o.split() function In this example, we want to split our dataset into 2 parts: 80% for training and 20% for testing.
wdata=as.h2o(data)
##
|
| | 0%
|
|=================================================================| 100%
splits=h2o.splitFrame(wdata, ratios=0.7,seed=123)
wtrain=splits[[1]]
wtest=splits[[2]]
Unlike the caret package, h2o does not handle well the balance of target’s class among the splitted subsets. Fortunately that our original dataset is well balanced.
p1=wtrain%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Train_set")
p2=wtest%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Test_set")
p3=wdata%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Pooled_data")
library(gridExtra)
grid.arrange(p1,p2,p3,ncol=3)
H2o package currently supports 4 types of algorithms: GLM, GBM, Random Forest and Deep learning. Our tutorial will only focus on Deep learning.
Deep learning, also known as Multilayer neural networks, is considered one among the most powerful machine learning algorithms. Deep learning can solve any classification/regression problem if we know how to tune it adequately and give it enough time and/or data for learning. The drawbacks of deep learning include a slow training process, blackbox model and difficulties with categorical data.
For the demonstration purpose, the present tutorial only treats the Deep learning with fixed configuration setting. In the next tutorial we will explore the grid-based tuning and more advance settings.
As we are all medical doctors, we will not waste our time to describe the structure and physiological activity of the neuron, as presented below (Knowledge in physiology would be a great advantage for us to understand the mechanism of neural network)
Figure: Structure of a neuron
The functional unit of a neural network is a neuron, which is a function that receive the numeric inputs and returns a numeric output. A neural network consists of multiple layers of many neurons, the output from an entire layer will be treated as the inputs for each neuron in the next layer. The first layer is our data frame, the last layer is the model’s outcome. If our model consists of a binary classification learner, the final layer will contain two neurons, one returns the probability output for each label. Each numerical variable in the data will become a neuron in the first layer, the categorical variables will be handled by several neurons, one for each factor level. The layers between the input and output layers are called hidden layers. The number and size of hidden layers should be preset in training function.
Each neuron is characterized by a weight for inputs, bias and an activation function. H2O supports 3 activation functions: Rectifier, hyperbolic tangent (Tanh) and Maxout.
Figure: Structure of the Deep neural net
response="Outcome"
features=setdiff(colnames(wtrain),c(response,"Death"))
dlmod=h2o.deeplearning (x = features,
y = response,
model_id = "Deep_learning",
training_frame = wtrain,
nfolds = 10,
hidden = c(200,200,200,200),
stopping_metric = "misclassification",
replicate_training_data = TRUE,
stopping_tolerance = 0.001,
stopping_rounds = 5,
overwrite_with_best_model=TRUE,
fold_assignment = "Stratified",
epochs=1000,
activation = "TanhWithDropout",
keep_cross_validation_fold_assignment = TRUE,
keep_cross_validation_predictions=FALSE,
score_each_iteration = TRUE,
variable_importances = TRUE,
reproducible = TRUE,seed=123)
##
|
| | 0%
|
|================= | 26%
|
|====================== | 34%
|
|============================= | 45%
|
|============================== | 46%
|
|=============================== | 47%
|
|=============================== | 48%
|
|================================ | 49%
|
|================================ | 50%
|
|================================= | 50%
|
|================================= | 51%
|
|================================== | 52%
|
|================================== | 53%
|
|=================================== | 53%
|
|=================================== | 54%
|
|==================================== | 55%
|
|==================================== | 56%
|
|===================================== | 56%
|
|===================================== | 57%
|
|====================================== | 58%
|
|====================================== | 59%
|
|======================================= | 59%
|
|======================================= | 60%
|
|======================================= | 61%
|
|======================================== | 61%
|
|======================================== | 62%
|
|========================================= | 63%
|
|========================================= | 64%
|
|========================================== | 64%
|
|========================================== | 65%
|
|=========================================== | 66%
|
|============================================ | 67%
|
|============================================ | 68%
|
|============================================= | 69%
|
|============================================= | 70%
|
|============================================== | 70%
|
|============================================== | 71%
|
|=============================================== | 72%
|
|=============================================== | 73%
|
|================================================ | 73%
|
|================================================= | 75%
|
|====================================================== | 84%
|
|======================================================= | 84%
|
|=================================================================| 100%
All machine learning function in h2o begins by introduction of features list (x) and the response variable (y). X is a vector that contains only name of features, not a real frame, y is the name of response variable.
The model_id is facultative
Epochs is the amount of training cycles a deep learning algorithm should do.
During the model training, h2o regularly score the temporal models on a validation subset by each iteration. We can set a fixed validation frame, or let the cross-validation test the model against random folds. If neither validation frame nor folds is set, validation will be performed on the entire training frame.
The “hidden “argument allows to set the network structure: by default is 200,200, indicating two hidden layers, each one contains 200 neurons. In our case, we can double the number of hidden layers as c(200,200,200,200)
Early stopping argument allows to optimize the training time. The training process should be stopped when the model’s performance already achieves a steady state, no more time should be wasted for training.
stopping_metric indicate the metric that will be used to decide whether the model’s performance is improving or not. By defaut, h2o uses logloss, here we use “misclassification” stopping_tolerance sets a threshold for the metric’s improvement, below which the training will stop, otherwise the model keeps learning
stopping_rounds allows to extend the flexibility of two above criteria, by giving the learner some chances to gets worse before it can get better. H2O has early stopping on by default for deep learning, so explicitly set stopping_rounds to 0 if you don’t want it (not recommended). overwrite_with_best_model means that the returned model will always the best model found during training.
nfolds=Number of folds for cross validation
fold_assignment will define how to split the training frame for cross-validation. It could be “Random,” “Modulo,”or “Stratified,” . Stratified is the preferred method for unbalanced dataset, since it tries to attribute the same proportion of each target class into each fold, but using this makes the training slower.
keep_cross_validation_fold_assignment= TRUE allows to find out which rows were in which folds. keep_cross_validation_predictions=TRUE allows to keep the prediction outcome of each step
Output Control: variable_importances is set as TRUE, in order to extract the relative importance of each variable. It can slow down the learning a bit so it was set off by default, but it worth a try in our case
Activation is a parameter with 6 possible values or a combination between 3 activation functions and dropout mode (with or without).
“reproductible” argument allows training the model in a reproducible way so others can get the same model results each time they run your code. By setting it true we must also set a seed. seed is an integer to control random number generation, which allows you to get exactly the same model if you run the algorithm again.
Activation functions in h2o
Once the training is completed, we can explore the model as other algorithms:
Summarising the model
summary(dlmod)
## Model Details:
## ==============
##
## H2OBinomialModel: deeplearning
## Model Key: Deep_learning
## Status of Neuron Layers: predicting Outcome, 2-class classification, bernoulli distribution, CrossEntropy loss, 123Â 202 weights/biases, 1,4 MB, 61Â 064 training samples, mini-batch size 1
## layer units type dropout l1 l2 mean_rate rate_rms
## 1 1 10 Input 0.00 %
## 2 2 200 TanhDropout 50.00 % 0.000000 0.000000 0.000555 0.000877
## 3 3 200 TanhDropout 50.00 % 0.000000 0.000000 0.010516 0.012201
## 4 4 200 TanhDropout 50.00 % 0.000000 0.000000 0.017772 0.016285
## 5 5 200 TanhDropout 50.00 % 0.000000 0.000000 0.086915 0.189088
## 6 6 2 Softmax 0.000000 0.000000 0.003735 0.001186
## momentum mean_weight weight_rms mean_bias bias_rms
## 1
## 2 0.000000 -0.001559 0.141423 -0.004069 0.077562
## 3 0.000000 0.000309 0.093140 -0.003771 0.247713
## 4 0.000000 -0.000453 0.081865 0.000362 0.228266
## 5 0.000000 -0.000092 0.071455 0.001291 0.103524
## 6 0.000000 -0.036175 0.353770 -0.000000 0.114421
##
## H2OBinomialMetrics: deeplearning
## ** Reported on training data. **
## ** Metrics reported on full training frame **
##
## MSE: 0.01270721
## RMSE: 0.1127263
## LogLoss: 0.04027583
## Mean Per-Class Error: 0.01482127
## AUC: 0.9986922
## Gini: 0.9973845
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Death Survive Error Rate
## Death 61 1 0.016129 =1/62
## Survive 1 73 0.013514 =1/74
## Totals 62 74 0.014706 =2/136
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.381416 0.986486 73
## 2 max f2 0.133083 0.986667 78
## 3 max f0point5 0.663120 0.994475 71
## 4 max accuracy 0.663120 0.985294 71
## 5 max precision 0.999988 1.000000 0
## 6 max recall 0.133083 1.000000 78
## 7 max specificity 0.999988 1.000000 0
## 8 max absolute_mcc 0.663120 0.970859 71
## 9 max min_per_class_accuracy 0.381416 0.983871 73
## 10 max mean_per_class_accuracy 0.663120 0.986486 71
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
##
## H2OBinomialMetrics: deeplearning
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 0.05036137
## RMSE: 0.2244134
## LogLoss: 0.3000668
## Mean Per-Class Error: 0.03378378
## AUC: 0.9803836
## Gini: 0.9607672
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Death Survive Error Rate
## Death 62 0 0.000000 =0/62
## Survive 5 69 0.067568 =5/74
## Totals 67 69 0.036765 =5/136
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.979448 0.965035 57
## 2 max f2 0.137449 0.957447 68
## 3 max f0point5 0.979448 0.985714 57
## 4 max accuracy 0.979448 0.963235 57
## 5 max precision 0.999999 1.000000 0
## 6 max recall 0.000000 1.000000 122
## 7 max specificity 0.999999 1.000000 0
## 8 max absolute_mcc 0.979448 0.928896 57
## 9 max min_per_class_accuracy 0.966180 0.945946 60
## 10 max mean_per_class_accuracy 0.979448 0.966216 57
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## Cross-Validation Metrics Summary:
## mean sd cv_1_valid cv_2_valid
## accuracy 0.98888886 0.023570227 1.0 0.8888889
## auc 0.99125 0.018561553 1.0 0.9125
## err 0.011111111 0.023570227 0.0 0.11111111
## err_count 0.2 0.42426407 0.0 2.0
## f0point5 0.99 0.021213204 1.0 0.9
## f1 0.99 0.021213204 1.0 0.9
## f2 0.99 0.021213204 1.0 0.9
## lift_top_group 1.9132631 0.2444304 1.7777778 1.8
## logloss 0.28878036 0.26993978 0.0057941065 1.1592535
## max_per_class_error 0.0125 0.026516505 0.0 0.125
## mcc 0.9775 0.04772971 1.0 0.775
## mean_per_class_accuracy 0.98875 0.023864854 1.0 0.8875
## mean_per_class_error 0.01125 0.023864854 0.0 0.1125
## mse 0.050650533 0.038761824 4.8375732E-4 0.10015181
## precision 0.99 0.021213204 1.0 0.9
## r2 0.7952921 0.15612668 0.99803424 0.5943852
## recall 0.99 0.021213204 1.0 0.9
## rmse 0.16316295 0.10960927 0.021994483 0.3164677
## specificity 0.9875 0.026516505 1.0 0.875
## cv_3_valid cv_4_valid cv_5_valid cv_6_valid
## accuracy 1.0 1.0 1.0 1.0
## auc 1.0 1.0 1.0 1.0
## err 0.0 0.0 0.0 0.0
## err_count 0.0 0.0 0.0 0.0
## f0point5 1.0 1.0 1.0 1.0
## f1 1.0 1.0 1.0 1.0
## f2 1.0 1.0 1.0 1.0
## lift_top_group 2.0 2.1666667 2.0 1.375
## logloss 0.76311564 0.51032287 0.14848118 0.013675269
## max_per_class_error 0.0 0.0 0.0 0.0
## mcc 1.0 1.0 1.0 1.0
## mean_per_class_accuracy 1.0 1.0 1.0 1.0
## mean_per_class_error 0.0 0.0 0.0 0.0
## mse 0.124564655 0.14242762 0.05323253 0.0017704639
## precision 1.0 1.0 1.0 1.0
## r2 0.50174135 0.42689836 0.78706986 0.9910739
## recall 1.0 1.0 1.0 1.0
## rmse 0.3529372 0.37739584 0.23072176 0.04207688
## specificity 1.0 1.0 1.0 1.0
## cv_7_valid cv_8_valid cv_9_valid
## accuracy 1.0 1.0 1.0
## auc 1.0 1.0 1.0
## err 0.0 0.0 0.0
## err_count 0.0 0.0 0.0
## f0point5 1.0 1.0 1.0
## f1 1.0 1.0 1.0
## f2 1.0 1.0 1.0
## lift_top_group 2.4285715 2.4 1.3846154
## logloss 0.2870191 2.6050602E-5 9.480391E-5
## max_per_class_error 0.0 0.0 0.0
## mcc 1.0 1.0 1.0
## mean_per_class_accuracy 1.0 1.0 1.0
## mean_per_class_error 0.0 0.0 0.0
## mse 0.08387434 7.7921025E-10 1.3558078E-7
## precision 1.0 1.0 1.0
## r2 0.6537188 1.0 0.99999934
## recall 1.0 1.0 1.0
## rmse 0.28961065 2.7914339E-5 3.6821293E-4
## specificity 1.0 1.0 1.0
## cv_10_valid
## accuracy 1.0
## auc 1.0
## err 0.0
## err_count 0.0
## f0point5 1.0
## f1 1.0
## f2 1.0
## lift_top_group 1.8
## logloss 2.1142827E-5
## max_per_class_error 0.0
## mcc 1.0
## mean_per_class_accuracy 1.0
## mean_per_class_error 0.0
## mse 8.3267787E-10
## precision 1.0
## r2 1.0
## recall 1.0
## rmse 2.8856159E-5
## specificity 1.0
##
## Scoring History:
## timestamp duration training_speed epochs iterations
## 1 2017-04-12 23:44:25 0.000 sec 0.00000 0
## 2 2017-04-12 23:44:26 8 min 48.827 sec 501 obs/sec 1.00000 1
## 3 2017-04-12 23:44:26 8 min 49.138 sec 501 obs/sec 2.00000 2
## 4 2017-04-12 23:44:26 8 min 49.451 sec 496 obs/sec 3.00000 3
## 5 2017-04-12 23:44:27 8 min 49.778 sec 490 obs/sec 4.00000 4
## samples training_rmse training_logloss training_auc training_lift
## 1 0.000000
## 2 136.000000 0.20348 0.18537 0.98670 1.83784
## 3 272.000000 0.17691 0.17185 0.98779 1.83784
## 4 408.000000 0.19761 0.16548 0.98932 1.83784
## 5 544.000000 0.16414 0.13956 0.98997 1.83784
## training_classification_error
## 1
## 2 0.02941
## 3 0.02206
## 4 0.02941
## 5 0.02206
##
## ---
## timestamp duration training_speed epochs
## 445 2017-04-12 23:46:47 11 min 9.858 sec 484 obs/sec 444.00000
## 446 2017-04-12 23:46:47 11 min 10.179 sec 484 obs/sec 445.00000
## 447 2017-04-12 23:46:47 11 min 10.489 sec 484 obs/sec 446.00000
## 448 2017-04-12 23:46:48 11 min 10.799 sec 484 obs/sec 447.00000
## 449 2017-04-12 23:46:48 11 min 11.119 sec 484 obs/sec 448.00000
## 450 2017-04-12 23:46:48 11 min 11.439 sec 484 obs/sec 449.00000
## iterations samples training_rmse training_logloss training_auc
## 445 444 60384.000000 0.13933 0.07535 0.99826
## 446 445 60520.000000 0.11239 0.04095 0.99847
## 447 446 60656.000000 0.12762 0.04726 0.99847
## 448 447 60792.000000 0.11958 0.04317 0.99847
## 449 448 60928.000000 0.12391 0.04558 0.99847
## 450 449 61064.000000 0.11273 0.04028 0.99869
## training_lift training_classification_error
## 445 1.83784 0.01471
## 446 1.83784 0.01471
## 447 1.83784 0.01471
## 448 1.83784 0.01471
## 449 1.83784 0.01471
## 450 1.83784 0.01471
##
## Variable Importances: (Extract with `h2o.varimp`)
## =================================================
##
## Variable Importances:
## variable relative_importance scaled_importance percentage
## 1 Leucocytes 1.000000 1.000000 0.172288
## 2 ALAT 0.736160 0.736160 0.126832
## 3 CRP 0.567268 0.567268 0.097734
## 4 ASAT 0.565372 0.565372 0.097407
## 5 Urea 0.557244 0.557244 0.096007
## 6 ESR 0.519068 0.519068 0.089429
## 7 CreClearance 0.498226 0.498226 0.085839
## 8 Bilirubin 0.493901 0.493901 0.085093
## 9 GammaGT 0.470984 0.470984 0.081145
## 10 Creatinine 0.396006 0.396006 0.068227
Important variable
h2o.varimp_plot(dlmod)
h2o.varimp(dlmod)
## Variable Importances:
## variable relative_importance scaled_importance percentage
## 1 Leucocytes 1.000000 1.000000 0.172288
## 2 ALAT 0.736160 0.736160 0.126832
## 3 CRP 0.567268 0.567268 0.097734
## 4 ASAT 0.565372 0.565372 0.097407
## 5 Urea 0.557244 0.557244 0.096007
## 6 ESR 0.519068 0.519068 0.089429
## 7 CreClearance 0.498226 0.498226 0.085839
## 8 Bilirubin 0.493901 0.493901 0.085093
## 9 GammaGT 0.470984 0.470984 0.081145
## 10 Creatinine 0.396006 0.396006 0.068227
Confusion matrix
h2o.confusionMatrix(dlmod)
## Confusion Matrix (vertical: actual; across: predicted) for max f1 @ threshold = 0.381415907908741:
## Death Survive Error Rate
## Death 61 1 0.016129 =1/62
## Survive 1 73 0.013514 =1/74
## Totals 62 74 0.014706 =2/136
Marginalised partial plots
h2o.partialPlot(dlmod,data=wdata)
##
|
| | 0%
|
|=================================================================| 100%
## [[1]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'GammaGT'
## GammaGT mean_response
## 1 2.000000 0.518582
## 2 122.947368 0.522329
## 3 243.894737 0.527045
## 4 364.842105 0.533433
## 5 485.789474 0.541715
## 6 606.736842 0.550912
## 7 727.684211 0.559677
## 8 848.631579 0.567565
## 9 969.578947 0.574678
## 10 1090.526316 0.581435
## 11 1211.473684 0.588337
## 12 1332.421053 0.595648
## 13 1453.368421 0.603215
## 14 1574.315789 0.610600
## 15 1695.263158 0.617396
## 16 1816.210526 0.623429
## 17 1937.157895 0.628751
## 18 2058.105263 0.633501
## 19 2179.052632 0.637792
## 20 2300.000000 0.641734
##
## [[2]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ASAT'
## ASAT mean_response
## 1 3.000000 0.534792
## 2 107.052632 0.512723
## 3 211.105263 0.494028
## 4 315.157895 0.463911
## 5 419.210526 0.414536
## 6 523.263158 0.313878
## 7 627.315789 0.182829
## 8 731.368421 0.106746
## 9 835.421053 0.041596
## 10 939.473684 0.011937
## 11 1043.526316 0.000261
## 12 1147.578947 0.000097
## 13 1251.631579 0.000049
## 14 1355.684211 0.000037
## 15 1459.736842 0.000032
## 16 1563.789474 0.000029
## 17 1667.842105 0.000027
## 18 1771.894737 0.000026
## 19 1875.947368 0.000026
## 20 1980.000000 0.000025
##
## [[3]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ALAT'
## ALAT mean_response
## 1 2.000000 0.554716
## 2 80.842105 0.514758
## 3 159.684211 0.481577
## 4 238.526316 0.419949
## 5 317.368421 0.309494
## 6 396.210526 0.158624
## 7 475.052632 0.059209
## 8 553.894737 0.014893
## 9 632.736842 0.000562
## 10 711.578947 0.000124
## 11 790.421053 0.000050
## 12 869.263158 0.000037
## 13 948.105263 0.000032
## 14 1026.947368 0.000029
## 15 1105.789474 0.000028
## 16 1184.631579 0.000027
## 17 1263.473684 0.000026
## 18 1342.315789 0.000026
## 19 1421.157895 0.000025
## 20 1500.000000 0.000025
##
## [[4]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Bilirubin'
## Bilirubin mean_response
## 1 1.000000 0.542578
## 2 22.000000 0.528720
## 3 43.000000 0.516002
## 4 64.000000 0.507999
## 5 85.000000 0.496483
## 6 106.000000 0.476670
## 7 127.000000 0.462963
## 8 148.000000 0.431021
## 9 169.000000 0.409299
## 10 190.000000 0.362632
## 11 211.000000 0.287991
## 12 232.000000 0.220285
## 13 253.000000 0.173899
## 14 274.000000 0.136822
## 15 295.000000 0.089550
## 16 316.000000 0.054679
## 17 337.000000 0.037184
## 18 358.000000 0.022033
## 19 379.000000 0.005819
## 20 400.000000 0.000963
##
## [[5]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Urea'
## Urea mean_response
## 1 2.000000 0.504329
## 2 7.052632 0.510908
## 3 12.105263 0.526168
## 4 17.157895 0.545514
## 5 22.210526 0.574097
## 6 27.263158 0.602961
## 7 32.315789 0.621378
## 8 37.368421 0.641056
## 9 42.421053 0.656762
## 10 47.473684 0.670351
## 11 52.526316 0.687223
## 12 57.578947 0.708152
## 13 62.631579 0.731078
## 14 67.684211 0.749708
## 15 72.736842 0.764716
## 16 77.789474 0.776137
## 17 82.842105 0.783861
## 18 87.894737 0.789156
## 19 92.947368 0.792336
## 20 98.000000 0.799021
##
## [[6]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Creatinine'
## Creatinine mean_response
## 1 47.000000 0.531442
## 2 90.052632 0.527676
## 3 133.105263 0.524740
## 4 176.157895 0.522338
## 5 219.210526 0.520213
## 6 262.263158 0.518206
## 7 305.315789 0.516270
## 8 348.368421 0.514378
## 9 391.421053 0.512306
## 10 434.473684 0.509201
## 11 477.526316 0.504059
## 12 520.578947 0.496567
## 13 563.631579 0.488162
## 14 606.684211 0.481104
## 15 649.736842 0.475492
## 16 692.789474 0.468133
## 17 735.842105 0.456342
## 18 778.894737 0.438736
## 19 821.947368 0.424749
## 20 865.000000 0.414454
##
## [[7]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'CreClearance'
## CreClearance mean_response
## 1 -132.000000 0.498861
## 2 -125.263158 0.499751
## 3 -118.526316 0.500432
## 4 -111.789474 0.501157
## 5 -105.052632 0.502177
## 6 -98.315789 0.503744
## 7 -91.578947 0.506112
## 8 -84.842105 0.509487
## 9 -78.105263 0.513964
## 10 -71.368421 0.519487
## 11 -64.631579 0.525911
## 12 -57.894737 0.533115
## 13 -51.157895 0.541039
## 14 -44.421053 0.549601
## 15 -37.684211 0.558600
## 16 -30.947368 0.567752
## 17 -24.210526 0.576779
## 18 -17.473684 0.585486
## 19 -10.736842 0.593791
## 20 -4.000000 0.601701
##
## [[8]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'ESR'
## ESR mean_response
## 1 2.000000 0.547599
## 2 11.368421 0.535137
## 3 20.736842 0.526186
## 4 30.105263 0.520307
## 5 39.473684 0.516673
## 6 48.842105 0.511003
## 7 58.210526 0.501052
## 8 67.578947 0.484122
## 9 76.947368 0.472520
## 10 86.315789 0.455520
## 11 95.684211 0.431987
## 12 105.052632 0.415811
## 13 114.421053 0.383739
## 14 123.789474 0.335590
## 15 133.157895 0.270888
## 16 142.526316 0.228295
## 17 151.894737 0.198115
## 18 161.263158 0.172016
## 19 170.631579 0.135098
## 20 180.000000 0.101062
##
## [[9]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'CRP'
## CRP mean_response
## 1 2.000000 0.507214
## 2 14.684211 0.512134
## 3 27.368421 0.520614
## 4 40.052632 0.536730
## 5 52.736842 0.554463
## 6 65.421053 0.568694
## 7 78.105263 0.587310
## 8 90.789474 0.613578
## 9 103.473684 0.637329
## 10 116.157895 0.653810
## 11 128.842105 0.669877
## 12 141.526316 0.687090
## 13 154.210526 0.703554
## 14 166.894737 0.718720
## 15 179.578947 0.731751
## 16 192.263158 0.741415
## 17 204.947368 0.749021
## 18 217.631579 0.758300
## 19 230.315789 0.770243
## 20 243.000000 0.782689
##
## [[10]]
## PartialDependence: Partial Dependence Plot of model Deep_learning on column 'Leucocytes'
## Leucocytes mean_response
## 1 2.000000 0.850415
## 2 3.473684 0.825656
## 3 4.947368 0.800200
## 4 6.421053 0.762767
## 5 7.894737 0.723762
## 6 9.368421 0.677537
## 7 10.842105 0.591172
## 8 12.315789 0.112328
## 9 13.789474 0.054820
## 10 15.263158 0.033871
## 11 16.736842 0.026521
## 12 18.210526 0.019641
## 13 19.684211 0.015134
## 14 21.157895 0.012435
## 15 22.631579 0.010745
## 16 24.105263 0.009628
## 17 25.578947 0.009168
## 18 27.052632 0.009241
## 19 28.526316 0.008198
## 20 30.000000 0.006028
Evaluating model’s performance on test subset
h2o.performance(dlmod,wtest)
## H2OBinomialMetrics: deeplearning
##
## MSE: 0.062223
## RMSE: 0.2494454
## LogLoss: 0.2237813
## Mean Per-Class Error: 0.06451613
## AUC: 0.9902248
## Gini: 0.9804497
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Death Survive Error Rate
## Death 27 4 0.129032 =4/31
## Survive 0 33 0.000000 =0/33
## Totals 27 37 0.062500 =4/64
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.010546 0.942857 36
## 2 max f2 0.010546 0.976331 36
## 3 max f0point5 0.964964 0.973154 28
## 4 max accuracy 0.964964 0.937500 28
## 5 max precision 0.999987 1.000000 0
## 6 max recall 0.010546 1.000000 36
## 7 max specificity 0.999987 1.000000 0
## 8 max absolute_mcc 0.964964 0.882244 28
## 9 max min_per_class_accuracy 0.071646 0.935484 32
## 10 max mean_per_class_accuracy 0.964964 0.939394 28
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
h2o.performance(dlmod,newdata=wtest)%>%plot()
Calculating all performance’s metric on test subset
h2o.performance(dlmod,newdata=wtest)%>%h2o.metric()
## Metrics for Thresholds: Binomial metrics as a function of classification thresholds
## threshold f1 f2 f0point5 accuracy precision recall
## 1 0.999987 0.058824 0.037594 0.135135 0.500000 1.000000 0.030303
## 2 0.999987 0.114286 0.074627 0.243902 0.515625 1.000000 0.060606
## 3 0.999986 0.166667 0.111111 0.333333 0.531250 1.000000 0.090909
## 4 0.999986 0.216216 0.147059 0.408163 0.546875 1.000000 0.121212
## 5 0.999986 0.263158 0.182482 0.471698 0.562500 1.000000 0.151515
## specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 1 1.000000 0.122111 0.030303 0.515152
## 2 1.000000 0.174078 0.060606 0.530303
## 3 1.000000 0.214941 0.090909 0.545455
## 4 1.000000 0.250252 0.121212 0.560606
## 5 1.000000 0.282152 0.151515 0.575758
## tns fns fps tps tnr fnr fpr tpr idx
## 1 31 32 0 1 1.000000 0.969697 0.000000 0.030303 0
## 2 31 31 0 2 1.000000 0.939394 0.000000 0.060606 1
## 3 31 30 0 3 1.000000 0.909091 0.000000 0.090909 2
## 4 31 29 0 4 1.000000 0.878788 0.000000 0.121212 3
## 5 31 28 0 5 1.000000 0.848485 0.000000 0.151515 4
##
## ---
## threshold f1 f2 f0point5 accuracy precision recall
## 59 0.000027 0.717391 0.863874 0.613383 0.593750 0.559322 1.000000
## 60 0.000026 0.709677 0.859375 0.604396 0.578125 0.550000 1.000000
## 61 0.000025 0.702128 0.854922 0.595668 0.562500 0.540984 1.000000
## 62 0.000025 0.694737 0.850515 0.587189 0.546875 0.532258 1.000000
## 63 0.000024 0.687500 0.846154 0.578947 0.531250 0.523810 1.000000
## 64 0.000024 0.680412 0.841837 0.570934 0.515625 0.515625 1.000000
## specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 59 0.161290 0.300355 0.161290 0.580645
## 60 0.129032 0.266398 0.129032 0.564516
## 61 0.096774 0.228808 0.096774 0.548387
## 62 0.064516 0.185308 0.064516 0.532258
## 63 0.032258 0.129989 0.032258 0.516129
## 64 0.000000 0.000000 0.000000 0.500000
## tns fns fps tps tnr fnr fpr tpr idx
## 59 5 0 26 33 0.161290 0.000000 0.838710 1.000000 58
## 60 4 0 27 33 0.129032 0.000000 0.870968 1.000000 59
## 61 3 0 28 33 0.096774 0.000000 0.903226 1.000000 60
## 62 2 0 29 33 0.064516 0.000000 0.935484 1.000000 61
## 63 1 0 30 33 0.032258 0.000000 0.967742 1.000000 62
## 64 0 0 31 33 0.000000 0.000000 1.000000 1.000000 63
Exploring the training evolution
plot(dlmod,
timestep = "epochs",
metric = "classification_error")
plot(dlmod,
timestep = "epochs",
metric = "auc")
plot(dlmod,
timestep = "epochs",
metric = "logloss")
plot(dlmod,
timestep = "epochs",
metric = "rmse")
Exploring the h2o’s model in mlr package
library(mlr)
train=as.data.frame(wtrain)
test=as.data.frame(wtest)
taskSepsis=mlr::makeClassifTask(id="Sepsis",data=train[,-1],target="Outcome",positive = "Death")
tasktest=mlr::makeClassifTask(id="Sepsis",data=test[,-1],target="Outcome",positive = "Death")
mets=list(auc,bac,tpr,tnr,mmce,ber,fpr,fnr,ppv,npv)
# classif.h2o.deeplearning
learnerDL = makeLearner(id="DL","classif.h2o.deeplearning", predict.type = "prob")
mlrDL=mlr::train(learnerDL,taskSepsis)
##
|
| | 0%
|
|=================================================================| 100%
##
|
| | 0%
|
|==================================================== | 80%
|
|=================================================================| 100%
mlrDL$learner.model=dlmod
predDL=predict(mlrDL,tasktest,measures=mets)
##
|
| | 0%
|
|=================================================================| 100%
##
|
| | 0%
|
|=================================================================| 100%
pdf=predDL%>%performance(.,mets)%>%as_tibble()%>%mutate(.,Metric=row.names(.))
cbind(pdf$Metric,pdf$value)
## [,1] [,2]
## [1,] "auc" "0.990224828934506"
## [2,] "bac" "0.923264907135875"
## [3,] "tpr" "0.967741935483871"
## [4,] "tnr" "0.878787878787879"
## [5,] "mmce" "0.078125"
## [6,] "ber" "0.0767350928641251"
## [7,] "fpr" "0.121212121212121"
## [8,] "fnr" "0.032258064516129"
## [9,] "ppv" "0.882352941176471"
## [10,] "npv" "0.966666666666667"
Exploring h2o’s model using caret package
library(caret)
confusionMatrix(predDL$data$response, reference=predDL$data$truth,positive ="Death",mode="everything")
## Confusion Matrix and Statistics
##
## Reference
## Prediction Death Survive
## Death 30 4
## Survive 1 29
##
## Accuracy : 0.9219
## 95% CI : (0.827, 0.9741)
## No Information Rate : 0.5156
## P-Value [Acc > NIR] : 2.373e-12
##
## Kappa : 0.8441
## Mcnemar's Test P-Value : 0.3711
##
## Sensitivity : 0.9677
## Specificity : 0.8788
## Pos Pred Value : 0.8824
## Neg Pred Value : 0.9667
## Precision : 0.8824
## Recall : 0.9677
## F1 : 0.9231
## Prevalence : 0.4844
## Detection Rate : 0.4688
## Detection Prevalence : 0.5312
## Balanced Accuracy : 0.9233
##
## 'Positive' Class : Death
##
Accuracy of h2o’s Deep learning model on pooled data and test subset
pdf=predict(dlmod,wdata)%>%as_tibble()%>%mutate(Accuracy=ifelse(data$Outcome ==.$predict, "yes", "no"))
##
|
| | 0%
|
|=================================================================| 100%
ggplot(data=pdf,aes(x=predict,fill=Accuracy))+stat_count()+scale_fill_manual(values=c("#ff003f","#0094ff"))+coord_flip()+ggtitle("Original sample")
pdf=predict(dlmod,wtest)%>%as_tibble()%>%mutate(Accuracy=ifelse(test$Outcome ==.$predict, "yes", "no"))
##
|
| | 0%
|
|=================================================================| 100%
ggplot(data=pdf,aes(x=predict,fill=Accuracy))+stat_count()+scale_fill_manual(values=c("#ff003f","#0094ff"))+coord_flip()+ggtitle("Test subset")
Conclusion
Through this case study X4, we have successfully trained a Deep learning model in h2o framework. The good news is that training a Deep Learning algorithm in h2o is easiser than we though! Despite that Deep neural nets is the most complicated algorithm in h2o with a large number of parameters, the authors of h2o’s function did a good job to keep everything simplest as they could be (You can get an impressive result without any tuning). The h2o framework could also be brought up within mlr package, but it’s not recommended since the stability of mlr functions cannot be assured. Once trained, the h2o’s model can be explored in different ways, even by using caret’s confusion matrix and mlr’s metrics. The h2o package provided enough functions for the validation and interpretation of Deep learning model.
We will return to Deep learning on the next tutorial, I will show you how to tune its parameters and how to deal with unbalanced data. See you soon and thank joining us.
END