Case study X9: Exploring the Gradient Boosting Machine (GBM) learner in h2o.

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

  1. Encourage using Machine Learning techniques in medical research in Vietnam and

  2. Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Introduction

Gradient boosting machine (GBM) is an Ensemble algorithm that implies Boosting method to combine weak Decision trees together. The principle of Boosting consists of developing multiple weak learners (decision tree) and consecutively isolating the difficult cases within the training data so new learners will focus and adapt to handle them. In other word, we try to reuse a weak learning method several times, each training step will focus on the observations that previous steps failed to classify, then combining their outputs to get a stronger model. Final predictions will be based on majority vote of the weak learners’ predictions, weighted by their individual accuracy.

Objective

The present case study aims to explore following Machine learning techniques:

  1. Training a GBM learner in h2o
  2. Tuning the cut-off for multiclass classficiation
  3. Interpretive study on GBM model
  4. Combining the strengths of 3 packages: h2o, mlr and caret

This case study implies the CMC (contraceptive method choices) dataset (Tjen-Sien Lim et al. 1999). This is a subset of the 1987 National Indonesia Contraceptive Prevalence Survey. Data were collected in married women who were either not pregnant or do not know if they were at the time of interview. Original dataset has 1473 observations, no missing value and 9 features including: Age, Education level, Religion, Media exposure and working status of wife, education level, occupation of husband, Standard-of-living index and number of children.

The main question is to predict the current contraceptive method choice (no use, long-term or short-term methods) of a woman based on her demographic and socio-economic characteristics. Such estimation could be difficult, as this is a multiclass problem and the final outcome might be stoschastic and psychological driven.

Materials and method

First, we will prepare the ggplot theme for our experiment

library(tidyverse)

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 12),
      panel.grid.major = element_line(color = "grey"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#fffcfc"),
      strip.background = element_rect(fill = "#820000", color = "#820000", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
    )
}
theme_set(my_theme())

mycolors=c("#ccc7c7","#db3434","#871618","#590050","#7f34a8")

Now we load the dataset from UCI website and perform a descriptive analysis

df=read.table("https://archive.ics.uci.edu/ml/machine-learning-databases/cmc/cmc.data",sep=",")%>%as_tibble()

names(df)=c("WifeAge","WifeEdu","HusbandEdu","NumbChild","Religion","Working","HusbandOcc","SLI","Media","CMC")

df$WifeAge=as.numeric(df$WifeAge)
df$WifeEdu=df$WifeEdu%>%as.integer()
df$HusbandEdu=df$HusbandEdu%>%as.integer()
df$Religion=df$Religion%>%recode_factor(.,`1` = "Islam", `0` = "NonIslam")
df$Working=df$Working%>%recode_factor(.,`0` = "No", `1` = "Yes")
df$HusbandOcc=df$HusbandOcc%>%as.factor()
df$Media=df$Media%>%recode_factor(.,`0` = "Good", `1` = "Low")
df$CMC=df$CMC%>%recode_factor(.,`1` = "None", `2` = "Longterm", `3` = "Shortterm")

