Classification of Mushroom Type using Naive Bayes, Decision Tree and Random Forest

Theresia Londong

2023-03-21

Photo by Thomas Bormans on Unsplash

Intro

Case Study

Plants plucked straight from their natural habitats have made their way onto our plates, with many chefs and home cooks seeing their local woodlands, fields, and bodies of water through a new, plentiful lens.

Mushroom hunting, mushrooming, mushroom picking, mushroom foraging, and similar terms describe the activity of gathering mushrooms in the wild, typically for culinary use, though medicinal and psychotropic uses are also known. This practice is popular throughout most of Europe, Australia, Japan, Korea, parts of the Middle East, and the Indian subcontinent, as well as the temperate regions of Canada and the United States.

The British Isles is home to a staggering 15,000 species of wild mushrooms or fungi. These organisms live almost everywhere in the UK but tend to grow more abundantly in woodland and grassland. For those who know little about fungi, the task of identifying them can be difficult so special care needs to be taken before picking or consuming any finds.

In this study, we will build a model to classify edible & poisonous mushrooms based on their attributes such as; cap shape, odor, color, stalk shape, etc.

Caution: If you are unsure whether a wild mushroom is safe to eat or not, seek advice from an expert. Eating a poisonous mushroom can be fatal – or at least make you feel very unwell, so don’t risk it. There are many foraging courses you can join where you can be guided by an expert.

What Will We Do?

As previously mentioned, we will build a predictive classification model to identify edible & poisonous mushrooms based on 22 attributes (predictors), provided in our mushroom dataset(link) from UCI. This data set includes descriptions of hypothetical samples corresponding to 23 species of gilled mushrooms in the Agaricus and Lepiota Family (pp. 500-525). Each species is identified as definitely edible, definitely poisonous, or of unknown edibility and not recommended. This latter class was combined with the poisonous one.

We set the ‘class’ column as our target variable which contains 2 classes, e for edible mushrooms and p for poisonous ones. We will then split our 8124 data points into 80:20 proportion of data train: data test. Based on that data train we will build our 3 models which are Naive Bayes, Decision Tree, and Random Forest model then compare their performance using Confusion Matrix, ROC, and AUC.

Data source Kaggle

Data Preparation

Load Packages

Before we start our prosess, first we must load the required packages.

library(dplyr) # for data wrangling
library(ggplot2) # to visualize data
library(gridExtra) # to display multiple graph
library(inspectdf) # for EDA
library(caTools) #data splitting
library(caret) # for modelling

library(e1071) #naive bayes classification

library(partykit) # decision tree
library(rpart) 

library(randomForest) #random forest

Read Data

Using raw data from Kaggle , we will start our model classification by first reading the csv file.

mushroom <- read.csv('data_input/mushrooms.csv')
Mushroom Dataset
class cap.shape cap.surface cap.color bruises odor gill.attachment gill.spacing gill.size gill.color stalk.shape stalk.root stalk.surface.above.ring stalk.surface.below.ring stalk.color.above.ring stalk.color.below.ring veil.type veil.color ring.number ring.type spore.print.color population habitat
p x s n t p f c n k e e s s w w p w o p k s u
e x s y t a f c b k e c s s w w p w o p n n g
e b s w t l f c b n e c s s w w p w o p n n m
p x y w t p f c n n e e s s w w p w o p k s u
e x s g f n f w b k t e s s w w p w o e n a g
e x y y t a f c b n e c s s w w p w o p k n g
Attribute Information
No Attribute Description
1 classes edible=e, poisonous=p
2 cap-shape bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s
3 cap-surface fibrous=f,grooves=g,scaly=y,smooth=s
4 cap-color brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y
5 bruises bruises=t,no=f
6 odor almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n,pungent=p,spicy=s
7 gill-attachment attached=a,descending=d,free=f,notched=n
8 gill-spacing close=c,crowded=w,distant=d
9 gill-size broad=b,narrow=n
10 gill-color black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y
11 stalk-shape enlarging=e,tapering=t
12 stalk-root bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=?
13 stalk-surface-above-ring fibrous=f,scaly=y,silky=k,smooth=s
14 stalk-surface-below-ring fibrous=f,scaly=y,silky=k,smooth=s
15 stalk-color-above-ring brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y
16 stalk-color-below-ring brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y
17 veil-type partial=p,universal=u
18 veil-color brown=n,orange=o,white=w,yellow=y
19 ring-number none=n,one=o,two=t
20 ring-type cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z
21 spore-print-color black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y
22 population abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y
23 habitat grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d

