Comparison of algorithm selection model performance between pairwise random forest regression and REINFORCE.
library(tidyr)
library(magrittr)
library(dplyr)
library(ggplot2)
library(ggthemes)
library(grid)
library(scales)
library(aslib)
library(llama)
library(randomForest)
library(parallelMap)
source('reinforce.R')
dataset_hard <- ExtractHardInstancesGRAPHS2015()
if (!file.exists("model_regr_hard.rds")) {
parallelStartSocket(4)
parallelLibrary("llama", "mlr")
start_time <- Sys.time()
system.time(model_regr_hard <- regressionPairs(makeLearner("regr.randomForest"), dataset_hard))
end_time <- Sys.time()
saveRDS(model_regr_hard, "model_regr_hard.rds")
cat("Training started at", format(start_time, "%X"), "and ended at", format(end_time, "%X"), "\n")
end_time - start_time
} else {
model_regr_hard <- readRDS("model_regr_hard.rds")
cat("Loaded model_regr_hard from disk.\n")
}
## Starting parallelization in mode=socket with cpus=4.
## Loading packages on slaves for mode socket: llama,mlr
## Mapping in parallel: mode = socket; cpus = 4; elements = 10.
## Training started at 09:30:37 and ended at 09:42:36
## Time difference of 11.98333 mins
if (!file.exists("model_reinforce.rds")) {
start_time <- Sys.time()
system.time(model_reinforce <- REINFORCE_AS(dataset_hard, EPOCHS = 30, NUM_BATCHES = 64, DROPOUT_PROB = 0.5, TB_ROOTFOLDER = paste(getwd(), "test_asresults"), OUTFILE = "test_reinforce.txt"))
end_time <- Sys.time()
saveRDS(model_reinforce, "model_reinforce.rds")
cat("Training started at", format(start_time, "%X"), "and ended at", format(end_time, "%X"), "\n")
end_time - start_time
} else {
model_reinforce <- readRDS("model_reinforce.rds")
cat("Loaded model_reinforce from disk.\n")
}
## Training started at 09:44:08 and ended at 09:51:58
## Time difference of 7.824702 mins
resvbs = data.frame(model = "Virtual best solver",
mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, vbs)),
solved = sum(successes(dataset_hard, vbs)),
mean.performance = mean(parscores(dataset_hard, vbs, factor = 1)),
median.performance = median(parscores(dataset_hard, vbs, factor = 1)))
ressb = data.frame(model = "Single best solver",
mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, singleBest)),
solved = sum(successes(dataset_hard, singleBest)),
mean.performance = mean(parscores(dataset_hard, singleBest, factor = 1)),
median.performance = median(parscores(dataset_hard, singleBest, factor = 1)))
resrp = data.frame(model = "Pairwise random forest regression",
mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, model_regr_hard)),
solved = sum(successes(dataset_hard, model_regr_hard)),
mean.performance = mean(parscores(dataset_hard, model_regr_hard, factor = 1)),
median.performance = median(parscores(dataset_hard, model_regr_hard, factor = 1)))
resrl = data.frame(model = "REINFORCE",
mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, model_reinforce)),
solved = sum(successes(dataset_hard, model_reinforce)),
mean.performance = mean(parscores(dataset_hard, model_reinforce, factor = 1)),
median.performance = median(parscores(dataset_hard, model_reinforce, factor = 1)))
rbind(resvbs, resrp, ressb, resrl)
The plot below shows the number of problems solved by each model for a given time duration.
runtimes <- data.frame(PRFR = parscores(dataset_hard, model_regr_hard, factor = 1),
REINFORCE = parscores(dataset_hard, model_reinforce, factor = 1),
VBS = parscores(dataset_hard, vbs, factor = 1),
SBS = parscores(dataset_hard, singleBest, factor = 1))
runtimes_long <- gather(runtimes, model, time, PRFR:SBS)
cdfplot = ggplot(runtimes_long, aes(x = time, col = model)) +
stat_ecdf() +
scale_linetype_manual(values=c(3,1), guide = FALSE) +
scale_x_log10(breaks = trans_breaks("log10", function(x) 10^x, n = 10),
labels = trans_format("log10", math_format(10^.x)),
limits = c(1, (1e8)-1)) +
coord_cartesian(xlim = c(1, (1e8)-1),
ylim = c(0,1)) +
ylab("fraction of instances solved") + xlab("runtime [ms]") +
annotation_logticks(sides = "b") +
theme_tufte(base_family='Times', base_size = 14) +
guides(col = guide_legend(ncol = 2, keyheight = .8)) +
theme(legend.justification=c(1,0), legend.position=c(1,0.7), aspect.ratio = 0.6, axis.line = element_line(colour="black"), panel.grid = element_line(), panel.grid.major = element_line(colour="lightgray"))
cdfplot_zoom = ggplot(runtimes_long, aes(x = time, col = model)) +
stat_ecdf() +
scale_linetype_manual(values=c(3,1), guide = FALSE) +
scale_x_log10(breaks = trans_breaks("log10", function(x) 10^x, n = 3),
labels = trans_format("log10", math_format(10^.x)),
limits = c(1, (1e8)-1)) +
coord_cartesian(xlim = c(1e7, (1e8)-1),
ylim = c(.963,1)) +
annotation_logticks(sides = "b") +
theme_tufte(base_family='Times', base_size = 14) +
theme(legend.position="none",
axis.title.x=element_blank(), axis.title.y=element_blank(),
panel.background = element_rect(fill='white', colour = "white"),
axis.line = element_line(colour="black"),
panel.grid = element_line(),
panel.grid.major = element_line(colour="lightgray"))
vp = viewport(width = 0.57, height = 0.41, x = 0.71, y = 0.41)
print(cdfplot)
## Warning: Transformation introduced infinite values in continuous x-axis
## Warning: Removed 808 rows containing non-finite values (stat_ecdf).
print(cdfplot_zoom, vp = vp)
## Warning: Transformation introduced infinite values in continuous x-axis
## Warning: Removed 808 rows containing non-finite values (stat_ecdf).
summary(runtimes)
## PRFR REINFORCE VBS
## Min. : 0 Min. : 0 Min. : 0
## 1st Qu.: 139 1st Qu.: 77 1st Qu.: 9
## Median : 1720 Median : 904 Median : 79
## Mean : 6534990 Mean : 8187933 Mean : 5822809
## 3rd Qu.: 14235 3rd Qu.: 11940 3rd Qu.: 814
## Max. :100000000 Max. :100000000 Max. :100000000
## SBS
## Min. : 0
## 1st Qu.: 44
## Median : 448
## Mean : 7781811
## 3rd Qu.: 4313
## Max. :100000000
runtimes_logscaled <- runtimes_long %>%
mutate(time = replace(time, time == 0, 1)) %>%
mutate(time = log10(time))
cdfplot_log <- ggplot(runtimes_logscaled, aes(x = time, col = model)) +
stat_ecdf() +
ylab("fraction of instances solved") + xlab("log(runtime)") +
theme_tufte(base_family='Times', base_size = 14) +
guides(col = guide_legend(ncol = 2, keyheight = .8)) +
theme(legend.justification=c(1,0), legend.position=c(1,0.6), aspect.ratio = 0.6, axis.line = element_line(colour="black"), panel.grid = element_line(), panel.grid.major = element_line(colour="lightgray"))
cdfplot_log_zoom <- ggplot(runtimes_logscaled, aes(x = time, col = model)) +
stat_ecdf() +
coord_cartesian(xlim = c(4, 8), ylim = c(.75,1)) +
theme_tufte(base_family='Times', base_size = 14) +
theme(legend.position="none",
axis.title.x=element_blank(), axis.title.y=element_blank(),
panel.background = element_rect(fill='white', colour = "white"),
axis.line = element_line(colour="black"),
panel.grid = element_line(),
panel.grid.major = element_line(colour="lightgray"))
vp = viewport(width = 0.55, height = 0.41, x = 0.72, y = 0.39)
print(cdfplot_log)
print(cdfplot_log_zoom, vp = vp)
runtimes_logscaled_wide <- runtimes_logscaled %>%
mutate(ID = rep(c(1:2336),4)) %>%
spread(model, time)
summary(runtimes_logscaled_wide[,2:5])
## PRFR REINFORCE SBS VBS
## Min. :0.000 Min. :0.000 Min. :0.000 Min. :0.0000
## 1st Qu.:2.143 1st Qu.:1.886 1st Qu.:1.643 1st Qu.:0.9542
## Median :3.236 Median :2.956 Median :2.651 Median :1.8976
## Mean :3.412 Mean :3.280 Mean :3.041 Mean :2.3876
## 3rd Qu.:4.153 3rd Qu.:4.077 3rd Qu.:3.635 3rd Qu.:2.9109
## Max. :8.000 Max. :8.000 Max. :8.000 Max. :8.0000