Description

Comparison of algorithm selection model performance between pairwise random forest regression and REINFORCE.

library(tidyr)
library(magrittr)
library(dplyr)
library(ggplot2)
library(ggthemes)
library(grid)
library(scales)
library(aslib)
library(llama)
library(randomForest)
library(parallelMap)
source('reinforce.R')
dataset_hard <- ExtractHardInstancesGRAPHS2015()

Model Training

if (!file.exists("model_regr_hard.rds")) {
  parallelStartSocket(4)
  parallelLibrary("llama", "mlr")
  start_time <- Sys.time()
  system.time(model_regr_hard <- regressionPairs(makeLearner("regr.randomForest"), dataset_hard))
  end_time <- Sys.time()
  saveRDS(model_regr_hard, "model_regr_hard.rds")
  cat("Training started at", format(start_time, "%X"), "and ended at", format(end_time, "%X"), "\n")
  end_time - start_time
} else {
  model_regr_hard <- readRDS("model_regr_hard.rds")
  cat("Loaded model_regr_hard from disk.\n")
}
## Starting parallelization in mode=socket with cpus=4.
## Loading packages on slaves for mode socket: llama,mlr
## Mapping in parallel: mode = socket; cpus = 4; elements = 10.
## Training started at 09:30:37 and ended at 09:42:36
## Time difference of 11.98333 mins
if (!file.exists("model_reinforce.rds")) {
  start_time <- Sys.time()
  system.time(model_reinforce <- REINFORCE_AS(dataset_hard, EPOCHS = 30, NUM_BATCHES = 64, DROPOUT_PROB = 0.5, TB_ROOTFOLDER = paste(getwd(), "test_asresults"), OUTFILE = "test_reinforce.txt"))
  end_time <- Sys.time()
  saveRDS(model_reinforce, "model_reinforce.rds")
  cat("Training started at", format(start_time, "%X"), "and ended at", format(end_time, "%X"), "\n")
  end_time - start_time
} else {
  model_reinforce <- readRDS("model_reinforce.rds")
  cat("Loaded model_reinforce from disk.\n")
}
## Training started at 09:44:08 and ended at 09:51:58
## Time difference of 7.824702 mins

Algorithm selection results

resvbs = data.frame(model = "Virtual best solver",
    mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, vbs)),
    solved = sum(successes(dataset_hard, vbs)),
    mean.performance = mean(parscores(dataset_hard, vbs, factor = 1)),
    median.performance = median(parscores(dataset_hard, vbs, factor = 1)))
ressb = data.frame(model = "Single best solver",
    mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, singleBest)),
    solved = sum(successes(dataset_hard, singleBest)),
    mean.performance = mean(parscores(dataset_hard, singleBest, factor = 1)),
    median.performance = median(parscores(dataset_hard, singleBest, factor = 1)))
resrp = data.frame(model = "Pairwise random forest regression",
    mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, model_regr_hard)),
    solved = sum(successes(dataset_hard, model_regr_hard)),
    mean.performance = mean(parscores(dataset_hard, model_regr_hard, factor = 1)),
    median.performance = median(parscores(dataset_hard, model_regr_hard, factor = 1)))
resrl = data.frame(model = "REINFORCE",
                        mean.misclassification.penalty = mean(misclassificationPenalties(dataset_hard, model_reinforce)),
                        solved = sum(successes(dataset_hard, model_reinforce)),
                        mean.performance = mean(parscores(dataset_hard, model_reinforce, factor = 1)),
                        median.performance = median(parscores(dataset_hard, model_reinforce, factor = 1)))
rbind(resvbs, resrp, ressb, resrl)

Cumulative Density Function Plots

Using original runtime values

The plot below shows the number of problems solved by each model for a given time duration.

runtimes <-  data.frame(PRFR = parscores(dataset_hard, model_regr_hard, factor = 1), 
                        REINFORCE = parscores(dataset_hard, model_reinforce, factor = 1), 
                        VBS = parscores(dataset_hard, vbs, factor = 1), 
                        SBS = parscores(dataset_hard, singleBest, factor = 1))
runtimes_long <- gather(runtimes, model, time, PRFR:SBS)

cdfplot = ggplot(runtimes_long, aes(x = time, col = model)) +
    stat_ecdf() +
    scale_linetype_manual(values=c(3,1), guide = FALSE) +
    scale_x_log10(breaks = trans_breaks("log10", function(x) 10^x, n = 10),
                  labels = trans_format("log10", math_format(10^.x)),
                  limits = c(1, (1e8)-1)) +
    coord_cartesian(xlim = c(1, (1e8)-1),
        ylim = c(0,1)) +
    ylab("fraction of instances solved") + xlab("runtime [ms]") +
    annotation_logticks(sides = "b") +
    theme_tufte(base_family='Times', base_size = 14) +
    guides(col = guide_legend(ncol = 2, keyheight = .8)) +
    theme(legend.justification=c(1,0), legend.position=c(1,0.7), aspect.ratio = 0.6, axis.line = element_line(colour="black"), panel.grid = element_line(), panel.grid.major = element_line(colour="lightgray"))