Data Wrangling

Using glimpse() function, we will take a look at the initial structure of our datase.

glimpse(mushroom)
## Rows: 8,124
## Columns: 23
## $ class                    <chr> "p", "e", "e", "p", "e", "e", "e", "e", "p", …
## $ cap.shape                <chr> "x", "x", "b", "x", "x", "x", "b", "b", "x", …
## $ cap.surface              <chr> "s", "s", "s", "y", "s", "y", "s", "y", "y", …
## $ cap.color                <chr> "n", "y", "w", "w", "g", "y", "w", "w", "w", …
## $ bruises                  <chr> "t", "t", "t", "t", "f", "t", "t", "t", "t", …
## $ odor                     <chr> "p", "a", "l", "p", "n", "a", "a", "l", "p", …
## $ gill.attachment          <chr> "f", "f", "f", "f", "f", "f", "f", "f", "f", …
## $ gill.spacing             <chr> "c", "c", "c", "c", "w", "c", "c", "c", "c", …
## $ gill.size                <chr> "n", "b", "b", "n", "b", "b", "b", "b", "n", …
## $ gill.color               <chr> "k", "k", "n", "n", "k", "n", "g", "n", "p", …
## $ stalk.shape              <chr> "e", "e", "e", "e", "t", "e", "e", "e", "e", …
## $ stalk.root               <chr> "e", "c", "c", "e", "e", "c", "c", "c", "e", …
## $ stalk.surface.above.ring <chr> "s", "s", "s", "s", "s", "s", "s", "s", "s", …
## $ stalk.surface.below.ring <chr> "s", "s", "s", "s", "s", "s", "s", "s", "s", …
## $ stalk.color.above.ring   <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w", …
## $ stalk.color.below.ring   <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w", …
## $ veil.type                <chr> "p", "p", "p", "p", "p", "p", "p", "p", "p", …
## $ veil.color               <chr> "w", "w", "w", "w", "w", "w", "w", "w", "w", …
## $ ring.number              <chr> "o", "o", "o", "o", "o", "o", "o", "o", "o", …
## $ ring.type                <chr> "p", "p", "p", "p", "e", "p", "p", "p", "p", …
## $ spore.print.color        <chr> "k", "n", "n", "k", "n", "k", "k", "n", "k", …
## $ population               <chr> "s", "n", "n", "s", "a", "n", "n", "s", "v", …
## $ habitat                  <chr> "u", "g", "m", "u", "g", "g", "m", "m", "g", …

From the result of glimpse() function above, we decide to transform all character variables to factor and save to mush_02 object.

Next step is to check missing values, if any.

mush_02 %>% is.na() %>% colSums()
##                    class                cap.shape              cap.surface 
##                        0                        0                        0 
##                cap.color                  bruises                     odor 
##                        0                        0                        0 
##          gill.attachment             gill.spacing                gill.size 
##                        0                        0                        0 
##               gill.color              stalk.shape               stalk.root 
##                        0                        0                        0 
## stalk.surface.above.ring stalk.surface.below.ring   stalk.color.above.ring 
##                        0                        0                        0 
##   stalk.color.below.ring                veil.type               veil.color 
##                        0                        0                        0 
##              ring.number                ring.type        spore.print.color 
##                        0                        0                        0 
##               population                  habitat 
##                        0                        0

EDA

As initial step to EDA, we can start by inspecting the summary() of our mush_02 dataframe.

