Analysis

0 Loading Data and EDA

train_data <- read.csv("data/processed/train_data.csv")

head(train_data)
str(train_data)
## 'data.frame':    18876 obs. of  11 variables:
##  $ Unnamed..0: int  19041 397 15627 16598 5812 6310 20880 24489 14748 22256 ...
##  $ carat     : num  1.17 1.2 0.31 2.19 0.3 1.01 0.74 0.53 0.7 0.52 ...
##  $ cut       : chr  "Good" "Ideal" "Very Good" "Ideal" ...
##  $ color     : chr  "D" "F" "I" "I" ...
##  $ clarity   : chr  "SI2" "VVS1" "VVS2" "SI2" ...
##  $ depth     : num  60.4 61.1 61.6 62.5 62.1 59.1 61.7 60.3 62.2 63.7 ...
##  $ table     : num  65 55 59 56 57 63 56 56 56 56 ...
##  $ x         : num  6.81 6.86 4.31 8.31 4.27 6.59 5.83 5.29 5.73 5.08 ...
##  $ y         : num  6.77 6.89 4.33 8.24 4.3 6.54 5.78 5.33 5.68 5.13 ...
##  $ z         : num  4.1 4.2 2.66 5.18 2.66 3.88 3.58 3.2 3.55 3.25 ...
##  $ price     : int  5567 13088 544 15254 491 3671 3170 2293 2792 1446 ...
summary(train_data)
##    Unnamed..0        carat            cut               color          
##  Min.   :    1   Min.   :0.2000   Length:18876       Length:18876      
##  1st Qu.: 6729   1st Qu.:0.4000   Class :character   Class :character  
##  Median :13468   Median :0.7000   Mode  :character   Mode  :character  
##  Mean   :13469   Mean   :0.7959                                        
##  3rd Qu.:20181   3rd Qu.:1.0400                                        
##  Max.   :26967   Max.   :4.5000                                        
##                                                                        
##    clarity              depth           table             x         
##  Length:18876       Min.   :50.80   Min.   :49.00   Min.   : 0.000  
##  Class :character   1st Qu.:61.00   1st Qu.:56.00   1st Qu.: 4.710  
##  Mode  :character   Median :61.80   Median :57.00   Median : 5.690  
##                     Mean   :61.74   Mean   :57.46   Mean   : 5.726  
##                     3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 6.540  
##                     Max.   :73.60   Max.   :79.00   Max.   :10.230  
##                     NA's   :498                                     
##        y                z             price      
##  Min.   : 0.000   Min.   :0.000   Min.   :  326  
##  1st Qu.: 4.720   1st Qu.:2.900   1st Qu.:  945  
##  Median : 5.700   Median :3.520   Median : 2367  
##  Mean   : 5.728   Mean   :3.534   Mean   : 3918  
##  3rd Qu.: 6.540   3rd Qu.:4.040   3rd Qu.: 5310  
##  Max.   :10.160   Max.   :6.720   Max.   :18818  
## 
# Create numerical mappings for ordinal data
cut_levels <- c("Fair" = 1, "Good" = 2, "Very Good" = 3, "Premium" = 4, "Ideal" = 5)
color_levels <- c("D" = 7, "E" = 6, "F" = 5, "G" = 4, "H" = 3, "I" = 2, "J" = 1)
clarity_levels <- c("I3" = 1, "I2" = 2, "I1" = 3, "SI2" = 4, "SI1" = 5, "VS2" = 6, "VS1" = 7, "VVS2" = 8, "VVS1" = 9, "IF" = 10, "FL" = 11)

# Apply mappings to columns
train_data$cut <- unname(sapply(train_data$cut, function(x) cut_levels[x]))
train_data$color <- unname(sapply(train_data$color, function(x) color_levels[x]))
train_data$clarity <- unname(sapply(train_data$clarity, function(x) clarity_levels[x]))

