This script illustrates a shootout (of predictive performance) between local linear regression forests and normal regression forests using the grf package’s implementation. The LLF has significant advantages for DGPs with strong smoothed effects. I haven’t left many comments, for more detail see here: https://grf-labs.github.io/grf/articles/llf.html
library(ggplot2)
library(ggthemes)
library(ggpubr)
library(mlbench)
library(ggdist)
library(caret)
library(ggalt)
library(dplyr)
library(doParallel)
library(parallel)
library(grf)
set.seed(1995)data <- mlbench::mlbench.friedman1(5000)
#
data <- data.frame(y = data$y, x = data$x)
#
trainIndex <- createDataPartition(data$y,
p=0.7,
list = FALSE)
#
train <- data[trainIndex,]
test <- data[-trainIndex,]
#
Y <- train$y
Y.test <- test$y
X <- train[,-1]
X.test <- test[,-1]
#cl <- makeCluster(detectCores() - 1, setup_timeout = 0.5)
registerDoParallel(cl)
#
c.forest.tune <- regression_forest(X = X,
Y = Y,
num.trees = 10000,
tune.parameters = "all",
tune.num.trees = 100,
tune.num.reps = 200,
tune.num.draws = 2000,
seed = 1995)
#
preds.out <- predict(c.forest.tune, X.test)
#
plot(x=preds.out$predictions, y=Y.test)#c.forest.ll <- ll_regression_forest(X = as.matrix(X),
Y = Y,
enable.ll.split = TRUE,
num.trees = 10000,
tune.num.trees = 100,
tune.num.reps = 200,
tune.num.draws = 2000,
seed = 1995)
#
preds <- predict(c.forest.ll,X.test)
plot(x = preds$predictions, y = Y.test)#
stopCluster(cl)
registerDoSEQ()
#plot.df <- data.frame(Y.real = Y.test,
preds.forest = preds.out$predictions,
preds.ll = preds$predictions)
#
cor(plot.df$Y.real,plot.df$preds.forest) ## [1] 0.9300032
cor(plot.df$Y.real,plot.df$preds.ll) ## [1] 0.9552713
# #
g.for <- ggplot(aes(x = preds.forest,
y = Y.real),
data = plot.df) +
geom_point(alpha = .4,color = "blue") +
scale_x_continuous(limits = c(0,30), expand = c(0, 0)) +
scale_y_continuous(limits = c(0,30), expand = c(0, 0)) +
geom_abline(size = .75, intercept = 0,slope = 1) +
ylab("Y (REAL)") +
xlab("Y (PREDICTED)") +
ggtitle("Regression Forest") +
theme_few()
#
g.ll <- ggplot(aes(x = preds.ll,
y = Y.real),
data = plot.df) +
geom_point(alpha = .4,color = "red") +
geom_abline(size = .75, intercept = 0,slope = 1) +
scale_x_continuous(limits = c(0,30), expand = c(0, 0)) +
scale_y_continuous(limits = c(0,30), expand = c(0, 0)) +
ylab("Y (REAL)") +
xlab("Y (PREDICTED)") +
ggtitle("Local Linear Forest") +
theme_few()
#
ggarrange(g.for,g.ll)Notice the LL forest is tighter and closer to the line of perfect correlation than the simple regression forest!
Another way to see the improvement from the LLF is with simple overlapping density curves.
#
overlap_data <- cbind(x=plot.df$Y.real,y=plot.df$preds.ll)
colnames(overlap_data) <- c("REAL","LL Forest")
overlap_data <- reshape2::melt(overlap_data)
#
hist_plot <- ggplot(data = overlap_data,
aes(x=value, fill=Var2)) +
geom_density(alpha = .25) +
theme_few() +
xlim(0,30) +
ylim(0,.1) +
xlab("Y") +
ylab("Density") +
ggtitle("") +
theme(legend.title=element_blank())
#
overlap_data <- cbind(x=plot.df$Y.real,y=plot.df$preds.forest)
colnames(overlap_data) <- c("REAL","Reg Forest")
overlap_data <- reshape2::melt(overlap_data)
#
hist_plot_reg <- ggplot(data = overlap_data,
aes(x=value, fill=Var2)) +
geom_density(alpha = .25) +
theme_few() +
xlim(0,30) +
ylim(0,.1) +
xlab("Y") +
ylab("Density") +
ggtitle("") +
theme(legend.title=element_blank())
#
ggarrange(hist_plot,hist_plot_reg)