Classification model with tidymodels and UCI glass identification data set

Author

Andrea Carpignani

Published

20 January 2024

On the identification of glass using machine learning and tidymodels

In the sequel we shall explore UCI’s Glass Identification data set, we shall build a regression model to predict the refractive index of the glass through the elements the glass is made of, and we shall build a classification model for the type of glass through the elements the glass is made of and we shall show the importance of the variables. This is a small data set so we shall use some resampling methods to evaluate it.

Load libraries

library(tidyverse)
library(tidymodels)

Load the data set

The data set can be downloaded from the following link.This gives us three files glass.data, glass.names and glass.tag. The first contains the documents in csv format without headings. The second contains a description of the variables, from which we can extract the names of the variables as well.

glass_levels <- c("building_float",
                  "building_non_float", 
                  "vehicle_float",
                  "containers",
                  "tableware",
                  "headlamps")

glass <- read_csv("./glass.data", col_names = FALSE) |> 
    rename(id = 1, reflective_index = 2, Na = 3, Mg = 4,
           Al = 5, Si = 6, K = 7, Ca = 8, Ba = 9 , Fe = 10,
           type_of_glass = 11) |> 
    mutate(type_of_glass = factor(type_of_glass, levels = c(1,2,3,5,6,7),
                                  labels = glass_levels))

rm(glass_levels)

Let us also copy the specifications of the variables from the glass.name file.

  • Attribute Information:

    • Id number: 1 to 214

    • RI: refractive index

    • Na: Sodium (unit measurement: weight percent in corresponding oxide)

    • Mg: Magnesium

    • Al: Aluminum

    • Si: Silicon

    • K: Potassium

    • Ca: Calcium

    • Ba: Barium

    • Fe: Iron

    • Type of glass: (class attribute)

      • 1 building_windows_float_processed

      • 2 building_windows_non_float_processed

      • 3 vehicle_windows_float_processed

      • 4 vehicle_windows_non_float_processed (none in this database)

      • 5 containers – 6 tableware

      • 7 headlamps

Exploratory Data Analysis (EDA)

Let us start to analyse the data:

glass
# A tibble: 214 × 11
      id reflective_index    Na    Mg    Al    Si     K    Ca    Ba    Fe
   <dbl>            <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
 1     1             1.52  13.6  4.49  1.1   71.8  0.06  8.75     0  0   
 2     2             1.52  13.9  3.6   1.36  72.7  0.48  7.83     0  0   
 3     3             1.52  13.5  3.55  1.54  73.0  0.39  7.78     0  0   
 4     4             1.52  13.2  3.69  1.29  72.6  0.57  8.22     0  0   
 5     5             1.52  13.3  3.62  1.24  73.1  0.55  8.07     0  0   
 6     6             1.52  12.8  3.61  1.62  73.0  0.64  8.07     0  0.26
 7     7             1.52  13.3  3.6   1.14  73.1  0.58  8.17     0  0   
 8     8             1.52  13.2  3.61  1.05  73.2  0.57  8.24     0  0   
 9     9             1.52  14.0  3.58  1.37  72.1  0.56  8.3      0  0   
