Executive Summary

In this project an in depth analysis of the latest database on Nearest Earth Objects has been performed.

Although the Accuracy of our Logistic Regression was 90.4%, we had a high rate of false positives which was reprented in the form of high Sensitivity of 98.9%, low Specificity 8.5%, and an almost unacceptably low ROC AUC of 11.9%.

Through the implementation of a Random Forest, we observe that the most important variables are “miss_distance” and “relative_velocity”, which are followed by “est_diameter_max”, “absolute_magnitude” and “est_diameter_min”.

Our Random Forest has an Accuracy of 91.8%, a Sensitivity of 97.8%, and a Specificity of 34%, which means that Random Forest is a more reliable model compared to the results of our Logistic Regression model.

Finally, by Boosting our Random Forest and tuning it, we find a more reliable model with an ROC AUC of 91.6%.

Data Import

During the data import process, I dropped the ‘id’ and ‘name’ columns, since they are man-made variables and have no relevance to our analysis.

library(ggplot2)
library(tidymodels)
library(dplyr)
library(xgboost)
library(yardstick)
nasa_df <- read.csv("neo_v2.csv")

nasa_df$hazardous <- as.factor(nasa_df$hazardous)

# Converting KM to Meters
nasa_df$est_diameter_min <- nasa_df$est_diameter_min * 1000
nasa_df$est_diameter_max <- nasa_df$est_diameter_max * 1000
#nasa_df$orbiting_body <- as.factor(nasa_df$orbiting_body)
#nasa_df$sentry_object <- as.factor(nasa_df$sentry_object)
#nasa_df

#nrow(nasa_df)
#head(nasa_df)
nasa_df$id <- NULL
nasa_df$name <- NULL
nasa_df$orbiting_body <- NULL
nasa_df$sentry_object <- NULL
str(nasa_df)
## 'data.frame':    90836 obs. of  6 variables:
##  $ est_diameter_min  : num  1198.3 265.8 722 96.5 255 ...
##  $ est_diameter_max  : num  2679 594 1615 216 570 ...
##  $ relative_velocity : num  13569 73589 114259 24764 42738 ...
##  $ miss_distance     : num  54839744 61438127 49798725 25434973 46275567 ...
##  $ absolute_magnitude: num  16.7 20 17.8 22.2 20.1 ...
##  $ hazardous         : Factor w/ 2 levels "False","True": 1 2 1 1 2 1 1 1 1 1 ...
summary(nasa_df)
##  est_diameter_min   est_diameter_max   relative_velocity  miss_distance     
##  Min.   :    0.61   Min.   :    1.36   Min.   :   203.3   Min.   :    6746  
##  1st Qu.:   19.26   1st Qu.:   43.06   1st Qu.: 28619.0   1st Qu.:17210820  
##  Median :   48.37   Median :  108.15   Median : 44190.1   Median :37846579  
##  Mean   :  127.43   Mean   :  284.95   Mean   : 48066.9   Mean   :37066546  
##  3rd Qu.:  143.40   3rd Qu.:  320.66   3rd Qu.: 62923.6   3rd Qu.:56548996  
##  Max.   :37892.65   Max.   :84730.54   Max.   :236990.1   Max.   :74798651  
##  absolute_magnitude hazardous    
##  Min.   : 9.23      False:81996  
##  1st Qu.:21.34      True : 8840  
##  Median :23.70                   
##  Mean   :23.53                   
##  3rd Qu.:25.70                   
##  Max.   :33.20

Univariate Analysis

hazardous

ggplot(nasa_df, aes(hazardous)) +
  geom_bar(color = "darkgreen", fill = "lightgreen", width = 0.2) +
  theme_classic()

summary(nasa_df$hazardous)
## False  True 
## 81996  8840

est_diameter_min

ggplot(nasa_df, aes((est_diameter_min))) +
  geom_histogram(color = "darkgreen", fill = "lightgreen") +
  scale_x_log10() +
  theme_classic()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