Hmisc::describe(df)
## df 
## 
##  10  Variables      1473  Observations
## ---------------------------------------------------------------------------
## WifeAge 
##        n  missing distinct     Info     Mean      Gmd      .05      .10 
##     1473        0       34    0.999    32.54    9.438       21       22 
##      .25      .50      .75      .90      .95 
##       26       32       39       45       47 
## 
## lowest : 16 17 18 19 20, highest: 45 46 47 48 49
## ---------------------------------------------------------------------------
## WifeEdu 
##        n  missing distinct     Info     Mean      Gmd 
##     1473        0        4    0.906    2.959    1.105 
##                                   
## Value          1     2     3     4
## Frequency    152   334   410   577
## Proportion 0.103 0.227 0.278 0.392
## ---------------------------------------------------------------------------
## HusbandEdu 
##        n  missing distinct     Info     Mean      Gmd 
##     1473        0        4    0.757     3.43   0.7902 
##                                   
## Value          1     2     3     4
## Frequency     44   178   352   899
## Proportion 0.030 0.121 0.239 0.610
## ---------------------------------------------------------------------------
## NumbChild 
##        n  missing distinct     Info     Mean      Gmd      .05      .10 
##     1473        0       15    0.978    3.261     2.54        0        1 
##      .25      .50      .75      .90      .95 
##        1        3        4        6        8 
##                                                                       
## Value          0     1     2     3     4     5     6     7     8     9
## Frequency     97   276   276   259   197   135    92    49    47    16
## Proportion 0.066 0.187 0.187 0.176 0.134 0.092 0.062 0.033 0.032 0.011
##                                         
## Value         10    11    12    13    16
## Frequency     11    11     4     2     1
## Proportion 0.007 0.007 0.003 0.001 0.001
## ---------------------------------------------------------------------------
## Religion 
##        n  missing distinct 
##     1473        0        2 
##                             
## Value         Islam NonIslam
## Frequency      1253      220
## Proportion    0.851    0.149
## ---------------------------------------------------------------------------
## Working 
##        n  missing distinct 
##     1473        0        2 
##                       
## Value         No   Yes
## Frequency    369  1104
## Proportion 0.251 0.749
## ---------------------------------------------------------------------------
## HusbandOcc 
##        n  missing distinct 
##     1473        0        4 
##                                   
## Value          1     2     3     4
## Frequency    436   425   585    27
## Proportion 0.296 0.289 0.397 0.018
## ---------------------------------------------------------------------------
## SLI 
##        n  missing distinct     Info     Mean      Gmd 
##     1473        0        4     0.87    3.134    1.026 
##                                   
## Value          1     2     3     4
## Frequency    129   229   431   684
## Proportion 0.088 0.155 0.293 0.464
## ---------------------------------------------------------------------------
## Media 
##        n  missing distinct 
##     1473        0        2 
##                       
## Value       Good   Low
## Frequency   1364   109
## Proportion 0.926 0.074
## ---------------------------------------------------------------------------
## CMC 
##        n  missing distinct 
##     1473        0        3 
##                                         
## Value           None  Longterm Shortterm
## Frequency        629       333       511
## Proportion     0.427     0.226     0.347
## ---------------------------------------------------------------------------

Data visualising

library(gridExtra)

e1=df%>%ggplot(aes(x=CMC,fill=as.factor(WifeEdu)))+geom_bar(position="fill",color="black",alpha=0.8,show.legend = T)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Wife_Education")

e2=df%>%ggplot(aes(x=CMC,fill=as.factor(HusbandEdu)))+geom_bar(position="fill",color="black",alpha=0.8,show.legend = T)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Husband_Education")

grid.arrange(e1,e2,ncol=1)

df%>%ggplot(aes(x=WifeAge,fill=CMC))+geom_histogram(color="black",alpha=0.7,show.legend = T,binwidth=1)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("Wife_Age")

df%>%ggplot(aes(x=NumbChild,fill=CMC))+geom_histogram(color="black",alpha=0.7,show.legend = T,binwidth = 1)+scale_fill_manual(values=mycolors)+ggtitle("Number of children")

df%>%ggplot(aes(x=CMC,fill=CMC))+geom_bar(position="identity",color="black",alpha=0.8,show.legend = T)+scale_fill_manual(values=mycolors)+coord_flip()+facet_grid(Religion~Media,scales="free")+ggtitle("Religion x Media Exposure")

df%>%ggplot(aes(x=CMC,fill=CMC))+geom_bar(position="identity",color="black",alpha=0.8,show.legend = T)+scale_fill_manual(values=mycolors)+coord_flip()+facet_grid(WifeEdu~HusbandEdu,scales="free")+ggtitle("Education: Wife x Husband")

df%>%ggplot(aes(x=SLI,fill=CMC))+geom_bar(position="fill",color="black",alpha=0.8,show.legend = T)+scale_fill_manual(values=mycolors)+coord_flip()+ggtitle("SLI")

Machine learning experiment

First, we will split the original dataset into 3 subsets:

  1. A train-subset containing 885 cases, for model training
  2. A validation subset of 295 cases for calibrating the model in h2o
  3. An independent subset (293 cases) fof testing the trained model.
library(caret)
set.seed(123)

idTrain=createDataPartition(y=df$CMC,p=0.6,list=FALSE)
trainset=df[idTrain,]
remain=df[-idTrain,]
idtest=createDataPartition(y=remain$CMC,p=0.5,list=FALSE)
validset=remain[idtest,]
testset=remain[-idtest,]

