On Dec 10th, Simon Jackson blogged a new article about Grid Search in the tidyverse: https://drsimonj.svbtle.com/grid-search-in-the-tidyverse
As a beginner in this field, Grid Search is a new term for me.
Simon explained grid search in this way: “grid search involves running a model many times with combinations of various hyperparameters. The point is to identify which hyperparameters are likely to work best.”.
In Wikipedia, “The traditional way of performing hyperparameter optimization has been grid search, or a parameter sweep, which is simply an exhaustive searching through a manually specified subset of the hyperparameter space of a learning algorithm.”
I will follow Simon’s script on Iris Dataset and markdown more details for beginners.
Iris dataset is very famous and predictable. Let’s predict the Species by other four features
library(tidyverse)
## Loading tidyverse: ggplot2
## Loading tidyverse: tibble
## Loading tidyverse: tidyr
## Loading tidyverse: readr
## Loading tidyverse: purrr
## Loading tidyverse: dplyr
## Conflicts with tidy packages ----------------------------------------------
## filter(): dplyr, stats
## lag(): dplyr, stats
d <- iris
ggplot(d, aes(Sepal.Length, Sepal.Width, color = Species)) + geom_point()
ggplot(d, aes(Sepal.Length, Petal.Length, color = Species)) + geom_point()
ggplot(d, aes(Sepal.Length, Petal.Width, color = Species)) + geom_point()
ggplot(d, aes(Sepal.Width, Petal.Length, color = Species)) + geom_point()
ggplot(d, aes(Sepal.Width, Petal.Width, color = Species)) + geom_point()
ggplot(d, aes(Petal.Length, Petal.Width, color = Species)) + geom_point()
tidyverse: this package is a set of packages that work in harmony because they share common data representations and ‘API’ design. This package is designed to make it easy to install and load multiple tidyverse packages in a single step. Learn more about the tidyverse at https://github.com/hadley/tidyverse. It will imports: broom, DBI, dplyr, forcats, ggplot2, haven, httr, hms, jsonlite, lubridate, magrittr, modelr, purrr, readr, readxl, stringr, tibble, rvest, tidyr, xml2.
ggplot(): initializes a ggplot object. It can be used to declare the input data frame for a graphic and to specify the set of plot aesthetics intended to be common throughout all subsequent layers unless specifically overridden.
geom_point(): The point geom is used to create scatterplots.
From the scatterplots, it’s obvious that setosa species can be identified strongly.
Let’s do a decision tree.
library(rpart)
library(rpart.plot)
# Set minsplit = 2 to fit every data point
full_fit <- rpart(Species ~ ., data = d, minsplit = 2)
prp(full_fit)
rpart: Recursive partitioning for classification, regression and survival trees.
rpart.plot: Plot rpart Models: An Enhanced Version of plot.rpart
rpart(): Fit an rpart model
prp(): Plot an rpart model.
Compared with the tree in Simon’s article, this tree is perfect at this point. But I still want to know which hyperparameter specification can surprise me.
80% training, and 20% testing.
set.seed(66666)
n <- nrow(d)
train_rows <- sample(seq(n), size = .8 * n)
train <- d[ train_rows, ]
test <- d[-train_rows, ]
set.seed(): the recommended way to specify seeds.
nrow(): return the number of rows.
sample(): takes a sample of the specified size from the elements of x using either with or without replacement.
Define hyperparameter combinations: (you can try different value for minsplit and maxdepth)
gs <- list(minsplit = c(2, 5, 10),
maxdepth = c(1, 3, 5)) %>%
cross_d() # Convert to data frame grid
gs
## # A tibble: 9 × 2
## minsplit maxdepth
## <dbl> <dbl>
## 1 2 1
## 2 5 1
## 3 10 1
## 4 2 3
## 5 5 3
## 6 10 3
## 7 2 5
## 8 5 5
## 9 10 5
list(): Functions to construct, coerce and check for both kinds of R lists.
cross_d(): cross_d() takes a list .l and returns a data frame, with one combination by row.
Create a function to go through our grid search hyperparameter combinations and modeling easily.
mod <- function(...) {
rpart(Species ~ ., data = train, control = rpart.control(...))
}
function(): provides the base mechanisms for defining new functions in the R language.Iterate down the values and fit our models.
gs <- gs %>% mutate(fit = pmap(gs, mod))
gs
## # A tibble: 9 × 3
## minsplit maxdepth fit
## <dbl> <dbl> <list>
## 1 2 1 <S3: rpart>
## 2 5 1 <S3: rpart>
## 3 10 1 <S3: rpart>
## 4 2 3 <S3: rpart>
## 5 5 3 <S3: rpart>
## 6 10 3 <S3: rpart>
## 7 2 5 <S3: rpart>
## 8 5 5 <S3: rpart>
## 9 10 5 <S3: rpart>
mutate(): Mutate adds new variables and preserves existing.
pmap(): pmap allows you to provide any number of arguments in a list.
Create a function to get accuracy easily:
compute_accuracy <- function(fit, test_features, test_labels) {
predicted <- predict(fit, test_features, type = "class")
mean(predicted == test_labels)
}
predict(): a generic function for predictions from the results of various model fitting functions.
mean(): Generic function for the (trimmed) arithmetic mean.
Apply to fit models:
test_features <- test %>% select(-Species)
test_labels <- test$Species
gs <- gs %>%
mutate(test_accuracy = map_dbl(fit, compute_accuracy,
test_features, test_labels))
gs
## # A tibble: 9 × 4
## minsplit maxdepth fit test_accuracy
## <dbl> <dbl> <list> <dbl>
## 1 2 1 <S3: rpart> 0.7333333
## 2 5 1 <S3: rpart> 0.7333333
## 3 10 1 <S3: rpart> 0.7333333
## 4 2 3 <S3: rpart> 0.9333333
## 5 5 3 <S3: rpart> 0.9333333
## 6 10 3 <S3: rpart> 0.9666667
## 7 2 5 <S3: rpart> 0.9333333
## 8 5 5 <S3: rpart> 0.9333333
## 9 10 5 <S3: rpart> 0.9666667
select(): keeps only the variables you mention
map_dbl(): return vectors of the corresponding type
Now we can see all the accuracy results of our fits.
Sort results by high-accuracy, more-minsplit and less-maxdepth:
gs <- gs %>% arrange(desc(test_accuracy), desc(minsplit), maxdepth)
gs
## # A tibble: 9 × 4
## minsplit maxdepth fit test_accuracy
## <dbl> <dbl> <list> <dbl>
## 1 10 3 <S3: rpart> 0.9666667
## 2 10 5 <S3: rpart> 0.9666667
## 3 5 3 <S3: rpart> 0.9333333
## 4 5 5 <S3: rpart> 0.9333333
## 5 2 3 <S3: rpart> 0.9333333
## 6 2 5 <S3: rpart> 0.9333333
## 7 10 1 <S3: rpart> 0.7333333
## 8 5 1 <S3: rpart> 0.7333333
## 9 2 1 <S3: rpart> 0.7333333
arrange(): Arrange rows by variables.
desc(): In the order of descending.
Our best fit is minsplit=10 and maxdepth=3 combination.
Iris data is very clean and predictable, I should have tried more chanllenging and complex dataset.
Arranging your results, I mean acuuracy or something else, will help you when doing a large project.
Learning from other’s scripts is one of the best ways to understand and master a language.
Do what you like.