summary(nasa_df$est_diameter_min)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
##     0.61    19.26    48.37   127.43   143.40 37892.65
mean(nasa_df$est_diameter_min)
## [1] 127.4321

est_diameter_max

ggplot(nasa_df, aes(est_diameter_max)) +
  geom_histogram(color = "darkgreen", fill = "lightgreen") +
  scale_x_log10() +
  theme_classic()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

summary(nasa_df$est_diameter_max)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
##     1.36    43.06   108.15   284.95   320.66 84730.54

relative_velocity

ggplot(nasa_df, aes(relative_velocity)) +
  geom_histogram(color = "darkgreen", fill = "lightgreen") +
  theme_classic()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

summary(nasa_df$relative_velocity)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
##    203.3  28619.0  44190.1  48066.9  62923.6 236990.1

miss_distance

ggplot(nasa_df, aes(miss_distance)) +
  geom_histogram(color = "darkgreen", fill = "lightgreen") +
  theme_classic()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

summary(nasa_df$miss_distance)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
##     6746 17210820 37846579 37066546 56548996 74798651

absolute_magnitude

ggplot(nasa_df, aes(absolute_magnitude)) +
  geom_histogram(color = "darkgreen", fill = "lightgreen") +
  theme_classic()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

summary(nasa_df$absolute_magnitude)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    9.23   21.34   23.70   23.53   25.70   33.20

Bivariate Analysis

Hazardous vs. est_diameter_min

glm_mdl <- glm(hazardous ~ est_diameter_min, data = nasa_df, family = binomial)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
summary(glm_mdl)
## 
## Call:
## glm(formula = hazardous ~ est_diameter_min, family = binomial, 
##     data = nasa_df)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -8.4904  -0.4230  -0.4000  -0.3922   2.2207  
## 
## Coefficients:
##                    Estimate Std. Error z value Pr(>|z|)    
## (Intercept)      -2.552e+00  1.369e-02 -186.35   <2e-16 ***
## est_diameter_min  1.985e-03  4.026e-05   49.31   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 57981  on 90835  degrees of freedom
## Residual deviance: 55084  on 90834  degrees of freedom
## AIC: 55088
## 
## Number of Fisher Scoring iterations: 6
coefficients(glm_mdl)
##      (Intercept) est_diameter_min 
##     -2.551719401      0.001985131
ggplot(nasa_df, aes(est_diameter_min, as.numeric(hazardous) - 1)) +
  geom_point() +
  geom_smooth(method="glm", color="blue", se=FALSE,
                method.args = list(family='binomial'))
## `geom_smooth()` using formula 'y ~ x'
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

Hazardous vs. est_diameter_max

glm_mdl <- glm(hazardous ~ est_diameter_max, data = nasa_df, family = binomial)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
summary(glm_mdl)
## 
## Call:
## glm(formula = hazardous ~ est_diameter_max, family = binomial, 
##     data = nasa_df)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -8.4904  -0.4230  -0.4000  -0.3922   2.2207  
## 
## Coefficients:
##                    Estimate Std. Error z value Pr(>|z|)    
## (Intercept)      -2.5517194  0.0136931 -186.35   <2e-16 ***
## est_diameter_max  0.0008878  0.0000180   49.31   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 57981  on 90835  degrees of freedom
## Residual deviance: 55084  on 90834  degrees of freedom
## AIC: 55088
## 
## Number of Fisher Scoring iterations: 6
coefficients(glm_mdl)
##      (Intercept) est_diameter_max 
##    -2.5517194008     0.0008877777
ggplot(nasa_df, aes(est_diameter_max, as.numeric(hazardous) - 1)) +
  geom_point() +
  geom_smooth(method="glm", color="blue", se=FALSE,
                method.args = list(family='binomial'))
## `geom_smooth()` using formula 'y ~ x'
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

Hazardous vs. relative_velocity