10    10             1.52  13    3.6   1.36  73.0  0.57  8.4      0  0.11
# ℹ 204 more rows
# ℹ 1 more variable: type_of_glass <fct>
glimpse(glass)
Rows: 214
Columns: 11
$ id               <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16…
$ reflective_index <dbl> 1.52101, 1.51761, 1.51618, 1.51766, 1.51742, 1.51596,…
$ Na               <dbl> 13.64, 13.89, 13.53, 13.21, 13.27, 12.79, 13.30, 13.1…
$ Mg               <dbl> 4.49, 3.60, 3.55, 3.69, 3.62, 3.61, 3.60, 3.61, 3.58,…
$ Al               <dbl> 1.10, 1.36, 1.54, 1.29, 1.24, 1.62, 1.14, 1.05, 1.37,…
$ Si               <dbl> 71.78, 72.73, 72.99, 72.61, 73.08, 72.97, 73.09, 73.2…
$ K                <dbl> 0.06, 0.48, 0.39, 0.57, 0.55, 0.64, 0.58, 0.57, 0.56,…
$ Ca               <dbl> 8.75, 7.83, 7.78, 8.22, 8.07, 8.07, 8.17, 8.24, 8.30,…
$ Ba               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ Fe               <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.26, 0.00, 0.00, 0.00,…
$ type_of_glass    <fct> building_float, building_float, building_float, build…
summary(glass)
       id         reflective_index       Na              Mg       
 Min.   :  1.00   Min.   :1.511    Min.   :10.73   Min.   :0.000  
 1st Qu.: 54.25   1st Qu.:1.517    1st Qu.:12.91   1st Qu.:2.115  
 Median :107.50   Median :1.518    Median :13.30   Median :3.480  
 Mean   :107.50   Mean   :1.518    Mean   :13.41   Mean   :2.685  
 3rd Qu.:160.75   3rd Qu.:1.519    3rd Qu.:13.82   3rd Qu.:3.600  
 Max.   :214.00   Max.   :1.534    Max.   :17.38   Max.   :4.490  
       Al              Si              K                Ca        
 Min.   :0.290   Min.   :69.81   Min.   :0.0000   Min.   : 5.430  
 1st Qu.:1.190   1st Qu.:72.28   1st Qu.:0.1225   1st Qu.: 8.240  
 Median :1.360   Median :72.79   Median :0.5550   Median : 8.600  
 Mean   :1.445   Mean   :72.65   Mean   :0.4971   Mean   : 8.957  
 3rd Qu.:1.630   3rd Qu.:73.09   3rd Qu.:0.6100   3rd Qu.: 9.172  
 Max.   :3.500   Max.   :75.41   Max.   :6.2100   Max.   :16.190  
       Ba              Fe                     type_of_glass
 Min.   :0.000   Min.   :0.00000   building_float    :70   
 1st Qu.:0.000   1st Qu.:0.00000   building_non_float:76   
 Median :0.000   Median :0.00000   vehicle_float     :17   
 Mean   :0.175   Mean   :0.05701   containers        :13   
 3rd Qu.:0.000   3rd Qu.:0.10000   tableware         : 9   
 Max.   :3.150   Max.   :0.51000   headlamps         :29   
glass |> 
    pivot_longer(cols = Na:Fe, values_to = "value") |> 
    ggplot(aes(x = name, y = value, fill = type_of_glass)) +
    geom_col(show.legend = FALSE) +
    facet_wrap(~ type_of_glass)

The graph shows a clear relationship between the elements and the type of glass.

glass |> 
    mutate(type_of_glass = fct_reorder(type_of_glass, reflective_index)) |> 
    ggplot(aes(x = type_of_glass, 
               y = reflective_index, 
               colour = type_of_glass,
               fill = type_of_glass)) +
    geom_boxplot(show.legend = FALSE, alpha = 0.3)

The reflective index seems to be different for the different types of glass.

Building the first model

Let us fit a linear regression model with predictors the elements (Na, Mg, Al, Si, K, Ca, Ba, Fe) and outcome the reflective index of the glass. To this end, let us start by creating the specification of the model.

lm_spec <- linear_reg() |> 
    set_engine("lm")

Let us also build a recipe to change the role of the variables id and type_of_glass which are not going to be part of the model.

index_rec <- recipe(reflective_index ~ ., data = glass) |> 
    update_role(id, type_of_glass, new_role = "id")

We don’t need to pre-process the data, because the linear model is not affected by normalisation. Let us create a workflow to connect the model with the data, and let us fit the model.

index_fit <- workflow() |> 
    add_model(lm_spec) |> 
    add_recipe(index_rec) |> 
    fit(data = glass)

Since the data set is not very big, splitting it into train and test sets would have reduced further the data set, so we would rather use resample techniques to evaluate the model.

set.seed(123)
index_boot <- bootstraps(glass, strata = type_of_glass)

We shall now fit the sample on the bootstrap set and collect the results for an evaluation.

index_res <- fit_resamples(
    index_fit,
    index_boot,
    control = control_resamples(save_pred = TRUE)
)

Let us first see the metrics:

collect_metrics(index_res)
# A tibble: 2 × 6
  .metric .estimator    mean     n   std_err .config             
  <chr>   <chr>        <dbl> <int>     <dbl> <chr>               
1 rmse    standard   0.00109    25 0.0000264 Preprocessor1_Model1
2 rsq     standard   0.868      25 0.00657   Preprocessor1_Model1

As we can see, the mean RMSE is very small, and so is its standard error. This is encouraging that the model correctly predicts the reflective index. Let us also see a visualisation of this:

collect_predictions(index_res) |> 
    group_by(.row) |> 
    summarise(estimate = mean(.pred), error = sd(.pred)) |> 
    inner_join(glass |> mutate(.row = row_number())) |> 
    ggplot(aes(x = reflective_index, y = estimate)) +
    geom_abline(linetype = "dashed") +
    geom_point(aes(colour = type_of_glass), alpha = 0.4) 

