library(ISLR)
Hitters = na.omit(Hitters)
Hitters$Salary = log(Hitters$Salary)
train = 1:200
Hitters.train = Hitters[train, ]
Hitters.test = Hitters[-train, ]
library(gbm)
## Warning: package 'gbm' was built under R version 3.6.2
## Loaded gbm 2.1.8
set.seed(1)
pows = seq(-10, -0.2, by = 0.1)
lambdas = 10^pows
train.err = rep(NA, length(lambdas))
for (i in 1:length(lambdas)) {
boost.hitters = gbm(Salary ~ ., data = Hitters.train, distribution = "gaussian", n.trees = 1000, shrinkage = lambdas[i])
pred.train = predict(boost.hitters, Hitters.train, n.trees = 1000)
train.err[i] = mean((pred.train - Hitters.train$Salary)^2)
}
plot(lambdas, train.err, type = "b", xlab = "Shrinkage values", ylab = "Training MSE")
set.seed(1)
test.err = rep(NA, length(lambdas))
for (i in 1:length(lambdas)) {
boost.hitters <- gbm(Salary ~ ., data = Hitters.train, distribution = "gaussian", n.trees = 1000, shrinkage = lambdas[i])
yhat = predict(boost.hitters, Hitters.test, n.trees = 1000)
test.err[i] = mean((yhat - Hitters.test$Salary)^2)
}
plot(lambdas, test.err, type = "b", xlab = "Shrinkage values", ylab = "Test MSE")
min(test.err)
## [1] 0.2540265
lambdas[which.min(test.err)]
## [1] 0.07943282
fit1 = lm(Salary ~ ., data = Hitters.train)
pred1 = predict(fit1, Hitters.test)
mean((pred1 - Hitters.test$Salary)^2)
## [1] 0.4917959
library(glmnet)
## Warning: package 'glmnet' was built under R version 3.6.2
## Loading required package: Matrix
## Loaded glmnet 4.1-1
x = model.matrix(Salary ~ ., data = Hitters.train)
x.test = model.matrix(Salary ~ ., data = Hitters.test)
y = Hitters.train$Salary
fit2 = glmnet(x, y, alpha = 0)
pred2 = predict(fit2, s = 0.01, newx = x.test)
mean((pred2 - Hitters.test$Salary)^2)
## [1] 0.4570283
The test MSE for boosting is lower than for linear regression and ridge regression.
boost.hitters = gbm(Salary ~ ., data = Hitters.train, distribution = "gaussian", n.trees = 1000, shrinkage = lambdas[which.min(test.err)])
summary(boost.hitters)
## var rel.inf
## CAtBat CAtBat 20.8404970
## CRBI CRBI 12.3158959
## Walks Walks 7.4186037
## PutOuts PutOuts 7.1958539
## Years Years 6.3104535
## CWalks CWalks 6.0221656
## CHmRun CHmRun 5.7759763
## CHits CHits 4.8914360
## AtBat AtBat 4.2187460
## RBI RBI 4.0812410
## Hits Hits 4.0117255
## Assists Assists 3.8786634
## HmRun HmRun 3.6386178
## CRuns CRuns 3.3230296
## Errors Errors 2.6369128
## Runs Runs 2.2048386
## Division Division 0.5347342
## NewLeague NewLeague 0.4943540
## League League 0.2062551
We may see that “CAtBat” is by far the most important variable.