summary(mush_02)
##  class    cap.shape cap.surface   cap.color    bruises       odor     
##  e:4208   b: 452    f:2320      n      :2284   f:4748   n      :3528  
##  p:3916   c:   4    g:   4      g      :1840   t:3376   f      :2160  
##           f:3152    s:2556      e      :1500            s      : 576  
##           k: 828    y:3244      y      :1072            y      : 576  
##           s:  32                w      :1040            a      : 400  
##           x:3656                b      : 168            l      : 400  
##                                 (Other): 220            (Other): 484  
##  gill.attachment gill.spacing gill.size   gill.color   stalk.shape stalk.root
##  a: 210          c:6812       b:5612    b      :1728   e:3516      ?:2480    
##  f:7914          w:1312       n:2512    p      :1492   t:4608      b:3776    
##                                         w      :1202               c: 556    
##                                         n      :1048               e:1120    
##                                         g      : 752               r: 192    
##                                         h      : 732                         
##                                         (Other):1170                         
##  stalk.surface.above.ring stalk.surface.below.ring stalk.color.above.ring
##  f: 552                   f: 600                   w      :4464          
##  k:2372                   k:2304                   p      :1872          
##  s:5176                   s:4936                   g      : 576          
##  y:  24                   y: 284                   n      : 448          
##                                                    b      : 432          
##                                                    o      : 192          
##                                                    (Other): 140          
##  stalk.color.below.ring veil.type veil.color ring.number ring.type
##  w      :4384           p:8124    n:  96     n:  36      e:2776   
##  p      :1872                     o:  96     o:7488      f:  48   
##  g      : 576                     w:7924     t: 600      l:1296   
##  n      : 512                     y:   8                 n:  36   
##  b      : 432                                            p:3968   
##  o      : 192                                                     
##  (Other): 156                                                     
##  spore.print.color population habitat 
##  w      :2388      a: 384     d:3148  
##  n      :1968      c: 340     g:2148  
##  k      :1872      n: 400     l: 832  
##  h      :1632      s:1248     m: 292  
##  r      :  72      v:4040     p:1144  
##  b      :  48      y:1712     u: 368  
##  (Other): 144                 w: 192

Based on above result, we decide to do final manipulation by removing factor column with 1 level/class only which is ‘veil.type’, and save our final data frame to mush_cln object.

mush_cln <- mush_02 %>% select(-c(veil.type))

And to visualize the proportion of each class on variables with factor data type, we can use the inspect_cat() function as follow,

mush_cln %>% 
  inspect_cat() %>% 
  show_plot()

Next, we continue to check our target variable : class which has 2 classes, e for edible and p for poisonous

prop.table(table(mush_cln$class))
## 
##         e         p 
## 0.5179714 0.4820286

Insight: The target variable class has balance proportion between classes.

Modelling

Cross Validation

Initially, we must define the positive and negative class of our target variable:

  • Positive Class: p (poisoous)
  • Negative CLass: e (edible)

In order to build our model, we must first split our data into training and validation/test sets. The training set is used to train the model, and the validation/test set is used to validate it on data it has never seen before. The classic approach of cross-validation is by doing a simple 80%-20% split on our main data. To avoid bias toward certain class, we have to make sure that our train data set has balance proportion between classes on the target variable.

We start by splitting our mush_cln data to data train and data test, with classic 80% data train proportion.

RNGkind(sample.kind = "Rounding") 
set.seed(222)

# train-test splitting
index <- sample(x = nrow(mush_cln), size= nrow(mush_cln)*0.8) #to subset our main data into data train

mush_train <- mush_cln[index, ] 
mush_test <- mush_cln[-index, ]

Then we make sure balance proportion on target variable:

# Target Var. proportion on data train
prop.table(table(mush_train$class))
## 
##       e       p 
## 0.52008 0.47992
# Target Var. proportion on data test
prop.table(table(mush_test$class))
## 
##         e         p 
## 0.5095385 0.4904615

Which resulted in balance classes in both data train and test.

Naive Bayes

Naive Bayes performs a classification based on the dependent probability between the predictor and the target variable (Bayesian Theorem).

  • Assumptions of Naive Bayes:
    • Independency between predictors / unrelated to each other
    • Each predictor has the same weight
    • Dependency between predictors and target variable
  • Pro’s:
    • Computation time is relatively faster than other classification models, because it only computes the frequency table proportions.
    • Therefore, it is often used as a baseline model or benchmark, which is a simple model (reference) that we will later compare with more complex models