glm_mdl <- glm(hazardous ~ relative_velocity, data = nasa_df, family = binomial)
summary(glm_mdl)
## 
## Call:
## glm(formula = hazardous ~ relative_velocity, family = binomial, 
##     data = nasa_df)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9757  -0.4755  -0.3868  -0.3267   2.5777  
## 
## Coefficients:
##                     Estimate Std. Error z value Pr(>|z|)    
## (Intercept)       -3.416e+00  2.646e-02 -129.10   <2e-16 ***
## relative_velocity  2.200e-05  3.987e-07   55.18   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 57981  on 90835  degrees of freedom
## Residual deviance: 55014  on 90834  degrees of freedom
## AIC: 55018
## 
## Number of Fisher Scoring iterations: 5
coefficients(glm_mdl)
##       (Intercept) relative_velocity 
##     -3.415572e+00      2.200084e-05
ggplot(nasa_df, aes(relative_velocity, as.numeric(hazardous) - 1)) +
  geom_point() +
  geom_smooth(method="glm", color="blue", se=FALSE,
                method.args = list(family='binomial'))
## `geom_smooth()` using formula 'y ~ x'

Hazardous vs. miss_distance

glm_mdl <- glm(hazardous ~ miss_distance, data = nasa_df, family = binomial)
summary(glm_mdl)
## 
## Call:
## glm(formula = hazardous ~ miss_distance, family = binomial, data = nasa_df)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -0.5054  -0.4749  -0.4456  -0.4146   2.2600  
## 
## Coefficients:
##                 Estimate Std. Error z value Pr(>|z|)    
## (Intercept)   -2.474e+00  2.294e-02 -107.85   <2e-16 ***
## miss_distance  6.426e-09  5.049e-10   12.72   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 57981  on 90835  degrees of freedom
## Residual deviance: 57818  on 90834  degrees of freedom
## AIC: 57822
## 
## Number of Fisher Scoring iterations: 5
coefficients(glm_mdl)
##   (Intercept) miss_distance 
## -2.473836e+00  6.425476e-09
ggplot(nasa_df, aes(miss_distance, as.numeric(hazardous) - 1)) +
  geom_point() +
  geom_smooth(method="glm", color="blue", se=FALSE,
                method.args = list(family='binomial'))
## `geom_smooth()` using formula 'y ~ x'

Hazardous vs. absolute_magnitude

glm_mdl <- glm(hazardous ~ absolute_magnitude, data = nasa_df, family = binomial)
summary(glm_mdl)
## 
## Call:
## glm(formula = hazardous ~ absolute_magnitude, family = binomial, 
##     data = nasa_df)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.9540  -0.4217  -0.2430  -0.1423   2.2411  
## 
## Coefficients:
##                     Estimate Std. Error z value Pr(>|z|)    
## (Intercept)         9.099601   0.113777   79.98   <2e-16 ***
## absolute_magnitude -0.514568   0.005415  -95.02   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 57981  on 90835  degrees of freedom
## Residual deviance: 44949  on 90834  degrees of freedom
## AIC: 44953
## 
## Number of Fisher Scoring iterations: 6
coefficients(glm_mdl)
##        (Intercept) absolute_magnitude 
##          9.0996008         -0.5145679
ggplot(nasa_df, aes(absolute_magnitude, as.numeric(hazardous) - 1)) +
  geom_point() +
  geom_smooth(method="glm", color="blue", se=FALSE,
                method.args = list(family='binomial'))
## `geom_smooth()` using formula 'y ~ x'

Correlation Matrix

# Correlation
#cor(nasa_df$est_diameter_min, nasa_df$relative_velocity, method="pearson")
#nasa_df$hazardous <- as.numeric(nasa_df$hazardous)
#cor(nasa_df)
library(corrplot)
## Warning: package 'corrplot' was built under R version 4.2.1
## corrplot 0.92 loaded
#corrplot(nasa_df, method = "circle")

