library(tidyverse)
library(tidymodels)Classification model with tidymodels and UCI glass identification data set
On the identification of glass using machine learning and tidymodels
In the sequel we shall explore UCI’s Glass Identification data set, we shall build a regression model to predict the refractive index of the glass through the elements the glass is made of, and we shall build a classification model for the type of glass through the elements the glass is made of and we shall show the importance of the variables. This is a small data set so we shall use some resampling methods to evaluate it.
Load libraries
Load the data set
The data set can be downloaded from the following link.This gives us three files glass.data, glass.names and glass.tag. The first contains the documents in csv format without headings. The second contains a description of the variables, from which we can extract the names of the variables as well.
glass_levels <- c("building_float",
"building_non_float",
"vehicle_float",
"containers",
"tableware",
"headlamps")
glass <- read_csv("./glass.data", col_names = FALSE) |>
rename(id = 1, reflective_index = 2, Na = 3, Mg = 4,
Al = 5, Si = 6, K = 7, Ca = 8, Ba = 9 , Fe = 10,
type_of_glass = 11) |>
mutate(type_of_glass = factor(type_of_glass, levels = c(1,2,3,5,6,7),
labels = glass_levels))
rm(glass_levels)Let us also copy the specifications of the variables from the glass.name file.
Attribute Information:
Id number: 1 to 214
RI: refractive index
Na: Sodium (unit measurement: weight percent in corresponding oxide)
Mg: Magnesium
Al: Aluminum
Si: Silicon
K: Potassium
Ca: Calcium
Ba: Barium
Fe: Iron
Type of glass: (class attribute)
1 building_windows_float_processed
2 building_windows_non_float_processed
3 vehicle_windows_float_processed
4 vehicle_windows_non_float_processed (none in this database)
5 containers – 6 tableware
7 headlamps
Exploratory Data Analysis (EDA)
Let us start to analyse the data:
glass# A tibble: 214 × 11
id reflective_index Na Mg Al Si K Ca Ba Fe
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 1.52 13.6 4.49 1.1 71.8 0.06 8.75 0 0
2 2 1.52 13.9 3.6 1.36 72.7 0.48 7.83 0 0
3 3 1.52 13.5 3.55 1.54 73.0 0.39 7.78 0 0
4 4 1.52 13.2 3.69 1.29 72.6 0.57 8.22 0 0
5 5 1.52 13.3 3.62 1.24 73.1 0.55 8.07 0 0
6 6 1.52 12.8 3.61 1.62 73.0 0.64 8.07 0 0.26
7 7 1.52 13.3 3.6 1.14 73.1 0.58 8.17 0 0
8 8 1.52 13.2 3.61 1.05 73.2 0.57 8.24 0 0
9 9 1.52 14.0 3.58 1.37 72.1 0.56 8.3 0 0
10 10 1.52 13 3.6 1.36 73.0 0.57 8.4 0 0.11
# ℹ 204 more rows
# ℹ 1 more variable: type_of_glass <fct>
glimpse(glass)Rows: 214
Columns: 11
$ id <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16…
$ reflective_index <dbl> 1.52101, 1.51761, 1.51618, 1.51766, 1.51742, 1.51596,…
$ Na <dbl> 13.64, 13.89, 13.53, 13.21, 13.27, 12.79, 13.30, 13.1…
$ Mg <dbl> 4.49, 3.60, 3.55, 3.69, 3.62, 3.61, 3.60, 3.61, 3.58,…
$ Al <dbl> 1.10, 1.36, 1.54, 1.29, 1.24, 1.62, 1.14, 1.05, 1.37,…
$ Si <dbl> 71.78, 72.73, 72.99, 72.61, 73.08, 72.97, 73.09, 73.2…
$ K <dbl> 0.06, 0.48, 0.39, 0.57, 0.55, 0.64, 0.58, 0.57, 0.56,…
$ Ca <dbl> 8.75, 7.83, 7.78, 8.22, 8.07, 8.07, 8.17, 8.24, 8.30,…
$ Ba <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ Fe <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.26, 0.00, 0.00, 0.00,…
$ type_of_glass <fct> building_float, building_float, building_float, build…
summary(glass) id reflective_index Na Mg
Min. : 1.00 Min. :1.511 Min. :10.73 Min. :0.000
1st Qu.: 54.25 1st Qu.:1.517 1st Qu.:12.91 1st Qu.:2.115
Median :107.50 Median :1.518 Median :13.30 Median :3.480
Mean :107.50 Mean :1.518 Mean :13.41 Mean :2.685
3rd Qu.:160.75 3rd Qu.:1.519 3rd Qu.:13.82 3rd Qu.:3.600
Max. :214.00 Max. :1.534 Max. :17.38 Max. :4.490
Al Si K Ca
Min. :0.290 Min. :69.81 Min. :0.0000 Min. : 5.430
1st Qu.:1.190 1st Qu.:72.28 1st Qu.:0.1225 1st Qu.: 8.240
Median :1.360 Median :72.79 Median :0.5550 Median : 8.600
Mean :1.445 Mean :72.65 Mean :0.4971 Mean : 8.957
3rd Qu.:1.630 3rd Qu.:73.09 3rd Qu.:0.6100 3rd Qu.: 9.172
Max. :3.500 Max. :75.41 Max. :6.2100 Max. :16.190
Ba Fe type_of_glass
Min. :0.000 Min. :0.00000 building_float :70
1st Qu.:0.000 1st Qu.:0.00000 building_non_float:76
Median :0.000 Median :0.00000 vehicle_float :17
Mean :0.175 Mean :0.05701 containers :13
3rd Qu.:0.000 3rd Qu.:0.10000 tableware : 9
Max. :3.150 Max. :0.51000 headlamps :29
glass |>
pivot_longer(cols = Na:Fe, values_to = "value") |>
ggplot(aes(x = name, y = value, fill = type_of_glass)) +
geom_col(show.legend = FALSE) +
facet_wrap(~ type_of_glass)The graph shows a clear relationship between the elements and the type of glass.
glass |>
mutate(type_of_glass = fct_reorder(type_of_glass, reflective_index)) |>
ggplot(aes(x = type_of_glass,
y = reflective_index,
colour = type_of_glass,
fill = type_of_glass)) +
geom_boxplot(show.legend = FALSE, alpha = 0.3)The reflective index seems to be different for the different types of glass.
Building the first model
Let us fit a linear regression model with predictors the elements (Na, Mg, Al, Si, K, Ca, Ba, Fe) and outcome the reflective index of the glass. To this end, let us start by creating the specification of the model.
lm_spec <- linear_reg() |>
set_engine("lm")Let us also build a recipe to change the role of the variables id and type_of_glass which are not going to be part of the model.
index_rec <- recipe(reflective_index ~ ., data = glass) |>
update_role(id, type_of_glass, new_role = "id")We don’t need to pre-process the data, because the linear model is not affected by normalisation. Let us create a workflow to connect the model with the data, and let us fit the model.
index_fit <- workflow() |>
add_model(lm_spec) |>
add_recipe(index_rec) |>
fit(data = glass)Since the data set is not very big, splitting it into train and test sets would have reduced further the data set, so we would rather use resample techniques to evaluate the model.
set.seed(123)
index_boot <- bootstraps(glass, strata = type_of_glass)We shall now fit the sample on the bootstrap set and collect the results for an evaluation.
index_res <- fit_resamples(
index_fit,
index_boot,
control = control_resamples(save_pred = TRUE)
)Let us first see the metrics:
collect_metrics(index_res)# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 0.00109 25 0.0000264 Preprocessor1_Model1
2 rsq standard 0.868 25 0.00657 Preprocessor1_Model1
As we can see, the mean RMSE is very small, and so is its standard error. This is encouraging that the model correctly predicts the reflective index. Let us also see a visualisation of this:
collect_predictions(index_res) |>
group_by(.row) |>
summarise(estimate = mean(.pred), error = sd(.pred)) |>
inner_join(glass |> mutate(.row = row_number())) |>
ggplot(aes(x = reflective_index, y = estimate)) +
geom_abline(linetype = "dashed") +
geom_point(aes(colour = type_of_glass), alpha = 0.4) Finally, let us see the importance of the variables in the model, using the package vip.
index_fit |>
vip::vi() |>
mutate(Variable = fct_reorder(Variable, Importance)) |>
ggplot(aes(y = Variable, x = Importance, fill = Variable)) +
geom_col(show.legend = FALSE)Building the second model
Once we have ascertained that the reflective index is linked to the other variables, let us use these variables to predict the type of glass. This is a classification model, and we wish to use two classification models, which we are going to compare. We shall compare random forest and support vector machine (SVM). We could tune the hyperparameters of the models, but for this article, we shall only use a standardised model for comparison.
Let us start by specifying the models.
rf_spec <- rand_forest(trees = 1000) |>
set_engine("ranger", importance = "permutation") |>
set_mode("classification")
svm_spec <- svm_rbf(cost = 0.5) |>
set_engine("kernlab") |>
set_mode("classification")Let us also prepare a recipe for the data set. This time we need to standardise the variables, so there is some pre-processing to perform: check that there are no zero-variance variable, and standardise all the predictors.
glass_rec <- recipe(type_of_glass ~ ., data = glass) |>
update_role(id, new_role = "id") |>
update_role(reflective_index, new_role = "mute") |>
step_zv(all_numeric()) |>
step_normalize(all_numeric())We can now create a workflow. We shall only put the recipe in the workflow, because we are still comparing the models, so we wish to add them only when we use them.
glass_wf <- workflow() |>
add_recipe(glass_rec)Again, we wish to bootstrap the data set to evaluate the two models on the bootstrap.
set.seed(123)
glass_boot <- bootstraps(glass, strata = type_of_glass)Let us fit the two models on the bootstrap.
set.seed(234)
rf_res <- glass_wf |>
add_model(rf_spec) |>
fit_resamples(
resamples = glass_boot,
metrics = metric_set(roc_auc, spec, sens),
control = control_resamples(save_pred = TRUE)
)
svm_res <- glass_wf |>
add_model(svm_spec) |>
fit_resamples(
resamples = glass_boot,
metrics = metric_set(roc_auc, spec, sens),
control = control_resamples(save_pred = TRUE)
)Let us now evaluate the models. First let us collect and compare the metrics:
collect_metrics(rf_res)# A tibble: 3 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 roc_auc hand_till 0.932 25 0.00434 Preprocessor1_Model1
2 sens macro 0.662 25 0.0129 Preprocessor1_Model1
3 spec macro 0.935 25 0.00263 Preprocessor1_Model1
collect_metrics(svm_res)# A tibble: 3 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 roc_auc hand_till 0.778 25 0.00814 Preprocessor1_Model1
2 sens macro 0.498 25 0.0145 Preprocessor1_Model1
3 spec macro 0.916 25 0.00260 Preprocessor1_Model1
At first glance, it looks like the random forest has performed better, with an area under the curve of 0.932 against 0.778 of the SVM. Let us also make a visualisation of the ROC curve.
rf_res |>
unnest(.predictions) |>
mutate(model = "rf") |>
bind_rows(
svm_res |>
unnest(.predictions) |>
mutate(model = "svm")
) |>
group_by(model) |>
roc_curve(type_of_glass, .pred_building_float,
.pred_building_non_float, .pred_vehicle_float,
.pred_containers, .pred_tableware, .pred_headlamps) |>
ggplot(aes(x = 1 - specificity, y = sensitivity, colour = model)) +
geom_abline(linetype = "dashed") +
geom_line() +
facet_wrap(~ .level) +
coord_equal()The vehicle_float glass shows how bad the SVM performs, being even under the line. It is clear, therefore, that it is the random forest which performs the best. For this model, let us therefore see the variable importance using again the vip package.
glass_wf |>
add_model(rf_spec) |>
fit(glass) |>
vip::vi() |>
mutate(Variable = fct_reorder(Variable, Importance)) |>
ggplot(aes(y = Variable, x = Importance, fill = Variable)) +
geom_col(show.legend = FALSE)