Our next step consists of exploring the potential contribution of 9 features. This could be done using the mlr package:

library(mlr)

task=makeClassifTask(id = "CMC", data=trainset, target = "CMC")

generateFilterValuesData(task,method="information.gain")%>%.$data%>%ggplot(aes(x=reorder(name,information.gain),y=information.gain,fill=reorder(name,-information.gain)))+geom_bar(stat="identity",color="black",show.legend=F)+scale_x_discrete("Features")+coord_flip()+scale_fill_brewer(palette="Reds",direction=-1)

The information gain results indicate that all features might contribute more or less to the CMC classification. Number of children, Wife’s education and Wife’s Age were the most important features in our problem. Media exposure, Islamic religion and working status consisted of the least important features. Despite that such exploration is not always accurate, its result might eventually help us filtrating the irrelevant features.

Control=caret::trainControl(method= "repeatedcv",number=10,repeats=10,classProbs=TRUE,summaryFunction=caret::multiClassSummary)

library(rattle)
library(rpart.plot)
library(partykit)
library(party)

set.seed(123)
cart=caret::train(CMC~.,data=trainset,method="rpart",trControl=Control)

fancyRpartPlot(cart$finalModel,palettes="Reds")

predcart=predict(cart,newdata=testset)
confusionMatrix(predcart,reference=testset$CMC,mode="everything")
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  None Longterm Shortterm
##   None        81       14        29
##   Longterm    12       28        19
##   Shortterm   32       24        54
## 
## Overall Statistics
##                                           
##                Accuracy : 0.5563          
##                  95% CI : (0.4974, 0.6141)
##     No Information Rate : 0.4266          
##     P-Value [Acc > NIR] : 5.423e-06       
##                                           
##                   Kappa : 0.3104          
##  Mcnemar's Test P-Value : 0.8296          
## 
## Statistics by Class:
## 
##                      Class: None Class: Longterm Class: Shortterm
## Sensitivity               0.6480         0.42424           0.5294
## Specificity               0.7440         0.86344           0.7068
## Pos Pred Value            0.6532         0.47458           0.4909
## Neg Pred Value            0.7396         0.83761           0.7377
## Precision                 0.6532         0.47458           0.4909
## Recall                    0.6480         0.42424           0.5294
## F1                        0.6506         0.44800           0.5094
## Prevalence                0.4266         0.22526           0.3481
## Detection Rate            0.2765         0.09556           0.1843
## Detection Prevalence      0.4232         0.20137           0.3754
## Balanced Accuracy         0.6960         0.64384           0.6181

As expected, the basic decision tree implied only 3 features: Wife’s Education, Number of children and Wife’s Age in order to predict the 3 classes of CMC. The model’s performance is weak, as the balanced accuracy for each class was only 0.62 to 0.69, the Kappa coefficient was low (0.31) and the F1 scores were 0.51 to 0.65. Thus CART based decision tree is a weak learner.

We hope that the Boosting method can eventually improve the classification accuracy in our next step. We will adopt the GBM algorithm integrated in h2O package. The first thing to do is initialising our h2o package in R.

library(h2o)

h2o.init(nthreads = -1,max_mem_size ="4g")
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         1 hours 53 minutes 
##     H2O cluster version:        3.10.3.6 
##     H2O cluster version age:    2 months and 11 days  
##     H2O cluster name:           H2O_started_from_R_Admin_bhj778 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.39 GB 
##     H2O cluster total cores:    4 
##     H2O cluster allowed cores:  4 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     R Version:                  R version 3.3.1 (2016-06-21)

then we train a GBM learner on training and validation datasets:

Following parameters are introduced for training a GBM learner in h2o (they are similar to parameters in random forest learner)

  1. Ntrees = how many trees (elementary weak learner) to be built ?

  2. Max_depth= How deep each tree is allowed to grow (the complexity level for each tree ?)

  3. Min_rows determines the required number of rows to make a leaf node in decision tree. The default is 10, lower value might lead to overfitting.

  4. Sampe_rate controls the randomised features feeding to the learner

We also applied early stopping criteria that based on mean_per_class_error metric and a reproducible training process using seed.