There are 2 approaches to create a model using the naiveBayes() function :

  1. Using the naiveBayes(formula, data) argument

    • formula: formula y~x, where y: target variable, x: predictor variable
    • data: data used for target and predictor variables
  2. Using the naiveBayes(x, y) argument

    • x: variable predictor of the data used
    • y: target variable of the data used

In certain cases, data scarcity can occur, which is a condition where a predictor is not present at all in one of the classes. This is the second characteristic of Naive Bayes: Skewness Due To Scarcity. When there is a predictor whose frequency is 0 for one of the classes, then the model automatically predicts that the probability is 0 for that condition, regardless of the values of the other predictors.

We want to make sure none of the observations are zero, but also that the proportions don’t change much from the original. An alternative solution is using Laplace Smoothing, namely by adding the frequency of each predictor to a certain number (best practice -> usually 1), so that there are no more predictors that have a value of 0.

# Modeling
model_naive <- naiveBayes(formula = class ~ ., 
                             data = mush_train, 
                             laplace = 1)
# Predict
mush_predClass <- predict(model_naive, newdata = mush_test, type = "class")

head(mush_predClass)
## [1] e e e e e e
## Levels: e p

Model Interpretation

To interpret our model, we can read each of the following conditional probability related to each predictors

## 
## Naive Bayes Classifier for Discrete Predictors
## 
## Call:
## naiveBayes.default(x = X, y = Y, laplace = laplace)
## 
## A-priori probabilities:
## Y
##       e       p 
## 0.52008 0.47992 
## 
## Conditional probabilities:
##    cap.shape
## Y              b            c            f            k            s
##   e 0.0989367986 0.0002953337 0.3797991731 0.0549320732 0.0076786769
##   p 0.0124800000 0.0012800000 0.3948800000 0.1539200000 0.0003200000
##    cap.shape
## Y              x
##   e 0.4583579445
##   p 0.4371200000
## 
##    cap.surface
## Y              f            g            s            y
##   e 0.3714539007 0.0002955083 0.2751182033 0.3531323877
##   p 0.1927633686 0.0016010247 0.3608709574 0.4447646494
## 
##    cap.color
## Y              b            c            e            g            n
##   e 0.0120943953 0.0082595870 0.1421828909 0.2469026549 0.3008849558
##   p 0.0294023650 0.0028763183 0.2224352828 0.2112496005 0.2556727389
##    cap.color
## Y              p            r            u            w            y
##   e 0.0135693215 0.0047197640 0.0041297935 0.1725663717 0.0946902655
##   p 0.0230105465 0.0003195909 0.0003195909 0.0843720038 0.1703419623
## 
##    bruises
## Y           f         t
##   e 0.3483146 0.6516854
##   p 0.8401153 0.1598847
## 
##    odor
## Y              a            c            f            l            m
##   e 0.0994393626 0.0002950723 0.0002950723 0.0929477722 0.0002950723
##   p 0.0003196931 0.0501918159 0.5527493606 0.0003196931 0.0089514066
##    odor
## Y              n            p            s            y
##   e 0.8058424314 0.0002950723 0.0002950723 0.0002950723
##   p 0.0319693095 0.0642583120 0.1445012788 0.1467391304
## 
##    gill.attachment
## Y             a           f
##   e 0.044648137 0.955351863
##   p 0.004165332 0.995834668
## 
##    gill.spacing
## Y            c          w
##   e 0.71052632 0.28947368
##   p 0.97180391 0.02819609
## 
##    gill.size
## Y            b          n
##   e 0.92962744 0.07037256
##   p 0.43543736 0.56456264
## 
##    gill.color
## Y              b            e            g            h            k
##   e 0.0002948113 0.0224056604 0.0589622642 0.0492334906 0.0828419811
##   p 0.4362823379 0.0003193868 0.1267965506 0.1347812201 0.0146917918
##    gill.color
## Y              n            o            p            r            u
##   e 0.2187500000 0.0159198113 0.2016509434 0.0002948113 0.1070165094
##   p 0.0303417439 0.0003193868 0.1670392846 0.0067071223 0.0118173108
##    gill.color
## Y              w            y
##   e 0.2290683962 0.0135613208
##   p 0.0654742894 0.0054295752
## 
##    stalk.shape
## Y           e         t
##   e 0.3855707 0.6144293
##   p 0.4873438 0.5126562
## 
##    stalk.root
## Y              ?            b            c            e            r
##   e 0.1719350074 0.4522895126 0.1223042836 0.2070901034 0.0463810931
##   p 0.4462227913 0.4779129321 0.0112035851 0.0643405890 0.0003201024
## 
##    stalk.surface.above.ring
## Y             f           k           s           y
##   e 0.096926714 0.036643026 0.862588652 0.003841608
##   p 0.038104387 0.571245597 0.388088377 0.002561639
## 
##    stalk.surface.below.ring
## Y            f          k          s          y
##   e 0.10786052 0.03546099 0.80673759 0.04994090
##   p 0.03650336 0.55427474 0.38904899 0.02017291
## 
##    stalk.color.above.ring
## Y              b            c            e            g            n
##   e 0.0002950723 0.0002950723 0.0244910003 0.1395691945 0.0038359398
##   p 0.1102941176 0.0089514066 0.0003196931 0.0003196931 0.1096547315
##    stalk.color.above.ring
## Y              o            p            w            y
##   e 0.0445559162 0.1318973148 0.6547654175 0.0002950723
##   p 0.0003196931 0.3308823529 0.4367007673 0.0025575448
## 
##    stalk.color.below.ring
## Y              b            c            e            g            n
##   e 0.0002950723 0.0002950723 0.0230156388 0.1348480378 0.0150486869
##   p 0.1131713555 0.0089514066 0.0003196931 0.0003196931 0.1160485934
##    stalk.color.below.ring
## Y              o            p            w            y
##   e 0.0445559162 0.1348480378 0.6467984656 0.0002950723
##   p 0.0003196931 0.3228900256 0.4306265985 0.0073529412
## 
##    veil.color
## Y              n            o            w            y
##   e 0.0212765957 0.0236406619 0.9547872340 0.0002955083
##   p 0.0003202049 0.0003202049 0.9967979507 0.0025616394
## 
##    ring.number
## Y              n            o            t
##   e 0.0002955956 0.8728938812 0.1268105232
##   p 0.0089686099 0.9718129404 0.0192184497
## 
##    ring.type
## Y              e            f            l            n            p
##   e 0.2404726736 0.0118168390 0.0002954210 0.0002954210 0.7471196455
##   p 0.4484635083 0.0003201024 0.3325864277 0.0089628681 0.2096670935
## 
##    spore.print.color
## Y              b            h            k            n            o
##   e 0.0112127471 0.0118028917 0.3886102095 0.4151667158 0.0106226025
##   p 0.0003196931 0.4069693095 0.0578644501 0.0565856777 0.0003196931
##    spore.print.color
## Y              r            u            w            y
##   e 0.0002950723 0.0120979640 0.1377987607 0.0123930363
##   p 0.0191815857 0.0003196931 0.4581202046 0.0003196931
## 
##    population
## Y            a          c          n          s          v          y
##   e 0.09273479 0.06704076 0.09805080 0.20850561 0.27643237 0.25723568
##   p 0.00032000 0.01280000 0.00032000 0.09280000 0.72512000 0.16864000
## 
##    habitat
## Y              d            g            l            m            p
##   e 0.4434602893 0.3365810452 0.0558015943 0.0617065249 0.0330676115
##   p 0.3234165067 0.1909788868 0.1465131158 0.0099168266 0.2597568778
##    habitat
## Y              u            w
##   e 0.0236197225 0.0457632123
##   p 0.0690978887 0.0003198976

