library(tree)
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
library(gbm)
## Loaded gbm 2.1.8.1
library(xgboost)
setwd("C:/Users/gabeg/Documents/Uni/Stat 5003/Week 9")

df <- read.csv("winequality-data.csv", header = TRUE)

df <- df[, -ncol(df)]

Question 2

set.seed(530306632)
library(caret)
## Loading required package: ggplot2
## 
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
## 
##     margin
## Loading required package: lattice
split_ind <- createDataPartition(df[["quality"]],p =0.5)[[1]]


train_data <- df[split_ind,]

test_data <- df[-split_ind,]
tree_model <- tree(quality ~., data = train_data)

plot(tree_model)

text(tree_model)

##Calcualte RSS for our tree

tree_test <- predict(tree_model, newdata = test_data)

RSS <- sum((tree_test - test_data$quality)^2)

MSE <- mean((tree_test - test_data$quality)^2)

print(RSS)
## [1] 1136.123
print(MSE)
## [1] 0.5799502

Question 3

## 3.1
library(ranger)
## 
## Attaching package: 'ranger'
## The following object is masked from 'package:randomForest':
## 
##     importance
rf_model <- ranger(quality ~., data= train_data, importance = "impurity")

rf_test <- predict(rf_model, data = test_data)

import_scores <- importance(rf_model)

import_df <- data.frame(variables = names(import_scores), values = as.numeric(import_scores))



# Sort the importance scores in descending order
sorted_df <- import_df[order(import_df$values, decreasing = TRUE),]
# Select the top 10 columns
top_10_columns <- sorted_df[1:10,]

# Create a bar chart
barplot(top_10_columns$values, main = "Top 10 Most Important Columns", 
        ylab = "Importance Score", xlab = "Variables", col = "blue", 
        names.arg = top_10_columns$variables, cex.names = 0.8, horiz = FALSE)

## 3.2

RSS_rf <- sum((rf_test$predictions - test_data$quality)^2)

MSE_rf <- mean((rf_test$predictions - test_data$quality)^2)

print(RSS_rf)
## [1] 832.7847
print(MSE_rf)
## [1] 0.4251071

Question 3.3

set.seed(530306632)
ntrees_big <- c(seq(from = 50, to= 500, by = 50), seq(from = 1000, to = 3000, by = 250))

RSS_store <- c()

for (i in ntrees_big){
  model <- ranger(quality ~., data = train_data,  num.trees = i)
  test <- predict(model, data = test_data)
  RSS <- sum((test$predictions - test_data$quality)^2)

  RSS_store <- c(RSS_store, RSS)
}

plot(ntrees_big, RSS_store, type = 'b', main = "Initial Tuning of Num Trees", xlab = "ntrees", ylab= "RSS")

lowest_index_big <- which.min(RSS_store)


points(ntrees_big[lowest_index_big], RSS_store[lowest_index_big], pch = 19, col = "red")

legend("topright", legend = "Best ntrees", pch = 19, col = "red")

Discussion

Looking at the above graph RSS is minimised at 300 trees (red dot). I will do some closer investigations to get a more accurate number.

set.seed(530306632)
ntrees_fine <- seq(from = 250, to= 1000, by = 25)

RSS_store_finer <- c()
for (i in ntrees_fine){
  model <- ranger(quality ~., data = train_data,  num.trees = i)
  test <- predict(model, data = test_data)
  RSS_finer <- sum((test$predictions - test_data$quality)^2)
  RSS_store_finer <- c(RSS_store_finer, RSS_finer)
}

plot(ntrees_fine, RSS_store_finer, type = 'b', main = "Fine Tuning of Num Trees", xlab = "ntrees", ylab= "RSS")


lowest_index <- which.min(RSS_store_finer)


points(ntrees_fine[lowest_index], RSS_store_finer[lowest_index], pch = 19, col = "red")
legend("topright", legend = "Best ntrees", pch = 19, col = "red")

Discussion

As you can see from this graph the relationship between RSS and num trees is not linear. It appears to be very volatile. It is minimised when numtrees = 375.

RSS_test_gbm <- c()
ntree_seq <- c(seq(from = 50, to= 500, by = 50), seq(from = 1000, to = 3000, by = 250))

for (i in ntree_seq){
  gbm_model <- gbm(quality ~., data = train_data, distribution = "gaussian", n.trees = i)
  gbm_test <- predict(gbm_model, newdata = test_data, n.trees = i)
  RSS <- sum((gbm_test - test_data$quality)^2)
  RSS_test_gbm <- c(RSS_test_gbm, RSS)
}

plot(ntree_seq, RSS_test_gbm, type = 'b', xlab = "ntrees", ylab= "RSS")