wdf=as.h2o(df)
wtrain=as.h2o(trainset)
wtest=as.h2o(testset)
wvalid=as.h2o(validset)

response="CMC"
features=setdiff(colnames(wtrain),response)

gbmod1=h2o.gbm(x = features,
               y = response,
               training_frame = wtrain,nfolds=10,validation_frame = wvalid,
               categorical_encoding="Enum",
               fold_assignment = "Stratified",
               balance_classes = TRUE,
               ntrees =100, max_depth = 10,min_rows = 30,sample_rate=0.8,
               stopping_metric = "mean_per_class_error",
               stopping_tolerance = 0.001,
               stopping_rounds = 5,
               keep_cross_validation_fold_assignment = TRUE, 
               keep_cross_validation_predictions=TRUE,
               score_each_iteration = TRUE,
               seed=123)

Evaluating the GBM models’ performance

learnergbmh2o=makeLearner(id="gbmh2o","classif.h2o.gbm", predict.type = "prob")

mlrmod=train(learnergbmh2o,task)
mlrmod$learner.model<-gbmod1

predgbm=predict(gbmod1,newdata=wtest)%>%as.data.frame()

predmlr=predict(mlrmod,newdata=testset)
confusionMatrix(predgbm$predict,reference=testset$CMC,mode="everything")
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  None Longterm Shortterm
##   None        85       18        39
##   Longterm    12       31         9
##   Shortterm   28       17        54
## 
## Overall Statistics
##                                           
##                Accuracy : 0.5802          
##                  95% CI : (0.5214, 0.6374)
##     No Information Rate : 0.4266          
##     P-Value [Acc > NIR] : 8.959e-08       
##                                           
##                   Kappa : 0.3396          
##  Mcnemar's Test P-Value : 0.1406          
## 
## Statistics by Class:
## 
##                      Class: None Class: Longterm Class: Shortterm
## Sensitivity               0.6800          0.4697           0.5294
## Specificity               0.6607          0.9075           0.7644
## Pos Pred Value            0.5986          0.5962           0.5455
## Neg Pred Value            0.7351          0.8548           0.7526
## Precision                 0.5986          0.5962           0.5455
## Recall                    0.6800          0.4697           0.5294
## F1                        0.6367          0.5254           0.5373
## Prevalence                0.4266          0.2253           0.3481
## Detection Rate            0.2901          0.1058           0.1843
## Detection Prevalence      0.4846          0.1775           0.3379
## Balanced Accuracy         0.6704          0.6886           0.6469
h2o.performance(gbmod1,wtest)
## H2OMultinomialMetrics: gbm
## 
## Test Set Metrics: 
## =====================
## 
## MSE: (Extract with `h2o.mse`) 0.3323592
## RMSE: (Extract with `h2o.rmse`) 0.576506
## Logloss: (Extract with `h2o.logloss`) 0.8894461
## Mean Per-Class Error: 0.4402971
## Confusion Matrix: Extract with `h2o.confusionMatrix(<model>, <data>)`)
## =========================================================================
## Confusion Matrix: vertical: actual; across: predicted
##           Longterm None Shortterm  Error        Rate
## Longterm        31   18        17 0.5303 =   35 / 66
## None            12   85        28 0.3200 =  40 / 125
## Shortterm        9   39        54 0.4706 =  48 / 102
## Totals          52  142        99 0.4198 = 123 / 293
## 
## Hit Ratio Table: Extract with `h2o.hit_ratio_table(<model>, <data>)`
## =======================================================================
## Top-3 Hit Ratios: 
##   k hit_ratio
## 1 1  0.580205
## 2 2  0.819113
## 3 3  1.000000
mets=list(multiclass.au1p,multiclass.au1u,multiclass.aunp,multiclass.aunu,multiclass.brier,wkappa)

mlr::performance(predmlr,mets)
##  multiclass.au1p  multiclass.au1u  multiclass.aunp  multiclass.aunu 
##        0.7451975        0.7508267        0.7462183        0.7526632 
## multiclass.brier           wkappa 
##        0.5308623        0.2975597

The GBM models’ performance was evaluated in 3 different ways using h2o, caret and mlr based functions.