Pairs Chart

pairs(data = nasa_df, hazardous ~ est_diameter_min + est_diameter_max + 
        relative_velocity + miss_distance + absolute_magnitude, col = nasa_df$hazardous)

Logistic Regression

# Create data split object
nasa_split <- initial_split(nasa_df, prop = 0.75, strata = hazardous)

# Create the training data
nasa_training <- nasa_split %>% 
  training()

# Create the test data
nasa_test <- nasa_split %>% 
  testing()

# Check the number of rows
nrow(nasa_training)
## [1] 68127
nrow(nasa_test)
## [1] 22709
# Specify a logistic regression model
logistic_model <- logistic_reg() %>% 
  # Set the engine
  set_engine('glm') %>% 
  # Set the mode
  set_mode('classification')

# Fit to training data
logistic_fit <- logistic_model %>% 
  fit(hazardous ~ est_diameter_min + est_diameter_max + relative_velocity +
                  miss_distance + absolute_magnitude,
      data = nasa_training)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
# Print model fit object
logistic_fit
## parsnip model object
## 
## 
## Call:  stats::glm(formula = hazardous ~ est_diameter_min + est_diameter_max + 
##     relative_velocity + miss_distance + absolute_magnitude, family = stats::binomial, 
##     data = data)
## 
## Coefficients:
##        (Intercept)    est_diameter_min    est_diameter_max   relative_velocity  
##          2.457e+01          -3.357e+06           1.501e+06           7.789e-06  
##      miss_distance  absolute_magnitude  
##         -2.085e-08          -1.169e+00  
## 
## Degrees of Freedom: 68126 Total (i.e. Null);  68121 Residual
## Null Deviance:       43410 
## Residual Deviance: 30260     AIC: 30270

Making Predictions based on Logistic Regression

# Predict outcome categories
class_preds <- predict(logistic_fit, new_data = nasa_test,
                       type = 'class')

# Obtain estimated probabilities for each outcome value
prob_preds <- predict(logistic_fit, new_data = nasa_test, 
                      type = 'prob')

# Combine test set results
nasa_results <- nasa_test %>% 
  select(hazardous) %>% 
  bind_cols(class_preds, prob_preds)

# View results tibble
#nasa_results

Confusion Matrix

# Calculate the confusion matrix
head(nasa_results)
##   hazardous .pred_class .pred_False   .pred_True
## 1     False       False   0.6001969 3.998031e-01
## 2     False       False   0.9199434 8.005656e-02
## 3      True       False   0.7420153 2.579847e-01
## 4     False       False   0.9999114 8.858862e-05
## 5     False       False   0.8853639 1.146361e-01
## 6     False       False   0.9998246 1.753648e-04
conf_mat(nasa_results, truth = hazardous,
         estimate = .pred_class)
##           Truth
## Prediction False  True
##      False 20269  2058
##      True    214   168

Accuracy

