library(tidymodels)
library(tidyverse)
df2=read_csv("SBAcase.11.13.17.csv")
df2$Default=as.factor(df2$Default)
set.seed(0)
sba_split <- initial_split(df2,strata = Default, prop = 0.8)
sba_training <- sba_split %>% training()
sba_test <- sba_split %>% testing()
sba_recipe <- recipe(Default ~ SBA_Appv + New+NoEmp+Recession+RetainedJob+UrbanRural+CreateJob+Term, data = sba_training)
dt_tune_model <- decision_tree(cost_complexity = tune(),tree_depth = tune(),min_n = tune())%>% set_engine('rpart')%>% set_mode('classification')
sba_tune_wkfl <- workflow()%>% add_model(dt_tune_model)%>% add_recipe(sba_recipe)
sba_metrics <- metric_set(accuracy)
set.seed(0)
sba_folds <- vfold_cv(sba_training,v =10, strata = Default)
dt_grid <- grid_random(parameters(dt_tune_model),size =5)
## Warning: `parameters.model_spec()` was deprecated in tune 0.1.6.9003.
## i Please use `hardhat::extract_parameter_set_dials()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
dt_tuning <- sba_tune_wkfl %>% tune_grid(resamples = sba_folds, grid = dt_grid, metrics = sba_metrics)
best_dt_model <- dt_tuning %>% select_best(metric ='accuracy')
final_sba_wkfl <- sba_tune_wkfl %>% finalize_workflow(best_dt_model)
sba_final_fit <- final_sba_wkfl %>% last_fit(split = sba_split)
sba_final_fit %>% collect_metrics()
## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.915 Preprocessor1_Model1
## 2 roc_auc binary 0.956 Preprocessor1_Model1