In this part, build and tune different models such as svm,nnet,rpart etc. sperately by using package “caret”,with leave-one-out crossvalidation method of trainControl which number is 1,repeats is 1
load("/home/gong/prepareData/dmaxTrainingAndTesting.RData")
library(doMC)
library(kernlab)
registerDoMC(cores = 3)
# svm RBF kernel
library(caret)
dmax.cvcontrol <- trainControl(method = "LOOCV", number = 1, repeats = 1)
if (file.exists("dmax.svmFit.RData")) {
load("dmax.svmFit.RData")
} else {
dmax.svmFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "svmRadial",
tuneLength = 4, trControl = dmax.cvcontrol, scaled = TRUE)
save(dmax.svmFit, file = "dmax.svmFit.RData")
}
# neural networks
if (file.exists("dmax.nnetFit.RData")) {
load("dmax.nnetFit.RData")
} else {
nnet.grid <- expand.grid(.size = c(7:15), .decay = c(1e-04, 2e-04, 0.005,
0.01))
dmax.nnetFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "nnet",
trControl = dmax.cvcontrol, tuneGrid = nnet.grid)
save(dmax.nnetFit, file = "dmax.nnetFit.RData")
}
# Recursive partitioning
if (file.exists("dmax.rpartFit.RData")) {
load("dmax.rpartFit.RData")
} else {
dmax.rpartFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "rpart",
trControl = dmax.cvcontrol, tuneLength = 4)
save(dmax.rpartFit, file = "dmax.rpartFit.RData")
}
# Boosted trees
if (file.exists("dmax.btFit.RData")) {
load("dmax.btFit.RData")
} else {
dmax.btFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "gbm", trControl = dmax.cvcontrol,
tuneLength = 3)
save(dmax.btFit, file = "dmax.btFit.RData")
}
# random Forests
if (file.exists("dmax.rfFit.RData")) {
load("dmax.rfFit.RData")
} else {
library(randomForest)
dmax.rfFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "rf", trControl = dmax.cvcontrol,
tuneLength = 3)
save(dmax.rfFit, file = "dmax.rfFit.RData")
}
# Linear Least Squares
if (file.exists("dmax.lmFit.RData")) {
load("dmax.lmFit.RData")
} else {
dmax.lmFit <- train(dmaxInputsTrain, dmaxTargetTrain, method = "lm", trControl = dmax.cvcontrol,
tuneLength = 4)
save(dmax.lmFit, file = "dmax.lmFit.RData")
}
In this part, I will make prediction according different models,plot and calculate the errors
# the function to caculate the model errors
modelErrors <- function(predicted, actual) {
sal <- vector(mode = "numeric", length = 3)
names(sal) <- c("MAE", "RMSE", "RELE")
meanPredicted <- mean(predicted)
meanActual <- mean(actual)
sumPred <- sum((predicted - meanPredicted)^2)
sumActual <- sum((actual - meanActual)^2)
n <- length(actual)
p3 <- vector(mode = "numeric", length = n)
for (i in c(1:n)) {
if (actual[i] == 0) {
p3[i] <- abs(predicted[i])
} else {
p3[i] <- ((abs(predicted[i] - actual[i]))/actual[i])
}
}
sal[1] <- mean(abs(predicted - actual))
sal[2] <- sqrt(sum((predicted - actual)^2)/n)
sal[3] <- mean(p3)
sal
}
# Predicting different models and plot the prediction values and true
# values
models <- list(svm = dmax.svmFit, rpart = dmax.rpartFit, nnet = dmax.nnetFit,
boostTrees = dmax.btFit, randomForest = dmax.rfFit, lm = dmax.lmFit)
dmax.preValues <- extractPrediction(models, testX = dmaxInputsTest, testY = dmaxTargetTest)
plotObsVsPred(dmax.preValues)
# calculate errors
dmax.error <- function(model) {
pd <- predict(model, newdata = dmaxInputsTest)
modelErrors(pd, dmaxTargetTest)
}
rf.error <- dmax.error(dmax.rfFit)
bt.error <- dmax.error(dmax.btFit)
nnet.error <- dmax.error(dmax.nnetFit)
svm.error <- dmax.error(dmax.svmFit)
rpart.error <- dmax.error(dmax.rpartFit)
lm.error <- dmax.error(dmax.lmFit)
errorAll <- rbind(rf.error, bt.error, nnet.error, svm.error, rpart.error, lm.error)
errorAll
## MAE RMSE RELE
## rf.error 0.05207 0.0759 0.2219
## bt.error 0.09862 0.1325 0.4372
## nnet.error 0.09820 0.1299 0.4220
## svm.error 0.08709 0.1231 0.3516
## rpart.error 0.10676 0.1397 0.4596
## lm.error 0.10321 0.1376 0.4494
# plot errors of models
barplot(errorAll[, c(1, 2)], main = "MAE&RMSE", col = rainbow(6), beside = TRUE,
, ylim = c(0, 0.3), legend = (rownames(errorAll)))
barplot(errorAll[, c(3)], main = "RELE", col = rainbow(6), beside = TRUE, ylim = c(0,
0.8), legend = (rownames(errorAll)))