For example:

  • A-priori probability is when we do not have any conditions / informations added to our knowing. So, the probability is a mush is edible (class = e) is 0.52 while being poisonous (class = p) is 0.48
  • the conditional probability of edible mushroom (class = e) which habitat is woods (habitat = d) is 0.44
  • the conditional probability of poisonous mushroom (class = p) which veil.color is orange (veil.color = o) is 0.00032
  • etc

Model Evaluation

The most basic evaluation is by using the Confusion Matrix.

confusionMatrix(data = mush_predClass , #hasil predict
                reference = mush_test$class , # data actual
                positive = "p")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   e   p
##          e 822  77
##          p   6 720
##                                           
##                Accuracy : 0.9489          
##                  95% CI : (0.9371, 0.9591)
##     No Information Rate : 0.5095          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8976          
##                                           
##  Mcnemar's Test P-Value : 1.548e-14       
##                                           
##             Sensitivity : 0.9034          
##             Specificity : 0.9928          
##          Pos Pred Value : 0.9917          
##          Neg Pred Value : 0.9143          
##              Prevalence : 0.4905          
##          Detection Rate : 0.4431          
##    Detection Prevalence : 0.4468          
##       Balanced Accuracy : 0.9481          
##                                           
##        'Positive' Class : p               
## 

