Introduction

On Dec 10th, Simon Jackson blogged a new article about Grid Search in the tidyverse: https://drsimonj.svbtle.com/grid-search-in-the-tidyverse

As a beginner in this field, Grid Search is a new term for me.

I will follow Simon’s script on Iris Dataset and markdown more details for beginners.

Decision Tree Example of Iris

Iris dataset is very famous and predictable. Let’s predict the Species by other four features

library(tidyverse)
## Loading tidyverse: ggplot2
## Loading tidyverse: tibble
## Loading tidyverse: tidyr
## Loading tidyverse: readr
## Loading tidyverse: purrr
## Loading tidyverse: dplyr
## Conflicts with tidy packages ----------------------------------------------
## filter(): dplyr, stats
## lag():    dplyr, stats
d <- iris
ggplot(d, aes(Sepal.Length, Sepal.Width, color = Species)) + geom_point()

ggplot(d, aes(Sepal.Length, Petal.Length, color = Species)) + geom_point()

ggplot(d, aes(Sepal.Length, Petal.Width, color = Species)) + geom_point()

ggplot(d, aes(Sepal.Width, Petal.Length, color = Species)) + geom_point()

ggplot(d, aes(Sepal.Width, Petal.Width, color = Species)) + geom_point()

ggplot(d, aes(Petal.Length, Petal.Width, color = Species)) + geom_point()

From the scatterplots, it’s obvious that setosa species can be identified strongly.

Let’s do a decision tree.

library(rpart)
library(rpart.plot)
# Set minsplit = 2 to fit every data point
full_fit <- rpart(Species ~ ., data = d, minsplit = 2)
prp(full_fit)

Compared with the tree in Simon’s article, this tree is perfect at this point. But I still want to know which hyperparameter specification can surprise me.

Split Training and Testing

80% training, and 20% testing.

set.seed(66666)
n <- nrow(d)
train_rows <- sample(seq(n), size = .8 * n)
train <- d[ train_rows, ]
test  <- d[-train_rows, ]

Create the Grid

Define hyperparameter combinations: (you can try different value for minsplit and maxdepth)

gs <- list(minsplit = c(2, 5, 10),
           maxdepth = c(1, 3, 5)) %>% 
  cross_d() # Convert to data frame grid
gs
## # A tibble: 9 × 2
##   minsplit maxdepth
##      <dbl>    <dbl>
## 1        2        1
## 2        5        1
## 3       10        1
## 4        2        3
## 5        5        3
## 6       10        3
## 7        2        5
## 8        5        5
## 9       10        5

Create a Model Function

Create a function to go through our grid search hyperparameter combinations and modeling easily.

mod <- function(...) {
  rpart(Species ~ ., data = train, control = rpart.control(...))
}

Fit the Models

Iterate down the values and fit our models.

gs <- gs %>% mutate(fit = pmap(gs, mod))
gs
## # A tibble: 9 × 3
##   minsplit maxdepth         fit
##      <dbl>    <dbl>      <list>
## 1        2        1 <S3: rpart>
## 2        5        1 <S3: rpart>
## 3       10        1 <S3: rpart>
## 4        2        3 <S3: rpart>
## 5        5        3 <S3: rpart>
## 6       10        3 <S3: rpart>
## 7        2        5 <S3: rpart>
## 8        5        5 <S3: rpart>
## 9       10        5 <S3: rpart>

Obtain Accurary

Create a function to get accuracy easily:

compute_accuracy <- function(fit, test_features, test_labels) {
  predicted <- predict(fit, test_features, type = "class")
  mean(predicted == test_labels)
}

Apply to fit models:

test_features <- test %>% select(-Species)
test_labels   <- test$Species
gs <- gs %>%
  mutate(test_accuracy = map_dbl(fit, compute_accuracy,
                                 test_features, test_labels))
gs
## # A tibble: 9 × 4
##   minsplit maxdepth         fit test_accuracy
##      <dbl>    <dbl>      <list>         <dbl>
## 1        2        1 <S3: rpart>     0.7333333
## 2        5        1 <S3: rpart>     0.7333333
## 3       10        1 <S3: rpart>     0.7333333
## 4        2        3 <S3: rpart>     0.9333333
## 5        5        3 <S3: rpart>     0.9333333
## 6       10        3 <S3: rpart>     0.9666667
## 7        2        5 <S3: rpart>     0.9333333
## 8        5        5 <S3: rpart>     0.9333333
## 9       10        5 <S3: rpart>     0.9666667

Now we can see all the accuracy results of our fits.

Arrange results

Sort results by high-accuracy, more-minsplit and less-maxdepth:

gs <- gs %>% arrange(desc(test_accuracy), desc(minsplit), maxdepth)
gs
## # A tibble: 9 × 4
##   minsplit maxdepth         fit test_accuracy
##      <dbl>    <dbl>      <list>         <dbl>
## 1       10        3 <S3: rpart>     0.9666667
## 2       10        5 <S3: rpart>     0.9666667
## 3        5        3 <S3: rpart>     0.9333333
## 4        5        5 <S3: rpart>     0.9333333
## 5        2        3 <S3: rpart>     0.9333333
## 6        2        5 <S3: rpart>     0.9333333
## 7       10        1 <S3: rpart>     0.7333333
## 8        5        1 <S3: rpart>     0.7333333
## 9        2        1 <S3: rpart>     0.7333333

Our best fit is minsplit=10 and maxdepth=3 combination.

Take Away

  • Iris data is very clean and predictable, I should have tried more chanllenging and complex dataset.

  • Arranging your results, I mean acuuracy or something else, will help you when doing a large project.

  • Learning from other’s scripts is one of the best ways to understand and master a language.

  • Do what you like.