# check change
head(train_data[c("cut", "color", "clarity")])
print(diagnose(independent_and_dependent_var))
## Error in diagnose(independent_and_dependent_var): object 'independent_and_dependent_var' not found
independent_and_dependent_var <- independent_and_dependent_var[, !(names(independent_and_dependent_var) %in% c("depth"))]
## Error in eval(expr, envir, enclos): object 'independent_and_dependent_var' not found
head(independent_and_dependent_var)
## Error in head(independent_and_dependent_var): object 'independent_and_dependent_var' not found

1 Plots of independent and dependent variables

# exclude index column
independent_and_dependent_var <- train_data[,-1]
head(independent_and_dependent_var)
# plot_pairs <- pairs(independent_and_dependent_var, pch=10, cex=0.5)

plot_pairs <- ggpairs(independent_and_dependent_var, progress = FALSE)
print(plot_pairs)

2 Correlation Analysis

independent_var <- independent_and_dependent_var[, !(names(independent_and_dependent_var) %in% c("price","depth"))]

library(corrplot)

# correlation matrix plot

corrplot(cor(independent_var),
  method = "square")

# install.packages("lares")
library(lares)
top_correlation_plot <- corr_cross(independent_var, rm.na = T, max_pvalue = 0.05, top = 15, grid = T)
## Returning only the top 15. You may override with the 'top' argument
print(top_correlation_plot)

3 Model Selection and Screening

model = lm(formula = price ~ carat+cut+color+clarity+table,data = independent_and_dependent_var)

summary(model)
## 
## Call:
## lm(formula = price ~ carat + cut + color + clarity + table, data = independent_and_dependent_var)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -14912.3   -701.7   -166.5    555.3   8972.0 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -7144.681    282.053 -25.331  < 2e-16 ***
## carat        8824.817     21.597 408.616  < 2e-16 ***
## cut           142.619      9.144  15.597  < 2e-16 ***
## color         333.234      5.560  59.929  < 2e-16 ***
## clarity       526.163      5.967  88.175  < 2e-16 ***
## table         -20.311      4.543  -4.471 7.83e-06 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1238 on 18870 degrees of freedom
## Multiple R-squared:  0.9046, Adjusted R-squared:  0.9046 
## F-statistic: 3.578e+04 on 5 and 18870 DF,  p-value: < 2.2e-16

4 model-fitting techniques

# Assess the goodness of fit
cat("Adjusted R-squared:", summary(model)$adj.r.squared, "\n")
## Adjusted R-squared: 0.9045597
cat("F-statistic:", summary(model)$fstatistic[1], "\n")
## F-statistic: 35779.51

p-value: < 2.2e-16

5 diagnostics and residual analysis:

par(mfrow=c(2,2))
plot(model)

# Residual analysis
# Residual plots
par(mfrow=c(2,2))
plot(model, which = c(1, 2, 3))

# QQ plot
qqnorm(model$residuals,alpha=0.5,pch=10,cex=0.3)
## Warning in plot.window(...): "alpha" is not a graphical parameter
## Warning in plot.xy(xy, type, ...): "alpha" is not a graphical parameter
## Warning in axis(side = side, at = at, labels = labels, ...): "alpha" is not a
## graphical parameter

## Warning in axis(side = side, at = at, labels = labels, ...): "alpha" is not a
## graphical parameter
## Warning in box(...): "alpha" is not a graphical parameter
## Warning in title(...): "alpha" is not a graphical parameter
qqline(model$residuals)

# Leverage vs. standardized residuals plot
plot(hatvalues(model), rstandard(model), xlab = "Leverage", ylab = "Standardized Residuals", main = "Leverage vs. Standardized Residuals")

6.1 Validation, and then do necessary model modifications/transformations

val_data <- read.csv("data/processed/validation_data.csv")

# Apply mappings to columns
val_data$cut <- unname(sapply(val_data$cut, function(x) cut_levels[x]))
val_data$color <- unname(sapply(val_data$color, function(x) color_levels[x]))
val_data$clarity <- unname(sapply(val_data$clarity, function(x) clarity_levels[x]))

