1 Objective

The objective of this article is to explore machine learning algorithm for classification of diamonds into various cost buckets depending on various characteristics.

2 Algorithm development & Testing

2.1 Initial setup

2.1.1 Load libraries

Data cleansing, tidying, transformation libraries. Plotting libraries.

# Load required libraries
library(dplyr); library(tidyr); library(ggplot2)

Specialised libraries for machine learning

# Load required libraries
library(caret);  library(randomForest)
library(rattle); library(rpart.plot)
# set see so that the results can be reproducible
set.seed(1000)

2.1.2 Cutting the diamond price and carats into ordinal factors

# Load the diamonds dataset
data(diamonds)
# Diamond price: Cut by interval of 1000
# fprice = factored price
diamonds$fprice <- as.numeric(cut(diamonds$price, 
                                  seq(from = 0, to = 50000, by = 4000)))
# Convert factored price into ordinal factor
diamonds$fprice <- ordered(diamonds$fprice)

# Diamond Carat: Cut by interval 0.5
# fcarat = factored carat
diamonds$fcarat <- as.numeric(cut(diamonds$carat, 
                                  seq(from = 0, to = 6, by = 0.1)))
# Convert factored carat into ordinal factor
diamonds$fcarat <- ordered(diamonds$fcarat)

2.1.3 Structure of Diamonds database

str(diamonds)
## 'data.frame':    53940 obs. of  12 variables:
##  $ carat  : num  0.23 0.21 0.23 0.29 0.31 0.24 0.24 0.26 0.22 0.23 ...
##  $ cut    : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 2 3 3 3 1 3 ...
##  $ color  : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 7 6 5 2 5 ...
##  $ clarity: Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 2 6 7 3 4 5 ...
##  $ depth  : num  61.5 59.8 56.9 62.4 63.3 62.8 62.3 61.9 65.1 59.4 ...
##  $ table  : num  55 61 65 58 58 57 57 55 61 61 ...
##  $ price  : int  326 326 327 334 335 336 336 337 337 338 ...
##  $ x      : num  3.95 3.89 4.05 4.2 4.34 3.94 3.95 4.07 3.87 4 ...
##  $ y      : num  3.98 3.84 4.07 4.23 4.35 3.96 3.98 4.11 3.78 4.05 ...
##  $ z      : num  2.43 2.31 2.31 2.63 2.75 2.48 2.47 2.53 2.49 2.39 ...
##  $ fprice : Ord.factor w/ 5 levels "1"<"2"<"3"<"4"<..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ fcarat : Ord.factor w/ 40 levels "2"<"3"<"4"<"5"<..: 2 2 2 2 3 2 2 2 2 2 ...
summary(diamonds)
##      carat               cut        color        clarity     
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258  
##  Median :0.7000   Very Good:12082   F: 9542   SI2    : 9194  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171  
##  3rd Qu.:1.0400   Ideal    :21551   H: 8304   VVS2   : 5066  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655  
##                                     J: 2808   (Other): 2531  
##      depth           table           price             x         
##  Min.   :43.00   Min.   :43.00   Min.   :  326   Min.   : 0.000  
##  1st Qu.:61.00   1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710  
##  Median :61.80   Median :57.00   Median : 2401   Median : 5.700  
##  Mean   :61.75   Mean   :57.46   Mean   : 3933   Mean   : 5.731  
##  3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :95.00   Max.   :18823   Max.   :10.740  
##                                                                  
##        y                z          fprice        fcarat     
##  Min.   : 0.000   Min.   : 0.000   1:34561   4      :10188  
##  1st Qu.: 4.720   1st Qu.: 2.910   2:11774   11     : 6010  
##  Median : 5.710   Median : 3.530   3: 4142   6      : 5516  
##  Mean   : 5.735   Mean   : 3.539   4: 2321   5      : 4541  
##  3rd Qu.: 6.540   3rd Qu.: 4.040   5: 1142   8      : 4249  
##  Max.   :58.900   Max.   :31.800             3      : 4191  
##                                              (Other):19245
head(diamonds, 5)
##   carat     cut color clarity depth table price    x    y    z fprice
## 1  0.23   Ideal     E     SI2  61.5    55   326 3.95 3.98 2.43      1
## 2  0.21 Premium     E     SI1  59.8    61   326 3.89 3.84 2.31      1
## 3  0.23    Good     E     VS1  56.9    65   327 4.05 4.07 2.31      1
## 4  0.29 Premium     I     VS2  62.4    58   334 4.20 4.23 2.63      1
## 5  0.31    Good     J     SI2  63.3    58   335 4.34 4.35 2.75      1
##   fcarat
## 1      3
## 2      3
## 3      3
## 4      3
## 5      4

