1 party, ggparty

library(party)
library(partykit)
library(rpart)
library(rpart.plot)


data("WeatherPlay", package = "partykit")
WeatherPlay
##     outlook temperature humidity windy play
## 1     sunny          85       85 false   no
## 2     sunny          80       90  true   no
## 3  overcast          83       86 false  yes
## 4     rainy          70       96 false  yes
## 5     rainy          68       80 false  yes
## 6     rainy          65       70  true   no
## 7  overcast          64       65  true  yes
## 8     sunny          72       95 false   no
## 9     sunny          69       70 false  yes
## 10    rainy          75       80 false  yes
## 11    sunny          75       70  true  yes
## 12 overcast          72       90  true  yes
## 13 overcast          81       75 false  yes
## 14    rainy          71       91  true   no
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
    partynode(3L, info = "yes"),
    partynode(4L, info = "no"))),
  partynode(5L, info = "yes"),
  partynode(6L, split = sp_w, kids = list(
    partynode(7L, info = "yes"),
    partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

set.seed(1234) #To get reproducible result
ind <- sample(2,nrow(iris), replace=TRUE, prob=c(0.7,0.3))
trainData <- iris[ind==1,]
testData <- iris[ind==2,]


myFormula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
iris_ctree <- ctree(myFormula, data=trainData)

train_predict <- predict(iris_ctree,trainData,type="response")
table(train_predict,trainData$Species)
##              
## train_predict setosa versicolor virginica
##    setosa         40          0         0
##    versicolor      0         37         3
##    virginica       0          1        31
mean(train_predict != trainData$Species) * 100
## [1] 3.571429
test_predict <- predict(iris_ctree, newdata= testData,type="response")
table(test_predict, testData$Species)
##             
## test_predict setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0         12         2
##   virginica       0          0        14
mean(test_predict != testData$Species) * 100
## [1] 5.263158
print(iris_ctree)
## 
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
## 
## Fitted party:
## [1] root
## |   [2] Petal.Length <= 1.9: setosa (n = 40, err = 0.0%)
## |   [3] Petal.Length > 1.9
## |   |   [4] Petal.Width <= 1.7
## |   |   |   [5] Petal.Length <= 4.4: versicolor (n = 21, err = 0.0%)
## |   |   |   [6] Petal.Length > 4.4: versicolor (n = 19, err = 15.8%)
## |   |   [7] Petal.Width > 1.7: virginica (n = 32, err = 3.1%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4
plot(iris_ctree)

plot(iris_ctree, type="simple")

model.rpart <- rpart(myFormula, data=trainData)
rpart.plot(model.rpart)

library(ggparty)

n1 <- partynode(id = 1L, split = sp_o, kids = lapply(2L:4L, partynode))
t2 <- party(n1,
            data = WeatherPlay,
            fitted = data.frame(
              "(fitted)" = fitted_node(n1, data = WeatherPlay),
              "(response)" = WeatherPlay$play,
              check.names = FALSE),
            terms = terms(play ~ ., data = WeatherPlay)
)
t2 <- as.constparty(t2)


ggplot(t2[2]$data) +
  geom_bar(aes(x = "", fill = play),
           position = position_fill()) +
  xlab("play")

ggparty(t2) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  # pass list to gglist containing all ggplot components we want to plot for each
  # (default: terminal) node
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
                                        position = position_fill()),
                               xlab("play")))

data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")
head(TeachingRatings)
##   minority age gender credits     beauty eval division native tenure students
## 1      yes  36 female    more  0.2899157  4.3    upper    yes    yes       24
## 2       no  59   male    more -0.7377322  4.5    upper    yes    yes       17
## 3       no  51   male    more -0.5719836  3.7    upper    yes    yes       55
## 4       no  40 female    more -0.6779634  4.3    upper    yes    yes       40
## 5       no  31 female    more  1.5097940  4.4    upper    yes    yes       42
## 6       no  62   male    more  0.5885687  4.2    upper    yes    yes      182
##   allstudents prof
## 1          43    1
## 2          20    2
## 3          55    3
## 4          46    4
## 5          48    5
## 6         282    6
tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
                    tenure, data = tr, weights = students, caseweights = FALSE)