The confusion matrix by either h2o or caret shows that the ensemble GBM model did only improve slightly the accuracy of CMC classfication. Basically, the accuracy was averaged across 3 classes : BAC = 0.65 for Shortterm, 0.69 for Longterm and 0.67 for No use. The relative error rate was still high in shortterm and longterm classes. It seems that our model worked best on No use classfication but failed in half of predictions for Longterm and Shortterm classes. The mlr package provides some special metrics for evaluating the overall performance of a multiclass classifier. The metrics such as multiclass AU1P, AU1U, AUNP and AUNU could be interpreted as an adjusted Area under ROC curve; their values were around 0.75. The unweighted and weighted kappa coefficients were respectively 0.336 and 0.297, indicating a low agreement between predicted outcome and the truth.

On the next step, we attempt to tune up the model’s classification by changing cut-off for each class. But first we would like to take a look on the relationship between the predicted probabilities and the true classes:

predh2o=predict(gbmod1,newdata=wtest)%>%as_tibble()%>%mutate(Truth=testset$CMC)
predh2o%>%ggplot(aes(x=Longterm,fill=Truth))+geom_density(alpha=0.6,color="black")+scale_fill_manual(values=c("#ccc7c7","#db3434","#590050"))+ggtitle("Longterm")

predh2o%>%ggplot(aes(x=None,fill=Truth))+geom_density(alpha=0.6,color="black")+scale_fill_manual(values=c("#ccc7c7","#db3434","#590050"))+ggtitle("None")

predh2o%>%ggplot(aes(x=Shortterm,fill=Truth))+geom_density(alpha=0.6,color="black")+scale_fill_manual(values=c("#ccc7c7","#db3434","#590050"))+ggtitle("Shortterm")

Those figures could explain why the accuracy per class was low for our model. There is a large overlapping region among our 3 classes.

We will perform a simple 4-steps tuning process on cut-off in 3 classes, by combining 3 situations in which the cut-offs could be fixed at 0.3-0.5-0.6 for Longterm class, 0.5-0.6-0.7-0.75 for None class and 0.5-0.6-0.7-0.75 for Shortterm class. It should be noted that changing cut-offs might generate the uncertain cases that cannot be classified to any class.

conf0=confusionMatrix(predh2o$predict,reference=predh2o$Truth,mode="everything")

confdf0=conf0$byClass%>%as_tibble()%>%mutate(Class=c("None","Longterm","Shortterm"),Sample=100*sum(conf0$table)/nrow(testset))
names(confdf0)=c("SEN","SPEC","PPV","NPV","PREC","RECALL","F1","PREV","DRATE","DPREV","BAC","CLASS","SAMPLE")

confdf0=confdf0%>%mutate(Cutoff=c(0.5,0.5,0.5))

predh2o$Reclassification1=case_when(
  predh2o$Longterm > 0.3 ~ "Longterm", 
  predh2o$None > 0.6 ~  "None", 
  predh2o$Shortterm > 0.6 ~ "Shortterm")

conf1=confusionMatrix(predh2o$Reclassification1,reference=predh2o$Truth,mode="everything")

confdf1=conf1$byClass%>%as_tibble()%>%mutate(Class=c("None","Longterm","Shortterm"),Sample=100*sum(conf1$table)/nrow(testset))
names(confdf1)=c("SEN","SPEC","PPV","NPV","PREC","RECALL","F1","PREV","DRATE","DPREV","BAC","CLASS","SAMPLE")

confdf1=confdf1%>%mutate(Cutoff=c(0.6,0.3,0.6))


predh2o$Reclassification2=case_when(
                               predh2o$Longterm > 0.4 ~ "Longterm", 
                               predh2o$None > 0.7 ~  "None", 
                               predh2o$Shortterm > 0.7 ~ "Shortterm")
                                
conf2=confusionMatrix(predh2o$Reclassification2,reference=predh2o$Truth,mode="everything")

confdf2=conf2$byClass%>%as_tibble()%>%mutate(Class=c("None","Longterm","Shortterm"),Sample=100*sum(conf2$table)/nrow(testset))
names(confdf2)=c("SEN","SPEC","PPV","NPV","PREC","RECALL","F1","PREV","DRATE","DPREV","BAC","CLASS","SAMPLE")