# Calculate the accuracy
accuracy(nasa_results, truth = hazardous,
         estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.900

Sensitivity

# Calculate the sensitivity
sens(nasa_results, truth = hazardous,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 sens    binary         0.990

Specificity

# Calculate the specificity
spec(nasa_results, truth = hazardous,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 spec    binary        0.0755
# Create a Heat Map of confusion matrix
conf_mat(nasa_results,
         truth = hazardous,
         estimate = .pred_class) %>%
  # Create a heat map
  autoplot(type = 'heatmap')

# Create a Mosaic Plor of confusion matrix
conf_mat(nasa_results,
         truth = hazardous,
         estimate = .pred_class) %>% 
  # Create a mosaic plot
  autoplot(type = "mosaic")

# ROC & AUC

# Calculate metrics across thresholds
threshold_df <- nasa_results %>% 
  roc_curve(truth = hazardous, .pred_True)

# View results
threshold_df
## # A tibble: 22,710 × 3
##     .threshold specificity sensitivity
##          <dbl>       <dbl>       <dbl>
##  1 -Inf                  0        1   
##  2    2.22e-16           0        1   
##  3    5.62e- 8           0        1.00
##  4    3.42e- 7           0        1.00
##  5    3.61e- 7           0        1.00
##  6    4.70e- 7           0        1.00
##  7    6.13e- 7           0        1.00
##  8    6.15e- 7           0        1.00
##  9    6.52e- 7           0        1.00
## 10    8.00e- 7           0        1.00
## # … with 22,700 more rows
# Plot ROC curve
threshold_df %>% 
  autoplot()

# Calculate ROC AUC
roc_auc(nasa_results,
        truth = hazardous, 
        .pred_True)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.123

Random Forest

spec <- rand_forest(mtry = 4,
            trees = 100,
            min_n = 10) %>%
  set_mode("classification") %>%
  set_engine("ranger")

Training the Forest

spec2 <- spec %>%
  fit(hazardous ~ est_diameter_min + est_diameter_max + relative_velocity + miss_distance + absolute_magnitude, data = nasa_training)
#library("ranger")
#model_rf <- ranger(hazardous ~ ., data = nasa_training, probability = TRUE)
##data = nasa_training[complete.cases(nasa_training),]
#model_rf

Variable Importance

We observe that the most important variables are “miss_distance” and “relative_velocity”, which are followed by “est_diameter_max”, “absolute_magnitude” and “est_diameter_min”.

rand_forest(mode = "classification") %>%
  set_engine("ranger", importance = "impurity") %>%
  fit(hazardous ~ ., data = nasa_training) %>%
  vip::vip()

Predictions based on Random Forest Model

forest_predictions <- predict(spec2, new_data = nasa_test)
forest_predictions
## # A tibble: 22,709 × 1
##    .pred_class
##    <fct>      
##  1 False      
##  2 False      
##  3 False      
##  4 False      
##  5 False      
##  6 False      
##  7 False      
##  8 False      
##  9 False      
## 10 False      
## # … with 22,699 more rows
#Combining predictions and truth value
pred_combined <- forest_predictions %>% 
  mutate(true_class = nasa_test$hazardous)

head(pred_combined)
## # A tibble: 6 × 2
##   .pred_class true_class
##   <fct>       <fct>     
## 1 False       False     
## 2 False       False     
## 3 False       True      
## 4 False       False     
## 5 False       False     
## 6 False       False
#Confusion matrix
conf_mat(data = pred_combined, estimate = .pred_class, truth = true_class)
##           Truth
## Prediction False  True
##      False 20058  1475
##      True    425   751

Accuracy of Random Forest

# Calculate the accuracy
accuracy(pred_combined, truth = true_class,
         estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.916

Sensitivity of Random Forest

# Calculate the sensitivity
sens(pred_combined, truth = true_class,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 sens    binary         0.979

Specificity of Random Forest

# Calculate the specificity
spec(pred_combined, truth = true_class,
     estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 spec    binary         0.337

Boosting the Random Forest

library("xgboost")

boost_spec <- boost_tree() %>%
  set_mode("classification") %>%
  set_engine("xgboost")
boost_spec
## Boosted Tree Model Specification (classification)
## 
## Computational engine: xgboost
boost_model <- fit(boost_spec, formula = hazardous ~ est_diameter_min + est_diameter_max + relative_velocity + miss_distance + absolute_magnitude, data = nasa_training)
boost_model
## parsnip model object
## 
## ##### xgb.Booster
## raw: 44.9 Kb 
## call:
##   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
##     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
##     subsample = 1, objective = "binary:logistic"), data = x$data, 
##     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
## params (as set within xgb.train):
##   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "binary:logistic", nthread = "1", validate_parameters = "TRUE"
## xgb.attributes:
##   niter
## callbacks:
##   cb.evaluation.log()
## # of features: 5 
## niter: 15
## nfeatures : 5 
## evaluation_log:
##     iter training_logloss
##        1        0.5000166
##        2        0.3939080
## ---                      
##       14        0.1814011
##       15        0.1795252
set.seed(99)

folds <- vfold_cv(nasa_training, v = 3)

cv_results <- fit_resamples(boost_spec,
                            hazardous ~ est_diameter_min + est_diameter_max + relative_velocity + miss_distance + absolute_magnitude,
                            resamples = folds,
                            metrics = metric_set(yardstick::roc_auc, accuracy, sens, specificity))
collect_metrics(cv_results)
## # A tibble: 4 × 6
##   .metric     .estimator  mean     n  std_err .config             
##   <chr>       <chr>      <dbl> <int>    <dbl> <chr>               
## 1 accuracy    binary     0.913     3 0.00111  Preprocessor1_Model1
## 2 roc_auc     binary     0.915     3 0.000505 Preprocessor1_Model1
## 3 sens        binary     0.996     3 0.000758 Preprocessor1_Model1
## 4 specificity binary     0.143     3 0.00724  Preprocessor1_Model1
set.seed(100)

predictions <- boost_tree() %>%
  set_mode("classification") %>%
  set_engine("xgboost") %>% 
  fit(hazardous ~ ., data = nasa_training) %>%
  predict(new_data = nasa_training, type = "prob") %>% 
  bind_cols(nasa_training)

predictions
## # A tibble: 68,127 × 8
##    .pred_False .pred_True est_diameter_min est_diameter_max relative_velocity
##          <dbl>      <dbl>            <dbl>            <dbl>             <dbl>
##  1       0.980    0.0203            1198.            2679.             13569.
##  2       0.723    0.277              266.             594.             73589.
##  3       0.995    0.00501             36.4             81.3            34298.
##  4       0.706    0.294              172.             384.             27529.
##  5       0.784    0.216              350.             784.             56625.
##  6       0.713    0.287              253.             565.             58431.
##  7       0.693    0.307              153.             342.             64394.
##  8       0.995    0.00501             69.9            156.             38019.
##  9       0.977    0.0227             290.             649.             10402.
## 10       0.995    0.00501             44.1             98.6            70771.
## # … with 68,117 more rows, and 3 more variables: miss_distance <dbl>,
## #   absolute_magnitude <dbl>, hazardous <fct>
# Calculate AUC
roc_auc(predictions, 
        truth = hazardous, 
        estimate = .pred_True)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary        0.0771
boost_spec <- boost_tree(
                trees = 100,
                learn_rate = tune(),
                tree_depth = tune(),
                sample_size = tune()) %>%
  set_mode("classification") %>%
  set_engine("xgboost")

# Create the tuning grid
tunegrid_boost <- grid_regular(parameters(boost_spec), 
                      levels = 3)
## Warning: `parameters.model_spec()` was deprecated in tune 0.1.6.9003.
## Please use `hardhat::extract_parameter_set_dials()` instead.
tunegrid_boost
## # A tibble: 27 × 3
##    tree_depth learn_rate sample_size
##         <int>      <dbl>       <dbl>
##  1          1     0.001         0.1 
##  2          8     0.001         0.1 
##  3         15     0.001         0.1 
##  4          1     0.0178        0.1 
##  5          8     0.0178        0.1 
##  6         15     0.0178        0.1 
##  7          1     0.316         0.1 
##  8          8     0.316         0.1 
##  9         15     0.316         0.1 
## 10          1     0.001         0.55
## # … with 17 more rows
# Create CV folds of training data
folds <- vfold_cv(nasa_training, v = 3)

# Tune along the grid
tune_results <- tune_grid(boost_spec,
                   hazardous ~ est_diameter_min + est_diameter_max + relative_velocity + miss_distance + absolute_magnitude,
                   resamples = folds,
                   grid = tunegrid_boost,
                   metrics = metric_set(yardstick::roc_auc))

# Plot the results
autoplot(tune_results)