tr_tree
## Linear model tree
## 
## Model formula:
## eval ~ beauty | minority + age + gender + division + native + 
##     tenure
## 
## Fitted party:
## [1] root
## |   [2] gender in male
## |   |   [3] age <= 50: n = 113
## |   |       (Intercept)      beauty 
## |   |         3.9967632   0.1291992 
## |   |   [4] age > 50: n = 137
## |   |       (Intercept)      beauty 
## |   |         4.0857450   0.5028092 
## |   [5] gender in female
## |   |   [6] age <= 40: n = 69
## |   |       (Intercept)      beauty 
## |   |          4.013707    0.122212 
## |   |   [7] age > 40
## |   |   |   [8] division in upper: n = 81
## |   |   |       (Intercept)      beauty 
## |   |   |         3.7752097  -0.1975861 
## |   |   |   [9] division in lower: n = 36
## |   |   |       (Intercept)      beauty 
## |   |   |         3.5899744   0.4032684 
## 
## Number of inner nodes:    4
## Number of terminal nodes: 5
## Number of parameters per node: 2
## Objective function (residual sum of squares): 2751.654
ggparty(tr_tree) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(geom_point(aes(x = beauty,
                                              y = eval,
                                              col = tenure,
                                              shape = minority),
                                          alpha = 0.8),
                               theme_bw(base_size = 10)),
                 shared_axis_labels = TRUE,
                 legend_separator = TRUE,
                 # predict based on variable
                 predict = "beauty",
                 # graphical parameters for geom_line of predictions
                 predict_gpar = list(col = "blue",
                                     size = 1))

# create dataframe with ids, densities and breaks
# since we are going to supply the data.frame directly to a geom inside gglist,
# we don't need to worry about the number of observations per id and only data for the ids
# used by the respective geom_node_plot() needs to be generated (2 and 5 in this case)
dens_df <- data.frame(x_dens = numeric(), y_dens = numeric(), id = numeric(), breaks = character())
for (id in c(2, 5)) {
  x_dens <- density(tr_tree[id]$data$age)$x
  y_dens <- density(tr_tree[id]$data$age)$y
  breaks <- rep("left", length(x_dens))
  if (id == 2) breaks[x_dens > 50] <- "right"
  if (id == 5) breaks[x_dens > 40] <- "right"
  dens_df <- rbind(dens_df, data.frame(x_dens, y_dens, id, breaks))
}

# adjust layout so that each node plot has enough space
ggparty(tr_tree, terminal_space = 0.4,
        layout = data.frame(id = c(1, 2, 5, 7),
                            x = c(0.35, 0.15, 0.7, 0.8),
                            y = c(0.95, 0.6, 0.8, 0.55))) +
  # map color of edges to birth_order (order from left to right)
  geom_edge(aes(col = factor(birth_order)),
            size = 1.2,
            alpha = 1,
            # exclude root so it doesn't count as it's own colour
            ids = -1) +
  # density plots for age splits
  geom_node_plot(ids = c(2, 5),
                 gglist = list( # supply dens_df and plot line
                   geom_line(data = dens_df,
                             aes(x = x_dens,
                                 y = y_dens),
                             show.legend = FALSE,
                             alpha = 0.8),
                   # supply dens_df and plot ribbon, map color to breaks
                   geom_ribbon(data = dens_df,
                               aes(x = x_dens,
                                   ymin = 0,
                                   ymax = y_dens,
                                   fill = breaks),
                               show.legend = FALSE,
                               alpha = 0.8),
                   xlab("age"),
                   theme_bw(),
                   theme(axis.title.y = element_blank())),
                 size = 1.5,
                 height = 0.5
  ) +
  # plot bar plot of gender at root
  geom_node_plot(ids = 1,
                 gglist = list(geom_bar(aes(x = gender, fill = gender),
                                        show.legend = FALSE,
                                        alpha = .8),
                               theme_bw(),
                               theme(axis.title.y = element_blank())),
                 size = 1.5,
                 height = 0.5
  ) +
  # plot bar plot of division for node 7
  geom_node_plot(ids = 7,
                 gglist = list(geom_bar(aes(x = division, fill = division),
                                        show.legend = FALSE,
                                        alpha = .8),
                               theme_bw(),
                               theme(axis.title.y = element_blank())),
                 size = 1.5,
                 height = 0.5
  ) +
  # plot terminal nodes with predictions
  geom_node_plot(gglist = list(geom_point(aes(x = beauty,
                                              y = eval,
                                              col = tenure,
                                              shape = minority),
                                          alpha = 0.8),
                               theme_bw(base_size = 10),
                               scale_color_discrete(h.start = 100)),
                 shared_axis_labels = TRUE,
                 legend_separator = TRUE,
                 predict = "beauty",
                 predict_gpar = list(col = "blue",
                                     size = 1.1)) +
  # remove all legends from top level since self explanatory
  theme(legend.position = "none")