Since our Positive class is p and Negative class is e, it is more concerning when an actually poisonous mushrooms predicted as edible, or False Negative Case. Hence our matrix of concern is the Recall/Sensitivity matrix. Based on confusion matrix above, model_naive performance resulted 90.3% of Sensitivity.

Decision Tree

Decision Tree is a fairly simple tree-based model with robust/powerful performance for prediction. The Decision Tree produces a visualization in the form of a decision tree which can be interpreted easily.

Decision Tree additional characters:

  • Variable predictors are assumed to be mutually dependent, so that multicollinearity can be overcome.
  • Can overcome numerical predictor values in the form of outliers.

Note: Decision Tree is not only limited to Classification cases, but can be used in Regression cases. In this course, our focus is on the Classification case because the idea is the same.

Model Fitting

Using ctree() function from partykit package, we build our decision tree model then create the plot to easily visualize the model.

#building model

dtree_model <- ctree(formula = class ~.,
                     data = mush_train)


plot(dtree_model, type = "simple")

# predict
pred_mush_test <- predict(object = dtree_model , 
                          newdata = mush_test , 
                          type = "response")

Model evaluation

Performance evaluation using Confusion Matrix

# Confusion Matrix: data test
confusionMatrix(data = pred_mush_test , 
                reference = mush_test$class, 
                positive = "p")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   e   p
##          e 828   3
##          p   0 794
##                                           
##                Accuracy : 0.9982          
##                  95% CI : (0.9946, 0.9996)
##     No Information Rate : 0.5095          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.9963          
##                                           
##  Mcnemar's Test P-Value : 0.2482          
##                                           
##             Sensitivity : 0.9962          
##             Specificity : 1.0000          
##          Pos Pred Value : 1.0000          
##          Neg Pred Value : 0.9964          
##              Prevalence : 0.4905          
##          Detection Rate : 0.4886          
##    Detection Prevalence : 0.4886          
##       Balanced Accuracy : 0.9981          
##                                           
##        'Positive' Class : p               
## 

In the decision tree model, apart from seeing the confusion matrix for test data only, it’s a good idea to also compare it with the performance model in the data train (diab_train) to find out the fitting of a model:

# Prediction in Data Train
pred_mush_train <- predict(object = dtree_model, 
                           newdata = mush_train , 
                           type = "response")
# Confusion Matrix: data train
confusionMatrix(data = pred_mush_train, # prediction
                reference = mush_train$class, 
                positive = "p")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    e    p
##          e 3380    5
##          p    0 3114
##                                           
##                Accuracy : 0.9992          
##                  95% CI : (0.9982, 0.9998)
##     No Information Rate : 0.5201          
##     P-Value [Acc > NIR] : < 2e-16         
##                                           
##                   Kappa : 0.9985          
##                                           
##  Mcnemar's Test P-Value : 0.07364         
##                                           
##             Sensitivity : 0.9984          
##             Specificity : 1.0000          
##          Pos Pred Value : 1.0000          
##          Neg Pred Value : 0.9985          
##              Prevalence : 0.4799          
##          Detection Rate : 0.4792          
##    Detection Prevalence : 0.4792          
##       Balanced Accuracy : 0.9992          
##                                           
##        'Positive' Class : p               
## 

