Load necessary packages and data
# library(tidyverse)
# library(xgboost)
# library(caret)
# library(xgboostExplainer)
year1 <- read.csv('year1.csv')
year2 <- read.csv('year2.csv')
all_data <- rbind(year1, year2)
Classify all rows as swing / no swing
unique(all_data$description)
## [1] "ball" "foul"
## [3] "called_strike" "blocked_ball"
## [5] "hit_into_play" "hit_by_pitch"
## [7] "swinging_strike" "foul_tip"
## [9] "foul_bunt" "swinging_strike_blocked"
## [11] "missed_bunt" "pitchout"
## [13] "bunt_foul_tip" "foul_pitchout"
swing_vector <- c("foul", "hit_into_play", "swinging_strike", "foul_tip",
"swinging_strike_blocked", "foul_pitchout")
bunt_vector <- c("foul_bunt", "missed_bunt", "bunt_foul_tip")
# remove bunts (irregular swing decisions)
# *impossible to remove bunts in play with data provided
all_data <- all_data %>%
filter(!description %in% bunt_vector)
# add true / false for swing
all_data <- all_data %>%
mutate(swing = description %in% swing_vector) # TRUE for swings, FALSE for no swings
summary(all_data) # check data types
## season pitch_id release_speed batter
## Min. :1.0 Length:1415072 Length:1415072 Min. :5001
## 1st Qu.:1.0 Class :character Class :character 1st Qu.:5384
## Median :1.0 Mode :character Mode :character Median :5829
## Mean :1.5 Mean :5873
## 3rd Qu.:2.0 3rd Qu.:6336
## Max. :2.0 Max. :7129
## pitcher description stand p_throws
## Min. :5001 Length:1415072 Length:1415072 Length:1415072
## 1st Qu.:5416 Class :character Class :character Class :character
## Median :5802 Mode :character Mode :character Mode :character
## Mean :5869
## 3rd Qu.:6288
## Max. :7165
## pitch_type balls strikes pfx_x
## Length:1415072 Min. :0.0000 Min. :0.000 Length:1415072
## Class :character 1st Qu.:0.0000 1st Qu.:0.000 Class :character
## Mode :character Median :1.0000 Median :1.000 Mode :character
## Mean :0.8799 Mean :0.898
## 3rd Qu.:2.0000 3rd Qu.:2.000
## Max. :4.0000 Max. :3.000
## pfx_z plate_x plate_z sz_top
## Length:1415072 Length:1415072 Length:1415072 Length:1415072
## Class :character Class :character Class :character Class :character
## Mode :character Mode :character Mode :character Mode :character
##
##
##
## sz_bot swing
## Length:1415072 Mode :logical
## Class :character FALSE:742231
## Mode :character TRUE :672841
##
##
##
Several of the numeric columns have been loaded as characters. Convert to numeric for modeling purposes.
all_data <- all_data %>%
mutate(release_speed = as.numeric(release_speed),
pfx_x = as.numeric(pfx_x),
pfx_z = as.numeric(pfx_z),
plate_x = as.numeric(plate_x),
plate_z = as.numeric(plate_z),
sz_top = as.numeric(sz_top),
sz_bot = as.numeric(sz_bot))
## Warning: There were 7 warnings in `mutate()`.
## The first warning was:
## ℹ In argument: `release_speed = as.numeric(release_speed)`.
## Caused by warning:
## ! NAs introduced by coercion
## ℹ Run ]8;;ide:run:dplyr::last_dplyr_warnings()dplyr::last_dplyr_warnings()]8;; to see the 6 remaining warnings.
summary(all_data)
## season pitch_id release_speed batter
## Min. :1.0 Length:1415072 Min. : 30.10 Min. :5001
## 1st Qu.:1.0 Class :character 1st Qu.: 84.60 1st Qu.:5384
## Median :1.0 Mode :character Median : 89.80 Median :5829
## Mean :1.5 Mean : 88.87 Mean :5873
## 3rd Qu.:2.0 3rd Qu.: 93.80 3rd Qu.:6336
## Max. :2.0 Max. :104.20 Max. :7129
## NA's :778
## pitcher description stand p_throws
## Min. :5001 Length:1415072 Length:1415072 Length:1415072
## 1st Qu.:5416 Class :character Class :character Class :character
## Median :5802 Mode :character Mode :character Mode :character
## Mean :5869
## 3rd Qu.:6288
## Max. :7165
##
## pitch_type balls strikes pfx_x
## Length:1415072 Min. :0.0000 Min. :0.000 Min. :-2.870
## Class :character 1st Qu.:0.0000 1st Qu.:0.000 1st Qu.:-0.850
## Mode :character Median :1.0000 Median :1.000 Median :-0.170
## Mean :0.8799 Mean :0.898 Mean :-0.108
## 3rd Qu.:2.0000 3rd Qu.:2.000 3rd Qu.: 0.600
## Max. :4.0000 Max. :3.000 Max. : 2.840
## NA's :3456
## pfx_z plate_x plate_z sz_top
## Min. :-2.5600 Min. :-8.6600 Min. :-5.070 Min. :2.500
## 1st Qu.: 0.1800 1st Qu.:-0.5300 1st Qu.: 1.650 1st Qu.:3.290
## Median : 0.7300 Median : 0.0400 Median : 2.290 Median :3.390
## Mean : 0.6405 Mean : 0.0427 Mean : 2.282 Mean :3.388
## 3rd Qu.: 1.2700 3rd Qu.: 0.6100 3rd Qu.: 2.930 3rd Qu.:3.490
## Max. : 2.8200 Max. : 9.1100 Max. :10.220 Max. :4.470
## NA's :1472 NA's :778 NA's :811 NA's :778
## sz_bot swing
## Min. :0.770 Mode :logical
## 1st Qu.:1.520 FALSE:742231
## Median :1.590 TRUE :672841
## Mean :1.586
## 3rd Qu.:1.640
## Max. :2.260
## NA's :823
Separate into train and test data. Use year 1 as train and year 2 as test (this is convention with temporal data).
train_data <- all_data %>%
filter(season == 1)
test_data <- all_data %>%
filter(season == 2)
# Create DMatrices for xgboost
# Only numerics, label is response
dtrain_1 <- xgb.DMatrix(data = as.matrix(train_data[,c(3,10:17)]),
label = train_data$swing)
dtest_1 <- xgb.DMatrix(data = as.matrix(test_data[,c(3,10:17)]),
label = test_data$swing)
Selected for its ability to handle complex relationships in data, as well as data with NAs.
Prioritizing accuracy over interpretability.
set.seed(1993)
model_1 <- xgboost(dtrain_1,
nrounds = 100, # number of iterations
eta = 0.1, # learning rate
verbose = 1, # print eval metric
print_every_n = 20, # print every 20 iterations
objective = "binary:logistic", # 1/0 response
eval_metric = "auc") # area under curve
## [1] train-auc:0.893935
## [21] train-auc:0.988571
## [41] train-auc:0.993877
## [61] train-auc:0.995085
## [81] train-auc:0.995692
## [100] train-auc:0.996731
Predict based on model for more-easily interpretable evaluation
Confusion matrix tabulates predictions against actuals
model_1_pred <- predict(model_1, dtest_1)
# Classify any prediction over 0.5 as a swing (1) and under as no swing (0)
model_1_pred_class <- ifelse(model_1_pred > 0.5, TRUE,FALSE)
# create confusion matrix
model_1_cm <- confusionMatrix(as.factor(model_1_pred_class),
as.factor(test_data$swing), positive = "TRUE")
model_1_cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 363987 300789
## TRUE 5244 37171
##
## Accuracy : 0.5673
## 95% CI : (0.5661, 0.5684)
## No Information Rate : 0.5221
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.0995
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.10999
## Specificity : 0.98580
## Pos Pred Value : 0.87636
## Neg Pred Value : 0.54753
## Prevalence : 0.47789
## Detection Rate : 0.05256
## Detection Prevalence : 0.05998
## Balanced Accuracy : 0.54789
##
## 'Positive' Class : TRUE
##
The model predictions seem to be very skewed toward 0 (no swing). This may be due to an imbalance in the model
Plot variable importance to see if any variables are disproportionately driving the model
importance_matrix <- xgb.importance(model = model_1)
# Plot variable importance
xgb.plot.importance(importance_matrix)
For some reason, the strike zone settings are dominating the model. Sometimes, colinear variables can cause this, and sz_top and sz_bot are obviously highly correlated.
It makes the most sense to remove these variables and re-run the model.
# do not include columns 16 & 17
dtrain_2 <- xgb.DMatrix(data = as.matrix(train_data[,c(3,10:15)]),
label = train_data$swing)
dtest_2 <- xgb.DMatrix(data = as.matrix(test_data[,c(3,10:15)]),
label = test_data$swing)
model_2 <- xgboost(dtrain_2,
nrounds = 100,
eta = 0.1,
verbose = 1,
print_every_n = 20,
objective = "binary:logistic",
eval_metric = "auc")
## [1] train-auc:0.801755
## [21] train-auc:0.847887
## [41] train-auc:0.859004
## [61] train-auc:0.862702
## [81] train-auc:0.864775
## [100] train-auc:0.865768
model_2_pred <- predict(model_2, dtest_2)
model_2_pred_class <- ifelse(model_2_pred > 0.5, TRUE,FALSE)
model_2_cm <- confusionMatrix(as.factor(model_2_pred_class),
as.factor(test_data$swing), positive = "TRUE")
model_2_cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 283142 75131
## TRUE 86089 262829
##
## Accuracy : 0.772
## 95% CI : (0.771, 0.773)
## No Information Rate : 0.5221
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5438
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.7777
## Specificity : 0.7668
## Pos Pred Value : 0.7533
## Neg Pred Value : 0.7903
## Prevalence : 0.4779
## Detection Rate : 0.3717
## Detection Prevalence : 0.4934
## Balanced Accuracy : 0.7723
##
## 'Positive' Class : TRUE
##
These predictions are much more balanced
The performance of this model can also be evaluated by simply calculating the correlation between the predictions and the actuals. A high correlation means that higher swing probabability predictions are associated with actual swings.
cor(model_2_pred, test_data$swing)
## [1] 0.6241325
imp_mat_2 <- xgb.importance(model = model_2)
# Plot importance (top 10 variables)
xgb.plot.importance(imp_mat_2)
Pitch location (plate_x and plate_z) are now the most important variables, which is intuitive. Strikes against the hitter is next, which is also intuitive.
Add more variables to the model which may have an impact on swing probability.
# 1. plate_z relative to sz_top / sz_bot
derived_data <- all_data %>%
mutate(plate_z_rel = plate_z - (sz_top + sz_bot) / 2) %>% # location - middle
select(-plate_z) # remove plate_z
# 2. RHH / LHH dummy variable
derived_data <- derived_data %>%
mutate(batter_r = ifelse(stand == "R", 1, 0)) # 0 if LHH
# 3. RHP / LHP dummy variable
derived_data <- derived_data %>%
mutate(pitcher_r = ifelse(p_throws == "R", 1, 0)) # 0 if LHP
# 4. pitcher's usage rate of the pitch type in the count vs. RHH / LHH
derived_data <- derived_data %>%
group_by(pitcher, batter_r, balls, strikes) %>%
mutate(total_pitches = n()) %>% # temp column
group_by(pitcher, batter_r, balls, strikes, pitch_type) %>%
mutate(pitch_type_count = n()) %>% # temp column
mutate(pitcher_pitch_usage = pitch_type_count / total_pitches) %>%
select(-total_pitches, -pitch_type_count) %>% # remove temps
ungroup()
# 5. Velocity difference between the pitch and the pitcher's most-used pitch
derived_data <- derived_data %>%
left_join(derived_data %>% # sub query to get table of most-used pitch
group_by(pitcher, batter_r) %>%
mutate(total_pitches = n()) %>%
group_by(pitcher, batter_r, pitch_type) %>%
mutate(pitch_type_count = n()) %>%
mutate(pitcher_pitch_usage = pitch_type_count / total_pitches) %>%
group_by(pitcher, batter_r) %>%
mutate(most_used_pitch = pitch_type[which.max(pitcher_pitch_usage)]) %>%
filter(pitch_type == most_used_pitch) %>%
group_by(pitcher, batter_r) %>%
summarize(avg_velo_most_used = mean(release_speed)),
by = c("pitcher", "batter_r")) %>%
mutate(diff_velo_most_used = release_speed - avg_velo_most_used) %>%
select(-avg_velo_most_used)
## `summarise()` has grouped output by 'pitcher'. You can override using the
## `.groups` argument.
train_derived <- derived_data %>%
filter(season == 1)
test_derived <- derived_data %>%
filter(season == 2)
# Create DMatrices for xgboost
# Only numerics, label is response
dtrain_3 <- xgb.DMatrix(data = as.matrix(train_derived[,c(3,10:14, 18:22)]),
label = train_derived$swing)
dtest_3 <- xgb.DMatrix(data = as.matrix(test_derived[,c(3,10:14, 18:22)]),
label = test_derived$swing)
New model
set.seed(2003)
model_3 <- xgboost(dtrain_3,
nrounds = 100,
eta = 0.1,
verbose = 1,
print_every_n = 20,
objective = "binary:logistic",
eval_metric = "auc")
## [1] train-auc:0.800149
## [21] train-auc:0.850537
## [41] train-auc:0.863073
## [61] train-auc:0.867689
## [81] train-auc:0.870342
## [100] train-auc:0.871556
save(model_3, file = 'model_3.rda')
model_3_pred <- predict(model_3, dtest_3)
model_3_pred_class <- ifelse(model_3_pred > 0.5, TRUE,FALSE)
model_3_cm <- confusionMatrix(as.factor(model_3_pred_class),
as.factor(test_data$swing), positive = "TRUE")
model_3_cm
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 284300 73399
## TRUE 84931 264561
##
## Accuracy : 0.7761
## 95% CI : (0.7751, 0.7771)
## No Information Rate : 0.5221
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.552
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.7828
## Specificity : 0.7700
## Pos Pred Value : 0.7570
## Neg Pred Value : 0.7948
## Prevalence : 0.4779
## Detection Rate : 0.3741
## Detection Prevalence : 0.4942
## Balanced Accuracy : 0.7764
##
## 'Positive' Class : TRUE
##
imp_mat_3 <- xgb.importance(model = model_3)
# Plot importance (top 10 variables)
xgb.plot.importance(imp_mat_3)
cor(model_3_pred, test_data$swing)
## [1] 0.6316809
Adding these derived variables improves model performance across all metrics: AUC, confusion matrix, and correlation.
year3_data <- read.csv('year3.csv')
year3_data <- year3_data %>%
mutate(release_speed = as.numeric(release_speed),
pfx_x = as.numeric(pfx_x),
pfx_z = as.numeric(pfx_z),
plate_x = as.numeric(plate_x),
plate_z = as.numeric(plate_z),
sz_top = as.numeric(sz_top),
sz_bot = as.numeric(sz_bot))
## Warning: There were 7 warnings in `mutate()`.
## The first warning was:
## ℹ In argument: `release_speed = as.numeric(release_speed)`.
## Caused by warning:
## ! NAs introduced by coercion
## ℹ Run ]8;;ide:run:dplyr::last_dplyr_warnings()dplyr::last_dplyr_warnings()]8;; to see the 6 remaining warnings.
Apply derived variables
# 1. plate_z relative to sz_top / sz_bot
derived_data <- year3_data %>%
mutate(plate_z_rel = plate_z - (sz_top + sz_bot) / 2) %>% # location - middle
select(-plate_z) # remove plate_z
# 2. RHH / LHH dummy variable
derived_data <- derived_data %>%
mutate(batter_r = ifelse(stand == "R", 1, 0)) # 0 if LHH
# 3. RHP / LHP dummy variable
derived_data <- derived_data %>%
mutate(pitcher_r = ifelse(p_throws == "R", 1, 0)) # 0 if LHP
# 4. pitcher's usage rate of the pitch type in the count vs. RHH / LHH
derived_data <- derived_data %>%
group_by(pitcher, batter_r, balls, strikes) %>%
mutate(total_pitches = n()) %>% # temp column
group_by(pitcher, batter_r, balls, strikes, pitch_type) %>%
mutate(pitch_type_count = n()) %>% # temp column
mutate(pitcher_pitch_usage = pitch_type_count / total_pitches) %>%
select(-total_pitches, -pitch_type_count) %>% # remove temps
ungroup()
# 5. Velocity difference between the pitch and the pitcher's most-used pitch
derived_data <- derived_data %>%
left_join(derived_data %>% # sub query to get table of most-used pitch
group_by(pitcher, batter_r) %>%
mutate(total_pitches = n()) %>%
group_by(pitcher, batter_r, pitch_type) %>%
mutate(pitch_type_count = n()) %>%
mutate(pitcher_pitch_usage = pitch_type_count / total_pitches) %>%
group_by(pitcher, batter_r) %>%
mutate(most_used_pitch = pitch_type[which.max(pitcher_pitch_usage)]) %>%
filter(pitch_type == most_used_pitch) %>%
group_by(pitcher, batter_r) %>%
summarize(avg_velo_most_used = mean(release_speed)),
by = c("pitcher", "batter_r")) %>%
mutate(diff_velo_most_used = release_speed - avg_velo_most_used) %>%
select(-avg_velo_most_used)
## `summarise()` has grouped output by 'pitcher'. You can override using the
## `.groups` argument.
dyear3 <- xgb.DMatrix(data = as.matrix(derived_data[,c(3,9:13, 16:20)]))
year3_pred <- predict(model_3, dyear3)
validation <- cbind.data.frame(year3_data, year3_pred)
write.csv(validation, 'validation.csv')