cdfplot_zoom = ggplot(runtimes_long, aes(x = time, col = model)) +
    stat_ecdf() +
    scale_linetype_manual(values=c(3,1), guide = FALSE) +
    scale_x_log10(breaks = trans_breaks("log10", function(x) 10^x, n = 3),
                  labels = trans_format("log10", math_format(10^.x)),
                  limits = c(1, (1e8)-1)) +
    coord_cartesian(xlim = c(1e7, (1e8)-1),
        ylim = c(.963,1)) +
    annotation_logticks(sides = "b") +
    theme_tufte(base_family='Times', base_size = 14) +
    theme(legend.position="none",
        axis.title.x=element_blank(), axis.title.y=element_blank(),
        panel.background = element_rect(fill='white', colour = "white"),
        axis.line = element_line(colour="black"),
        panel.grid = element_line(),
        panel.grid.major = element_line(colour="lightgray"))

vp = viewport(width = 0.57, height = 0.41, x = 0.71, y = 0.41)
print(cdfplot)
## Warning: Transformation introduced infinite values in continuous x-axis
## Warning: Removed 808 rows containing non-finite values (stat_ecdf).
print(cdfplot_zoom, vp = vp)
## Warning: Transformation introduced infinite values in continuous x-axis

## Warning: Removed 808 rows containing non-finite values (stat_ecdf).

summary(runtimes)
##       PRFR             REINFORCE              VBS           
##  Min.   :        0   Min.   :        0   Min.   :        0  
##  1st Qu.:      139   1st Qu.:       77   1st Qu.:        9  
##  Median :     1720   Median :      904   Median :       79  
##  Mean   :  6534990   Mean   :  8187933   Mean   :  5822809  
##  3rd Qu.:    14235   3rd Qu.:    11940   3rd Qu.:      814  
##  Max.   :100000000   Max.   :100000000   Max.   :100000000  
##       SBS           
##  Min.   :        0  
##  1st Qu.:       44  
##  Median :      448  
##  Mean   :  7781811  
##  3rd Qu.:     4313  
##  Max.   :100000000

Using log-scaled runtime values

runtimes_logscaled <- runtimes_long %>% 
                      mutate(time = replace(time, time == 0, 1)) %>% 
                      mutate(time = log10(time))

cdfplot_log <- ggplot(runtimes_logscaled, aes(x = time, col = model)) +
        stat_ecdf() +
        ylab("fraction of instances solved") + xlab("log(runtime)") +
        theme_tufte(base_family='Times', base_size = 14) +
        guides(col = guide_legend(ncol = 2, keyheight = .8)) +
        theme(legend.justification=c(1,0), legend.position=c(1,0.6), aspect.ratio = 0.6, axis.line = element_line(colour="black"), panel.grid = element_line(), panel.grid.major = element_line(colour="lightgray"))

cdfplot_log_zoom <- ggplot(runtimes_logscaled, aes(x = time, col = model)) +
    stat_ecdf() +
    coord_cartesian(xlim = c(4, 8), ylim = c(.75,1)) +
    theme_tufte(base_family='Times', base_size = 14) +
    theme(legend.position="none",
        axis.title.x=element_blank(), axis.title.y=element_blank(),
        panel.background = element_rect(fill='white', colour = "white"),
        axis.line = element_line(colour="black"),
        panel.grid = element_line(),
        panel.grid.major = element_line(colour="lightgray"))

vp = viewport(width = 0.55, height = 0.41, x = 0.72, y = 0.39)
print(cdfplot_log)
print(cdfplot_log_zoom, vp = vp)

runtimes_logscaled_wide <- runtimes_logscaled %>% 
                           mutate(ID = rep(c(1:2336),4)) %>% 
                           spread(model, time)
summary(runtimes_logscaled_wide[,2:5])
##       PRFR         REINFORCE          SBS             VBS        
##  Min.   :0.000   Min.   :0.000   Min.   :0.000   Min.   :0.0000  
##  1st Qu.:2.143   1st Qu.:1.886   1st Qu.:1.643   1st Qu.:0.9542  
##  Median :3.236   Median :2.956   Median :2.651   Median :1.8976  
##  Mean   :3.412   Mean   :3.280   Mean   :3.041   Mean   :2.3876  
##  3rd Qu.:4.153   3rd Qu.:4.077   3rd Qu.:3.635   3rd Qu.:2.9109  
##  Max.   :8.000   Max.   :8.000   Max.   :8.000   Max.   :8.0000

Source code

Source Rmd file