Result:

Sensitivity matrix of 99.6%, in data test

Sensitivity matrix of 99.8%, in data train

The tree_model model is in Just Right condition which means the model performs well in the data train, but decreases slightly (difference <0.1%) in the test data

Random Forest

Random Forest is a type of Ensemble Method which consists of many Decision Trees. Each Decision Tree has its own characteristics and is not related to each other. Random Forest makes use of the Bagging (Bootstrap and Aggregation) concept in its creation. Here is the process:

  1. PROCESS 1 = Bootstrap sampling: Generates data by random sampling (with replacement) of the entire data and allows for duplicate rows.
  2. PROCESS 2 = 1 decision tree is made for each bootstrap data. The mtry parameter is used to randomly select the number of predictor candidates (Automatic Feature Selection)
  3. PROCESS 3 = Make predictions on new observations for each Decision Tree.
  4. PROCESS 4 = Aggregation: Generates a single prediction to predict.
    • Case classification: majority voting
    • Regression case: average of target values

Benefits of Random Forests:

  • Suppresses the bias and variance of the Decision Tree, resulting in better predictive performance.
  • Automatic feature selection: Predictors are randomly selected in the making of the Decision Tree.
  • There is an out-of-bag error as a substitute for model evaluation.

K-fold Cross Validation

Usually we do cross validation by splitting the data into training and testing data. Using K-Fold Cross Validation, we divide data into equal parts of \(k\), where each part is used as testing data alternately.

From the mush_train we created. For example, we will create a random forest model with k-fold cross validation (k=5) and create the k-fold set 3 times:

set.seed(217)
control <- trainControl(method = "repeatedcv", number = 5, repeats = 3)

Model Fitting

One of the weaknesses of the Random Forest is that model building takes a long time. A good practice after completing training is to save the model in the form of an RDS file with the saveRDS() function so that the model can be used immediately without training from the start.

# build random forest model
#mush_rf_model <- train(form = class ~ ., data = mush_train, method = "rf",
#                   trainControl = control)

# save model to as RDS file
#saveRDS(mush_rf_model, "mush_rf_model.RDS")
# read model
mush_forest <- readRDS("mush_rf_model.RDS")

Model Evaluation

Based on the summary of our model bellow, we choose the model with mtry = 48, which has the highest accuracy, 99.98%, when tested on data from boostrap sampling (can be considered as data train in making decision trees in random forests ).

## Random Forest 
## 
## 6499 samples
##   21 predictor
##    2 classes: 'e', 'p' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 6499, 6499, 6499, 6499, 6499, 6499, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.9621823  0.9240760
##   48    0.9998000  0.9995995
##   95    0.9997833  0.9995659
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 48.

To check the final model, we use $finalModel parameter to our model mush_forest.

library(randomForest)
# check final model

mush_forest$finalModel
## 
## Call:
##  randomForest(x = x, y = y, mtry = param$mtry, trainControl = ..1) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 48
## 
##         OOB estimate of  error rate: 0.02%
## Confusion matrix:
##      e    p  class.error
## e 3380    0 0.0000000000
## p    1 3118 0.0003206156

In the mush_forest model, the value of Out of Bag Error is 0.02%. In other words, the accuracy of the model on the test data (out of bag data) is 100% - error rate (0.02) = 99.98%.

Performance with ROC & AUC

Accuracy has the disadvantage of demonstrating the goodness of the model in classifying the two classes. Overcoming the lack of accuracy, ROC and AUC are available as evaluation tools besides the Confusion Matrix. In this case we will try to find the ROC and AUC on the most well-performed model between the mush_forest and dtree_model

Receiver-Operating Curve (ROC)

