Drug discovery and development is considered to be a tedious process, often consuming a significant amount of time and resources. On an average, it costs USD 2.5 billion to bring a new drug to the market. For big pharmaceutical companies, this average is around USD 4 billion and has been shown to go as high as USD 11 billion. Designing a new drug that binds to any specific target requires a large amount of time, as well as computing power. In many ways, deep learning algorithms are being developed to accelerate this process. It is anticipated that digital solutions for drug discovery may save significant time and capital.
Many proteins in our body are enzymes, which means that they speed up the rate of a chemical reaction without being irreversibly changed themselves. Others are part of signaling pathways that control highly specific cell adaptations and reactions. We can imagine these proteins working via a lock-and-key principle: One or multiple smaller molecules (substrates, ligands) fit snugly into a “hole” (binding pocket, active site) of a protein, thereby facilitating a subtle structure change, which in turn can lead to a domino chain of reactions.
Note: No matter how amazing Deep Learning techniques can be, they will never replace clinical trials. Clinical trials are the precondition for FDA approval
2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science
In this case study, we would like to introduce a new algorithm called “Deep learning”, supported by “h2o”, a new framework for machine learning.
The dataset from Chapter 13 - Machine Learning in Medicine-Cookbook Two by T. J. Cleophas and A. H. Zwinderman, Springer, 2014 will be reused. It contains the laboratory data and survival outcome of 200 patients hospitalized for sepsis. The main objective is to estimate the mortality risk of these patients in terms of 10 markers: gamma-glutamyl transferase (gammagt, U/l),aspartate aminotransferase and Alanine transaminase (U/l), bilirubine (μmol/l), urea (mmol/l), creatinine (μmol/l), creatinine clearance (ml/min), erythrocyte sedimentation rate (ESR, mm), c-reactive protein (CRP, mg/l) and leucocyte count (×10^9/l).
This binary classification problem was previously treated with SVM, XGBT and Random Forest, now we will try the Deep learning approach, which is also considered as one of the most powerful algorithms for machine learning.
The original dataset was provided in Chaper 13, Machine Learning in Medicine-Cookbook Three (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in csv format svm.csv here
Let’s get down to business
First load the tidyverse library then read in the dataset
library(tidyverse)
## ── Attaching packages ──────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 3.0.0 ✔ purrr 0.2.5
## ✔ tibble 1.4.2 ✔ dplyr 0.7.6
## ✔ tidyr 0.8.1 ✔ stringr 1.3.1
## ✔ readr 1.1.1 ✔ forcats 0.3.0
## ── Conflicts ─────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
my_theme <- function(base_size = 10, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size = 10),
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size = 12),
panel.grid.major = element_line(color = "grey"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#faefff"),
strip.background = element_rect(fill = "#400156", color = "#400156", size =0.5),
strip.text = element_text(face = "bold", size = 10, color = "white"),
legend.position = "bottom",
legend.justification = "center",
legend.background = element_blank(),
panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
)
}
theme_set(my_theme())
data<-read.csv("~/Downloads/978-3-319-12162-8/cookbook3extras/svm.csv",sep=";")%>%as_tibble()%>%dplyr::rename(.,Death=VAR00001,GammaGT=VAR00002,ASAT=VAR00003,ALAT=VAR00004,Bilirubin=VAR00005,Urea=VAR00006,Creatinine=VAR00007,CreClearance=VAR00008,ESR=VAR00009,CRP=VAR00010,Leucocytes=VAR00011)
data$Outcome=recode_factor(data$Death,`0` = "Survive", `1` = "Death")
As the h2o framework would be used through out this study, you must install this package in R
library(h2o)
##
## ----------------------------------------------------------------------
##
## Your next step is to start H2O:
## > h2o.init()
##
## For H2O package documentation, ask for help:
## > ??h2o
##
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit http://docs.h2o.ai
##
## ----------------------------------------------------------------------
##
## Attaching package: 'h2o'
## The following objects are masked from 'package:stats':
##
## cor, sd, var
## The following objects are masked from 'package:base':
##
## &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
## colnames<-, ifelse, is.character, is.factor, is.numeric, log,
## log10, log1p, log2, round, signif, trunc
We must activate it by h2o.init( ) function. By defaut, h2o will only use two cores on our PC and 25% of system memory, to unleash the full power of our system, we will add an argument of nthreads=-1 to init function as follows
h2o.init(nthreads = -1)
## Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 3 days 22 hours
## H2O cluster timezone: Africa/Lagos
## H2O data parsing timezone: UTC
## H2O cluster version: 3.20.0.9
## H2O cluster version age: 18 days
## H2O cluster name: H2O_started_from_R_Cartwheel_irq804
## H2O cluster total nodes: 1
## H2O cluster total memory: 1.71 GB
## H2O cluster total cores: 4
## H2O cluster allowed cores: 4
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4
## R Version: R version 3.5.1 (2018-07-02)
Importing data frame to h2o
There is only one inconvenience when using h2o, that this package sets its own environment, class and functions. For example, h2o does not accept neither dataframe nor tibble, we must convert our original R dataframe into h2o frame using as.h2o( ) function. Once converted, we can no more manipulate the h2o frame by tidyverse’s function (some basic functions still work), so it’s recommended that data manipulation and cleaning should be carefully done before converting into h2o frame.
wdata=as.h2o(data)
##
|
| | 0%
|
|=================================================================| 100%
The machine learning experiment often requires splitting the original data into training, validation and testing subsets. We can easily do this in h2o using the h2o.split() function In this example, we want to split our dataset into 2 parts: 80% for training and 20% for testing.
splits=h2o.splitFrame(wdata, ratios=0.7,seed=123)
wtrain=splits[[1]]
wtest=splits[[2]]
Load in the ggplot2 and magrittr libraries for visualizations
library(ggplot2)
library(magrittr)
##
## Attaching package: 'magrittr'
## The following object is masked from 'package:purrr':
##
## set_names
## The following object is masked from 'package:tidyr':
##
## extract
Unlike the caret package, h2o does not handle well the balance of target’s class among the splitted subsets. Fortunately that our original dataset is well balanced.
p1=wtrain%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Train_set")
p2=wtest%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Test_set")
p3=wdata%>%as.data.frame()%>%ggplot(aes(x=Outcome,fill=Outcome))+stat_count(show.legend=F)+scale_fill_brewer(palette = "Set1")+ggtitle("Pooled_data")
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
grid.arrange(p1,p2,p3,ncol=3)
Figure: Structure of the Deep neural net
response="Outcome"
features=setdiff(colnames(wtrain),c(response,"Death"))
dlmod=h2o.deeplearning (x = features,
y = response,
model_id = "Deep_learning",
training_frame = wtrain,
nfolds = 10,
hidden = c(200,200,200,200),
stopping_metric = "misclassification",
replicate_training_data = TRUE,
stopping_tolerance = 0.001,
stopping_rounds = 5,
overwrite_with_best_model=TRUE,
fold_assignment = "Stratified",
epochs=1000,
activation = "TanhWithDropout",
keep_cross_validation_fold_assignment = TRUE,
keep_cross_validation_predictions=FALSE,
score_each_iteration = TRUE,
variable_importances = TRUE,
reproducible = TRUE,seed=123)
##
|
| | 0%
|
| | 1%
|
|====== | 9%
|
|================ | 25%
|
|=========================================================== | 91%
|
|=================================================================| 100%
summary(dlmod)
## Model Details:
## ==============
##
## H2OBinomialModel: deeplearning
## Model Key: Deep_learning
## Status of Neuron Layers: predicting Outcome, 2-class classification, bernoulli distribution, CrossEntropy loss, 137,802 weights/biases, 1.6 MB, 1,360 training samples, mini-batch size 1
## layer units type dropout l1 l2 mean_rate rate_rms
## 1 1 83 Input 0.00 % NA NA NA NA
## 2 2 200 TanhDropout 50.00 % 0.000000 0.000000 0.215939 0.386003
## 3 3 200 TanhDropout 50.00 % 0.000000 0.000000 0.012185 0.012896
## 4 4 200 TanhDropout 50.00 % 0.000000 0.000000 0.022633 0.023472
## 5 5 200 TanhDropout 50.00 % 0.000000 0.000000 0.048465 0.099613
## 6 6 2 Softmax NA 0.000000 0.000000 0.010809 0.005068
## momentum mean_weight weight_rms mean_bias bias_rms
## 1 NA NA NA NA NA
## 2 0.000000 0.001250 0.087496 0.003416 0.048140
## 3 0.000000 0.000424 0.072048 -0.001930 0.036842
## 4 0.000000 -0.000553 0.070477 -0.001117 0.026180
## 5 0.000000 -0.000085 0.069523 0.001194 0.014636
## 6 0.000000 -0.036175 0.400344 0.000000 0.006223
##
## H2OBinomialMetrics: deeplearning
## ** Reported on training data. **
## ** Metrics reported on full training frame **
##
## MSE: 0.0139467
## RMSE: 0.1180961
## LogLoss: 0.06572236
## Mean Per-Class Error: 0.006756757
## AUC: 0.9971665
## Gini: 0.994333
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Death Survive Error Rate
## Death 62 0 0.000000 =0/62
## Survive 1 73 0.013514 =1/74
## Totals 63 73 0.007353 =1/136
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.302635 0.993197 72
## 2 max f2 0.302635 0.989160 72
## 3 max f0point5 0.302635 0.997268 72
## 4 max accuracy 0.302635 0.992647 72
## 5 max precision 0.999925 1.000000 0
## 6 max recall 0.003727 1.000000 86
## 7 max specificity 0.999925 1.000000 0
## 8 max absolute_mcc 0.302635 0.985306 72
## 9 max min_per_class_accuracy 0.302635 0.986486 72
## 10 max mean_per_class_accuracy 0.302635 0.993243 72
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
##
## H2OBinomialMetrics: deeplearning
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 0.06027491
## RMSE: 0.2455095
## LogLoss: 0.3184665
## Mean Per-Class Error: 0.03509154
## AUC: 0.9797297
## Gini: 0.9594595
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## Death Survive Error Rate
## Death 61 1 0.016129 =1/62
## Survive 4 70 0.054054 =4/74
## Totals 65 71 0.036765 =5/136
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.953310 0.965517 70
## 2 max f2 0.009654 0.958005 84
## 3 max f0point5 0.968119 0.985714 68
## 4 max accuracy 0.968119 0.963235 68
## 5 max precision 0.999992 1.000000 0
## 6 max recall 0.000000 1.000000 133
## 7 max specificity 0.999992 1.000000 0
## 8 max absolute_mcc 0.968119 0.928896 68
## 9 max min_per_class_accuracy 0.953310 0.945946 70
## 10 max mean_per_class_accuracy 0.968119 0.966216 68
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## Cross-Validation Metrics Summary:
## mean sd cv_1_valid cv_2_valid
## accuracy 0.98888886 0.023570227 1.0 0.8888889
## auc 0.99125 0.018561553 1.0 0.9125
## err 0.011111111 0.023570227 0.0 0.11111111
## err_count 0.2 0.42426407 0.0 2.0
## f0point5 0.99 0.021213204 1.0 0.9
## f1 0.99 0.021213204 1.0 0.9
## f2 0.99 0.021213204 1.0 0.9
## lift_top_group 1.9132631 0.2444304 1.7777778 1.8
## logloss 0.27004936 0.276661 0.08845539 1.3904209
## max_per_class_error 0.0125 0.026516505 0.0 0.125
## mcc 0.9775 0.04772971 1.0 0.775
## mean_per_class_accuracy 0.98875 0.023864854 1.0 0.8875
## mean_per_class_error 0.01125 0.023864854 0.0 0.1125
## mse 0.055812407 0.031374566 0.033743847 0.112540394
## precision 0.99 0.021213204 1.0 0.9
## r2 0.77009857 0.12658584 0.86288214 0.5442114
## recall 0.99 0.021213204 1.0 0.9
## rmse 0.20035236 0.08851933 0.18369497 0.3354704
## specificity 0.9875 0.026516505 1.0 0.875
## cv_3_valid cv_4_valid cv_5_valid cv_6_valid
## accuracy 1.0 1.0 1.0 1.0
## auc 1.0 1.0 1.0 1.0
## err 0.0 0.0 0.0 0.0
## err_count 0.0 0.0 0.0 0.0
## f0point5 1.0 1.0 1.0 1.0
## f1 1.0 1.0 1.0 1.0
## f2 1.0 1.0 1.0 1.0
## lift_top_group 2.0 2.1666667 2.0 1.375
## logloss 0.19018012 0.35782763 0.20756435 0.0041301767
## max_per_class_error 0.0 0.0 0.0 0.0
## mcc 1.0 1.0 1.0 1.0
## mean_per_class_accuracy 1.0 1.0 1.0 1.0
## mean_per_class_error 0.0 0.0 0.0 0.0
## mse 0.07565593 0.124409355 0.06327105 4.0317103E-5
## precision 1.0 1.0 1.0 1.0
## r2 0.6973763 0.49940044 0.7469158 0.99979675
## recall 1.0 1.0 1.0 1.0
## rmse 0.2750562 0.3527171 0.25153738 0.0063495752
## specificity 1.0 1.0 1.0 1.0
## cv_7_valid cv_8_valid cv_9_valid cv_10_valid
## accuracy 1.0 1.0 1.0 1.0
## auc 1.0 1.0 1.0 1.0
## err 0.0 0.0 0.0 0.0
## err_count 0.0 0.0 0.0 0.0
## f0point5 1.0 1.0 1.0 1.0
## f1 1.0 1.0 1.0 1.0
## f2 1.0 1.0 1.0 1.0
## lift_top_group 2.4285715 2.4 1.3846154 1.8
## logloss 0.30916795 3.7249018E-4 0.12359503 0.028779574
## max_per_class_error 0.0 0.0 0.0 0.0
## mcc 1.0 1.0 1.0 1.0
## mean_per_class_accuracy 1.0 1.0 1.0 1.0
## mean_per_class_error 0.0 0.0 0.0 0.0
## mse 0.09960356 3.2696607E-7 0.043328166 0.0055311285
## precision 1.0 1.0 1.0 1.0
## r2 0.58877957 0.9999986 0.7840257 0.9775989
## recall 1.0 1.0 1.0 1.0
## rmse 0.3156003 5.718095E-4 0.20815419 0.074371554
## specificity 1.0 1.0 1.0 1.0
##
## Scoring History:
## timestamp duration training_speed epochs iterations
## 1 2018-10-20 08:55:45 0.000 sec NA 0.00000 0
## 2 2018-10-20 08:55:45 15.877 sec 490 obs/sec 0.80473 1
## 3 2018-10-20 08:55:45 16.207 sec 493 obs/sec 1.60947 2
## 4 2018-10-20 08:55:46 16.544 sec 495 obs/sec 2.41420 3
## 5 2018-10-20 08:55:46 16.975 sec 455 obs/sec 3.21893 4
## 6 2018-10-20 08:55:46 17.335 sec 455 obs/sec 4.02367 5
## 7 2018-10-20 08:55:47 17.690 sec 457 obs/sec 4.82840 6
## 8 2018-10-20 08:55:47 18.087 sec 447 obs/sec 5.63314 7
## 9 2018-10-20 08:55:47 18.403 sec 455 obs/sec 6.43787 8
## 10 2018-10-20 08:55:48 18.818 sec 446 obs/sec 7.24260 9
## 11 2018-10-20 08:55:48 19.238 sec 439 obs/sec 8.04734 10
## samples training_rmse training_logloss training_r2 training_auc
## 1 0.000000 NA NA NA NA
## 2 136.000000 0.17388 0.15020 0.87811 0.98692
## 3 272.000000 0.15931 0.12971 0.89768 0.98867
## 4 408.000000 0.17037 0.12892 0.88298 0.98976
## 5 544.000000 0.16161 0.13194 0.89471 0.99019
## 6 680.000000 0.16839 0.11882 0.88569 0.99085
## 7 816.000000 0.14135 0.09990 0.91945 0.99368
## 8 952.000000 0.13632 0.08554 0.92508 0.99564
## 9 1088.000000 0.13497 0.08184 0.92656 0.99629
## 10 1224.000000 0.15326 0.08077 0.90531 0.99695
## 11 1360.000000 0.11810 0.06572 0.94378 0.99717
## training_lift training_classification_error
## 1 NA NA
## 2 1.83784 0.02206
## 3 1.83784 0.02206
## 4 1.83784 0.02206
## 5 1.83784 0.02206
## 6 1.83784 0.02206
## 7 1.83784 0.02206
## 8 1.83784 0.02206
## 9 1.83784 0.01471
## 10 1.83784 0.00735
## 11 1.83784 0.00735
##
## Variable Importances: (Extract with `h2o.varimp`)
## =================================================
##
## Variable Importances:
## variable relative_importance scaled_importance percentage
## 1 Leucocytes 1.000000 1.000000 0.016467
## 2 ALAT 0.814358 0.814358 0.013410
## 3 Urea.46 0.813982 0.813982 0.013404
## 4 CRP 0.802077 0.802077 0.013208
## 5 ASAT 0.798128 0.798128 0.013143
##
## ---
## variable relative_importance scaled_importance percentage
## 78 Urea.5,3 0.693156 0.693156 0.011414
## 79 Creatinine 0.691135 0.691135 0.011381
## 80 Urea.29 0.686716 0.686716 0.011308
## 81 Urea.8 0.685664 0.685664 0.011291
## 82 Urea.4 0.675419 0.675419 0.011122
## 83 Urea.missing(NA) 0.000000 0.000000 0.000000
h2o.varimp_plot(dlmod)
The visualization tells us that the presence Leucocytes(White blood cells) more than any other biomarker decreases the mortality risk of patients.
h2o.varimp(dlmod)
## Variable Importances:
## variable relative_importance scaled_importance percentage
## 1 Leucocytes 1.000000 1.000000 0.016467
## 2 ALAT 0.814358 0.814358 0.013410
## 3 Urea.46 0.813982 0.813982 0.013404
## 4 CRP 0.802077 0.802077 0.013208
## 5 ASAT 0.798128 0.798128 0.013143
##
## ---
## variable relative_importance scaled_importance percentage
## 78 Urea.5,3 0.693156 0.693156 0.011414
## 79 Creatinine 0.691135 0.691135 0.011381
## 80 Urea.29 0.686716 0.686716 0.011308
## 81 Urea.8 0.685664 0.685664 0.011291
## 82 Urea.4 0.675419 0.675419 0.011122
## 83 Urea.missing(NA) 0.000000 0.000000 0.000000
h2o.confusionMatrix(dlmod)
## Confusion Matrix (vertical: actual; across: predicted) for max f1 @ threshold = 0.302635476499049:
## Death Survive Error Rate
## Death 62 0 0.000000 =0/62
## Survive 1 73 0.013514 =1/74
## Totals 63 73 0.007353 =1/136
h2o.performance(dlmod,newdata=wtest)%>%plot()
h2o.performance(dlmod,newdata=wtest)%>%h2o.metric()
## Metrics for Thresholds: Binomial metrics as a function of classification thresholds
## threshold f1 f2 f0point5 accuracy precision recall
## 1 0.999971 0.058824 0.037594 0.135135 0.500000 1.000000 0.030303
## 2 0.999786 0.114286 0.074627 0.243902 0.515625 1.000000 0.060606
## 3 0.999708 0.166667 0.111111 0.333333 0.531250 1.000000 0.090909
## 4 0.999652 0.216216 0.147059 0.408163 0.546875 1.000000 0.121212
## 5 0.999569 0.263158 0.182482 0.471698 0.562500 1.000000 0.151515
## specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 1 1.000000 0.122111 0.030303 0.515152
## 2 1.000000 0.174078 0.060606 0.530303
## 3 1.000000 0.214941 0.090909 0.545455
## 4 1.000000 0.250252 0.121212 0.560606
## 5 1.000000 0.282152 0.151515 0.575758
## tns fns fps tps tnr fnr fpr tpr idx
## 1 31 32 0 1 1.000000 0.969697 0.000000 0.030303 0
## 2 31 31 0 2 1.000000 0.939394 0.000000 0.060606 1
## 3 31 30 0 3 1.000000 0.909091 0.000000 0.090909 2
## 4 31 29 0 4 1.000000 0.878788 0.000000 0.121212 3
## 5 31 28 0 5 1.000000 0.848485 0.000000 0.151515 4
##
## ---
## threshold f1 f2 f0point5 accuracy precision recall
## 59 0.000000 0.717391 0.863874 0.613383 0.593750 0.559322 1.000000
## 60 0.000000 0.709677 0.859375 0.604396 0.578125 0.550000 1.000000
## 61 0.000000 0.702128 0.854922 0.595668 0.562500 0.540984 1.000000
## 62 0.000000 0.694737 0.850515 0.587189 0.546875 0.532258 1.000000
## 63 0.000000 0.687500 0.846154 0.578947 0.531250 0.523810 1.000000
## 64 0.000000 0.680412 0.841837 0.570934 0.515625 0.515625 1.000000
## specificity absolute_mcc min_per_class_accuracy mean_per_class_accuracy
## 59 0.161290 0.300355 0.161290 0.580645
## 60 0.129032 0.266398 0.129032 0.564516
## 61 0.096774 0.228808 0.096774 0.548387
## 62 0.064516 0.185308 0.064516 0.532258
## 63 0.032258 0.129989 0.032258 0.516129
## 64 0.000000 0.000000 0.000000 0.500000
## tns fns fps tps tnr fnr fpr tpr idx
## 59 5 0 26 33 0.161290 0.000000 0.838710 1.000000 58
## 60 4 0 27 33 0.129032 0.000000 0.870968 1.000000 59
## 61 3 0 28 33 0.096774 0.000000 0.903226 1.000000 60
## 62 2 0 29 33 0.064516 0.000000 0.935484 1.000000 61
## 63 1 0 30 33 0.032258 0.000000 0.967742 1.000000 62
## 64 0 0 31 33 0.000000 0.000000 1.000000 1.000000 63
Our model got an accuracy of 99% on the training dataset and 55% on the test data. As the model is feed in more data it will contibue improve its accuracy