The fingerprints matrix contains 1,107 predictor before the nearZeroVar function is applied. After applying nearZeroVar predictors have been reduced to 388.
## [1] 165 1107
## [1] 165 388
perm_rec<- recipe(permeability ~ ., data=perm_train) %>%
step_center(all_numeric(),-all_outcomes()) %>%
step_scale(all_numeric(),-all_outcomes())
perm_prep <- perm_rec %>%
prep()
perm_prep## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 388
##
## Training data contained 124 data points and no missing data.
##
## Operations:
##
## Centering for X1, X2, X3, X4, X5, X6, X11, X12, X15, X16, ... [trained]
## Scaling for X1, X2, X3, X4, X5, X6, X11, X12, X15, X16, ... [trained]
pls_spec <- pls(num_comp = 4) %>%
set_mode("regression") %>%
set_engine("mixOmics")
wf <- workflow() %>%
add_recipe(perm_prep)
pls_fit <- wf %>%
add_model(pls_spec) %>%
fit(data=perm_train)
pls_fit %>%
pull_workflow_fit() %>%
tidy()## # A tibble: 1,556 x 4
## term value type component
## <chr> <dbl> <chr> <dbl>
## 1 X1 -0.0399 predictors 1
## 2 X1 0.0628 predictors 2
## 3 X1 -0.0360 predictors 3
## 4 X1 -0.0223 predictors 4
## 5 X2 -0.0467 predictors 1
## 6 X2 0.0496 predictors 2
## 7 X2 -0.0454 predictors 3
## 8 X2 -0.0315 predictors 4
## 9 X3 -0.0407 predictors 1
## 10 X3 0.0547 predictors 2
## # ... with 1,546 more rows
set.seed(7)
perm_folds <- vfold_cv(perm_train)
pls_tune_spec <- pls(num_comp = tune()) %>%
set_mode("regression") %>%
set_engine("mixOmics")
comp_grid <- expand.grid(num_comp = seq(from = 1, to = 20, by = 1))
set.seed(7)
pls_grid <- tune_grid(
wf %>% add_model(pls_tune_spec),
grid = comp_grid,
resamples = perm_folds,
control = control_resamples(save_pred = TRUE)
)
pls_grid %>%
collect_metrics()## # A tibble: 40 x 7
## num_comp .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 rmse standard 12.8 10 0.939 Model01
## 2 1 rsq standard 0.317 10 0.0849 Model01
## 3 2 rmse standard 12.1 10 0.988 Model02
## 4 2 rsq standard 0.381 10 0.0865 Model02
## 5 3 rmse standard 12.2 10 0.806 Model03
## 6 3 rsq standard 0.361 10 0.0725 Model03
## 7 4 rmse standard 12.1 10 0.776 Model04
## 8 4 rsq standard 0.350 10 0.0665 Model04
## 9 5 rmse standard 11.7 10 0.762 Model05
## 10 5 rsq standard 0.412 10 0.0630 Model05
## # ... with 30 more rows
pls_grid %>%
unnest(.predictions) %>%
ggplot(aes(permeability, .pred, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_point(alpha = 0.5) +
geom_jitter() +
labs(
x = "Truth",
y = "Predicted Permeability",
color = NULL,
title = "Partial Linear Regression Model"
)
| Top RSQ and RMSE by Model | |||
|---|---|---|---|
| PLS Regression Analysis | |||
| Model | Components | RMSE | RSQ |
| Model08 | 8 | 11.51699 | 0.4554150 |
| Model09 | 9 | 11.34313 | 0.4538902 |
| Model10 | 10 | 11.57480 | 0.4441147 |
| Model07 | 7 | 11.46595 | 0.4437339 |
| Model11 | 11 | 11.62149 | 0.4417096 |
| Model12 | 12 | 11.84449 | 0.4384039 |
| Model20 | 20 | 12.80790 | 0.4327574 |
| Model13 | 13 | 12.12847 | 0.4294358 |
| Model19 | 19 | 12.89176 | 0.4269516 |
| Model06 | 6 | 11.66860 | 0.4227555 |
| Model18 | 18 | 12.92894 | 0.4209539 |
| Model14 | 14 | 12.35795 | 0.4193582 |
| Model15 | 15 | 12.66515 | 0.4166174 |
| Model16 | 16 | 12.85534 | 0.4133839 |
| Model17 | 17 | 13.02397 | 0.4129026 |
| Model05 | 5 | 11.65301 | 0.4122434 |
| Model02 | 2 | 12.05171 | 0.3813176 |
| Model03 | 3 | 12.18132 | 0.3609442 |
| Model04 | 4 | 12.06318 | 0.3500150 |
| Model01 | 1 | 12.77339 | 0.3168359 |
| Table produced with gt package - the grammar of tables | |||
pls_grid %>%
collect_metrics() %>%
ggplot(aes(num_comp, mean, color=.metric)) +
geom_line(size=1.5) +
facet_wrap(~.metric, scales ="free", nrow=2) +
theme_fivethirtyeight() + labs(title='PLS Model', subtitle = 'Tuned with 20 Components') +
theme(legend.position = "none")## # A tibble: 1 x 2
## num_comp .config
## <dbl> <chr>
## 1 8 Model08
final_pls <- finalize_workflow(wf %>% add_model(pls_tune_spec),
best_rsq_p)
last_fit(final_pls,
perm_split) %>%
collect_metrics() ## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 11.7
## 2 rsq standard 0.575
lasso_spec <- linear_reg(penalty = 0.1, mixture = 1) %>%
set_engine("glmnet")
wf <- workflow() %>%
add_recipe(perm_prep)
lasso_fit <- wf %>%
add_model(lasso_spec) %>%
fit(data=perm_train)
lasso_fit %>%
pull_workflow_fit() %>%
tidy()## # A tibble: 389 x 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 10.9 0.1
## 2 X1 0 0.1
## 3 X2 0 0.1
## 4 X3 0 0.1
## 5 X4 0 0.1
## 6 X5 0 0.1
## 7 X6 3.75 0.1
## 8 X11 0 0.1
## 9 X12 1.12 0.1
## 10 X15 -0.668 0.1
## # ... with 379 more rows
set.seed(4763)
perm_folds <- vfold_cv(perm_train)
tune_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
lambda_grid <- grid_regular(penalty(), levels = 50)
set.seed(4763)
lasso_grid <- tune_grid(
wf %>% add_model(tune_spec),
resamples = perm_folds,
grid = lambda_grid,
control = control_resamples(save_pred = TRUE)
)
lasso_grid %>%
collect_metrics()## # A tibble: 100 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1.00e-10 rmse standard 12.9 10 0.823 Model01
## 2 1.00e-10 rsq standard 0.361 10 0.0737 Model01
## 3 1.60e-10 rmse standard 12.9 10 0.823 Model02
## 4 1.60e-10 rsq standard 0.361 10 0.0737 Model02
## 5 2.56e-10 rmse standard 12.9 10 0.823 Model03
## 6 2.56e-10 rsq standard 0.361 10 0.0737 Model03
## 7 4.09e-10 rmse standard 12.9 10 0.823 Model04
## 8 4.09e-10 rsq standard 0.361 10 0.0737 Model04
## 9 6.55e-10 rmse standard 12.9 10 0.823 Model05
## 10 6.55e-10 rsq standard 0.361 10 0.0737 Model05
## # ... with 90 more rows
lasso_grid %>%
unnest(.predictions) %>%
ggplot(aes(permeability, .pred, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_point(alpha = 0.5) +
labs(
x = "Truth",
y = "Predicted Permeability",
color = NULL,
title = "Lasso Model"
)options(scipen=999)
lasso_grid %>%
collect_metrics() %>%
ggplot(aes(penalty, mean, color=.metric)) +
geom_line(size=1.5) +
facet_wrap(~.metric, scales ="free", nrow=2) +
theme_fivethirtyeight() + labs(title='Lasso Model', subtitle = 'Tuned with 50 Levels') +
theme(legend.position = "none")## # A tibble: 1 x 2
## penalty .config
## <dbl> <chr>
## 1 0.244 Model47
final_lasso <- finalize_workflow(wf %>% add_model(tune_spec),
best_rsq_l)
last_fit(final_lasso,
perm_split) %>%
collect_metrics()## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 11.4
## 2 rsq standard 0.604
## # A tibble: 100 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1.00e-10 rmse standard 12.9 10 0.823 Model01
## 2 1.00e-10 rsq standard 0.361 10 0.0737 Model01
## 3 1.60e-10 rmse standard 12.9 10 0.823 Model02
## 4 1.60e-10 rsq standard 0.361 10 0.0737 Model02
## 5 2.56e-10 rmse standard 12.9 10 0.823 Model03
## 6 2.56e-10 rsq standard 0.361 10 0.0737 Model03
## 7 4.09e-10 rmse standard 12.9 10 0.823 Model04
## 8 4.09e-10 rsq standard 0.361 10 0.0737 Model04
## 9 6.55e-10 rmse standard 12.9 10 0.823 Model05
## 10 6.55e-10 rsq standard 0.361 10 0.0737 Model05
## # ... with 90 more rows
enet_spec <- linear_reg(penalty = 0.1, mixture = 0.6) %>%
set_engine("glmnet")
wf <- workflow() %>%
add_recipe(perm_prep)
enet_fit <- wf %>%
add_model(enet_spec) %>%
fit(data=perm_train)
enet_fit %>%
pull_workflow_fit() %>%
tidy()## # A tibble: 389 x 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 10.9 0.1
## 2 X1 0 0.1
## 3 X2 0 0.1
## 4 X3 0 0.1
## 5 X4 0 0.1
## 6 X5 0 0.1
## 7 X6 3.73 0.1
## 8 X11 0 0.1
## 9 X12 1.27 0.1
## 10 X15 -0.914 0.1
## # ... with 379 more rows
set.seed(4763)
perm_folds <- vfold_cv(perm_train)
tune_spec <- linear_reg(penalty = tune(), mixture = 0.6) %>%
set_engine("glmnet")
enet_grid <- grid_regular(penalty(), levels = 50)
set.seed(4763)
enet_grid <- tune_grid(
wf %>% add_model(tune_spec),
resamples = perm_folds,
grid = enet_grid,
control = control_resamples(save_pred = TRUE)
)
enet_grid %>%
collect_metrics()## # A tibble: 100 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1.00e-10 rmse standard 12.4 10 0.688 Model01
## 2 1.00e-10 rsq standard 0.389 10 0.0763 Model01
## 3 1.60e-10 rmse standard 12.4 10 0.688 Model02
## 4 1.60e-10 rsq standard 0.389 10 0.0763 Model02
## 5 2.56e-10 rmse standard 12.4 10 0.688 Model03
## 6 2.56e-10 rsq standard 0.389 10 0.0763 Model03
## 7 4.09e-10 rmse standard 12.4 10 0.688 Model04
## 8 4.09e-10 rsq standard 0.389 10 0.0763 Model04
## 9 6.55e-10 rmse standard 12.4 10 0.688 Model05
## 10 6.55e-10 rsq standard 0.389 10 0.0763 Model05
## # ... with 90 more rows
enet_grid %>%
unnest(.predictions) %>%
ggplot(aes(permeability, .pred, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_point(alpha = 0.5) +
labs(
x = "Truth",
y = "Predicted Permeability",
color = NULL,
title = "Elastic Net Model"
)options(scipen=999)
enet_grid %>%
collect_metrics() %>%
ggplot(aes(penalty, mean, color=.metric)) +
geom_line(size=1.5) +
facet_wrap(~.metric, scales ="free", nrow=2) +
theme_fivethirtyeight() + labs(title='Elastic Net Model', subtitle = 'Tuned with 50 Levels') +
theme(legend.position = "none")## # A tibble: 1 x 2
## penalty .config
## <dbl> <chr>
## 1 0.391 Model48
final_enet <- finalize_workflow(wf %>% add_model(tune_spec),
best_rsq_e)
last_fit(final_enet,
perm_split) %>%
collect_metrics()## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 11.3
## 2 rsq standard 0.609
## # A tibble: 100 x 7
## penalty .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1.00e-10 rmse standard 12.4 10 0.688 Model01
## 2 1.00e-10 rsq standard 0.389 10 0.0763 Model01
## 3 1.60e-10 rmse standard 12.4 10 0.688 Model02
## 4 1.60e-10 rsq standard 0.389 10 0.0763 Model02
## 5 2.56e-10 rmse standard 12.4 10 0.688 Model03
## 6 2.56e-10 rsq standard 0.389 10 0.0763 Model03
## 7 4.09e-10 rmse standard 12.4 10 0.688 Model04
## 8 4.09e-10 rsq standard 0.389 10 0.0763 Model04
## 9 6.55e-10 rmse standard 12.4 10 0.688 Model05
## 10 6.55e-10 rsq standard 0.389 10 0.0763 Model05
## # ... with 90 more rows