Finally, let us see the importance of the variables in the model, using the package vip.

index_fit |> 
    vip::vi() |> 
    mutate(Variable = fct_reorder(Variable, Importance)) |>
    ggplot(aes(y = Variable, x = Importance, fill = Variable)) +
    geom_col(show.legend = FALSE)

Building the second model

Once we have ascertained that the reflective index is linked to the other variables, let us use these variables to predict the type of glass. This is a classification model, and we wish to use two classification models, which we are going to compare. We shall compare random forest and support vector machine (SVM). We could tune the hyperparameters of the models, but for this article, we shall only use a standardised model for comparison.

Let us start by specifying the models.

rf_spec <- rand_forest(trees = 1000) |> 
    set_engine("ranger", importance = "permutation") |> 
    set_mode("classification")

svm_spec <- svm_rbf(cost = 0.5) |> 
    set_engine("kernlab") |> 
    set_mode("classification")

Let us also prepare a recipe for the data set. This time we need to standardise the variables, so there is some pre-processing to perform: check that there are no zero-variance variable, and standardise all the predictors.

glass_rec <- recipe(type_of_glass ~ ., data = glass) |> 
    update_role(id, new_role = "id") |> 
    update_role(reflective_index, new_role = "mute") |> 
    step_zv(all_numeric()) |> 
    step_normalize(all_numeric())

We can now create a workflow. We shall only put the recipe in the workflow, because we are still comparing the models, so we wish to add them only when we use them.

glass_wf <- workflow() |> 
    add_recipe(glass_rec)

Again, we wish to bootstrap the data set to evaluate the two models on the bootstrap.

set.seed(123)
glass_boot <- bootstraps(glass, strata = type_of_glass)

Let us fit the two models on the bootstrap.

set.seed(234)
rf_res <- glass_wf |> 
    add_model(rf_spec) |> 
    fit_resamples(
        resamples = glass_boot,
        metrics = metric_set(roc_auc, spec, sens),
        control = control_resamples(save_pred = TRUE)
    )

svm_res <- glass_wf |> 
    add_model(svm_spec) |> 
    fit_resamples(
        resamples = glass_boot,
        metrics = metric_set(roc_auc, spec, sens),
        control = control_resamples(save_pred = TRUE)
    )

Let us now evaluate the models. First let us collect and compare the metrics:

collect_metrics(rf_res)
# A tibble: 3 × 6
  .metric .estimator  mean     n std_err .config             
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1 roc_auc hand_till  0.932    25 0.00434 Preprocessor1_Model1
2 sens    macro      0.662    25 0.0129  Preprocessor1_Model1
3 spec    macro      0.935    25 0.00263 Preprocessor1_Model1
collect_metrics(svm_res)
# A tibble: 3 × 6
  .metric .estimator  mean     n std_err .config             
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1 roc_auc hand_till  0.778    25 0.00814 Preprocessor1_Model1
2 sens    macro      0.498    25 0.0145  Preprocessor1_Model1
3 spec    macro      0.916    25 0.00260 Preprocessor1_Model1

At first glance, it looks like the random forest has performed better, with an area under the curve of 0.932 against 0.778 of the SVM. Let us also make a visualisation of the ROC curve.

rf_res |> 
    unnest(.predictions) |> 
    mutate(model = "rf") |> 
    bind_rows(
        svm_res |> 
            unnest(.predictions) |> 
            mutate(model = "svm")
    ) |> 
    group_by(model) |>
    roc_curve(type_of_glass, .pred_building_float, 
              .pred_building_non_float, .pred_vehicle_float,
              .pred_containers, .pred_tableware, .pred_headlamps) |> 
    ggplot(aes(x = 1 - specificity, y = sensitivity, colour = model)) +
    geom_abline(linetype = "dashed") +
    geom_line() +
    facet_wrap(~ .level) +
    coord_equal()

The vehicle_float glass shows how bad the SVM performs, being even under the line. It is clear, therefore, that it is the random forest which performs the best. For this model, let us therefore see the variable importance using again the vip package.

glass_wf |> 
    add_model(rf_spec) |> 
    fit(glass) |> 
    vip::vi() |> 
    mutate(Variable = fct_reorder(Variable, Importance)) |>
    ggplot(aes(y = Variable, x = Importance, fill = Variable)) +
    geom_col(show.legend = FALSE)