ROC is a curve that describes the relationship between the True Positive Rate (Sensitivity or Recall) and the False Positive Rate (1-Specificity) at each threshold. A good model should ideally have a high True Positive Rate and a low False Positive Rate. Note: Specificity is a True Negative Rate.

  • Step 1: Predict the Probability

First, prepare the prediction result in the form of probability from mush_test, save it to an object named mush_pred_prob:

mush_pred_prob <- predict(object = dtree_model, # model
                         newdata = mush_test, # data test 
                         type = "prob") # "raw" hasil keluaran peluang
head(mush_pred_prob)
##    e p
## 1  0 1
## 4  0 1
## 8  1 0
## 13 1 0
## 15 1 0
## 18 0 1

Set up the ROC data then save it to an object named data_mush_roc:

data_mush_roc <- data.frame(pred_prob_mush = mush_pred_prob[ , "p"], # positive class
                            actual = mush_test$class) # label test

head(data_mush_roc)
##    pred_prob_mush actual
## 1               1      p
## 4               1      p
## 8               0      e
## 13              0      e
## 15              0      e
## 18              1      p
  • Step 2. Create Prediction Object with prediction()

The prediction() function will transform the probability results into a form of prediction that can be made a ROC AUC curve later. In this function there are 2 parameters that will be used, namely:

  • predictions = to retrieve positive class predictions still in the form of probability
  • labels = to change positive class targets to 1 and 0

Create the ROC curve. Save the resulting prediction() object with the name mush_roc:

library(ROCR)

mush_roc <- prediction(predictions = data_mush_roc$pred_prob_mush,
                      labels = data_mush_roc$actual)
  • Step 3. Create ROC Curve

After successfully making the prediction results, we will display them in the form of a curve. To create a curve we will use the following syntax:

plot(performance(prediction.obj = ..., measure = "tpr", x.measure = "fpr"))

plot(performance(prediction.obj = mush_roc, 
                 measure = "tpr", # true positive
                 x.measure = "fpr")) # false positive

we can see from the above ROC curve, we have near perfect 90° angle on top left of the curve

Area Under Curve (AUC)

AUC shows the area under the ROC curve. The closer to 1, the better the model performance in separating positive and negative classes. To get the AUC value, use the parameter measure = "auc" in the performance() function and then take the value y.values.

  • Step 4. AUC Calculation

In addition to utilizing a form of visualization with ROC curve, we can check the AUC value by using the performance() function. In the function, we will use the following 2 parameters:

  • prediction.obj = object that stores the predicted result of the function predictions()
  • measure = to see the results of the measurement, we will fill it with “auc”
mush_value <- performance(prediction.obj = mush_roc, 
                         measure = "auc")
mush_value@y.values
## [[1]]
## [1] 0.9999477

And our AUC value is almost 1.

Meaning that our model dtree_model is good at distinguishing positive and negative classes marked by the shape of the ROC curve and the model performance results from the AUC are close to 1.

Conclusion

From our model above we conclude that:

  • Our matrix of interest is to minimize number of False Negative case, where an actually poisonous mushroom detected as edible, by increasing the number of Sensitivity/Recall
  • Our best performing model is the Random Forest model, mush_forest, with Sensitivity matrix of 99.98%
  • Followed by our second best dtree_model using decision tree algorithm, with Sensitivity matrix of 99.6%, in data test, and 99.8%, in data train, which gain the Just Right status
  • And then the last one is base model model_naive using the Naive Bayes approach, with Sensitivity matrix of 90.3%
  • To further see our model performance in differentiating the positive and negative class we inspect the ROC plot and AUC value on one of our best model, dtree_model. It resulted in near perfect ROC curve and 0.9999477 AUC value. Meaning that our model is good at distinguishing positive & negative classes.

References

  1. https://archive.ics.uci.edu/ml/datasets/mushroom

  2. https://www.countryfile.com/wildlife/how-to-identify/guide-to-british-fungi-where-to-find-and-how-to-identify/

  3. Mushroom records drawn from The Audubon Society Field Guide to North American Mushrooms (1981). G. H. Lincoff (Pres.), New York: Alfred A. Knopf