Tidymodels is a framework on training statistical and machine learning models using the collection of R packages for modeling and machine learning using tidyverse principles.
library(tidymodels)
## -- Attaching packages -------------------------------------- tidymodels 0.1.1 --
## v broom 0.7.1 v recipes 0.1.13
## v dials 0.0.9 v rsample 0.0.8
## v dplyr 1.0.2 v tibble 3.0.4
## v ggplot2 3.3.2 v tidyr 1.1.2
## v infer 0.5.3 v tune 0.1.1
## v modeldata 0.0.2 v workflows 0.2.1
## v parsnip 0.1.3 v yardstick 0.0.7
## v purrr 0.3.4
## -- Conflicts ----------------------------------------- tidymodels_conflicts() --
## x purrr::discard() masks scales::discard()
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
## x recipes::step() masks stats::step()
Let’s begin the process using the simple steps in data analysis: Data collection/importing, Data Preprocessing, EDA, Modeling and Model Evaluation.
In this article, we will used the iris data.
Let’s begin by splitting the data set. The rsample library under tidymodels contains different way of sampling including simple train-test split and K-fold cross-validation. In this example, we will split the iris data to training and testing. Then later we will apply cross validation in training set to further tune our models.
iris_split <- initial_split(iris, prop = 0.6)
iris_split
## <Analysis/Assess/Total>
## <90/60/150>
iris_split %>%
training() %>%
glimpse()
## Rows: 90
## Columns: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.6, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 5.8, 5.7,...
## $ Sepal.Width <dbl> 3.5, 3.0, 3.1, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 4.0, 4.4,...
## $ Petal.Length <dbl> 1.4, 1.4, 1.5, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.2, 1.5,...
## $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2, 0.4,...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
The recipe library contains a wide range of pre-processing steps like data normalization, centring, scaling, applying PCA and so on.
Here we will remove those variables that are highly correlated to other variables, standardize e.g. center and scale. Observed the prefix step_*. Lastly, we need to call prep() to ‘cook’ or train the data recipe.
iris_recipe <- training(iris_split) %>%
recipe(Species ~.) %>%
step_corr(all_predictors()) %>%
step_center(all_predictors(), -all_outcomes()) %>%
step_scale(all_predictors(), -all_outcomes()) %>%
prep()
iris_recipe
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 4
##
## Training data contained 90 data points and no missing data.
##
## Operations:
##
## Correlation filter removed Petal.Length [trained]
## Centering for Sepal.Length, Sepal.Width, Petal.Width [trained]
## Scaling for Sepal.Length, Sepal.Width, Petal.Width [trained]
## Apply Pre-process steps to testing datasets
iris_testing <- iris_recipe %>%
bake(testing(iris_split))
glimpse(iris_testing)
## Rows: 60
## Columns: 4
## $ Sepal.Length <dbl> -1.3550610, -1.0099040, -0.5496946, -1.2400087, -1.815...
## $ Sepal.Width <dbl> 0.1161296, 1.0883771, 1.8175627, -0.3699942, -0.369994...
## $ Petal.Width <dbl> -1.2129756, -1.2129756, -0.9573133, -1.3408067, -1.340...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
iris_training <- iris_recipe %>%
bake(training(iris_split))
glimpse(iris_training)
## Rows: 90
## Columns: 4
## $ Sepal.Length <dbl> -0.89485163, -1.12495634, -1.47011340, -1.47011340, -1...
## $ Sepal.Width <dbl> 0.8453152, -0.3699942, -0.1269323, 0.6022533, 0.602253...
## $ Petal.Width <dbl> -1.2129756, -1.2129756, -1.2129756, -1.0851444, -1.212...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
You can see that we apply bake() as well to training dataset, but we don’t need to since prep() saves the training dataset and we can get it using juice() function.
## juice() function
iris_training <- juice(iris_recipe)
glimpse(iris_training)
## Rows: 90
## Columns: 4
## $ Sepal.Length <dbl> -0.89485163, -1.12495634, -1.47011340, -1.47011340, -1...
## $ Sepal.Width <dbl> 0.8453152, -0.3699942, -0.1269323, 0.6022533, 0.602253...
## $ Petal.Width <dbl> -1.2129756, -1.2129756, -1.2129756, -1.0851444, -1.212...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
We can now train our model:
iris_ranger <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("ranger") %>%
fit(Species ~ ., data = iris_training)
iris_rf <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("randomForest") %>%
fit(Species ~ ., data = iris_training)
predict(iris_ranger, iris_testing)
## # A tibble: 60 x 1
## .pred_class
## <fct>
## 1 setosa
## 2 setosa
## 3 setosa
## 4 setosa
## 5 setosa
## 6 setosa
## 7 setosa
## 8 setosa
## 9 setosa
## 10 setosa
## # ... with 50 more rows
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
glimpse()
## Rows: 60
## Columns: 5
## $ .pred_class <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
## $ Sepal.Length <dbl> -1.3550610, -1.0099040, -0.5496946, -1.2400087, -1.815...
## $ Sepal.Width <dbl> 0.1161296, 1.0883771, 1.8175627, -0.3699942, -0.369994...
## $ Petal.Width <dbl> -1.2129756, -1.2129756, -0.9573133, -1.3408067, -1.340...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
iris_rf %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
glimpse()
## Rows: 60
## Columns: 5
## $ .pred_class <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
## $ Sepal.Length <dbl> -1.3550610, -1.0099040, -0.5496946, -1.2400087, -1.815...
## $ Sepal.Width <dbl> 0.1161296, 1.0883771, 1.8175627, -0.3699942, -0.369994...
## $ Petal.Width <dbl> -1.2129756, -1.2129756, -0.9573133, -1.3408067, -1.340...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa...
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.933
## 2 kap multiclass 0.899
iris_rf %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.95
## 2 kap multiclass 0.924
Also we can get other metrics if we get the estimated probability first:
iris_prob <- predict(iris_ranger, iris_testing, type = "prob") %>%
bind_cols(predict(iris_ranger, iris_testing)) %>%
bind_cols(select(iris_testing, Species))
iris_prob_tr <- predict(iris_ranger, iris_training, type = "prob") %>%
bind_cols(predict(iris_ranger, iris_training)) %>%
bind_cols(select(iris_training, Species))
# Model evaluation with sensitivity
iris_prob %>%
metrics(Species, estimate = .pred_class, .pred_setosa:.pred_virginica) %>%
bind_rows(sens(iris_prob, Species, estimate = .pred_class))
## # A tibble: 5 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.933
## 2 kap multiclass 0.899
## 3 mn_log_loss multiclass 0.264
## 4 roc_auc hand_till 0.985
## 5 sens macro 0.939
iris_prob_tr %>%
metrics(Species, estimate = .pred_class, .pred_setosa:.pred_virginica) %>%
bind_rows(sens(iris_prob, Species, estimate = .pred_class))
## # A tibble: 5 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.978
## 2 kap multiclass 0.967
## 3 mn_log_loss multiclass 0.119
## 4 roc_auc hand_till 0.998
## 5 sens macro 0.939
Now let’s see how to perform cross validation for Hyperparameter tuning using the libraries under tidymodels
Example here is taken directly from https://www.tidymodels.org/start/resampling/.
The data that we will be using is from Hill, LaPan, Li, and Haney (2007). We can get the data from modeldata package and let’s create the model to predict the image segmentation quality.
data(cells, package = "modeldata")
cells
## # A tibble: 2,019 x 58
## case class angle_ch_1 area_ch_1 avg_inten_ch_1 avg_inten_ch_2 avg_inten_ch_3
## <fct> <fct> <dbl> <int> <dbl> <dbl> <dbl>
## 1 Test PS 143. 185 15.7 4.95 9.55
## 2 Train PS 134. 819 31.9 207. 69.9
## 3 Train WS 107. 431 28.0 116. 63.9
## 4 Train PS 69.2 298 19.5 102. 28.2
## 5 Test PS 2.89 285 24.3 112. 20.5
## 6 Test WS 40.7 172 326. 654. 129.
## 7 Test WS 174. 177 260. 596. 124.
## 8 Test PS 180. 251 18.3 5.73 17.2
## 9 Test WS 18.9 495 16.1 89.5 13.7
## 10 Test WS 153. 384 17.7 89.9 20.4
## # ... with 2,009 more rows, and 51 more variables: avg_inten_ch_4 <dbl>,
## # convex_hull_area_ratio_ch_1 <dbl>, convex_hull_perim_ratio_ch_1 <dbl>,
## # diff_inten_density_ch_1 <dbl>, diff_inten_density_ch_3 <dbl>,
## # diff_inten_density_ch_4 <dbl>, entropy_inten_ch_1 <dbl>,
## # entropy_inten_ch_3 <dbl>, entropy_inten_ch_4 <dbl>,
## # eq_circ_diam_ch_1 <dbl>, eq_ellipse_lwr_ch_1 <dbl>,
## # eq_ellipse_oblate_vol_ch_1 <dbl>, eq_ellipse_prolate_vol_ch_1 <dbl>,
## # eq_sphere_area_ch_1 <dbl>, eq_sphere_vol_ch_1 <dbl>,
## # fiber_align_2_ch_3 <dbl>, fiber_align_2_ch_4 <dbl>,
## # fiber_length_ch_1 <dbl>, fiber_width_ch_1 <dbl>, inten_cooc_asm_ch_3 <dbl>,
## # inten_cooc_asm_ch_4 <dbl>, inten_cooc_contrast_ch_3 <dbl>,
## # inten_cooc_contrast_ch_4 <dbl>, inten_cooc_entropy_ch_3 <dbl>,
## # inten_cooc_entropy_ch_4 <dbl>, inten_cooc_max_ch_3 <dbl>,
## # inten_cooc_max_ch_4 <dbl>, kurt_inten_ch_1 <dbl>, kurt_inten_ch_3 <dbl>,
## # kurt_inten_ch_4 <dbl>, length_ch_1 <dbl>, neighbor_avg_dist_ch_1 <dbl>,
## # neighbor_min_dist_ch_1 <dbl>, neighbor_var_dist_ch_1 <dbl>,
## # perim_ch_1 <dbl>, shape_bfr_ch_1 <dbl>, shape_lwr_ch_1 <dbl>,
## # shape_p_2_a_ch_1 <dbl>, skew_inten_ch_1 <dbl>, skew_inten_ch_3 <dbl>,
## # skew_inten_ch_4 <dbl>, spot_fiber_count_ch_3 <int>,
## # spot_fiber_count_ch_4 <dbl>, total_inten_ch_1 <int>,
## # total_inten_ch_2 <dbl>, total_inten_ch_3 <int>, total_inten_ch_4 <int>,
## # var_inten_ch_1 <dbl>, var_inten_ch_3 <dbl>, var_inten_ch_4 <dbl>,
## # width_ch_1 <dbl>
There are 2,019 cells, with 58 variables. The class is the what we want to predict, which you can see is a factor.
Some biologists conduct experiments on cells. In drug discovery, a particular type of cell can be treated with either a drug or control and then observed to see what the effect is (if any). A common approach for this kind of measurement is cell imaging. Different parts of the cells can be colored so that the locations of a cell can be determined.
For example, in top panel of this image of five cells, the green color is meant to define the boundary of the cell (coloring something called the cytoskeleton) while the blue color defines the nucleus of the cell.
Biologists then will measure the different metrics from the image such as the size “oblongness” of the each cell. Poorly segmented (PS) cells such as 2,3 and 4 would make their metric inaccurate, and so their analyses. Cell 1 and 5 are examples of well segmented cells (WS). If we can predict the labels (PS/WS) accurately, the larger data set can be improved by filtering out the cells most likely to be poorly segmented.
Before training the model, we will still follow the train-test split method. We will use the training set to create a models and will use the test set to estimate the generability of our model to unseen data.
We will not used the case column to determine which will belong to training and test data set.
Later when we will be tuning the model, we will use resampling to use the training set only.
set.seed(123)
cell_split <- initial_split(cells %>% select(-case),
strata = class)
cell_train <- training(cell_split)
cell_test <- testing(cell_split)
We will use random forest and decision tree.
tune_spec <-
decision_tree(
cost_complexity = tune(),
tree_depth = tune()
) %>%
set_engine("rpart") %>%
set_mode("classification")
tree_grid <- grid_regular(cost_complexity(),
tree_depth(),
levels = 5)
We will used k-fold cross-validation with k-10
set.seed(234)
cell_folds <- vfold_cv(cell_train)
set.seed(345)
tree_wf <- workflow() %>%
add_model(tune_spec) %>%
add_formula(class ~ .)
tree_res <-
tree_wf %>%
tune_grid(
resamples = cell_folds,
grid = tree_grid
)
tree_res
## # Tuning results
## # 10-fold cross-validation
## # A tibble: 10 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [1.4K/152]> Fold01 <tibble [50 x 6]> <tibble [0 x 1]>
## 2 <split [1.4K/152]> Fold02 <tibble [50 x 6]> <tibble [0 x 1]>
## 3 <split [1.4K/152]> Fold03 <tibble [50 x 6]> <tibble [0 x 1]>
## 4 <split [1.4K/152]> Fold04 <tibble [50 x 6]> <tibble [0 x 1]>
## 5 <split [1.4K/152]> Fold05 <tibble [50 x 6]> <tibble [0 x 1]>
## 6 <split [1.4K/151]> Fold06 <tibble [50 x 6]> <tibble [0 x 1]>
## 7 <split [1.4K/151]> Fold07 <tibble [50 x 6]> <tibble [0 x 1]>
## 8 <split [1.4K/151]> Fold08 <tibble [50 x 6]> <tibble [0 x 1]>
## 9 <split [1.4K/151]> Fold09 <tibble [50 x 6]> <tibble [0 x 1]>
## 10 <split [1.4K/151]> Fold10 <tibble [50 x 6]> <tibble [0 x 1]>
Let’s visualize:
tree_res %>%
collect_metrics() %>%
mutate(tree_depth = factor(tree_depth)) %>%
ggplot(aes(cost_complexity, mean, color = tree_depth)) +
geom_line(size = 1.5, alpha = 0.6) +
geom_point(size = 2) +
facet_wrap(~ .metric, scales = "free", nrow = 2) +
scale_x_log10(labels = scales::label_number()) +
scale_color_viridis_d(option = "plasma", begin = .9, end = 0)
Get the best parameters:
best_tree <- tree_res %>%
select_best("roc_auc")
best_tree
## # A tibble: 1 x 3
## cost_complexity tree_depth .config
## <dbl> <int> <fct>
## 1 0.0000000001 4 Model06
Update workflow with the values from the best parameters
final_wf <-
tree_wf %>%
finalize_workflow(best_tree)
final_wf
## == Workflow ====================================================================
## Preprocessor: Formula
## Model: decision_tree()
##
## -- Preprocessor ----------------------------------------------------------------
## class ~ .
##
## -- Model -----------------------------------------------------------------------
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = 1e-10
## tree_depth = 4
##
## Computational engine: rpart
Now, fit the model again to whole training set with the new workflow.
final_tree <-
final_wf %>%
fit(data = cell_train)
final_tree
## == Workflow [trained] ==========================================================
## Preprocessor: Formula
## Model: decision_tree()
##
## -- Preprocessor ----------------------------------------------------------------
## class ~ .
##
## -- Model -----------------------------------------------------------------------
## n= 1515
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1515 540 PS (0.64356436 0.35643564)
## 2) total_inten_ch_2< 39344 623 28 PS (0.95505618 0.04494382) *
## 3) total_inten_ch_2>=39344 892 380 WS (0.42600897 0.57399103)
## 6) fiber_width_ch_1< 10.6594 349 122 PS (0.65042980 0.34957020)
## 12) avg_inten_ch_1< 195.0149 287 74 PS (0.74216028 0.25783972) *
## 13) avg_inten_ch_1>=195.0149 62 14 WS (0.22580645 0.77419355) *
## 7) fiber_width_ch_1>=10.6594 543 153 WS (0.28176796 0.71823204)
## 14) shape_p_2_a_ch_1>=1.215913 381 136 WS (0.35695538 0.64304462)
## 28) total_inten_ch_4>=102129.5 65 25 PS (0.61538462 0.38461538) *
## 29) total_inten_ch_4< 102129.5 316 96 WS (0.30379747 0.69620253) *
## 15) shape_p_2_a_ch_1< 1.215913 162 17 WS (0.10493827 0.89506173) *
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
final_tree %>%
pull_workflow_fit() %>%
vip()
Evaluating Model using Test Set
final_fit <-
final_wf %>%
last_fit(cell_split)
final_fit %>%
collect_metrics()
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.786
## 2 roc_auc binary 0.850
final_fit %>%
collect_predictions() %>%
roc_curve(class, .pred_PS) %>%
autoplot()
cells_pred <- predict(final_tree, cell_test) %>%
bind_cols(predict(final_tree, cell_test, type="prob")) %>%
bind_cols(select(cell_test,class))
cells_pred
## # A tibble: 504 x 4
## .pred_class .pred_PS .pred_WS class
## <fct> <dbl> <dbl> <fct>
## 1 PS 0.955 0.0449 PS
## 2 WS 0.304 0.696 WS
## 3 PS 0.955 0.0449 PS
## 4 WS 0.304 0.696 PS
## 5 PS 0.955 0.0449 PS
## 6 PS 0.955 0.0449 PS
## 7 PS 0.615 0.385 WS
## 8 PS 0.955 0.0449 PS
## 9 WS 0.304 0.696 WS
## 10 PS 0.955 0.0449 PS
## # ... with 494 more rows
Let’s measure the accuracy:
cells_pred %>%
accuracy(truth = class, .pred_class)
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.786