predicted_prices <- predict(model, newdata = val_data)
val_data$predicted_price <- predicted_prices

mse <- mean((val_data$predicted_price - val_data$price)^2)
rmse <- sqrt(mse)

# MSE and RMSE
print(paste("MSE:", mse))
## [1] "MSE: 1515490.8714597"
print(paste("RMSE:", rmse))
## [1] "RMSE: 1231.05274925963"
# val vs real price

ggplot(val_data, aes(x = price, y = predicted_price)) +
  geom_point(alpha = 0.6, pch=20,cex=0.5) + 
  geom_abline(color = "red", linetype = "dashed") +
  xlab("Actual Price") +
  ylab("Predicted Price") +
  ggtitle("Validation Data: Predicted vs Actual Prices")

6.2 non-linear to linear

library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
library(scales)
## 
## Attaching package: 'scales'
## The following object is masked from 'package:readr':
## 
##     col_factor
plot1 <- qplot(x=price,data=independent_and_dependent_var) + 
  ggtitle('Price')
## Warning: `qplot()` was deprecated in ggplot2 3.4.0.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
plot2 <- qplot(x=log10(price),data=independent_and_dependent_var) +
  ggtitle('Price (log10)')

grid.arrange(plot1,plot2)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

# non-linear
qplot(x=carat,y=price,data=independent_and_dependent_var)+
  geom_point(alpha = 0.5, pch=20,cex=0.5) +
  ggtitle('Price by Carat')

# close to linear
qplot(x=carat,y=price,data=independent_and_dependent_var)+
  scale_y_continuous(trans = log10_trans())+
  ggtitle('Price (log10) by Carat')

# transformation on carat
cuberoot_trans = function() trans_new('cuberoot', transform = function(x) x^(1/3),
                                      inverse = function(x) x^3)

# linear
ggplot(aes(carat, price), data = independent_and_dependent_var) + 
  geom_point(alpha = 0.5, pch=20,cex=0.5) +
  scale_x_continuous(trans = cuberoot_trans(), limits = c(0.2, 3),
                     breaks = c(0.2, 0.5, 1, 2, 3)) + 
  scale_y_continuous(trans = log10_trans(), limits = c(350, 15000),
                     breaks = c(350, 1000, 5000, 10000, 15000)) +
  ggtitle('Price (log10) by Cube-Root of Carat')
## Warning: Removed 622 rows containing missing values or values outside the scale range
## (`geom_point()`).

6.3 necessary model modifications/transformations

