Foreword: About the Machine Learning in Medicine (MLM) project
The MLM project has been initialized in 2016 and aims to:
Encourage using Machine Learning techniques in medical research in Vietnam and
Promote the use of R statistical programming language, an open source and leading tool for practicing data science.
Introduction
Though the decision tree (CART) and its ensemble forms such as Random Forest, GBM are frequently used for classification purposes, these algorithms could also be applied to regression tasks. In this case study X12, we introduce a particular type of Tree based model which is prediction tree.
We are all familiar with the predictive models that based on linear regression. The linear regression applies a single and continuous rule over the entire data-space. This approach works well as long as each predictor has an additive effect on the outcome and there are not too much interaction effects among predictors. The capacity of linear model might also be extended by implementing polynomial function, in order to fit the curvilinear relationship between response variable and predictor. However, linear model might not fit well data with complicated interactions among the variables.
Regression tree is a kind of non-linear regression model. It divides the original data space into small partitions in which the interaction is easier to be fitted. The partitioning process continues until we reach a data space that could be described by a simple model. The recursive partition makes the model look like a tree. Each node on the tree represents a small data space corresponding to a simple model. The tree model could be interpreted by asking a sequence of yes/no questions, each one on a single feature. From the root node, answering one question will lead us to the next one, until we reach a terminal node or leave and find out the simple model attached to that leaf. In classic regression tree, this simple model consists of a constant estimate of response variable (a single predicted value). However, advanced regression trees implied a more flexible estimation by averaging the predicted value on an interval. Like classification tree, the regression tree could also be upgraded by Bagging and Boosting techniques. By boosting regression tree, the model becomes a blackbox but its accuracy is improved.
Objective
The present study aims to explore the Regression GBM (gradient boosting machine) algorithm in h2o. More than that, we will perform a quantile regression using 3 GBM models, each one for predicting a percentile of the response variable.
Materials and method
First, we will prepare the ggplot theme for our experiment
library(tidyverse)
my_theme <- function(base_size = 10, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size = 10),
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size = 12),
panel.grid.major = element_line(color = "grey"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#fffcfc"),
strip.background = element_rect(fill = "#820000", color = "#820000", size =0.5),
strip.text = element_text(face = "bold", size = 10, color = "white"),
legend.position = "bottom",
legend.justification = "center",
legend.background = element_blank(),
panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
)
}
theme_set(my_theme())
mycolors=c("#f32440","#2185ef","#d421ef")
Our study question is how to approximate a subject’s height my measuring his/her arms length. Our data are taken from the well known NHANES III dataset:
require(rms)
getHdata(nhgh)
data=nhgh%>%as_tibble()%>%na.omit()%>%.[,c("sex","ht","arml")]
ht=data$ht
arml=data$arml
sex=as.factor(data$sex)
data=as.data.frame(cbind(sex,ht,arml))%>%as_tibble()
data$sex%<>%as.factor()%>%recode_factor(.,`1` = "Male", `2` = "Female")
Data visualising
data%>%ggplot(aes(x=arml,y=ht,color=sex))+geom_point(alpha=0.1)+geom_smooth(se=F)+scale_colour_manual(values=mycolors)
data%>%ggplot(aes(x=arml,y=ht))+geom_boxplot(mapping = aes(group = cut_width(arml,1),fill=cut_width(arml,1)),show.legend = F)+geom_smooth(se=F,aes(color=sex),show.legend = F)+scale_colour_manual(values=mycolors)
data%>%ggplot(aes(x=ht,fill=sex))+geom_density(alpha=0.5)+scale_fill_manual(values=mycolors)
Then we will split the origin dataset into training and testing subsets:
library(caret)
set.seed(123)
idTrain=caret::createDataPartition(y=data$ht,p=0.6,list=FALSE)
trainset=data[idTrain,]
remain=data[-idTrain,]
idTest=caret::createDataPartition(y=remain$ht,p=0.5,list=FALSE)
testset=remain[-idTest,]
validset=remain[idTest,]
rm(idTrain,idTest,remain)
p1=data%>%ggplot(aes(x=arml,y=ht))+geom_point(alpha=0.05,color="red")+geom_smooth(se=F,color="red4")+ggtitle("Origin")
p2=trainset%>%ggplot(aes(x=arml,y=ht))+geom_point(alpha=0.05,color="blue")+geom_smooth(se=F,color="blue4")+ggtitle("Train")
p3=testset%>%ggplot(aes(x=arml,y=ht))+geom_point(alpha=0.1,color="green")+geom_smooth(se=F,color="darkgreen")+ggtitle("Test")
p4=validset%>%ggplot(aes(x=arml,y=ht))+geom_point(alpha=0.1,color="purple")+geom_smooth(se=F,color="purple")+ggtitle("Validation")
gridExtra::grid.arrange(p1,p2,p3,p4,ncol=2)
First, we will train a simple regression tree using CART algorithm in caret
Control=caret::trainControl(method= "repeatedcv",number=10,repeats=10)
library(partykit)
library(party)
set.seed(123)
cart=caret::train(ht~arml+sex,data=trainset,method="rpart2",trControl=Control)
cart
## CART
##
## 3207 samples
## 2 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 2886, 2887, 2886, 2887, 2886, 2887, ...
## Resampling results across tuning parameters:
##
## maxdepth RMSE Rsquared
## 1 7.531822 0.4575832
## 2 7.001082 0.5315599
## 3 5.889646 0.6685904
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was maxdepth = 3.
cart$finalModel
## n= 3207
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 3207 334560.500 167.2659
## 2) arml< 36.75 1620 86454.170 160.3767
## 4) arml< 34.15 587 22726.910 155.4446
## 8) arml< 32.45 167 5204.512 151.2186 *
## 9) arml>=32.45 420 13353.910 157.1250 *
## 5) arml>=34.15 1033 41334.090 163.1794
## 10) sexFemale>=0.5 720 22427.950 161.2990 *
## 11) sexFemale< 0.5 313 10504.440 167.5048 *
## 3) arml>=36.75 1587 92736.960 174.2982
## 6) arml< 39.75 1110 46454.400 171.7198
## 12) sexFemale>=0.5 328 9734.196 166.7384 *
## 13) sexFemale< 0.5 782 25167.230 173.8092 *
## 7) arml>=39.75 477 21730.520 180.2983 *
Then we initialise h2o:
library(h2o)
h2o.init(nthreads = -1,max_mem_size ="4g")
## Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 19 minutes 37 seconds
## H2O cluster version: 3.10.3.6
## H2O cluster version age: 2 months and 20 days
## H2O cluster name: H2O_started_from_R_Admin_nyq462
## H2O cluster total nodes: 1
## H2O cluster total memory: 3.08 GB
## H2O cluster total cores: 4
## H2O cluster allowed cores: 4
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## R Version: R version 3.3.1 (2016-06-21)
wdata=as.h2o(data)
##
|
| | 0%
|
|=================================================================| 100%
wtrain=as.h2o(trainset)
##
|
| | 0%
|
|=================================================================| 100%
wtest=as.h2o(testset)
##
|
| | 0%
|
|=================================================================| 100%
wvalid=as.h2o(validset)
##
|
| | 0%
|
|=================================================================| 100%
TRAINING 3 GBM MODELS FOR Median, 2.5th and 97.5th percentiles
First we train the baisc learner:
Note: A full manual for GBM algorithm in h2o could be found here:
http://docs.h2o.ai/h2o/latest-stable/h2o-docs/booklets/GBMBooklet.pdf
response="ht"
features=setdiff(colnames(wdata),c(response,"sex"))
fit.gbm50=h2o.gbm(x = features, y= response,
training_frame = wtrain,
validation_frame = wvalid,
nfolds = 10,
keep_cross_validation_predictions = FALSE,
keep_cross_validation_fold_assignment = FALSE,
score_each_iteration = TRUE,
fold_assignment = "Random",
ntrees = 100, max_depth = 50, min_rows = 1,nbins = 10,
stopping_rounds = 5,
stopping_metric = "RMSE", stopping_tolerance = 0.0001,
seed = 123,
learn_rate = 0.3,
learn_rate_annealing = 1,
distribution = "quantile",quantile_alpha =0.5,
huber_alpha = 0.9,
sample_rate = 0.7,
col_sample_rate = 1,
col_sample_rate_change_per_level = 1,
col_sample_rate_per_tree = 1,
min_split_improvement = 1e-05,
histogram_type = "Random",
max_abs_leafnode_pred = Inf, pred_noise_bandwidth = 0,
categorical_encoding = "Binary"
)
fit.gbm25=h2o.gbm(x = features, y= response,
training_frame = wtrain,
validation_frame = wvalid,
nfolds = 10,
keep_cross_validation_predictions = FALSE,
keep_cross_validation_fold_assignment = FALSE,
score_each_iteration = TRUE,
fold_assignment = "Random",
ntrees = 200, max_depth = 50, min_rows = 1,nbins = 10,
stopping_rounds = 5,
stopping_metric = "RMSE", stopping_tolerance = 0.0001,
seed = 123,
learn_rate = 0.3,
learn_rate_annealing = 1,
distribution = "quantile",quantile_alpha =0.025,
huber_alpha = 0.9,
sample_rate = 0.66,
col_sample_rate = 1,
col_sample_rate_change_per_level = 1,
col_sample_rate_per_tree = 1,
min_split_improvement = 1e-05,
histogram_type = "Random",
max_abs_leafnode_pred = Inf, pred_noise_bandwidth = 0,
categorical_encoding = "Binary"
)
fit.gbm975=h2o.gbm(x = features, y= response,
training_frame = wtrain,
validation_frame = wvalid,
nfolds = 10,
keep_cross_validation_predictions = FALSE,
keep_cross_validation_fold_assignment = FALSE,
score_each_iteration = TRUE,
fold_assignment = "Random",
ntrees = 200, max_depth = 50, min_rows = 1,nbins = 10,
stopping_rounds = 5,
stopping_metric = "RMSE", stopping_tolerance = 0.0001,
seed = 123,
learn_rate = 0.3,
learn_rate_annealing = 1,
distribution = "quantile",quantile_alpha =0.975,
huber_alpha = 0.9,
sample_rate = 0.66,
col_sample_rate = 1,
col_sample_rate_change_per_level = 1,
col_sample_rate_per_tree = 1,
min_split_improvement = 1e-05,
histogram_type = "Random",
max_abs_leafnode_pred = Inf, pred_noise_bandwidth = 0,
categorical_encoding = "Binary"
)
We can explore these models
fit.gbm50
## Model Details:
## ==============
##
## H2ORegressionModel: gbm
## Model ID: GBM_model_R_1494530170586_37
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 17 17 27192 11
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 17 14.47059 79 137 122.23529
##
##
## H2ORegressionMetrics: gbm
## ** Reported on training data. **
##
## MSE: 33.91419
## RMSE: 5.823589
## MAE: 4.482219
## RMSLE: 0.03469198
## Mean Residual Deviance : 2.241109
##
##
## H2ORegressionMetrics: gbm
## ** Reported on validation data. **
##
## MSE: 39.61755
## RMSE: 6.294247
## MAE: 4.941346
## RMSLE: 0.03762129
## Mean Residual Deviance : 2.470673
##
##
## H2ORegressionMetrics: gbm
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 37.13434
## RMSE: 6.093795
## MAE: 4.797369
## RMSLE: 0.03628378
## Mean Residual Deviance : 2.398684
##
##
## Cross-Validation Metrics Summary:
## mean sd cv_1_valid cv_2_valid
## mae 4.7932587 0.12822463 4.762732 4.799897
## mse 37.057934 1.9790109 36.725586 38.040077
## r2 0.64309645 0.027426397 0.65942234 0.6129406
## residual_deviance 2.3966293 0.06411231 2.381366 2.3999486
## rmse 6.083212 0.16196994 6.060164 6.167664
## rmsle 0.036233127 8.5624016E-4 0.036010027 0.036449693
## cv_3_valid cv_4_valid cv_5_valid cv_6_valid
## mae 4.8928833 4.753488 4.9139266 4.7695336
## mse 36.85869 34.17145 40.274765 37.84854
## r2 0.6403843 0.66358846 0.61186963 0.6137844
## residual_deviance 2.4464417 2.376744 2.4569633 2.3847668
## rmse 6.0711355 5.8456354 6.34624 6.152117
## rmsle 0.03635718 0.03502546 0.037848014 0.036560412
## cv_7_valid cv_8_valid cv_9_valid cv_10_valid
## mae 4.717065 4.7331676 5.176918 4.412977
## mse 34.265137 37.267643 42.541416 32.58604
## r2 0.67154545 0.65605474 0.57818174 0.72319305
## residual_deviance 2.3585324 2.3665838 2.588459 2.2064886
## rmse 5.853643 6.104723 6.522378 5.7084184
## rmsle 0.035210177 0.036110207 0.03853327 0.034226824
fit.gbm25
## Model Details:
## ==============
##
## H2ORegressionModel: gbm
## Model ID: GBM_model_R_1494530170586_49
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 32 32 38523 12
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 22 16.71875 50 115 90.75000
##
##
## H2ORegressionMetrics: gbm
## ** Reported on training data. **
##
## MSE: 149.5933
## RMSE: 12.23083
## MAE: 10.60564
## RMSLE: 0.07397679
## Mean Residual Deviance : 0.3269672
##
##
## H2ORegressionMetrics: gbm
## ** Reported on validation data. **
##
## MSE: 148.9736
## RMSE: 12.20547
## MAE: 10.66646
## RMSLE: 0.07376965
## Mean Residual Deviance : 0.4746812
##
##
## H2ORegressionMetrics: gbm
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 150.3832
## RMSE: 12.26308
## MAE: 10.69248
## RMSLE: 0.07419446
## Mean Residual Deviance : 0.4465867
##
##
## Cross-Validation Metrics Summary:
## mean sd cv_1_valid cv_2_valid
## mae 10.679189 0.38195664 10.786995 11.034384
## mse 150.04637 8.787474 150.20583 158.93976
## r2 -0.44230506 0.100500196 -0.39294565 -0.6172186
## residual_deviance 0.44617358 0.042815406 0.41663563 0.43110228
## rmse 12.238609 0.36250335 12.255849 12.607131
## rmsle 0.07404595 0.002175862 0.07395229 0.07647477
## cv_3_valid cv_4_valid cv_5_valid cv_6_valid
## mae 11.111965 10.031077 10.097907 11.176072
## mse 159.87573 135.51915 136.4218 159.01256
## r2 -0.55984455 -0.33416072 -0.3147052 -0.62260246
## residual_deviance 0.4507702 0.35550457 0.6013477 0.47079825
## rmse 12.644197 11.64127 11.679974 12.610019
## rmsle 0.07644209 0.07063842 0.07038476 0.07638832
## cv_7_valid cv_8_valid cv_9_valid cv_10_valid
## mae 9.705143 11.471856 10.6152525 10.761231
## mse 127.331985 169.17278 153.05605 150.92809
## r2 -0.22056337 -0.56130564 -0.51762325 -0.2820811
## residual_deviance 0.402774 0.42207563 0.45659184 0.45413572
## rmse 11.284147 13.006643 12.371582 12.285279
## rmsle 0.06859064 0.07877969 0.07463913 0.074169405
fit.gbm975
## Model Details:
## ==============
##
## H2ORegressionModel: gbm
## Model ID: GBM_model_R_1494530170586_61
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 37 37 46023 9
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 21 16.94595 45 117 93.81081
##
##
## H2ORegressionMetrics: gbm
## ** Reported on training data. **
##
## MSE: 144.8542
## RMSE: 12.03554
## MAE: 10.43734
## RMSLE: 0.07043336
## Mean Residual Deviance : 0.3197336
##
##
## H2ORegressionMetrics: gbm
## ** Reported on validation data. **
##
## MSE: 152.0232
## RMSE: 12.32977
## MAE: 10.69089
## RMSLE: 0.07224249
## Mean Residual Deviance : 0.4423573
##
##
## H2ORegressionMetrics: gbm
## ** Reported on cross-validation data. **
## ** 10-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 145.0165
## RMSE: 12.04228
## MAE: 10.52975
## RMSLE: 0.07047903
## Mean Residual Deviance : 0.4656513
##
##
## Cross-Validation Metrics Summary:
## mean sd cv_1_valid cv_2_valid
## mae 10.535788 0.25543556 10.1642885 10.19656
## mse 145.12074 6.142722 137.77377 138.25925
## r2 -0.39609122 0.08661117 -0.27765605 -0.40679353
## residual_deviance 0.46563193 0.048042547 0.54053617 0.41766816
## rmse 12.041273 0.2534598 11.737707 11.758369
## rmsle 0.07047995 0.0015044017 0.06885511 0.06868834
## cv_3_valid cv_4_valid cv_5_valid cv_6_valid
## mae 10.329627 11.06928 11.034871 10.353079
## mse 140.50479 159.9546 156.15402 141.26045
## r2 -0.3708499 -0.57472306 -0.50486594 -0.44145557
## residual_deviance 0.41217688 0.40471303 0.42592978 0.55300295
## rmse 11.853472 12.647316 12.4961605 11.8853035
## rmsle 0.06934074 0.07411614 0.07331979 0.06948638
## cv_7_valid cv_8_valid cv_9_valid cv_10_valid
## mae 11.017733 10.1414385 10.690248 10.360752
## mse 153.75427 133.9359 150.96063 138.64972
## r2 -0.4738388 -0.23610246 -0.49684614 -0.17778069
## residual_deviance 0.458273 0.58474374 0.48273748 0.37653825
## rmse 12.399769 11.573068 12.286604 11.774961
## rmsle 0.072656415 0.06746559 0.07143942 0.069431625
EVALUATING THE PERFORMANCE OF THE MODEL FOR MEDIAN ON TEST SET
library(mlr)
regr.task= mlr::makeRegrTask(id = "Ht", data=testset, target = "ht")
regr.lrn = makeLearner("regr.glm")
predh2o=predict(fit.gbm50,newdata=wtest)
dummy=mlr::train(regr.lrn,regr.task)
dumpred=predict(dummy,regr.task)
dumpred$data$response<-as.vector(predh2o)
mets=list(rsq,expvar,mae,rmse,mse,medse,medae,rrse)
performance(dumpred,measures =mets)
## rsq expvar mae rmse mse medse
## 0.6206384 0.7410569 4.8966982 6.1756358 38.1384781 16.9765114
## medae rrse
## 4.1202562 0.6159234
The GBM based model for predicting Median of Height could explain 74.10% of Height Variance in the testset. Its R2 is only acceptable but not very good. The model has mean absolute error of 4.9 and root mean squared error of 6.18
Bootstraping the evaluation on a combined testset
On the next step we will perform a bootstrap function on our evaluation. This time the test set will be extended with the validation subset. Two principal performance metrics will be evaluated: R2 and RMSE
bootPRED=function(dummymodel,h2omodel,data,i){
d=data%>%.[i,]
wdf=as.h2o(d)
predh2o=predict(h2omodel,newdata=wdf)
predmlr=predict(dummymodel,newdata=d)
predmlr$data$response<-as.vector(predh2o)
perf=mlr::performance(predmlr,measure=list(rsq,rmse))
R2=perf[[1]]
RMSE=perf[[2]]
return=cbind(R2,RMSE)
}
testdf=rbind(testset,validset)
set.seed(123)
library(boot)
bootpd=boot(statistic=bootPRED,dummymodel=dummy,h2omodel=fit.gbm50,data=testdf,R=100)%>%.$t%>%as_tibble()
names(bootpd)=c("R2","RMSE")
bootpd$iteration=rownames(bootpd)%>%as.numeric()
bootpd%>%ggplot(aes(x=iteration,y=R2))+geom_line(color="red",alpha=0.8,size=1.2)+geom_point(shape=21,size=3,color="red4",fill="red")+geom_hline(yintercept=median(bootpd$R2),linetype=2,size=1)+ggtitle("R2")
bootpd%>%ggplot(aes(x=iteration,y=RMSE))+geom_line(color="blue",alpha=0.8,size=1.2)+geom_point(shape=21,size=3,color="blue4",fill="blue")+geom_hline(yintercept=median(bootpd$RMSE),linetype=2,size=1)+ggtitle("RMSE")
bootpd%>%ggplot(aes(x=R2))+geom_density(fill="red",alpha=0.6)+geom_vline(xintercept=median(bootpd$R2),linetype=2,size=1)+ggtitle("R2")
VISUALISING THE PREDICTION
predh2o50=predict(fit.gbm50,newdata=wtest)
predh2o25=predict(fit.gbm25,newdata=wtest)
predh2o975=predict(fit.gbm975,newdata=wtest)
Mdf=testset%>%mutate(.,Pred=as.vector(predh2o50),Parameter="Median")
Ldf=testset%>%mutate(.,Pred=as.vector(predh2o25),Parameter="LL")
Udf=testset%>%mutate(.,Pred=as.vector(predh2o975),Parameter="UL")
rbind(Mdf,Ldf,Udf)%>%ggplot(aes(x=arml))+geom_point(aes(y=ht),alpha=0.2,color="grey")+geom_point(aes(y=Pred,color=Parameter))+facet_wrap(~sex,ncol=1,scales="free")+scale_color_manual(values=c("blue","black","red"))
rbind(Mdf,Ldf,Udf)%>%ggplot(aes(x=arml))+geom_point(aes(y=ht),alpha=0.2,color="grey")+geom_smooth(se=F,aes(y=Pred,color=Parameter),method="auto")+facet_wrap(~sex,ncol=1,scales="free")+scale_color_manual(values=c("blue","black","red"))
preddf=testset%>%mutate(.,Median=as.vector(predh2o50),LL=as.vector(predh2o25),UL=as.vector(predh2o975))
preddf%>%ggplot(aes(x=arml))+geom_point(aes(y=ht),alpha=0.2,color="red4")+geom_ribbon(aes(ymin=LL,ymax=UL),fill="red",alpha=0.4)+geom_smooth(aes(y=Median),method="loess",se=F,color="red4")
preddf%>%ggplot(aes(x=arml))+geom_point(aes(y=ht),alpha=0.2,color="grey")+geom_errorbar(aes(ymin=LL,ymax=UL,color=as.factor(ht)),alpha=0.4,show.legend = F)+geom_smooth(aes(y=Median),se=F,method="gam",color="gold")+geom_point(aes(y=Median),color="blue4")
Conclusion
In thiscase study, we have found that the Tree based machine learning model could also be applied to Regression task. We have also explored the use of a gradient boosting machine for quantile regression. Despite that the performance of final model was not so good, such approach could be useful for the regression problem with high compexity in data structure and non-linear interaction.
See you in the next tutorial and thank for joining us !
END