confdf2=confdf2%>%mutate(Cutoff=c(0.7,0.4,0.7))

predh2o$Reclassification3=case_when(
  predh2o$Longterm > 0.6 ~ "Longterm", 
  predh2o$None > 0.75 ~  "None", 
  predh2o$Shortterm > 0.75 ~ "Shortterm")

conf3=confusionMatrix(predh2o$Reclassification3,reference=predh2o$Truth,mode="everything")

confdf3=conf3$byClass%>%as_tibble()%>%mutate(Class=c("None","Longterm","Shortterm"),Sample=100*sum(conf3$table)/nrow(testset))
names(confdf3)=c("SEN","SPEC","PPV","NPV","PREC","RECALL","F1","PREV","DRATE","DPREV","BAC","CLASS","SAMPLE")

confdf3=confdf3%>%mutate(Cutoff=c(0.75,0.6,0.75))
confdf=rbind(confdf0,confdf1,confdf2,confdf3)

ggplot(data=confdf,aes(x=Cutoff,y=SAMPLE,fill=as.factor(Cutoff)))+coord_flip()+geom_bar(stat="identity")+facet_wrap(~CLASS,ncol=1,scales="free")

confdf%>%gather(F1,RECALL,BAC,key="Metric",value="Score")%>%ggplot(aes(x=Cutoff,y=Score,color=CLASS,fill=CLASS))+geom_line(size=1)+geom_point(shape=21,aes(fill=CLASS),color="black",size=3)+facet_grid(Metric~CLASS,scales="free_x")

The result of this tuning indicates that only original cut-off can warrant a full range prediction (100% sample). The sucessful prediction rate gradually decreased as we modify the Cut-off from 0.5 to 0.7 or 0.75. However, the relative accuracy was really improved. The optimal accuracy could be obtained at cut-off of 0.4 for Longterm, 0.75 for None and 0.75 for Shortterm classes. We should carefully combine those cut-off in order to minimise the uncertain and good predictions.

Finally, we decided to construct a new classification rule of cut-offs for Longterm, None and Shortterm classes of respectively 0.4, 0.6 and 0.7:

predh2o$Reclassification=case_when(
  predh2o$Longterm > 0.4 ~ "Longterm", 
  predh2o$None > 0.6 ~  "None", 
  predh2o$Shortterm > 0.7~ "Shortterm")

confusionMatrix(predh2o$Reclassification,reference=predh2o$Truth,mode="everything")

Such classification rule allows to correctly classify most of the woman who don’t use any contraceptive method and the ones who chose the longterm methods. This model stills fail to make prediction in most of cases.

predh2o=predh2o%>%mutate(Accuracy= case_when(.$Reclassification == .$Truth ~ "Yes", 
                                             .$Reclassification != .$Truth ~ "No"))

predh2o%>%ggplot(aes(x=Truth,fill=Accuracy))+geom_bar(position="fill",color="black")+coord_flip()

One might ask about an interpretive study or how to extract the information on relationship between socio-economic parameters and the CMC classification ? In the final part of this document, we will show that such interpretive analysis is possible using the partial dependency function in mlr package.

pd1=generatePartialDependenceData(mlrmod,task,features="NumbChild",gridsize = 30L)

pd2=generatePartialDependenceData(mlrmod,task,features="WifeAge",gridsize = 30L)

pd3=generatePartialDependenceData(mlrmod,task,features="WifeEdu",gridsize = 30L)

pd4=generatePartialDependenceData(mlrmod,task,features="SLI",gridsize = 30L)

As Wife’age, wife’s education and number of children were considered the 3 most important variables, we present below a marginalised effect of those 3 variables to the predicted probability of each contraceptive method choice.

pd1%>%plotPartialDependence(.)+facet_wrap(~Class)

pd2%>%plotPartialDependence(.)+facet_wrap(~Class)

pd3%>%plotPartialDependence(.)+facet_wrap(~Class)

pd4%>%plotPartialDependence(.)+facet_wrap(~Class)

Conclusion

Despite that our GBM model did not significantly improve the accuracy of prediction in our multiclass problem, we have learnt a lot of machine learning techniques in the present study. Those include the GBM training in h2o, cut-off tuning and interpretive exploration of GBM model.

See you in the next tutorial and thank for joining us.

END