In the lab, a classification tree was applied to the car seats data set after converting sales into a qualitative response variable. Now we will seek to predict Sales using regression trees and related approaches, treating the response as a quantitative variable.
#Load Packages
library(ISLR)
## Warning: package 'ISLR' was built under R version 4.4.2
#Upload Carsets dataset in our environment
Carseats <- Carseats
?Carseats
## starting httpd help server ... done
(a) Split the data set into training set and a test set.
set.seed(1)
# 80/20 split for train & test sets
train <- sample(1:nrow(Carseats), 0.8*nrow(Carseats))
#Separate dfs for train and test
Carseats_train <- Carseats[train,]
Carseats_test <- Carseats[-train,]
(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test MSE do you obtain?
After creating our regression tree, we have 16 terminal nodes. Our residual mean deviance is 2.572 meaning that on average, the model’s predictions deviate from actual values of Sales by 2.572 units.
After running our predictions against our test set, I received a test MSE of 4.936.
library(tree)
## Warning: package 'tree' was built under R version 4.4.2
#Regression tree on training set
one_tree <- tree(Sales~., data = Carseats_train)
#summary
summary(one_tree)
##
## Regression tree:
## tree(formula = Sales ~ ., data = Carseats_train)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Age" "Income" "CompPrice"
## [6] "Advertising"
## Number of terminal nodes: 16
## Residual mean deviance: 2.572 = 781.9 / 304
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.45400 -1.07000 -0.05544 0.00000 1.14500 4.69600
#Visualize Tree on one_tree
plot(one_tree)
text(one_tree, pretty = 0)
#Compute test MSE
one_tree_preds <- predict(one_tree, newdata = Carseats_test)
mean((one_tree_preds - Carseats_test$Sales)^2)
## [1] 4.936081
(c) Use cross-validation in order to determine the optimal level of tree complexity. Does pruning the tree improve the test MSE?
The tree size of 9 will be used for cross-validation.. As seen in the prune tree plot, we have 9 terminal nodes. After pruning our tree it actually increased our test MSE to 12.
#Finding # of trees
set.seed(1)
cv_one_tree <- cv.tree(one_tree)
plot(cv_one_tree$size, cv_one_tree$dev, type = "b")
#Pruning
prune_one_tree <- prune.tree(one_tree, best = 9)
plot(prune_one_tree)
text(prune_one_tree, pretty = 0)
#Test MSE
prune_preds <- predict(prune_one_tree, newdata = Carseats_train)
mean((prune_preds - Carseats_test$Sales)^2)
## [1] 12.0048
(d) Use the bagging approach in order to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important.
After using the bagging approach, I obtained a test MSE of 2.945.
We are able to see that the variables that are important are:
Price a company sells their car seats
Shelving location within a location
library(randomForest)
## Warning: package 'randomForest' was built under R version 4.4.2
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
set.seed(1)
#bagging
bag_carseats <- randomForest(Sales~., data = Carseats_train, mtry = 10, importance = TRUE)
#Prediction & test MSE
bag_preds <- predict(bag_carseats, newdata = Carseats_test)
mean((bag_preds - Carseats_test$Sales)^2)
## [1] 2.945423
#Variable importance
importance(bag_carseats)
## %IncMSE IncNodePurity
## CompPrice 38.143176 259.77221
## Income 11.839999 138.20846
## Advertising 22.964249 191.25839
## Population -2.309923 74.13451
## Price 76.903940 744.20064
## ShelveLoc 73.841154 692.64875
## Age 25.449768 231.66005
## Education 2.547928 61.58542
## Urban -2.600879 10.15212
## US 3.572899 12.28877
(e) Use random forests to analyze this data. What test MSE do you obtain? Use the importance() function to determine which variables are most important. Describe the effect of m, the number of variables considered at each split, on the error rate obtained.
After using random forest, I obtained a test MSE of 3.41. Using random forest did not improve our test MSE compared to bagging approach.
We are able to see that the variables that are important are:
Price a company sells their car seats
Shelving location within a location
set.seed(1)
#Random forest
rf_carseats <- randomForest(Sales~., data = Carseats_train, mtry = 3, importance = TRUE)
#Rf prediction & test MSE
rf_preds <- predict(rf_carseats, newdata = Carseats_test)
mean((rf_preds - Carseats_test$Sales)^2)
## [1] 3.41076
#Variable Importance
importance(rf_carseats)
## %IncMSE IncNodePurity
## CompPrice 18.9120265 229.42731
## Income 6.1930585 186.13067
## Advertising 17.0572579 193.19072
## Population -1.2262803 148.27359
## Price 47.3801171 606.13470
## ShelveLoc 49.8670998 546.57301
## Age 18.4531716 276.73978
## Education 0.7027208 101.32559
## Urban -3.0631701 19.75315
## US 6.1644035 36.47125
We now use boosting to predict Salary in the Hitters data set.
#create df of Hitters
Hitters <- Hitters
summary(Hitters)
## AtBat Hits HmRun Runs
## Min. : 16.0 Min. : 1 Min. : 0.00 Min. : 0.00
## 1st Qu.:255.2 1st Qu.: 64 1st Qu.: 4.00 1st Qu.: 30.25
## Median :379.5 Median : 96 Median : 8.00 Median : 48.00
## Mean :380.9 Mean :101 Mean :10.77 Mean : 50.91
## 3rd Qu.:512.0 3rd Qu.:137 3rd Qu.:16.00 3rd Qu.: 69.00
## Max. :687.0 Max. :238 Max. :40.00 Max. :130.00
##
## RBI Walks Years CAtBat
## Min. : 0.00 Min. : 0.00 Min. : 1.000 Min. : 19.0
## 1st Qu.: 28.00 1st Qu.: 22.00 1st Qu.: 4.000 1st Qu.: 816.8
## Median : 44.00 Median : 35.00 Median : 6.000 Median : 1928.0
## Mean : 48.03 Mean : 38.74 Mean : 7.444 Mean : 2648.7
## 3rd Qu.: 64.75 3rd Qu.: 53.00 3rd Qu.:11.000 3rd Qu.: 3924.2
## Max. :121.00 Max. :105.00 Max. :24.000 Max. :14053.0
##
## CHits CHmRun CRuns CRBI
## Min. : 4.0 Min. : 0.00 Min. : 1.0 Min. : 0.00
## 1st Qu.: 209.0 1st Qu.: 14.00 1st Qu.: 100.2 1st Qu.: 88.75
## Median : 508.0 Median : 37.50 Median : 247.0 Median : 220.50
## Mean : 717.6 Mean : 69.49 Mean : 358.8 Mean : 330.12
## 3rd Qu.:1059.2 3rd Qu.: 90.00 3rd Qu.: 526.2 3rd Qu.: 426.25
## Max. :4256.0 Max. :548.00 Max. :2165.0 Max. :1659.00
##
## CWalks League Division PutOuts Assists
## Min. : 0.00 A:175 E:157 Min. : 0.0 Min. : 0.0
## 1st Qu.: 67.25 N:147 W:165 1st Qu.: 109.2 1st Qu.: 7.0
## Median : 170.50 Median : 212.0 Median : 39.5
## Mean : 260.24 Mean : 288.9 Mean :106.9
## 3rd Qu.: 339.25 3rd Qu.: 325.0 3rd Qu.:166.0
## Max. :1566.00 Max. :1378.0 Max. :492.0
##
## Errors Salary NewLeague
## Min. : 0.00 Min. : 67.5 A:176
## 1st Qu.: 3.00 1st Qu.: 190.0 N:146
## Median : 6.00 Median : 425.0
## Mean : 8.04 Mean : 535.9
## 3rd Qu.:11.00 3rd Qu.: 750.0
## Max. :32.00 Max. :2460.0
## NA's :59
#Get rid of NAs
Hitters <- na.omit(Hitters)
#log transformation
Hitters$Salary = log(Hitters$Salary)
summary(Hitters)
## AtBat Hits HmRun Runs
## Min. : 19.0 Min. : 1.0 Min. : 0.00 Min. : 0.00
## 1st Qu.:282.5 1st Qu.: 71.5 1st Qu.: 5.00 1st Qu.: 33.50
## Median :413.0 Median :103.0 Median : 9.00 Median : 52.00
## Mean :403.6 Mean :107.8 Mean :11.62 Mean : 54.75
## 3rd Qu.:526.0 3rd Qu.:141.5 3rd Qu.:18.00 3rd Qu.: 73.00
## Max. :687.0 Max. :238.0 Max. :40.00 Max. :130.00
## RBI Walks Years CAtBat
## Min. : 0.00 Min. : 0.00 Min. : 1.000 Min. : 19.0
## 1st Qu.: 30.00 1st Qu.: 23.00 1st Qu.: 4.000 1st Qu.: 842.5
## Median : 47.00 Median : 37.00 Median : 6.000 Median : 1931.0
## Mean : 51.49 Mean : 41.11 Mean : 7.312 Mean : 2657.5
## 3rd Qu.: 71.00 3rd Qu.: 57.00 3rd Qu.:10.000 3rd Qu.: 3890.5
## Max. :121.00 Max. :105.00 Max. :24.000 Max. :14053.0
## CHits CHmRun CRuns CRBI
## Min. : 4.0 Min. : 0.00 Min. : 2.0 Min. : 3.0
## 1st Qu.: 212.0 1st Qu.: 15.00 1st Qu.: 105.5 1st Qu.: 95.0
## Median : 516.0 Median : 40.00 Median : 250.0 Median : 230.0
## Mean : 722.2 Mean : 69.24 Mean : 361.2 Mean : 330.4
## 3rd Qu.:1054.0 3rd Qu.: 92.50 3rd Qu.: 497.5 3rd Qu.: 424.5
## Max. :4256.0 Max. :548.00 Max. :2165.0 Max. :1659.0
## CWalks League Division PutOuts Assists
## Min. : 1.0 A:139 E:129 Min. : 0.0 Min. : 0.0
## 1st Qu.: 71.0 N:124 W:134 1st Qu.: 113.5 1st Qu.: 8.0
## Median : 174.0 Median : 224.0 Median : 45.0
## Mean : 260.3 Mean : 290.7 Mean :118.8
## 3rd Qu.: 328.5 3rd Qu.: 322.5 3rd Qu.:192.0
## Max. :1566.0 Max. :1377.0 Max. :492.0
## Errors Salary NewLeague
## Min. : 0.000 Min. :4.212 A:141
## 1st Qu.: 3.000 1st Qu.:5.247 N:122
## Median : 7.000 Median :6.052
## Mean : 8.593 Mean :5.927
## 3rd Qu.:13.000 3rd Qu.:6.620
## Max. :32.000 Max. :7.808
train_rows <- 1:200
hit_train <- Hitters[train_rows,]
hit_test <- Hitters[-train_rows]
library(gbm)
## Warning: package 'gbm' was built under R version 4.4.2
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
set.seed(1)
#Create lambda values
lambda <- seq(from=0.001, to=1, by=0.05)
#train error of the length of lambda variables & boosting on training set by 1000 trees
train_error <- rep(NA, length(lambda))
for (l in 1:length(lambda)) {
boost_Hitt <- gbm(Salary~.,
data=hit_train,
distribution = "gaussian",
n.trees = 1000,
shrinkage=lambda[l])
train_pred <- predict(boost_Hitt, hit_train, n.trees=1000) #training predictions
train_error[l] <- mean((train_pred- hit_train$Salary)^2) #calc training MSE
}
#Plotting the change in Lambda and resulitng our training MSE
plot(lambda, train_error, type="b", xlab="Shrinkage values", ylab="Training set MSE")
Which variables appear to be the most important predictors in the boosted model?
THE MOST important predictor in our boosted model is CAtBat, the number of times at bat during a player’s career.
summary(boost_Hitt)
## var rel.inf
## CAtBat CAtBat 25.5158639
## PutOuts PutOuts 9.3357305
## HmRun HmRun 7.1430848
## Walks Walks 6.7620682
## Assists Assists 5.5926337
## CWalks CWalks 5.5898943
## RBI RBI 5.3915223
## Errors Errors 4.3036769
## CHmRun CHmRun 4.1175349
## CRuns CRuns 4.1099896
## Hits Hits 3.2847656
## Runs Runs 3.2028648
## AtBat AtBat 3.1556257
## CRBI CRBI 3.1451171
## Years Years 2.6417484
## Division Division 2.4488855
## CHits CHits 2.3025127
## League League 1.1774293
## NewLeague NewLeague 0.7790519