install.packages("memisc")
## Error in contrib.url(repos, "source"): trying to use CRAN without setting a mirror
library(memisc)
## Loading required package: lattice
## 
## Attaching package: 'memisc'
## The following object is masked from 'package:scales':
## 
##     percent
## The following objects are masked from 'package:dplyr':
## 
##     collect, recode, rename, syms
## The following objects are masked from 'package:lubridate':
## 
##     as.interval, is.interval
## The following object is masked from 'package:car':
## 
##     recode
## The following object is masked from 'package:ggplot2':
## 
##     syms
## The following objects are masked from 'package:stats':
## 
##     contr.sum, contr.treatment, contrasts
## The following object is masked from 'package:base':
## 
##     as.array
m1 <- lm(I(log(price)) ~ I(carat^(1/3)), data = independent_and_dependent_var)
m2 <- update(m1, ~ . + carat)
m3 <- update(m2, ~ . + cut)
m4 <- update(m3, ~ . + color)
m5 <- update(m4, ~ . + clarity)
m6 <- update(m5, ~ . + table)
mtable(m1, m2, m3, m4, m5, m6)
## 
## Calls:
## m1: lm(formula = I(log(price)) ~ I(carat^(1/3)), data = independent_and_dependent_var)
## m2: lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat, data = independent_and_dependent_var)
## m3: lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat + cut, data = independent_and_dependent_var)
## m4: lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat + cut + color, 
##     data = independent_and_dependent_var)
## m5: lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat + cut + color + 
##     clarity, data = independent_and_dependent_var)
## m6: lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat + cut + color + 
##     clarity + table, data = independent_and_dependent_var)
## 
## ======================================================================================================
##                        m1            m2            m3            m4            m5            m6       
## ------------------------------------------------------------------------------------------------------
##   (Intercept)         2.817***      1.082***      0.769***      0.536***     -0.800***     -0.777***  
##                      (0.011)       (0.032)       (0.033)       (0.030)       (0.020)       (0.037)    
##   I(carat^(1/3))      5.560***      8.492***      8.639***      8.470***      9.287***      9.288***  
##                      (0.012)       (0.053)       (0.052)       (0.048)       (0.030)       (0.030)    
##   carat                            -1.109***     -1.144***     -1.016***     -1.156***     -1.156***  
##                                    (0.020)       (0.019)       (0.018)       (0.011)       (0.011)    
##   cut                                             0.054***      0.054***      0.033***      0.032***  
##                                                  (0.002)       (0.002)       (0.001)       (0.001)    
##   color                                                         0.064***      0.079***      0.079***  
##                                                                (0.001)       (0.001)       (0.001)    
##   clarity                                                                     0.122***      0.122***  
##                                                                              (0.001)       (0.001)    
##   table                                                                                    -0.000     
##                                                                                            (0.001)    
## ------------------------------------------------------------------------------------------------------
##   R-squared           0.924         0.935         0.939         0.949         0.981         0.981     
##   N               18876         18876         18876         18876         18876         18876         
## ======================================================================================================
##   Significance: *** = p < 0.001; ** = p < 0.01; * = p < 0.05
summary(m6)
## 
## Call:
## lm(formula = I(log(price)) ~ I(carat^(1/3)) + carat + cut + color + 
##     clarity + table, data = independent_and_dependent_var)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.98335 -0.08625  0.00278  0.09222  1.43545 
## 
## Coefficients:
##                  Estimate Std. Error  t value Pr(>|t|)    
## (Intercept)    -0.7768639  0.0365619  -21.248   <2e-16 ***
## I(carat^(1/3))  9.2875769  0.0295223  314.596   <2e-16 ***
## carat          -1.1562243  0.0108575 -106.491   <2e-16 ***
## cut             0.0323703  0.0010398   31.130   <2e-16 ***
## color           0.0787013  0.0006321  124.509   <2e-16 ***
## clarity         0.1216744  0.0006862  177.308   <2e-16 ***
## table          -0.0003826  0.0005164   -0.741    0.459    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.1406 on 18869 degrees of freedom
## Multiple R-squared:  0.9809, Adjusted R-squared:  0.9809 
## F-statistic: 1.615e+05 on 6 and 18869 DF,  p-value: < 2.2e-16
#we could see that m6 doesn't really improve the model and also its pvalue is larger than 0.05 which is 0.459(from summary(m6)), so eventually we choose m5
predicted_prices_m5 <- predict(m5, newdata = val_data)
val_data$predicted_prices_m5 <- predicted_prices_m5
val_data$log_price <- log(val_data$price)

mse <- mean((val_data$predicted_prices_m5 - val_data$log_price)^2)
rmse <- sqrt(mse)

# MSE and RMSE
print(paste("MSE:", mse))
## [1] "MSE: 0.019585421936943"
print(paste("RMSE:", rmse))
## [1] "RMSE: 0.139947925804361"
# val vs real price

ggplot(val_data, aes(x = log_price, y = predicted_prices_m5)) +
  geom_point(alpha = 0.5, pch=20,cex=0.5) +
  geom_abline(color = "red", linetype = "dashed") +
  xlab("Actual Price") +
  ylab("Predicted Price") +
  ggtitle("Validation Data: Predicted vs Actual Prices")

