data.table
This notebook compliments my blog post that is also found here.
With release 3.22.0.1 H2O-3 (a.k.a. O/S H2O or simply H2O) added to its family of tree-based algorithms (which already included DRF, GBM, and XGBoost) support for one more: Isolation Forest (random forest for unsupervised anomaly detection). There were no simple way to visualize H2O trees except following clunky (albeit reliable) method of creating a MOJO object and running combination of Java and dot commands.
That changed in 3.22.0.1 too with introduction of unified Tree API to work with any of the tree-based algorithms above. Data scientists are now able to utilize powerful visualization tools in R (or Python) without resorting to producing intermediate artifacts like MOJO and running external utilities. Please read this article by Pavel Pscheidl who did superb job of explaining H2O Tree API and S4 classes in R before coming back to take it a step further to visualize trees.
Whether you are still here or came back after reading Pavel’s excellent post let’s set goal straight: create single decision tree model in H2O and visualize its tree graph. With H2O there is always a choice between using Python or R - the choice for R here will become clear when discussing its graphical and analytical capabilities later.
CART models operate on labeled data (classification and regression) and offer arguably unmatched model interpretability by means of analyzing a tree graph. In data science there is never single way to solve given problem so let’s define end-to-end logical workflow from “raw” data to visualized decision tree:
Figure 1. Workflow of tasks in this post
One may argue that the choice of executing steps inside H2O or R could be different but let’s follow outlined plan for this post. Next diagram adds implementation details:
Figure 2. Workflow of tasks in this post with implementation details
Discussion of this workflow continues for the rest of this post.
The famous Titanic dataset contains information about the fate of passengers of the RMS Titanic that sank after colliding with an iceberg. It regularly serves as toy data for blog exercises like this.
H2O public S3 bucket holds the Titanic dataset readly available and using package data.table makes it fast one-liner to load into R:
library(data.table)
titanicDT = fread("https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv")
Passenger features from the Titanic dataset are discussed at length online, e.g. see Predicting the Survival of Titanic Passengers and Predicting Titanic Survival using Five Algorithms. To summarize the following features were selected and engineered for decision tree model:
Data load and data munging steps above are implemented in using R package data.table
. First, create new features:
# Titles mapping
TITLES = data.frame(from=c("Capt", "Col", "Major", "Jonkheer", "Don", "Sir", "Dr", "Rev", "the Countess",
"Mme", "Mlle", "Ms", "Mr", "Mrs", "Miss", "Master", "Lady"),
to = c("Officer", "Officer", "Officer", "Royalty", "Royalty", "Royalty", "Officer", "Officer", "Royalty",
"Mrs", "Miss", "Mrs", "Mr", "Mrs", "Miss", "Master", "Royalty"),
stringsAsFactors = FALSE)
# Create features
titanicDT[,
c("sex",
"embarked",
"survived",
"pclass",
"cabin_type",
"family_size",
"family_type",
"title") := list(
factor(sex, labels = c("Female","Male")),
factor(embarked, labels = c("", "Cherbourg","Queenstown","Southampton")),
factor(-survived, labels = c('Yes','No')),
factor(pclass, labels = c("Class 1","Class 2","Class 3")),
factor(substring(cabin, 1, 1), exclude = ""),
sibsp + parch,
as.factor(ifelse(sibsp + parch <= 1, "SINGLE", ifelse(sibsp + parch <= 3, "SMALL", "LARGE"))),
as.factor(sapply(strsplit(name, "[\\., ]+"), function(x) {
words = trimws(x)
words = words[!words=="" ]
words = words[words %in% TITLES$from]
if (length(words) > 0)
title_word = words[[1]]
else
return(NA)
return(TITLES[title_word == TITLES$from, 'to'])
}))
)]
Then impute missing features:
# Handle missing values by imputing them with nulls
titanicDT[, c("age","fare") :=
list(ifelse(is.na(age), mean(age, na.rm=T), age),
ifelse(is.na(fare), mean(fare, na.rm=T), fare)),
by = c("survived","sex","embarked")]
Finally, create Titanic dataset with label (survived) and features that do not leak:
# create dataset for Titanic survived predictive model
response = "survived"
predictors = setdiff(colnames(titanicDT),
c(response,"name","ticket","cabin","boat","body","home.dest"))
titanicDT = titanicDT[, c(response, predictors), with=FALSE]
Creating models with H2O requires running a server process (remote or local) and a client (package h2o in R available from CRAN) where the latter connects and sends commands to the former. The Tree API was introduced with release 3.22.0.1 (10/26/2018) but due to CRAN policies h2o
package usually lags several versions behind (on the time of this writing CRAN hosted version 3.20.0.8). There are two ways to work around this:
strict_version_check=FALSE
inside h2o.connect()
to communicate with newer version running on serverh2o
available from H2O repository either to connect to remote server or to both connect and run server locally.Tree API is available only with 2d option because it requires access to new classes and functions in h2o
package (remember, I asked you read Pavel’s blog). Below code from the official H2O download page shows how to download and install the latest version of the package (H2O 3.22.0.1 Xia):
# The following two commands remove any previously installed H2O packages for R.
if ("package:h2o" %in% search()) { detach("package:h2o", unload=TRUE) }
if ("h2o" %in% rownames(installed.packages())) { remove.packages("h2o") }
# Next, we download packages that H2O depends on.
pkgs <- c("RCurl","jsonlite")
for (pkg in pkgs) {
if (! (pkg %in% rownames(installed.packages()))) { install.packages(pkg) }
}
# Now we download, install and initialize the H2O package for R.
install.packages("h2o", type="source", repos="http://h2o-release.s3.amazonaws.com/h2o/rel-xia/2/R")
Connect to H2O with and load titanic dataset into its memory (as an H2O frame):
library(h2o)
h2o.init()
h2o.no_progress()
titanicHex = as.h2o(titanicDT)
While H2O offers no dedicated single decision tree algorithm there are two approaches using superseding models:
h2o.randomForest()
with arguments ntrees = 1
mtries = number of features
(would be determined dynamically at runtime) sample_rate = 1
min_rows = 1
h2o.gbm()
with arguments ntrees = 1
min_rows = 1
sample_rate = 1
col_sample_rate = 1
Choosing GBM option requires one less line of code (no need to calculate number of features to set mtries) so it was used for this post. Otherwise both ways result in the same decision tree with the steps below fully reproducible using h2o.randomForest()
instead of h2o.gbm()
.
When building single decision tree models maximum tree depth stands as the most important parameter to pick. Shallow trees tend to underfit by failing to capture important relationships in data producing similar trees despite varying training data (error due to high bias). On the other hand trees grown too deep overfit by reacting to noise and slight changes in data (error due to high variance). Tuning H2O model’s parameter max_depth
that limits decision tree depth aims at balancing the effects of bias and variance. In R using H2O to split data and to tune the model, then visualizing results with ggplot2
to look for right value unfolds like this:
max_depth
max_depth
values to determine “inflection point” where AUC growth stops or saturates (see plot below)max_depth
parameter.Code below implements these steps:
# split into train and validation
splits = h2o.splitFrame(data = titanicHex, ratios = .8, seed = 123)
trainHex = splits[[1]]
validHex = splits[[2]]
# GBM hyperparamters
gbm_params = list(max_depth = seq(2, 10))
# Train and validate a cartesian grid of GBMs
gbm_grid = h2o.grid("gbm", x = predictors, y = response,
grid_id = "gbm_grid_1tree8",
training_frame = trainHex,
validation_frame = validHex,
ntrees = 1, min_rows = 1, sample_rate = 1, col_sample_rate = 1,
learn_rate = .01, seed = 1111,
hyper_params = gbm_params)
gbm_gridperf = h2o.getGrid(grid_id = "gbm_grid_1tree8",
sort_by = "auc",
decreasing = TRUE)
and results in the plot that points to inflection point for maximum tree depth at 5:
Figure 3. Visualization of AUC vs. maximum tree depth hyper-parameter trend extracted from the H2O grid object after running grid search in H2O. Marked inflection point indicates when increasing maximum tree depth no longer improves model performance on validation set
As evident from the Figure 3 optimal decision tree depth is 5. The code below constructs single decision tree model in H2O and then retrieves tree representation from a GBM model with Tree API function h2o.getModelTree()
, which creates an instance of S4 class H2OTree
and assigns to variable titanicH2oTree
:
titanic_1tree = h2o.gbm(x = predictors, y = response,
training_frame = titanicHex,
ntrees = 1, min_rows = 1, sample_rate = 1, col_sample_rate = 1,
max_depth = 5,
# use early stopping once the validation AUC doesn't improve by at least 0.01%
# for 5 consecutive scoring events
stopping_rounds = 3, stopping_tolerance = 0.01,
stopping_metric = "AUC",
seed = 1)
titanicH2oTree = h2o.getModelTree(model = titanic_1tree, tree_number = 1)
At this point all action moved back inside R with its unparalleled access to analytical and visualization tools. So before navigating and plotting a decision tree - final goal for this post - let’s have brief intro to networks in R.
R offers arguably the richest functionality when it comes to analyzing and visualizing network (graph, tree) objects. Before taking on the task of conquering it spend time visiting a couple of comprehensive articles describing vast landscape of tools and approaches available: Static and dynamic network visualization with R by Katya Ognyanova and Introduction to Network Analysis with R by Jesse Sadler.
To summarize there are two commonly used packages to manage and analyze networks in R: network
(part of statnet family) and igraph
(family in itself). Each package implements namesake classes to represent network structures so there is significant overlap between the two and they mask each other’s functions. Preferred approach is picking only one of two: it appears that igraph
is more common for general-purpose applications while network
is preferred for social network and statistical analysis (my subjective assessment). And while researching these packages do not forget about package intergraph
that seamlessly transforms objects between network
and igraph
classes. (And this analysis stopped short of expanding into universe of R packages hosted on Bioconductor).
When it comes to visualizing networks choices quickly proliferate. Both network and igraph offer graphical functions that use R base plotting system but it doesn’t stop here. Following packages specialize in advanced visualizations for at least one or both of the classes:
ggraph
: an implementation of grammar of graphics for graphs and networks.ggnet2
: available through the GGally
package: a visualization function to plot network objects as ggplot2
objects. It accepts any object that can be coerced to the network
class, including adjacency or incidence matrices, edge lists, or one-mode igraph
network objects.ggnetwork
: geometries to plot networks with ggplot2
.visNetwork
: provides an R interface to the ‘vis.js’ JavaScript charting library.DiagrammeR
: Graph/Network visualization that builds graph/network structures using functions for stepwise addition and deletion of nodes and edges.networkD3
: creates ‘D3’ ‘JavaScript’ network, tree, dendrogram, and Sankey graphs from R.Finally, there is package data.tree
designed specifically to create and analyze trees in R. It fits the bill of representing and visualizing decision trees perfectly, so it became a tool of choice for this post. Still, visualizing H2O model trees could be fully reproduced with any of network and visualization packages mentioned above.
In the last step a decision tree for the model created by GBM moved from H2O cluster memory to H2OTree
object in R by means of Tree API. Still, specific to H2O the H2OTree
object now contains necessary details about decision tree, but not in the format understood by R packages such as data.tree
.
To fill this gap function createDataTree(H2OTree)
below traverses a tree and translates it from H2OTree
object into data.tree
one accumulating information about decision tree splits and predictions into the tree’s node and edge attributes:
library(data.tree)
createDataTree <- function(h2oTree) {
h2oTreeRoot = h2oTree@root_node
dataTree = Node$new(h2oTreeRoot@split_feature)
dataTree$type = 'split'
addChildren(dataTree, h2oTreeRoot)
return(dataTree)
}
addChildren <- function(dtree, node) {
if(class(node)[1] != 'H2OSplitNode') return(TRUE)
feature = node@split_feature
id = node@id
na_direction = node@na_direction
if(is.na(node@threshold)) {
leftEdgeLabel = printValues(node@left_levels, na_direction=='LEFT', 4)
rightEdgeLabel = printValues(node@right_levels, na_direction=='RIGHT', 4)
}else {
leftEdgeLabel = paste("<", node@threshold, ifelse(na_direction=='LEFT',',NA',''))
rightEdgeLabel = paste(">=", node@threshold, ifelse(na_direction=='RIGHT',',NA',''))
}
left_node = node@left_child
right_node = node@right_child
if(class(left_node)[[1]] == 'H2OLeafNode')
leftLabel = paste("prediction:", left_node@prediction)
else
leftLabel = left_node@split_feature
if(class(right_node)[[1]] == 'H2OLeafNode')
rightLabel = paste("prediction:", right_node@prediction)
else
rightLabel = right_node@split_feature
if(leftLabel == rightLabel) {
leftLabel = paste(leftLabel, "(L)")
rightLabel = paste(rightLabel, "(R)")
}
dtreeLeft = dtree$AddChild(leftLabel)
dtreeLeft$edgeLabel = leftEdgeLabel
dtreeLeft$type = ifelse(class(left_node)[1] == 'H2OSplitNode', 'split', 'leaf')
dtreeRight = dtree$AddChild(rightLabel)
dtreeRight$edgeLabel = rightEdgeLabel
dtreeRight$type = ifelse(class(right_node)[1] == 'H2OSplitNode', 'split', 'leaf')
addChildren(dtreeLeft, left_node)
addChildren(dtreeRight, right_node)
return(FALSE)
}
printValues <- function(values, is_na_direction, n=4) {
l = length(values)
if(l == 0)
value_string = ifelse(is_na_direction, "NA", "")
else
value_string = paste0(paste0(values[1:min(n,l)], collapse = ', '),
ifelse(l > n, ",...", ""),
ifelse(is_na_direction, ", NA", ""))
return(value_string)
}
Note that have we decided on using other R network classes network
or igraph
then analogous to function createDataTree(H2OTree)
new functions createNetwork(H2OTree)
or createIgraph(H2OTree)
would have been created to translate H2OTree
to corresponding R representation.
data.table
Finally everything lined up and ready for the final step of plotting decision tree:
data.tree
object for network analysis.Styling and plotting data.tree
objects is built around rich functionality of the DiagrammerR
package. For anything that goes beyond simple plotting read documentation here but also remember that for plotting data.tree
takes advantage of:
GetEdgeLabel(node)
, GetNodeShape(node)
, GetFontName(node)
) to customize tree’s feel and lookThe following code will produce this moderately customized decision tree for our H2O model:
titanicDataTree = createDataTree(titanicH2oTree)
GetEdgeLabel <- function(node) {return (node$edgeLabel)}
GetNodeShape <- function(node) {switch(node$type, split = "diamond", leaf = "oval")}
GetFontName <- function(node) {switch(node$type, split = 'Palatino-bold', leaf = 'Palatino')}
SetEdgeStyle(titanicDataTree, fontname = 'Palatino-italic', label = GetEdgeLabel, labelfloat = TRUE,
fontsize = "26", fontcolor='royalblue4')
SetNodeStyle(titanicDataTree, fontname = GetFontName, shape = GetNodeShape,
fontsize = "26", fontcolor='royalblue4',
height="0.75", width="1")
SetGraphStyle(titanicDataTree, rankdir = "LR", dpi=70.)
plot(titanicDataTree, output = "graph")
Figure 4. H2O Decision Tree for Titanic Model Visualized with data.tree in R
airlinesHex = h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip")
y = "IsDepDelayed"
obvious_dataleakage_features = c("ArrDelay","DepDelay","Cancelled","CancellationCode","Diverted",
"CarrierDelay","WeatherDelay","NASDelay","SecurityDelay",
"LateAircraftDelay","IsArrDelayed","DepTime","CRSDepTime",
"ArrTime","CRSArrTime","ActualElapsedTime","AirTime",
"CRSElapsedTime")
x = setdiff(h2o.colnames(airlinesHex), c(y, obvious_dataleakage_features))
splits = h2o.splitFrame(data = airlinesHex, ratios = .8, seed = 123)
trainHex = splits[[1]]
validHex = splits[[2]]
# GBM hyperparamters
gbm_params = list(max_depth = seq(2, 20))
# Train and validate a cartesian grid of GBMs
gbm_grid = h2o.grid("gbm", x = x, y = y,
grid_id = "gbm_grid_airline1",
training_frame = trainHex,
validation_frame = validHex,
ntrees = 1, min_rows = 1, sample_rate = 1, col_sample_rate = 1,
learn_rate = .01, seed = 1111,
hyper_params = gbm_params)
gbm_gridperf = h2o.getGrid(grid_id = "gbm_grid_airline1",
sort_by = "auc",
decreasing = TRUE)
optimal_depth = as.integer(gbm_gridperf@summary_table$max_depth[[1]])
airlines.gbm = h2o.gbm(x = x, y = y, training_frame = airlinesHex,
ntrees = 1, max_depth = 6, # optimal_depth,
learn_rate = 0.1, distribution = "bernoulli")
airlinesTree = h2o.getModelTree(model = airlines.gbm, tree_number = 1)
dtree = createDataTree(airlinesTree)
GetEdgeLabel <- function(node) {return (node$edgeLabel)}
GetNodeShape <- function(node) {switch(node$type, split = "diamond", leaf = "oval")}
SetEdgeStyle(dtree, fontname = 'Palatino', label = GetEdgeLabel, labelfloat = TRUE)
SetNodeStyle(dtree, fontname = 'Palatino', shape = GetNodeShape)
SetGraphStyle(dtree, rankdir = "LR", dpi=70.)
plot(dtree, output="graph")
Figure 5. H2O Decision Tree for Airline Delay Visualized with data.tree in R