party, ggparty
library(party)
library(partykit)
library(rpart)
library(rpart.plot)
data("WeatherPlay", package = "partykit")
WeatherPlay
## outlook temperature humidity windy play
## 1 sunny 85 85 false no
## 2 sunny 80 90 true no
## 3 overcast 83 86 false yes
## 4 rainy 70 96 false yes
## 5 rainy 68 80 false yes
## 6 rainy 65 70 true no
## 7 overcast 64 65 true yes
## 8 sunny 72 95 false no
## 9 sunny 69 70 false yes
## 10 rainy 75 80 false yes
## 11 sunny 75 70 true yes
## 12 overcast 72 90 true yes
## 13 overcast 81 75 false yes
## 14 rainy 71 91 true no
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
partynode(2L, split = sp_h, kids = list(
partynode(3L, info = "yes"),
partynode(4L, info = "no"))),
partynode(5L, info = "yes"),
partynode(6L, split = sp_w, kids = list(
partynode(7L, info = "yes"),
partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)
set.seed(1234) #To get reproducible result
ind <- sample(2,nrow(iris), replace=TRUE, prob=c(0.7,0.3))
trainData <- iris[ind==1,]
testData <- iris[ind==2,]
myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
iris_ctree <- ctree(myFormula, data=trainData)
train_predict <- predict(iris_ctree,trainData,type="response")
table(train_predict,trainData$Species)
##
## train_predict setosa versicolor virginica
## setosa 40 0 0
## versicolor 0 37 3
## virginica 0 1 31
mean(train_predict != trainData$Species) * 100
## [1] 3.571429
test_predict <- predict(iris_ctree, newdata= testData,type="response")
table(test_predict, testData$Species)
##
## test_predict setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 12 2
## virginica 0 0 14
mean(test_predict != testData$Species) * 100
## [1] 5.263158
##
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##
## Fitted party:
## [1] root
## | [2] Petal.Length <= 1.9: setosa (n = 40, err = 0.0%)
## | [3] Petal.Length > 1.9
## | | [4] Petal.Width <= 1.7
## | | | [5] Petal.Length <= 4.4: versicolor (n = 21, err = 0.0%)
## | | | [6] Petal.Length > 4.4: versicolor (n = 19, err = 15.8%)
## | | [7] Petal.Width > 1.7: virginica (n = 32, err = 3.1%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4

plot(iris_ctree, type="simple")

model.rpart <- rpart(myFormula, data=trainData)
rpart.plot(model.rpart)

library(ggparty)
n1 <- partynode(id = 1L, split = sp_o, kids = lapply(2L:4L, partynode))
t2 <- party(n1,
data = WeatherPlay,
fitted = data.frame(
"(fitted)" = fitted_node(n1, data = WeatherPlay),
"(response)" = WeatherPlay$play,
check.names = FALSE),
terms = terms(play ~ ., data = WeatherPlay)
)
t2 <- as.constparty(t2)
ggplot(t2[2]$data) +
geom_bar(aes(x = "", fill = play),
position = position_fill()) +
xlab("play")

ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")))

data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")
head(TeachingRatings)
## minority age gender credits beauty eval division native tenure students
## 1 yes 36 female more 0.2899157 4.3 upper yes yes 24
## 2 no 59 male more -0.7377322 4.5 upper yes yes 17
## 3 no 51 male more -0.5719836 3.7 upper yes yes 55
## 4 no 40 female more -0.6779634 4.3 upper yes yes 40
## 5 no 31 female more 1.5097940 4.4 upper yes yes 42
## 6 no 62 male more 0.5885687 4.2 upper yes yes 182
## allstudents prof
## 1 43 1
## 2 20 2
## 3 55 3
## 4 46 4
## 5 48 5
## 6 282 6
tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
tenure, data = tr, weights = students, caseweights = FALSE)
tr_tree
## Linear model tree
##
## Model formula:
## eval ~ beauty | minority + age + gender + division + native +
## tenure
##
## Fitted party:
## [1] root
## | [2] gender in male
## | | [3] age <= 50: n = 113
## | | (Intercept) beauty
## | | 3.9967632 0.1291992
## | | [4] age > 50: n = 137
## | | (Intercept) beauty
## | | 4.0857450 0.5028092
## | [5] gender in female
## | | [6] age <= 40: n = 69
## | | (Intercept) beauty
## | | 4.013707 0.122212
## | | [7] age > 40
## | | | [8] division in upper: n = 81
## | | | (Intercept) beauty
## | | | 3.7752097 -0.1975861
## | | | [9] division in lower: n = 36
## | | | (Intercept) beauty
## | | | 3.5899744 0.4032684
##
## Number of inner nodes: 4
## Number of terminal nodes: 5
## Number of parameters per node: 2
## Objective function (residual sum of squares): 2751.654
ggparty(tr_tree) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10)),
shared_axis_labels = TRUE,
legend_separator = TRUE,
# predict based on variable
predict = "beauty",
# graphical parameters for geom_line of predictions
predict_gpar = list(col = "blue",
size = 1))

# create dataframe with ids, densities and breaks
# since we are going to supply the data.frame directly to a geom inside gglist,
# we don't need to worry about the number of observations per id and only data for the ids
# used by the respective geom_node_plot() needs to be generated (2 and 5 in this case)
dens_df <- data.frame(x_dens = numeric(), y_dens = numeric(), id = numeric(), breaks = character())
for (id in c(2, 5)) {
x_dens <- density(tr_tree[id]$data$age)$x
y_dens <- density(tr_tree[id]$data$age)$y
breaks <- rep("left", length(x_dens))
if (id == 2) breaks[x_dens > 50] <- "right"
if (id == 5) breaks[x_dens > 40] <- "right"
dens_df <- rbind(dens_df, data.frame(x_dens, y_dens, id, breaks))
}
# adjust layout so that each node plot has enough space
ggparty(tr_tree, terminal_space = 0.4,
layout = data.frame(id = c(1, 2, 5, 7),
x = c(0.35, 0.15, 0.7, 0.8),
y = c(0.95, 0.6, 0.8, 0.55))) +
# map color of edges to birth_order (order from left to right)
geom_edge(aes(col = factor(birth_order)),
size = 1.2,
alpha = 1,
# exclude root so it doesn't count as it's own colour
ids = -1) +
# density plots for age splits
geom_node_plot(ids = c(2, 5),
gglist = list( # supply dens_df and plot line
geom_line(data = dens_df,
aes(x = x_dens,
y = y_dens),
show.legend = FALSE,
alpha = 0.8),
# supply dens_df and plot ribbon, map color to breaks
geom_ribbon(data = dens_df,
aes(x = x_dens,
ymin = 0,
ymax = y_dens,
fill = breaks),
show.legend = FALSE,
alpha = 0.8),
xlab("age"),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot bar plot of gender at root
geom_node_plot(ids = 1,
gglist = list(geom_bar(aes(x = gender, fill = gender),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot bar plot of division for node 7
geom_node_plot(ids = 7,
gglist = list(geom_bar(aes(x = division, fill = division),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot terminal nodes with predictions
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10),
scale_color_discrete(h.start = 100)),
shared_axis_labels = TRUE,
legend_separator = TRUE,
predict = "beauty",
predict_gpar = list(col = "blue",
size = 1.1)) +
# remove all legends from top level since self explanatory
theme(legend.position = "none")