7.1 Linearity Assumption

independent_and_dependent_var
# 1. Linearity/Mean Zero Assumption

## Plot standardized residuals against each predictor
# Calculate standardized residuals
resids = stdres(m5)

plot(independent_and_dependent_var[,1]^(1/3),resids,xlab="carat^(1/3)",ylab="Residuals",main="Standardized Residuals vs. carat^(1/3)")
abline(0,0,col="red")

plot(independent_and_dependent_var[,1],resids,xlab="carat",ylab="Residuals", main = "Standardized Residuals vs. carat")
abline(0,0,col="red")

plot(independent_and_dependent_var[,2],resids,xlab="cut",ylab="Residuals", main = "Standardized Residuals vs. cut")
abline(0,0,col="red")

# how to handle qualitative variables?

7.2 Constant Variance of Errors Assumption

# Calculate standardized residuals and fitted values
independent_and_dependent_var$std_resid <- rstandard(m5)
independent_and_dependent_var$fitted_values <- fitted(m5)

# Plot Standardized Residuals vs. Fitted Values
plot(independent_and_dependent_var$fitted_values, independent_and_dependent_var$std_resid, xlab="Fitted Values",ylab="Residuals",pch=18,cex=0.15)
abline(0,0,col="red")

## No trends -> Constant variance assumption holds 

7.3 Independence of Errors Assumption

Plot standardized residuals against fitted values No clusters -> Independence assumption holds

7.4 Normality of Errors

## QQ normal plot & histogram
p6 <- ggplot(independent_and_dependent_var, aes(x = residuals(m5))) +
  geom_histogram(binwidth = 0.1, fill = "blue", alpha = 0.7) +
  labs(title = "Histogram of Residuals",
       x = "Residuals", y = "Frequency")
# The residuals should have an approximately symmetric, unimodal distribution, with no gaps in the data.

p7 <- ggplot(independent_and_dependent_var, aes(sample = residuals(m5))) +
  geom_qq() +
  geom_qq_line() +
  labs(title = "Q-Q Plot",
       x = "Theoretical Quantiles", y = "Sample Quantiles")
# Curvature (especially at the ends) shows non-normality

grid.arrange(p6, p7, ncol = 2)

8 Identifying and Handling Unusual Observations

# Cook's distance checking for outliers
# Di > 4/n, Di > 1, OR any “large” Di should be investigated
cat("4/n = ", 4/(nrow(independent_and_dependent_var)), "\n")
## 4/n =  0.0002119093
# Calculate Cook's distance
cooksd <- cooks.distance(m5)
#independent_and_dependent_var$cooksd <- cooksd
plot(cooksd,type="h",lwd=3,col="red", ylab = "Cook's Distance")

## We have a large number of outliers, which suggests a heavy tailed distribution rather than truly extreme values.

cleaned_data <- independent_and_dependent_var[independent_and_dependent_var$cooksd <= 4/(nrow(independent_and_dependent_var)), ]

m5_clean <- lm(I(log(price)) ~ I(carat^(1/3)) + carat + cut + color + clarity, data = cleaned_data)
## Error in lm.fit(x, y, offset = offset, singular.ok = singular.ok, ...): 0 (non-NA) cases
summary(m5_clean)
## Error in h(simpleError(msg, call)): error in evaluating the argument 'object' in selecting a method for function 'summary': object 'm5_clean' not found

9 Multicollinearity Checks

## Acceptable: VIF_j < MAX(10, 1/(1-R^2))
cat("R^2 = ", summary(m5)$r.squared,"\n")
## R^2 =  0.9808975
cat("MAX(10, 1/(1-R^2)) = ", max(10, 1/(1-summary(m5)$r.squared)), "\n")
## MAX(10, 1/(1-R^2)) =  52.34903
vif(m5)
## I(carat^(1/3))          carat            cut          color        clarity 
##      25.714876      25.372489       1.045569       1.115387       1.216591
## No multicollinearity