flashlight

library(flashlight)      # model interpretation
library(MetricsWeighted) # Metrics
library(dplyr)           # data prep
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(moderndive)      # data
library(caret)           # data split
## Loading required package: lattice
## Loading required package: ggplot2
## 
## Attaching package: 'caret'
## The following objects are masked from 'package:MetricsWeighted':
## 
##     precision, recall
library(xgboost)         # gradient boosting
## 
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
## 
##     slice
library(ranger)          # random forest

head(house_prices)
## # A tibble: 6 x 21
##   id    date        price bedrooms bathrooms sqft_living sqft_lot floors
##   <chr> <date>      <dbl>    <int>     <dbl>       <int>    <int>  <dbl>
## 1 7129~ 2014-10-13 2.22e5        3      1           1180     5650      1
## 2 6414~ 2014-12-09 5.38e5        3      2.25        2570     7242      2
## 3 5631~ 2015-02-25 1.80e5        2      1            770    10000      1
## 4 2487~ 2014-12-09 6.04e5        4      3           1960     5000      1
## 5 1954~ 2015-02-18 5.10e5        3      2           1680     8080      1
## 6 7237~ 2014-05-12 1.23e6        4      4.5         5420   101930      1
## # ... with 13 more variables: waterfront <lgl>, view <int>,
## #   condition <fct>, grade <fct>, sqft_above <int>, sqft_basement <int>,
## #   yr_built <int>, yr_renovated <int>, zipcode <fct>, lat <dbl>,
## #   long <dbl>, sqft_living15 <int>, sqft_lot15 <int>
dim(house_prices)
## [1] 21613    21
summary(house_prices)
##       id                 date                price        
##  Length:21613       Min.   :2014-05-02   Min.   :  75000  
##  Class :character   1st Qu.:2014-07-22   1st Qu.: 321950  
##  Mode  :character   Median :2014-10-16   Median : 450000  
##                     Mean   :2014-10-29   Mean   : 540088  
##                     3rd Qu.:2015-02-17   3rd Qu.: 645000  
##                     Max.   :2015-05-27   Max.   :7700000  
##                                                           
##     bedrooms        bathrooms      sqft_living       sqft_lot      
##  Min.   : 0.000   Min.   :0.000   Min.   :  290   Min.   :    520  
##  1st Qu.: 3.000   1st Qu.:1.750   1st Qu.: 1427   1st Qu.:   5040  
##  Median : 3.000   Median :2.250   Median : 1910   Median :   7618  
##  Mean   : 3.371   Mean   :2.115   Mean   : 2080   Mean   :  15107  
##  3rd Qu.: 4.000   3rd Qu.:2.500   3rd Qu.: 2550   3rd Qu.:  10688  
##  Max.   :33.000   Max.   :8.000   Max.   :13540   Max.   :1651359  
##                                                                    
##      floors      waterfront           view        condition     grade     
##  Min.   :1.000   Mode :logical   Min.   :0.0000   1:   30   7      :8981  
##  1st Qu.:1.000   FALSE:21450     1st Qu.:0.0000   2:  172   8      :6068  
##  Median :1.500   TRUE :163       Median :0.0000   3:14031   9      :2615  
##  Mean   :1.494                   Mean   :0.2343   4: 5679   6      :2038  
##  3rd Qu.:2.000                   3rd Qu.:0.0000   5: 1701   10     :1134  
##  Max.   :3.500                   Max.   :4.0000             11     : 399  
##                                                             (Other): 378  
##    sqft_above   sqft_basement       yr_built     yr_renovated   
##  Min.   : 290   Min.   :   0.0   Min.   :1900   Min.   :   0.0  
##  1st Qu.:1190   1st Qu.:   0.0   1st Qu.:1951   1st Qu.:   0.0  
##  Median :1560   Median :   0.0   Median :1975   Median :   0.0  
##  Mean   :1788   Mean   : 291.5   Mean   :1971   Mean   :  84.4  
##  3rd Qu.:2210   3rd Qu.: 560.0   3rd Qu.:1997   3rd Qu.:   0.0  
##  Max.   :9410   Max.   :4820.0   Max.   :2015   Max.   :2015.0  
##                                                                 
##     zipcode           lat             long        sqft_living15 
##  98103  :  602   Min.   :47.16   Min.   :-122.5   Min.   : 399  
##  98038  :  590   1st Qu.:47.47   1st Qu.:-122.3   1st Qu.:1490  
##  98115  :  583   Median :47.57   Median :-122.2   Median :1840  
##  98052  :  574   Mean   :47.56   Mean   :-122.2   Mean   :1987  
##  98117  :  553   3rd Qu.:47.68   3rd Qu.:-122.1   3rd Qu.:2360  
##  98042  :  548   Max.   :47.78   Max.   :-121.3   Max.   :6210  
##  (Other):18163                                                  
##    sqft_lot15    
##  Min.   :   651  
##  1st Qu.:  5100  
##  Median :  7620  
##  Mean   : 12768  
##  3rd Qu.: 10083  
##  Max.   :871200  
## 
prep <- transform(house_prices, 
                  log_price = log(price),
                  grade = as.integer(as.character(grade)),
                  year = factor(lubridate::year(date)),
                  age = lubridate::year(date) - yr_built,
                  zipcode = as.numeric(as.character(zipcode)),
                  waterfront = factor(waterfront, levels = c(FALSE, TRUE), labels = c("no", "yes")))