2.1.4 Subsetting the dataset

The dataset is subset to a smaller size as the dataset it huge

# Number of observations
input.obs <- 10000
# Sampling the data from the subset based on the number of observations
data.sample <- diamonds[sample(1:nrow(diamonds), input.obs,
                                  replace=FALSE),]
# Assigning the value of dataset variable
dataset <- data.sample

2.2 Exploring the data

2.2.1 Price distribution in the dataset

# Histogram of price distribution
qplot(fprice, data=dataset, geom="histogram")

# Histogram of carat distribution
qplot(fcarat, data=dataset, geom="histogram")

# Association of price with carat and clarity
g <- ggplot(data.sample, aes(y = fprice, x = fcarat))
g <- g + geom_point(aes(color=clarity), position="jitter")
g <- g + geom_smooth(method=loess, col="blue", lwd=1)
g <- g + theme(legend.position="bottom")
g
## geom_smooth: Only one unique x value each group.Maybe you want aes(group = 1)?

2.3 Model development

2.3.0.1 Splitting data into Training and Testing

# Tidy up the dataset used for development of the model
dataset.pr <- select(dataset, fcarat, cut:table, x:z, fprice, price)
dataset    <- select(dataset, fcarat, cut:table, x:z, fprice)

# Split the data into training and testing datasets
# 70% in the training dataset and 30% in testing dataset
inTrain  <- createDataPartition(y=dataset$fprice, p=0.7, list=FALSE)
training <- dataset[inTrain,]
testing  <- dataset[-inTrain,]
dim(training); dim(testing)
## [1] 7003   10
## [1] 2997   10

2.4 Developing Predictive Models (Decision Tree)

2.4.1 Model definition

modFit <- train(fprice ~., method = "rpart", data = training)

2.4.2 Plotting the classification tree, the fancy style

library(rattle); library(rpart.plot)
fancyRpartPlot(modFit$finalModel)

2.4.3 Model validation

2.4.3.1 Training set accuracy (In-Sample)

pred.train <- predict(modFit, training)
print(confusionMatrix(pred.train, training$fprice))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5
##          1 4212   66    0    1    0
##          2  288 1376  230   31    6
##          3    1  122  291  238  141
##          4    0    0    0    0    0
##          5    0    0    0    0    0
## 
## Overall Statistics
##                                          
##                Accuracy : 0.8395         
##                  95% CI : (0.8307, 0.848)
##     No Information Rate : 0.6427         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.7013         
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.9358   0.8798  0.55854  0.00000  0.00000
## Specificity            0.9732   0.8980  0.92255  1.00000  1.00000
## Pos Pred Value         0.9843   0.7126  0.36696      NaN      NaN
## Neg Pred Value         0.8939   0.9629  0.96296  0.96145  0.97901
## Prevalence             0.6427   0.2233  0.07440  0.03855  0.02099
## Detection Rate         0.6015   0.1965  0.04155  0.00000  0.00000
## Detection Prevalence   0.6110   0.2757  0.11324  0.00000  0.00000
## Balanced Accuracy      0.9545   0.8889  0.74055  0.50000  0.50000

