Created on 13 Aug 2013
Revised on Tue Aug 13 15:18:10 2013
Steps in building a prediction
1. Find the right data
2. Define your error rate
3. Split data into:
· Training
· Testing
· Validation (optional)
4. On the training set pick features
5. On the training set pick prediction function
6. On the training set cross-validate
7. If no validation - apply 1x to test set
8. If validation - apply to test set and refine
9. If validation - apply 1x to validation
data(faithful)
dim(faithful)
## [1] 272 2
set.seed(333)
trainSamples <- sample(1:272, size = (272/2), replace = F)
trainFaith <- faithful[trainSamples, ]
testFaith <- faithful[-trainSamples, ]
head(trainFaith)
## eruptions waiting
## 128 4.500 82
## 23 3.450 78
## 263 1.850 58
## 154 4.600 81
## 6 2.883 55
## 194 4.100 84
plot(trainFaith$waiting, trainFaith$eruptions, pch = 19, col = "blue", xlab = "Waiting",
ylab = "Duration")
2.1 Fit a linear model
lm1 <- lm(eruptions ~ waiting, data = trainFaith)
summary(lm1)
##
## Call:
## lm(formula = eruptions ~ waiting, data = trainFaith)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.2969 -0.3543 0.0487 0.3310 1.0760
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -1.92491 0.22925 -8.4 5.8e-14 ***
## waiting 0.07639 0.00316 24.2 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.494 on 134 degrees of freedom
## Multiple R-squared: 0.814, Adjusted R-squared: 0.812
## F-statistic: 585 on 1 and 134 DF, p-value: <2e-16
plot(trainFaith$waiting, trainFaith$eruptions, pch = 19, col = "blue", xlab = "Waiting",
ylab = "Duration")
lines(trainFaith$waiting, lm1$fitted, lwd = 3)
Predict a new value
coef(lm1)[1] + coef(lm1)[2] * 80
## (Intercept)
## 4.186
newdata <- data.frame(waiting = 80)
predict(lm1, newdata)
## 1
## 4.186
Plot predictions - training and test
oldpar = par(mfrow = c(1, 2))
plot(trainFaith$waiting, trainFaith$eruptions, pch = 19, col = "blue", xlab = "Waiting",
ylab = "Duration")
lines(trainFaith$waiting, predict(lm1), lwd = 3) # predict(lm1) is the same as lm1$fitted ?
plot(testFaith$waiting, testFaith$eruptions, pch = 19, col = "blue", xlab = "Waiting",
ylab = "Duration")
lines(testFaith$waiting, predict(lm1, newdata = testFaith), lwd = 3)
par(oldpar)
Get training set/test set errors
sqrt(sum((lm1$fitted - trainFaith$eruptions)^2)) # Calculate RMSE on training
## [1] 5.713
sqrt(sum((predict(lm1, newdata = testFaith) - testFaith$eruptions)^2)) # Calculate RMSE on test
## [1] 5.827
Prediction intervals
pred1 <- predict(lm1, newdata = testFaith, interval = "prediction")
ord <- order(testFaith$waiting)
plot(testFaith$waiting, testFaith$eruptions, pch = 19, col = "blue")
matlines(testFaith$waiting[ord], pred1[ord, ], type = "l", , col = c(1, 2, 2),
lty = c(1, 1, 1), lwd = 3)
2.2.Example with binary data (glm): Baltimore Ravens
require(RCurl)
## Loading required package: RCurl
## Loading required package: bitops
myCsv <- getURL("https://dl.dropboxusercontent.com/u/8272421/ravensData.csv",
ssl.verifypeer = FALSE)
ravensData <- read.csv(textConnection(myCsv))
head(ravensData)
## ravenWinNum ravenWin ravenScore opponentScore
## 1 1 W 24 9
## 2 1 W 38 35
## 3 1 W 28 13
## 4 1 W 34 31
## 5 1 W 44 13
## 6 0 L 23 24
Fit a logistic regression
glm1 <- glm(ravenWinNum ~ ravenScore, family = "binomial", data = ravensData)
oldpar = par(mfrow = c(1, 2))
boxplot(predict(glm1) ~ ravensData$ravenWinNum, col = "blue")
boxplot(predict(glm1, type = "response") ~ ravensData$ravenWinNum, col = "blue")
par(oldpar)
Choosing a cutoff (re-substitution)
xx <- seq(0, 1, length = 10)
err <- rep(NA, 10)
for (i in 1:length(xx)) {
err[i] <- sum((predict(glm1, type = "response") > xx[i]) != ravensData$ravenWinNum)
}
plot(xx, err, pch = 19, xlab = "Cutoff", ylab = "Error")
Comparing models with cross validation
library(boot)
cost <- function(win, pred = 0) mean(abs(win - pred) > 0.5)
glm1 <- glm(ravenWinNum ~ ravenScore, family = "binomial", data = ravensData)
glm2 <- glm(ravenWinNum ~ ravenScore, family = "gaussian", data = ravensData)
cv1 <- cv.glm(ravensData, glm1, cost, K = 3)
cv2 <- cv.glm(ravensData, glm2, cost, K = 3)
cv1$delta
## [1] 0.350 0.365
cv2$delta
## [1] 0.40 0.42
3.1 Example: Iris Data
data(iris)
names(iris)
## [1] "Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width"
## [5] "Species"
table(iris$Species)
##
## setosa versicolor virginica
## 50 50 50
Iris petal widths/sepal width
plot(iris$Petal.Width, iris$Sepal.Width, pch = 19, col = as.numeric(iris$Species))
legend(1, 4.5, legend = unique(iris$Species), col = unique(as.numeric(iris$Species)),
pch = 19)
library(tree) # An alternative is library(rpart)
## Warning: package 'tree' was built under R version 3.0.1
tree1 <- tree(Species ~ Sepal.Width + Petal.Width, data = iris)
summary(tree1)
##
## Classification tree:
## tree(formula = Species ~ Sepal.Width + Petal.Width, data = iris)
## Number of terminal nodes: 5
## Residual mean deviance: 0.204 = 29.6 / 145
## Misclassification error rate: 0.0333 = 5 / 150
Plot tree
plot(tree1)
text(tree1)
Another way of looking at a CART model
plot(iris$Petal.Width, iris$Sepal.Width, pch = 19, col = as.numeric(iris$Species))
partition.tree(tree1, label = "Species", add = TRUE)
legend(1.75, 4.5, legend = unique(iris$Species), col = unique(as.numeric(iris$Species)),
pch = 19)
Predicting new values
set.seed(32313)
newdata <- data.frame(Petal.Width = runif(20, 0, 2.5), Sepal.Width = runif(20,
2, 4.5))
pred1 <- predict(tree1, newdata)
pred1
## setosa versicolor virginica
## 1 0 0.02174 0.97826
## 2 0 0.02174 0.97826
## 3 1 0.00000 0.00000
## 4 0 1.00000 0.00000
## 5 0 0.02174 0.97826
## 6 0 0.02174 0.97826
## 7 0 0.02174 0.97826
## 8 0 0.90476 0.09524
## 9 0 1.00000 0.00000
## 10 0 0.02174 0.97826
## 11 0 1.00000 0.00000
## 12 1 0.00000 0.00000
## 13 1 0.00000 0.00000
## 14 1 0.00000 0.00000
## 15 0 0.02174 0.97826
## 16 0 0.02174 0.97826
## 17 0 1.00000 0.00000
## 18 1 0.00000 0.00000
## 19 0 1.00000 0.00000
## 20 0 1.00000 0.00000
Overlaying new values
pred1 <- predict(tree1, newdata, type = "class")
plot(newdata$Petal.Width, newdata$Sepal.Width, col = as.numeric(pred1), pch = 19)
partition.tree(tree1, "Species", add = TRUE)
3.2. Pruning trees example: Cars
data(Cars93, package = "MASS")
head(Cars93)
## Manufacturer Model Type Min.Price Price Max.Price MPG.city
## 1 Acura Integra Small 12.9 15.9 18.8 25
## 2 Acura Legend Midsize 29.2 33.9 38.7 18
## 3 Audi 90 Compact 25.9 29.1 32.3 20
## 4 Audi 100 Midsize 30.8 37.7 44.6 19
## 5 BMW 535i Midsize 23.7 30.0 36.2 22
## 6 Buick Century Midsize 14.2 15.7 17.3 22
## MPG.highway AirBags DriveTrain Cylinders EngineSize
## 1 31 None Front 4 1.8
## 2 25 Driver & Passenger Front 6 3.2
## 3 26 Driver only Front 6 2.8
## 4 26 Driver & Passenger Front 6 2.8
## 5 30 Driver only Rear 4 3.5
## 6 31 Driver only Front 4 2.2
## Horsepower RPM Rev.per.mile Man.trans.avail Fuel.tank.capacity
## 1 140 6300 2890 Yes 13.2
## 2 200 5500 2335 Yes 18.0
## 3 172 5500 2280 Yes 16.9
## 4 172 5500 2535 Yes 21.1
## 5 208 5700 2545 Yes 21.1
## 6 110 5200 2565 No 16.4
## Passengers Length Wheelbase Width Turn.circle Rear.seat.room
## 1 5 177 102 68 37 26.5
## 2 5 195 115 71 38 30.0
## 3 5 180 102 67 37 28.0
## 4 6 193 106 70 37 31.0
## 5 4 186 109 69 39 27.0
## 6 6 189 105 69 41 28.0
## Luggage.room Weight Origin Make
## 1 11 2705 non-USA Acura Integra
## 2 15 3560 non-USA Acura Legend
## 3 14 3375 non-USA Audi 90
## 4 17 3405 non-USA Audi 100
## 5 13 3640 non-USA BMW 535i
## 6 16 2880 USA Buick Century
Build a tree
treeCars <- tree(DriveTrain ~ MPG.city + MPG.highway + AirBags + EngineSize +
Width + Length + Weight + Price + Cylinders + Horsepower + Wheelbase, data = Cars93)
plot(treeCars)
text(treeCars)
Plot errors
oldpar = par(mfrow = c(1, 2))
plot(cv.tree(treeCars, FUN = prune.tree, method = "misclass"))
plot(cv.tree(treeCars))
par(oldpar)
pruneTree <- prune.tree(treeCars, best = 4)
plot(pruneTree)
text(pruneTree)
Show resubstitution error (Note that cross validation error is a better measure of test set accuracy)
table(Cars93$DriveTrain, predict(pruneTree, type = "class"))
##
## 4WD Front Rear
## 4WD 5 5 0
## Front 1 66 0
## Rear 1 10 5
table(Cars93$DriveTrain, predict(treeCars, type = "class"))
##
## 4WD Front Rear
## 4WD 5 5 0
## Front 2 61 4
## Rear 0 3 13