x <- c("grade", "year", "age", "sqft_living", "sqft_lot", "zipcode", 
       "condition", "waterfront")

head(x)
## [1] "grade"       "year"        "age"         "sqft_living" "sqft_lot"   
## [6] "zipcode"
# Data wrapper for the linear model
prep_lm <- function(data) {
  data %>% 
    mutate(sqrt_living = log(sqft_living),
           sqrt_lot = log(sqft_lot),
           zipcode = factor(zipcode %/% 10))
}

# Data wrapper for xgboost
prep_xgb <- function(data, x) {
  data %>% 
    select_at(x) %>% 
    mutate_if(Negate(is.numeric), as.integer) %>% 
    data.matrix()
}

# Train / valid / test split (70% / 20% / 10%)
set.seed(56745)
ind <- caret::createFolds(prep[["log_price"]], k = 10, list = FALSE)

train <- prep[ind >= 4, ]
valid <- prep[ind %in% 2:3, ]
test <- prep[ind == 1, ]

(form <- reformulate(x, "log_price"))
## log_price ~ grade + year + age + sqft_living + sqft_lot + zipcode + 
##     condition + waterfront
#> log_price ~ grade + year + age + sqft_living + sqft_lot + zipcode + 
#>     condition + waterfront
fit_lm <- lm(update.formula(form, . ~ . + I(sqft_living^2)), data = prep_lm(train))

# Random forest
fit_rf <- ranger(form, data = train, seed = 8373)
cat("R-squared OOB:", fit_rf$r.squared)
## R-squared OOB: 0.7836735
#> R-squared OOB: 0.7842605

# Gradient boosting
dtrain <- xgb.DMatrix(prep_xgb(train, x), label = train[["log_price"]])
dvalid <- xgb.DMatrix(prep_xgb(valid, x), label = valid[["log_price"]])


params <- list(learning_rate = 0.5,
               max_depth = 6,
               alpha = 1,
               lambda = 1,
               colsample_bytree = 0.8)

fit_xgb <- xgb.train(params, 
                     data = dtrain,
                     watchlist = list(train = dtrain, valid = dvalid),
                     nrounds = 200, 
                     print_every_n = 100,
                     objective = "reg:linear",
                     seed = 2698)
## [1]  train-rmse:6.290392 valid-rmse:6.287863 
## [101]    train-rmse:0.136828 valid-rmse:0.183153 
## [200]    train-rmse:0.115749 valid-rmse:0.185014
fl_mean <- flashlight(model = mean(train$log_price), label = "mean", 
                      predict_function = function(mod, X) rep(mod, nrow(X)))