2.4.3.2 Validation set accuracy (Out-of-Sample)

pred.test <- predict(modFit, testing)
print(confusionMatrix(pred.test, testing$fprice))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5
##          1 1788   29    0    0    0
##          2  139  599  103    7    3
##          3    1   42  119  108   59
##          4    0    0    0    0    0
##          5    0    0    0    0    0
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8362          
##                  95% CI : (0.8224, 0.8493)
##     No Information Rate : 0.6433          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.6957          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.9274   0.8940  0.53604  0.00000  0.00000
## Specificity            0.9729   0.8917  0.92432  1.00000  1.00000
## Pos Pred Value         0.9840   0.7039  0.36170      NaN      NaN
## Neg Pred Value         0.8814   0.9669  0.96139  0.96163  0.97931
## Prevalence             0.6433   0.2236  0.07407  0.03837  0.02069
## Detection Rate         0.5966   0.1999  0.03971  0.00000  0.00000
## Detection Prevalence   0.6063   0.2840  0.10978  0.00000  0.00000
## Balanced Accuracy      0.9501   0.8929  0.73018  0.50000  0.50000

2.5 Developing Predictive Models (Randomforest)

2.5.1 Model definition

modFit <- randomForest(fprice ~. , data=training,
                       importance = TRUE, ntrees = 10)

2.5.2 Model validation

2.5.2.1 Training set accuracy (In-Sample)

pred.train <- predict(modFit, training)
print(confusionMatrix(pred.train, training$fprice))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5
##          1 4501    0    0    0    0
##          2    0 1564    0    0    0
##          3    0    0  521    0    0
##          4    0    0    0  270    0
##          5    0    0    0    0  147
## 
## Overall Statistics
##                                      
##                Accuracy : 1          
##                  95% CI : (0.9995, 1)
##     No Information Rate : 0.6427     
##     P-Value [Acc > NIR] : < 2.2e-16  
##                                      
##                   Kappa : 1          
##  Mcnemar's Test P-Value : NA         
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            1.0000   1.0000   1.0000  1.00000  1.00000
## Specificity            1.0000   1.0000   1.0000  1.00000  1.00000
## Pos Pred Value         1.0000   1.0000   1.0000  1.00000  1.00000
## Neg Pred Value         1.0000   1.0000   1.0000  1.00000  1.00000
## Prevalence             0.6427   0.2233   0.0744  0.03855  0.02099
## Detection Rate         0.6427   0.2233   0.0744  0.03855  0.02099
## Detection Prevalence   0.6427   0.2233   0.0744  0.03855  0.02099
## Balanced Accuracy      1.0000   1.0000   1.0000  1.00000  1.00000

2.5.2.2 Validation set accuracy (Out-of-Sample)

pred.test <- predict(modFit, testing)
print(confusionMatrix(pred.test, testing$fprice))
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5
##          1 1885   40    0    0    0
##          2   43  611   35    0    0
##          3    0   19  168   17    1
##          4    0    0   14   69   21
##          5    0    0    5   29   40
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9253          
##                  95% CI : (0.9153, 0.9344)
##     No Information Rate : 0.6433          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.8586          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5
## Sensitivity            0.9777   0.9119  0.75676  0.60000  0.64516
## Specificity            0.9626   0.9665  0.98667  0.98786  0.98842
## Pos Pred Value         0.9792   0.8868  0.81951  0.66346  0.54054
## Neg Pred Value         0.9599   0.9744  0.98066  0.98410  0.99247
## Prevalence             0.6433   0.2236  0.07407  0.03837  0.02069
## Detection Rate         0.6290   0.2039  0.05606  0.02302  0.01335
## Detection Prevalence   0.6423   0.2299  0.06840  0.03470  0.02469
## Balanced Accuracy      0.9701   0.9392  0.87171  0.79393  0.81679