library(tidyverse)
library(tidymodels)
library(patchwork)
library(kableExtra)
library(ggcorrplot)
df_raw = read.csv("breast cancer.csv")
df = df_raw %>% select(-X)
569 X 33 data frame
kable(head(df)) %>%
kable_styling() %>%
scroll_box(width = "100%", height = "300px")
| id | diagnosis | radius_mean | texture_mean | perimeter_mean | area_mean | smoothness_mean | compactness_mean | concavity_mean | concave.points_mean | symmetry_mean | fractal_dimension_mean | radius_se | texture_se | perimeter_se | area_se | smoothness_se | compactness_se | concavity_se | concave.points_se | symmetry_se | fractal_dimension_se | radius_worst | texture_worst | perimeter_worst | area_worst | smoothness_worst | compactness_worst | concavity_worst | concave.points_worst | symmetry_worst | fractal_dimension_worst |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 842302 | M | 17.99 | 10.38 | 122.80 | 1001.0 | 0.11840 | 0.27760 | 0.3001 | 0.14710 | 0.2419 | 0.07871 | 1.0950 | 0.9053 | 8.589 | 153.40 | 0.006399 | 0.04904 | 0.05373 | 0.01587 | 0.03003 | 0.006193 | 25.38 | 17.33 | 184.60 | 2019.0 | 0.1622 | 0.6656 | 0.7119 | 0.2654 | 0.4601 | 0.11890 |
| 842517 | M | 20.57 | 17.77 | 132.90 | 1326.0 | 0.08474 | 0.07864 | 0.0869 | 0.07017 | 0.1812 | 0.05667 | 0.5435 | 0.7339 | 3.398 | 74.08 | 0.005225 | 0.01308 | 0.01860 | 0.01340 | 0.01389 | 0.003532 | 24.99 | 23.41 | 158.80 | 1956.0 | 0.1238 | 0.1866 | 0.2416 | 0.1860 | 0.2750 | 0.08902 |
| 84300903 | M | 19.69 | 21.25 | 130.00 | 1203.0 | 0.10960 | 0.15990 | 0.1974 | 0.12790 | 0.2069 | 0.05999 | 0.7456 | 0.7869 | 4.585 | 94.03 | 0.006150 | 0.04006 | 0.03832 | 0.02058 | 0.02250 | 0.004571 | 23.57 | 25.53 | 152.50 | 1709.0 | 0.1444 | 0.4245 | 0.4504 | 0.2430 | 0.3613 | 0.08758 |
| 84348301 | M | 11.42 | 20.38 | 77.58 | 386.1 | 0.14250 | 0.28390 | 0.2414 | 0.10520 | 0.2597 | 0.09744 | 0.4956 | 1.1560 | 3.445 | 27.23 | 0.009110 | 0.07458 | 0.05661 | 0.01867 | 0.05963 | 0.009208 | 14.91 | 26.50 | 98.87 | 567.7 | 0.2098 | 0.8663 | 0.6869 | 0.2575 | 0.6638 | 0.17300 |
| 84358402 | M | 20.29 | 14.34 | 135.10 | 1297.0 | 0.10030 | 0.13280 | 0.1980 | 0.10430 | 0.1809 | 0.05883 | 0.7572 | 0.7813 | 5.438 | 94.44 | 0.011490 | 0.02461 | 0.05688 | 0.01885 | 0.01756 | 0.005115 | 22.54 | 16.67 | 152.20 | 1575.0 | 0.1374 | 0.2050 | 0.4000 | 0.1625 | 0.2364 | 0.07678 |
| 843786 | M | 12.45 | 15.70 | 82.57 | 477.1 | 0.12780 | 0.17000 | 0.1578 | 0.08089 | 0.2087 | 0.07613 | 0.3345 | 0.8902 | 2.217 | 27.19 | 0.007510 | 0.03345 | 0.03672 | 0.01137 | 0.02165 | 0.005082 | 15.47 | 23.75 | 103.40 | 741.6 | 0.1791 | 0.5249 | 0.5355 | 0.1741 | 0.3985 | 0.12440 |
str(df)
## 'data.frame': 569 obs. of 32 variables:
## $ id : int 842302 842517 84300903 84348301 84358402 843786 844359 84458202 844981 84501001 ...
## $ diagnosis : chr "M" "M" "M" "M" ...
## $ radius_mean : num 18 20.6 19.7 11.4 20.3 ...
## $ texture_mean : num 10.4 17.8 21.2 20.4 14.3 ...
## $ perimeter_mean : num 122.8 132.9 130 77.6 135.1 ...
## $ area_mean : num 1001 1326 1203 386 1297 ...
## $ smoothness_mean : num 0.1184 0.0847 0.1096 0.1425 0.1003 ...
## $ compactness_mean : num 0.2776 0.0786 0.1599 0.2839 0.1328 ...
## $ concavity_mean : num 0.3001 0.0869 0.1974 0.2414 0.198 ...
## $ concave.points_mean : num 0.1471 0.0702 0.1279 0.1052 0.1043 ...
## $ symmetry_mean : num 0.242 0.181 0.207 0.26 0.181 ...
## $ fractal_dimension_mean : num 0.0787 0.0567 0.06 0.0974 0.0588 ...
## $ radius_se : num 1.095 0.543 0.746 0.496 0.757 ...
## $ texture_se : num 0.905 0.734 0.787 1.156 0.781 ...
## $ perimeter_se : num 8.59 3.4 4.58 3.44 5.44 ...
## $ area_se : num 153.4 74.1 94 27.2 94.4 ...
## $ smoothness_se : num 0.0064 0.00522 0.00615 0.00911 0.01149 ...
## $ compactness_se : num 0.049 0.0131 0.0401 0.0746 0.0246 ...
## $ concavity_se : num 0.0537 0.0186 0.0383 0.0566 0.0569 ...
## $ concave.points_se : num 0.0159 0.0134 0.0206 0.0187 0.0188 ...
## $ symmetry_se : num 0.03 0.0139 0.0225 0.0596 0.0176 ...
## $ fractal_dimension_se : num 0.00619 0.00353 0.00457 0.00921 0.00511 ...
## $ radius_worst : num 25.4 25 23.6 14.9 22.5 ...
## $ texture_worst : num 17.3 23.4 25.5 26.5 16.7 ...
## $ perimeter_worst : num 184.6 158.8 152.5 98.9 152.2 ...
## $ area_worst : num 2019 1956 1709 568 1575 ...
## $ smoothness_worst : num 0.162 0.124 0.144 0.21 0.137 ...
## $ compactness_worst : num 0.666 0.187 0.424 0.866 0.205 ...
## $ concavity_worst : num 0.712 0.242 0.45 0.687 0.4 ...
## $ concave.points_worst : num 0.265 0.186 0.243 0.258 0.163 ...
## $ symmetry_worst : num 0.46 0.275 0.361 0.664 0.236 ...
## $ fractal_dimension_worst: num 0.1189 0.089 0.0876 0.173 0.0768 ...
df$diagnosis = as.factor(df$diagnosis)
No missing values
colSums(is.na(df))
## id diagnosis radius_mean
## 0 0 0
## texture_mean perimeter_mean area_mean
## 0 0 0
## smoothness_mean compactness_mean concavity_mean
## 0 0 0
## concave.points_mean symmetry_mean fractal_dimension_mean
## 0 0 0
## radius_se texture_se perimeter_se
## 0 0 0
## area_se smoothness_se compactness_se
## 0 0 0
## concavity_se concave.points_se symmetry_se
## 0 0 0
## fractal_dimension_se radius_worst texture_worst
## 0 0 0
## perimeter_worst area_worst smoothness_worst
## 0 0 0
## compactness_worst concavity_worst concave.points_worst
## 0 0 0
## symmetry_worst fractal_dimension_worst
## 0 0
The number of unique IDs is equal to number of observations in this dataframe. So each row a represents unique observation. Let’s drop the id column as we don’t gain any additional information from it.
apply(df,2,function(x) length(unique(x)))
## id diagnosis radius_mean
## 569 2 456
## texture_mean perimeter_mean area_mean
## 479 522 539
## smoothness_mean compactness_mean concavity_mean
## 474 537 537
## concave.points_mean symmetry_mean fractal_dimension_mean
## 542 432 499
## radius_se texture_se perimeter_se
## 540 519 533
## area_se smoothness_se compactness_se
## 528 547 541
## concavity_se concave.points_se symmetry_se
## 533 507 498
## fractal_dimension_se radius_worst texture_worst
## 545 457 511
## perimeter_worst area_worst smoothness_worst
## 514 544 411
## compactness_worst concavity_worst concave.points_worst
## 529 539 492
## symmetry_worst fractal_dimension_worst
## 500 535
df = df %>% select(-id)
We may need to scale the data before modeling
num.dat = df %>% select_if(is.numeric)
apply(num.dat,2,function(x) round(summary(x),3)) %>%
kbl() %>%
kable_styling(bootstrap_options = c("striped","hover","bordered")) %>%
kable_paper() %>%
scroll_box(width = "100%", height = "320px")
| radius_mean | texture_mean | perimeter_mean | area_mean | smoothness_mean | compactness_mean | concavity_mean | concave.points_mean | symmetry_mean | fractal_dimension_mean | radius_se | texture_se | perimeter_se | area_se | smoothness_se | compactness_se | concavity_se | concave.points_se | symmetry_se | fractal_dimension_se | radius_worst | texture_worst | perimeter_worst | area_worst | smoothness_worst | compactness_worst | concavity_worst | concave.points_worst | symmetry_worst | fractal_dimension_worst | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Min. | 6.981 | 9.71 | 43.790 | 143.500 | 0.053 | 0.019 | 0.000 | 0.000 | 0.106 | 0.050 | 0.112 | 0.360 | 0.757 | 6.802 | 0.002 | 0.002 | 0.000 | 0.000 | 0.008 | 0.001 | 7.930 | 12.020 | 50.410 | 185.200 | 0.071 | 0.027 | 0.000 | 0.000 | 0.156 | 0.055 |
| 1st Qu. | 11.700 | 16.17 | 75.170 | 420.300 | 0.086 | 0.065 | 0.030 | 0.020 | 0.162 | 0.058 | 0.232 | 0.834 | 1.606 | 17.850 | 0.005 | 0.013 | 0.015 | 0.008 | 0.015 | 0.002 | 13.010 | 21.080 | 84.110 | 515.300 | 0.117 | 0.147 | 0.114 | 0.065 | 0.250 | 0.071 |
| Median | 13.370 | 18.84 | 86.240 | 551.100 | 0.096 | 0.093 | 0.062 | 0.034 | 0.179 | 0.062 | 0.324 | 1.108 | 2.287 | 24.530 | 0.006 | 0.020 | 0.026 | 0.011 | 0.019 | 0.003 | 14.970 | 25.410 | 97.660 | 686.500 | 0.131 | 0.212 | 0.227 | 0.100 | 0.282 | 0.080 |
| Mean | 14.127 | 19.29 | 91.969 | 654.889 | 0.096 | 0.104 | 0.089 | 0.049 | 0.181 | 0.063 | 0.405 | 1.217 | 2.866 | 40.337 | 0.007 | 0.025 | 0.032 | 0.012 | 0.021 | 0.004 | 16.269 | 25.677 | 107.261 | 880.583 | 0.132 | 0.254 | 0.272 | 0.115 | 0.290 | 0.084 |
| 3rd Qu. | 15.780 | 21.80 | 104.100 | 782.700 | 0.105 | 0.130 | 0.131 | 0.074 | 0.196 | 0.066 | 0.479 | 1.474 | 3.357 | 45.190 | 0.008 | 0.032 | 0.042 | 0.015 | 0.023 | 0.005 | 18.790 | 29.720 | 125.400 | 1084.000 | 0.146 | 0.339 | 0.383 | 0.161 | 0.318 | 0.092 |
| Max. | 28.110 | 39.28 | 188.500 | 2501.000 | 0.163 | 0.345 | 0.427 | 0.201 | 0.304 | 0.097 | 2.873 | 4.885 | 21.980 | 542.200 | 0.031 | 0.135 | 0.396 | 0.053 | 0.079 | 0.030 | 36.040 | 49.540 | 251.200 | 4254.000 | 0.223 | 1.058 | 1.252 | 0.291 | 0.664 | 0.208 |
The correlation plot shows that there are quite a few highly correlated variables.
We will take that into consideration in setting up the model.
ggcorrplot(cor(df[,-1]),tl.cex = 9,tl.srt = 50,title = "Correlation heat-map")
Both plots attempt to capture how well separated are both classes across all variables in the dataframe.
We expect that a variable that provides good separation might have strong predictive power as well.
scaled_df = df %>% mutate(across(where(is.numeric), scale))
scaled_M = scaled_df %>% filter(diagnosis == "M")
scaled_B = scaled_df %>% filter(diagnosis == "B")
#custom function to calculate confidence interval for the sample mean for all
#numeric columns using bootstrap
#Not very efficient, please let me know if you have a better implementation of this code.
mean_bootstrap_data = function(data.frame,boots_number) {
require(boot)
require(tidyverse)
numeric_cols = unlist(lapply(data.frame, is.numeric))
data = data.frame[,numeric_cols]
ci_list = list()
meanfun <- function(data, idx)
{
dt <- data[idx, ]
c(mean(dt[,i]))
}
for (i in seq(ncol(data))) {
bootstrap <- boot(data, meanfun, R = boots_number)
ci_list[[i]] = cbind(boot.ci(boot.out = bootstrap,
type = "norm")$normal, colnames(data)[i])
}
dt_matrix = do.call(rbind, ci_list)
data_output = dt_matrix %>%
as.data.frame() %>%
rename(lower_bound = V2, upper_bound = V3, variable = V4) %>%
mutate(across(c("lower_bound", "upper_bound"), as.numeric),
mean = (lower_bound+upper_bound)/2)
return(data_output)
}
set.seed(1)
ci_B = mean_bootstrap_data(scaled_B,2500) %>% mutate(diagnosis = "B")
ci_M = mean_bootstrap_data(scaled_M,2500) %>% mutate(diagnosis = "M")
boots_ci = rbind(ci_M,ci_B)
Palette1 = c('cyan3','firebrick1')
boots_ci = boots_ci %>% mutate(name = fct_reorder(variable, (mean)))
ci_plot = ggplot(boots_ci,aes(name,mean,color = diagnosis))+
geom_errorbar(aes(ymin = lower_bound,ymax = upper_bound))+
geom_point()+
coord_flip()+
theme_minimal()+
scale_color_manual(values = Palette1)+
ggtitle("Bootstrap Confidence Intervals - Columns Means")+
theme(
axis.title.y=element_blank(),
legend.position = "none",
plot.title = element_text(size = 16,face = "bold.italic"),
axis.text.y = element_text(size = 10,face = "bold")
)
df_long = df %>% pivot_longer(cols = where(is.numeric),names_to = "variable")
levels_order = levels(fct_rev((boots_ci$name)))
df_long$variable = factor(df_long$variable, levels = levels_order) # So we get the same order of variables in both plots
dens_plot = ggplot(df_long, aes(value,fill = diagnosis)) +
geom_density(alpha = 0.4) +
facet_wrap(~variable,scales = "free")+
theme(
axis.ticks.y=element_blank(),
axis.title.y=element_blank(),
axis.text.y=element_blank(),
axis.ticks.x=element_blank(),
axis.title.x=element_blank(),
axis.text.x=element_blank(),
strip.text.x = element_text(size = 8, color = "black"),
plot.title = element_text(size = 16,face = "bold.italic")
)+
scale_fill_manual(values = Palette1)+
ggtitle("Density")
ci_plot+dens_plot
Assuming you’re familiar with logistic regression and lasso regression, penalized logistic regression is a natural extension/combination of the two. Just as in lasso regression, by incorporating the L1 norm penalty into the cost function we can perform feature selection and shrinkage of coefficients simultaneously.
Mathematically, Our goal is to minimize the following likelihood function:
\[\min_{(\beta_0, \beta) \in \mathbb{R}^{p+1}} -\left[\frac{1}{N} \sum_{i=1}^N y_i \cdot (\beta_0 + x_i^T \beta) - \log (1+e^{(\beta_0+x_i^T \beta)})\right] + \lambda \big[\|\beta\|_1\big]\]
We set up the framework of the model using tidymodels.
We split the dataset and create cross-validation folds to estimate the optimal penalty.
Terminology note: throughout this notebook, I use “lambda” and “penalty” interchangeably.
set.seed(1)
splits = initial_split(df)
train = training(splits)
test = testing(splits)
folds = vfold_cv(train, v = 10 , strata = diagnosis)
lr_mod = logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
lr_recipe = recipe(diagnosis ~ ., data = train) %>%
step_normalize(all_numeric(), -all_outcomes()) %>%
#The mathematical structure of lasso enforces normalization
step_corr(all_numeric(), -all_outcomes(),threshold = 0.9)
#We drop highly correlated variables
model_train_data = lr_recipe %>% prep() %>% bake(train)
# we are left with 20 predictors of the 30 we had started with due to collinearity
wflow = workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
set.seed(1)
grid = grid_regular(penalty(), levels = 50)
# The grid contains all levels of lambdas we are going check using cv.
lr_res = wflow %>%
tune_grid(folds,
grid = grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc,recall,precision))
#extract the results
top_penalties = lr_res %>%
collect_metrics() %>%
select(-c(.config,n,.estimator)) %>%
mutate(penalty = round(penalty,6)) %>%
rename(cv.mean = mean) %>%
arrange(desc(cv.mean)) %>%
group_by(.metric) %>%
slice(1:3) %>%
arrange(desc(.metric))
#Create a table
top_penalties %>%
ungroup() %>%
select(-c(.metric)) %>%
round(digits = 5) %>%
kable() %>%
kable_classic_2(full_width = F,font_size = 16,html_font = "Cambria") %>%
pack_rows(
index = c("roc_auc" = 3, "Recall" = 3, "Precision" = 3),
label_row_css = "background-color: #666; color: #fff;"
) %>%
row_spec(1, bold = TRUE, underline = TRUE, background = "yellow") %>%
row_spec(0, bold = TRUE, font_size = 16,background = "grey",color = "white",underline = TRUE) %>%
column_spec(c(1:3), border_right = TRUE,border_left = TRUE) %>%
add_header_above(c(" " = 1, "Top Three Lambdas For Each Metric" = 2),
bold = T,
italic = T,
font_size = 16)
|
Top Three Lambdas For Each Metric
|
||
|---|---|---|
| penalty | cv.mean | std_err |
| roc_auc | ||
| 0.00356 | 0.99466 | 0.00347 |
| 0.00569 | 0.99397 | 0.00370 |
| 0.00910 | 0.99185 | 0.00430 |
| Recall | ||
| 0.24420 | 1.00000 | 0.00000 |
| 0.39069 | 1.00000 | 0.00000 |
| 0.62506 | 1.00000 | 0.00000 |
| Precision | ||
| 0.00222 | 0.97179 | 0.00866 |
| 0.00034 | 0.97166 | 0.00868 |
| 0.00054 | 0.97166 | 0.00868 |
best_penalty = top_penalties[7,1]$penalty
There are a few methods and metrics to consider when choosing the optimal penalty - select_by_one_std_err() or select_by_pct_loss(). For our purposes let’s stick to choosing the penalty which maximizes the ROC CURVE metric. In our case that would be \[\lambda \approx 0.003\]
This plot shows lambda’s effect on each metric.
It demonstrates nicely the trade off between recall and precision as lambda increases.
* The vertical line indicates the best penalty (= maximizes roc_curve) determined by cross validation.
* Any lambda chosen from the shaded area would be, in my estimation, a reasonable choice.
lr_res %>%
collect_metrics() %>%
ggplot(aes(penalty,mean,color = .metric))+
geom_point()+
geom_line()+
annotate("rect",xmin=0.001, xmax=0.03,
ymin=0.9, ymax=1,
alpha = 0.2,fill = "orange")+
geom_errorbar(aes(ymin = mean - std_err,ymax = mean + std_err))+
geom_vline(xintercept = best_penalty)+
annotate("text", x = 5*1e-5 , y = 0.925, label ="Best Penalty - \n 0.003")+
scale_x_log10(labels = scales::scientific)+
facet_grid(~.metric ,scales = "free_y")+
ylim(c(0.9,1)) +
theme_light()+
theme(legend.position = "none",
strip.text.x = element_text(size = 10, face = "bold"))
# Let's extract the coefficients for each penalty
fit = wflow %>%
fit(data = train)
coef_by_lambda = fit %>%
pull_workflow_fit() %>%
pluck("fit") %>%
coef(s = grid$penalty)
# coef_by_lambda = fit %>%
# workflowsets::extract_fit_parsnip()%>%
# pluck("fit") %>%
# coef(s = grid$penalty)
# we get a matrix, let's tidy it up
colnames(coef_by_lambda) <- grid$penalty
coef_by_lambda= as.data.frame(as.matrix(coef_by_lambda))
coef_by_lambda = cbind(variable = c("Intercept",colnames(model_train_data[,-21])),coef_by_lambda)
coef_by_lambda = coef_by_lambda %>% pivot_longer(cols = 2:51, names_to = "penalty")
coef_by_lambda$penalty = as.numeric(coef_by_lambda$penalty)
Lasso shrinks the coefficients as lambda increases.
You can obtain this plot easily using the “glmnet” package, but I wanted to make it myself anyway.
There are too many coefficients plotted altogether - facet in the next tab.
coef_by_lambda %>%
ggplot(aes(penalty,value,color = variable))+
scale_x_log10(labels = scales::scientific)+
geom_line()+
geom_point(size = 1)+
geom_vline(xintercept = best_penalty)+
annotate("text",x = best_penalty*5,y = 10,label = "Best Penalty - \n 0.003")+
ylab("coefficient_estimate")+
theme_classic()+
theme(legend.position = "bottom")+
ggtitle("Coefficients Shrink As Lambda Increases")
This plot is a facet of the previous one.
A variable is “Important” if it a gets non-zero coefficient for lambda = 0.003 (the penalty we chose for our model).
zero_at_best = coef_by_lambda %>%
arrange((penalty)) %>%
mutate(diff = abs(best_penalty - penalty)) %>%
arrange((diff),value) %>%
slice(1:21) %>%
mutate(variable_importance = ifelse(round(value,3)==0,"Unimportant","Imortant"))
coef_by_lambda2 = merge(coef_by_lambda,zero_at_best[,-c(2,3)],by = c("variable"))
coef_by_lambda2 %>%
ggplot(aes(penalty,value,color = variable_importance))+
scale_x_log10(labels = scales::scientific)+
geom_line(size = 1.2)+
facet_wrap(~ variable_importance ~ variable,scales = "free_y",ncol = 5)+
geom_hline(yintercept = 0 , color = "purple",size = 1)+
geom_vline(xintercept = best_penalty,size = 1)+
theme_light()+
theme(strip.text = element_text(colour = 'black',size = 10),legend.position = "top",legend.text = element_text(size = 12))
This table shows the coefficients for the best lambda (=0.003 approx). We got 3 negative, 8 positive and 10 which lasso set to zero.
coef_by_lambda %>%
group_by(variable) %>%
slice(38) %>%
ungroup() %>%
select(-penalty) %>%
arrange(value) %>%
kable() %>%
kable_classic(full_width = F,html_font = "Cambria") %>%
kable_styling(bootstrap_options = c("striped", "hover"),full_width = F) %>%
row_spec(0, bold = TRUE, font_size = 16,background = "grey",color = "white",underline = TRUE) %>%
column_spec(1,background = "bisque", color = "steelblue") %>%
scroll_box(width = "60%", height = "400px")
| variable | value |
|---|---|
| compactness_se | -0.8240190 |
| Intercept | -0.2530469 |
| symmetry_se | -0.0721330 |
| compactness_mean | 0.0000000 |
| compactness_worst | 0.0000000 |
| concave.points_se | 0.0000000 |
| concavity_se | 0.0000000 |
| fractal_dimension_mean | 0.0000000 |
| fractal_dimension_se | 0.0000000 |
| fractal_dimension_worst | 0.0000000 |
| smoothness_mean | 0.0000000 |
| symmetry_mean | 0.0000000 |
| texture_se | 0.0000000 |
| smoothness_se | 0.0314942 |
| symmetry_worst | 0.6063408 |
| concavity_worst | 0.6177825 |
| smoothness_worst | 0.6243412 |
| texture_mean | 1.1595181 |
| area_mean | 1.9474129 |
| concave.points_worst | 2.2959850 |
| area_se | 3.4905724 |
Finally, we apply the model to the test set.
Overall, we were able to achieve good results.
wflow = workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
best_lambda = lr_res %>%
select_best("roc_auc")
final_wflow <- finalize_workflow(
wflow,
best_lambda
)
final_fit = last_fit(
final_wflow,
splits
)
preds = final_fit %>% collect_predictions()
final_fit %>%
collect_metrics() %>%
as.data.frame() %>%
select(-.estimator,-.config) %>%
kable() %>%
kable_classic(full_width = F,font_size = 16,html_font = "Cambria") %>%
kable_styling(bootstrap_options = c("striped", "hover"))
| .metric | .estimate |
|---|---|
| accuracy | 0.986014 |
| roc_auc | 0.997144 |
preds %>% conf_mat(truth = diagnosis , estimate = .pred_class)
## Truth
## Prediction B M
## B 86 2
## M 0 55
preds %>% roc_curve(truth = diagnosis , estimate = .pred_B) %>% autoplot()