fl_lm <- flashlight(model = fit_lm, label = "lm", 
                    predict_function = function(mod, X) predict(mod, prep_lm(X)))
fl_rf <- flashlight(model = fit_rf, label = "rf",
                    predict_function = function(mod, X) predict(mod, X)$predictions)
fl_xgb <- flashlight(model = fit_xgb, label = "xgb",
                     predict_function = function(mod, X) predict(mod, prep_xgb(X, x)))
print(fl_xgb)
## 
## Flashlight xgb 
## 
## Model:            Yes
## y:            No
## w:            No
## by:           No
## data dim:         No
## predict_fct default:  FALSE
## linkinv default:  TRUE
## metrics:      rmse
fls <- multiflashlight(list(fl_mean, fl_lm, fl_rf, fl_xgb), y = "log_price", linkinv = exp, 
                       data = valid, metrics = list(rmse = rmse, `R-squared` = r_squared))

fl_lm <- fls$lm

perf <- light_performance(fls)
perf
## 
## I am an object with class(es) light_performance_multi, light_performance, light, list 
## 
## Tibbles:
## 
##  data 
## # A tibble: 8 x 3
##   metric          value label
##   <fct>           <dbl> <fct>
## 1 rmse       0.527      mean 
## 2 R-squared -0.00000252 mean 
## # ... with 6 more rows
plot(perf)

(imp <- light_importance(fls, n_max = 1000))
## 
## I am an object with class(es) light_importance_multi, light_importance, light, list 
## 
## Tibbles:
## 
##  data 
## # A tibble: 92 x 6
##   variable metric value_shuffled label value_original value
##   <chr>    <fct>           <dbl> <fct>          <dbl> <dbl>
## 1 id       rmse            0.535 mean           0.535     0
## 2 date     rmse            0.535 mean           0.535     0
## # ... with 90 more rows
(imp <- light_importance(fls, v = x, metric = list(mse = mse)))
## 
## I am an object with class(es) light_importance_multi, light_importance, light, list 
## 
## Tibbles:
## 
##  data 
## # A tibble: 32 x 6
##   variable metric value_shuffled label value_original value
##   <chr>    <fct>           <dbl> <fct>          <dbl> <dbl>
## 1 grade    mse             0.278 mean           0.278     0
## 2 year     mse             0.278 mean           0.278     0
## # ... with 30 more rows
most_important(imp, top_m = 3)
## [1] "grade"       "sqft_living" "zipcode"
imp_r2 <- light_importance(fls, metric = list(r_squared = r_squared), 
                           v = x, lower_is_better = FALSE)
plot(imp_r2, fill = "darkred") +
  ggtitle("Drop in R-squared")

cp <- light_ice(fls, v = "sqft_living", n_max = 30, seed = 35)
plot(cp, alpha = 0.2)

pd <- light_profile(fls, v = "sqft_living")
pd
## 
## I am an object with class(es) light_profile_multi, light_profile, light, list 
## 
## Tibbles:
## 
##  data 
## # A tibble: 40 x 5
##   sqft_living counts   value label type              
##         <dbl>  <int>   <dbl> <fct> <fct>             
## 1         500   1000 463974. mean  partial dependence
## 2        1500   1000 463974. mean  partial dependence
## # ... with 38 more rows
plot(pd)

format_y <- function(x) format(x, big.mark = "'", scientific = FALSE)

pvp <- light_profile(fls, v = "sqft_living", type = "predicted", format = "fg", big.mark = "'")
plot(pvp) +
  scale_y_continuous(labels = format_y)

rvp <- light_profile(fl_lm, v = "sqft_living", type = "response", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

rvp <- light_profile(fl_lm, v = "sqft_living", type = "response", 
                     stats = "quartiles", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

rvp <- light_profile(fl_lm, v = "sqft_living", type = "response", 
                     stats = "quartiles", format = "fg") 
plot(rvp) +
  scale_y_continuous(labels = format_y)

2019-08-27