Classificação de Churn com Tidymodels
Vamos tentar resolver um problema simples de customer churn utilizando uma série de pacotes no R que fazem parte de um universo de pequenos pacotes do tidymodels
. O principal deles é o parsnip
, que unifica uma série de algorítimos diferentes de machine learning, que vão de modelos simples como GLM, Regressão Linear e regression tree até modelos de machine learning mais complexos como Random Forest, XGBoost, lightgbm e outros. Este pacote unifica as sintaxes de várias implementações diferentes, e serve como um sucessor para o caret
.
Outro pacote interessante é o recipes
, que cria pequenas receitas de bolo, com instruções de pré-processamento dos dados, incluindo criação de dummies, normalização de variáveis, remoção de variáveis com variância zero e muitas outras. O pacote rsample
também faz parte deste universo e tem o objetivo de realizar reamostragens dos dados, como o split entre training set e test set e a criação de cross-validation.
Já o pacote yardstick
facilita a criação de medidas de performance dos modelos que serão estimados, produzindo as principais medidas de desempenho, seja para problemas de regressão (RMSE, R-Quadrado e outros), quanto problemas de classificação (matriz de confusão, precisão, acurácia e outros). Por fim, temos os pacotes tune
e dials
que facilitam a busca de hiperparâmetros ideais dos modelos e o pacote workflow
, que é a cola que une todas estas diferentes funções.
O Banco de Dados
Para este projeto vou utilizar uma base de customer churn da Telco. Ela contém 7043 linhas, cada uma representando um cliente, e 20 colunas que representam potenciais preditores, oferecendo informações que podem nos ajudar a prever o comportamento dos clientes e a desenvolver programas focados em retenção de clientes.
library(tidyverse)
library(tidymodels)
library(skimr)
library(knitr)
library(doFuture)
<- read_csv("https://raw.githubusercontent.com/DiegoUsaiUK/Classification_Churn_with_Parsnip/master/00_Data/WA_Fn-UseC_-Telco-Customer-Churn.csv") %>%
df select(-customerID) %>%
drop_na()
%>% head() %>% kable() df
gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | OnlineBackup | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | Yes | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
Male | 0 | No | No | 34 | Yes | No | DSL | Yes | No | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.50 | No |
Male | 0 | No | No | 2 | Yes | No | DSL | Yes | Yes | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | No | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | No | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
Female | 0 | No | No | 8 | Yes | Yes | Fiber optic | No | No | Yes | No | Yes | Yes | Month-to-month | Yes | Electronic check | 99.65 | 820.50 | Yes |
%>% skim() df
Name | Piped data |
Number of rows | 7032 |
Number of columns | 20 |
_______________________ | |
Column type frequency: | |
character | 16 |
numeric | 4 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
Partner | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Dependents | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
PhoneService | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
MultipleLines | 0 | 1 | 2 | 16 | 0 | 3 | 0 |
InternetService | 0 | 1 | 2 | 11 | 0 | 3 | 0 |
OnlineSecurity | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
OnlineBackup | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
DeviceProtection | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
TechSupport | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
StreamingTV | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
StreamingMovies | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
Contract | 0 | 1 | 8 | 14 | 0 | 3 | 0 |
PaperlessBilling | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
PaymentMethod | 0 | 1 | 12 | 25 | 0 | 4 | 0 |
Churn | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
SeniorCitizen | 0 | 1 | 0.16 | 0.37 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▂ |
tenure | 0 | 1 | 32.42 | 24.55 | 1.00 | 9.00 | 29.00 | 55.00 | 72.00 | ▇▃▃▃▅ |
MonthlyCharges | 0 | 1 | 64.80 | 30.09 | 18.25 | 35.59 | 70.35 | 89.86 | 118.75 | ▇▅▆▇▅ |
TotalCharges | 0 | 1 | 2283.30 | 2266.77 | 18.80 | 401.45 | 1397.47 | 3794.74 | 8684.80 | ▇▂▂▂▁ |
O banco de dados inclui 20 variáveis, sendo 4 numéricas e 16 categóricas. Churn
é a nossa variável de interesse, e indica se o cliente abandonou ou não a empresa. Como a ideia é apresentar os procedimentos do tidymodels
, não irei passar muito tempo analisando o banco de dados.
Definindo o Conjunto de Treinamento e Teste: rsample
O pacote rsample
fornece uma série de funções que tornam simples o processo de gerar reamostragens aleatórias como os bancos de dados de treinamento e teste. Vamos usar a função initial_split()
para gerar uma base de treinamento com 80% das informações do banco de dados original.
set.seed(123)
<- initial_split(data = df, prop = 0.8)
train_test_split <- train_test_split %>% training()
train_tbl <- train_test_split %>% testing()
test_tbl
train_test_split
## <Analysis/Assess/Total>
## <5626/1406/7032>
Do total de 7042 clientes, 5626 foram atribuídos ao conjunto de treinamento e 1406 para o conjunto de teste.
Pré-processando dados: recipes
O pacote recipes
utiliza uma série de funções para lidar com o pré-processamento dos dados, como a imputação de valores missing, padronização de variáveis, criação de dummies e mais. O primeiro passo é informar a formula utilizada na receita, no nosso caso Churn
como nossa variável dependente e as demais como preditores. Esta formula tem o objetivo de permitir o acesso a uma série de funções úteis e convenientes como all_predictors()
, que permite aplicar transformações sobre todos os preditores, ou all_outcomes()
, que permite transformações sobre todos os outcomes.
Muitas outras funções estão incluidas no pacote recipe
, como ímputação por média step_meanimpute
e por regressão linear step_impute_linear()
, as transformações de variável como log (step_log
) e quadrática (step_sqrt
), a discretizaçã de variáveis numéricas (step_cut
e outras), gerador de variáveis de data com step_date
(que constroi variáveis de dia, mês, ano) e muitas outras. Uma lista de todas as transformações possíveis pode ser obtida aqui.
<- recipe(Churn ~ ., data = train_tbl) %>%
receita_simples step_dummy(all_nominal(), -all_outcomes()) %>%
step_normalize(all_numeric())
Para exibir como o banco de dados ficou após as transformações, Chamamos prep()
, que aplica a receita ao banco de dados e utilizar juice()
para extrair o banco de dados transformado.
juice(prep(receita_simples)) %>% head() %>% kable()
SeniorCitizen | tenure | MonthlyCharges | TotalCharges | Churn | gender_Male | Partner_Yes | Dependents_Yes | PhoneService_Yes | MultipleLines_No.phone.service | MultipleLines_Yes | InternetService_Fiber.optic | InternetService_No | OnlineSecurity_No.internet.service | OnlineSecurity_Yes | OnlineBackup_No.internet.service | OnlineBackup_Yes | DeviceProtection_No.internet.service | DeviceProtection_Yes | TechSupport_No.internet.service | TechSupport_Yes | StreamingTV_No.internet.service | StreamingTV_Yes | StreamingMovies_No.internet.service | StreamingMovies_Yes | Contract_One.year | Contract_Two.year | PaperlessBilling_Yes | PaymentMethod_Credit.card..automatic. | PaymentMethod_Electronic.check | PaymentMethod_Mailed.check |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
-0.4409595 | 0.0613097 | -0.2635753 | -0.1776291 | No | 0.9956547 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | 1.9315911 | -0.5593716 | -1.2055272 | -0.5291892 | -0.7075153 | 1.8231223 |
-0.4409595 | -1.2423051 | -0.3664940 | -0.9592242 | Yes | 0.9956547 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | 1.3771987 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | -0.7075153 | 1.8231223 |
-0.4409595 | 0.5094273 | -0.7499488 | -0.1990189 | No | 0.9956547 | -0.9656622 | -0.6538199 | -3.1070222 | 3.1070222 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | 1.5730685 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | 1.9315911 | -0.5593716 | -1.2055272 | -0.5291892 | -0.7075153 | -0.5484121 |
-0.4409595 | -1.2423051 | 0.1929186 | -0.9401379 | Yes | -1.0041858 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | -0.7259826 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | 1.4131458 | -0.5484121 |
-0.4409595 | -0.9978774 | 1.1540457 | -0.6466695 | Yes | -1.0041858 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | 1.1690193 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | -0.6355872 | -0.5258861 | 1.2685330 | -0.5258861 | 1.2548101 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | 1.4131458 | -0.5484121 |
-0.4409595 | -0.4275459 | 0.8037904 | -0.1513471 | No | 0.9956547 | -0.9656622 | 1.5292013 | 0.3217944 | -0.3217944 | 1.1690193 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | 1.3771987 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | 1.2685330 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | 1.8893474 | -0.7075153 | -0.5484121 |
Ajustando modelos: parsnip
O pacote parsnip
realiza o trabalho de unificar uma série de diferentes modelos estatísticos e de machine learning em um único local. O pacote é extremamente conveniente porque temos uma única forma de se comunicar com diferentes modelos que inicialmente possuiam sintaxes totalmente diferentes ou exigiam dados em diferentes formatos (as.matrix
, ts
, data.frame
).
Para utilizar o parsnip
, sempre começamos definindo o modelo com uma função específica do algorítimo. Assim, para estimar uma regressão linear utilizamos a função linear_reg()
e para estimar um random forest utilizamos a função rand_forest()
.
Definido o algorítimo, devemos informar a implementação do algorítimo, que é definida pela função set_engine()
. Assim, ao rodar um modelo de Random Forest podemos utilizar a implementação do pacote randomForest
, com set_engine("randomForest")
, a implementação do pacote ranger
com set_engine("ranger")
ou do spark com set_engine("spark")
. Em cada um dos pacotes, os hiperparâmetros são nomeados de maneira diferente, como n.trees
no ranger e trees
no randomForest. O parsnip
unifica essas sintaxes.
Para nossos dados, podemos estimar um modelo de Random Forest. Nenhuma razão específica nesta escolha, a não ser que ele é um modelo bastante popular para o problema de classificação que temos.
Assim, para ajustar o modelo de random forest aos dados, definimos o modelo com a função rand_forest()
especificando os hiperparâmetros do modelo: mtry
e trees
. Em um primeiro momento, vamos definir estes hiperparâmetros de maneira fixa.
<-
rf_model rand_forest(mtry = 3, trees = 200) %>%
set_mode("classification") %>%
set_engine("ranger")
rf_model
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 3
## trees = 200
##
## Computational engine: ranger
Combinando tudo: workflows
Com o pacote workflows
temos a oportunidade de unir todos os passos da construção de um modelo: podemos adicionar receitas, adicionar um modelo e por fim estimar o modelo definido sobre os dados tratados pela receita. Assim, iniciamos o workflow com workflow()
, adicionamos o modelo definido acima com add_model
e a receita que deve ser aplicada aos dados com add_recipe
. O último passo é o fit
, onde passamos a base de treinamento construida pelo pacote rsample
para que o random forest seja estimado.
<- workflow() %>%
rf_wflow add_model(rf_model) %>%
add_recipe(receita_simples) %>%
fit(training(train_test_split))
rf_wflow
## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_dummy()
## * step_normalize()
##
## -- Model -----------------------------------------------------------------------
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~3, x), num.trees = ~200, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 200
## Sample size: 5626
## Number of independent variables: 30
## Mtry: 3
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.1364033
Assim, é possível observar que o modelo foi ajustado para as 5626 observações do conjunto de treinamento. O número de variáveis independentes é igual a 30, como resultado da aplicação da receita sobre as variáveis do banco de dados original. Em Call observamos que nossos dados foram aplicados na função ranger::ranger()
, como definido no set_engine()
.
Avaliando a Performance do Modelo: yardstick
Podemos calibrar o modelo no conjunto de teste e construir medidas de performance do modelo a partir da comparação dos valores previstos e dos valores verdadeiros. Para tanto utilizamos o pacote yardstick
e sua função conf_mat()
para a construção da matriz de confusão. Para problemas de regressão é possível utilizar a função metrics()
, que retorna medidas como RMSE, MAE e R-Quadrado.
Portando, utilizamos a função predict
para produzir o vetor de valores previstos e unimos com os dados observamos na base de teste. A função conf_mat
precisa dos valores observados e previstos.
<- rf_wflow %>%
matriz_confusao predict(new_data = testing(train_test_split)) %>%
bind_cols(testing(train_test_split) %>% select(Churn)) %>%
mutate_all(as.factor) %>%
conf_mat(Churn, .pred_class)
%>%
matriz_confusao pluck(1) %>%
as_tibble() %>%
ggplot(aes(Prediction, Truth, alpha = n)) +
geom_tile(show.legend = FALSE) +
geom_text(aes(label = n), color = 'white', alpha = 1, size = 8)
A matriz de confusão mostra que o modelo erroneamente previu como negativo 187 clientes que iriam sair da empresa e previu incorretamente que 82 clientes iriam sair da empresa quando na verdade eles permaneceram como clientes.
A partir da matriz de confusão podemos estimar algumas medidas de performance, como acurácia, precisão e recall.
%>%
matriz_confusao summary() %>%
select(-.estimator) %>%
filter(.metric %in% c('precision', 'recall', 'f_meas',
'accuracy', 'spec', 'sens')) %>%
kable()
.metric | .estimate |
---|---|
accuracy | 0.7880512 |
sens | 0.9053537 |
spec | 0.4472222 |
precision | 0.8263525 |
recall | 0.9053537 |
f_meas | 0.8640511 |
A acurácia do modelo é a fração de previsões que o modelo fez corretamente. Contudo, acurácia não é uma medida muito confiável como ela pode produzir resultados incorretos se a base de dados for muito desbalanceada. No caso, nosso modelo random Forest produziu uma acurácia de 78%. Contudo, 73% das observações no banco de dados são negativas.
Já a precisão mostra o quanto o modelo é sensível a falsos positivos (FP), como prever que um cliente está abandonando a empresa quando na verdade ele vai permanecer. O recall procura mostrar o quão sensível o modelo é a falsos negativos (FN), como prever que o cliente irá permanecer quando na verdade ele está saindo. Estas duas medidas são mais relevantes no contexto de uma empresa, dado que a firma está interessada em prever de maneira precisa quais clientes realmente estão em risco de abandonar a empresa. A previsão correta garante que a empresa pode conduzir estratégias de retenção com estes clientes, ao mesmo tempo que minimiza a possibilidade de aplicar esforços de retenção sob clientes falsos positivos.
Outra medida popular é o F1 Score, que é uma média harmônica da precisão e do recall. Um F1 score obtém seu melhor valor em 1 quando se tem um recall e precisão perfeitos. O nosso modelo obteve um F1 de 0,86. Por fim, temos a medida de UAC-ROC, que vamos utilizar na próxima seção para escolher o melhor modelo.
Tuning e Cross-validation: tune
e dials
Em vez de utilizar um valor padrão para os hiperparâmetros do modelo, podemos ajustar diferentes valores utilizando a função tune()
do pacote tune
. Assim, deixamos que o cross-validation decida o melhor valor dos hiperparâmetros do nosso modelo.
Assim, vamos utilizar a cross-validation para encontrar o melhor valor de hiperparâmetro. Podemos utilizar otimização bayesiana para problemas complexos de tuning com o uso da função tune_bayes()
. Contudo, para problemas mais simples como este, um grid search é suficiente. Assim, vamos construir uma combinação de valores dos hiperparâmetros, de modo que em cada fold do cross-validation serão estimados múltiplos modelos com diferentes hiperparâmetros.
Antes disto, precisamos construir nossa base de cross-validation. O pacote rsample
permite construir um objetivo de cross-validation com n
folds. No nosso exemplo, vamos criar 4 folds diferentes. Em cada fold temos uma base de análise (equivalente ao treinamento) e de assessment (equivalente ao teste). Por fim, note que em todos os folds a proporção de churns é bem parecida.
<- vfold_cv(df, v = 4)
cv_folds
map_dbl(cv_folds$splits, ~ mean(as.data.frame(.x)$Churn == "Yes"))
## [1] 0.2679181 0.2654532 0.2603337 0.2694350
Agora precisamos definir novamente nosso modelo, mas agora sem valores fixos para mtry
e trees
. Ao invés disto, vamos sinalizar com o uso de tune()
que diferentes valores dos hiperparâmetros devem ser utilizados. Para tanto, podemos atualizar o nosso modelo anterior substituindo mtry = 3
e trees = 200
por tune()
.
<-
rf_model_tune rand_forest(mtry = tune(), trees = tune()) %>%
set_mode("classification") %>%
set_engine("ranger")
rf_model_tune
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = tune()
##
## Computational engine: ranger
Vamos criar um workflow diferente, agora com um modelo que aceita hiperparâmetros não fixos
<- workflow() %>%
rf_wflow_tune add_model(rf_model_tune) %>%
add_recipe(receita_simples)
E atualizar nosso modelo com os valores máximos e mínimos dos hiperparâmetros. utilizando as funções parameters
e update
do pacote dials
. Para mtry
, a documentação do randomForest informa que este parâmetro assume valores entre 1 e o número de colunas. Assim, vamos definir seus valores mínimos e maximos como 1 e 30. trees
é um hiperparâmetro quantitativo sem relação com a base de dados, seus valores máximos e mínimos já são definidos por padrão como 1 e 2000.
<-
rf_param %>%
rf_wflow_tune parameters() %>%
update(mtry = mtry(range = c(1L, 30L)),
trees = trees())
$object rf_param
## [[1]]
## # Randomly Selected Predictors (quantitative)
## Range: [1, 30]
##
## [[2]]
## # Trees (quantitative)
## Range: [1, 2000]
Como definimos o valor mínimo e máximo para mtry
como 1L e 30L e para trees
como 1L e 2000L, a função grid_regular()
constroi uma combinação de valores possíveis para os dois hiperparâmetros. Estes valores poderiam ser definidos manualmente com o uso de uma função como expand.grid
, mas grid_regular
torna esta tarefa mais conveniente. Assim, passamos o argumento levels = 3
que indica que serão utilizados 3 valores igualmente esparçados dos dois hiperparâmetros. Contudo, se a intenção for produzir valores aleatórios do grid, podemos utilizar random_grid()
.
<- grid_regular(rf_param, levels = 3)
rf_grid rf_grid
## # A tibble: 9 x 2
## mtry trees
## <int> <int>
## 1 1 1
## 2 15 1
## 3 30 1
## 4 1 1000
## 5 15 1000
## 6 30 1000
## 7 1 2000
## 8 15 2000
## 9 30 2000
A realização de cross-validation e do tuning de hiperparâmetros aumenta enormente o número de modelos que serão ajustados. Aqui temos 4 folds e 9 valores possíveis de hiperparâmetros, totalizando 36 modelos. Para aumentar a velocidade com que o processo de estimação ocorre, podemos estimar os modelos em paralelo com o uso do pacote doFuture
. Para tanto, será necessário a instalações de alguns pacotes adicionais como Rmpi
e snow
, que dependem da instalação do Microsoft MPI que pode ser encontrado aqui. Caso ocorra dificuldades na instalação dos pacotes acima, é possível ignorar os códigos abaixo. Neste caso, os modelos podem ser ajustados serialmente, com o único custo sendo um tempo maior de processamento.
<- parallel::detectCores(logical = FALSE) - 1
all_cores registerDoFuture()
<- makeClusterPSOCK(all_cores)
cl plan(future::cluster, workers = cl)
Finalmente podemos iniciar o trabalho de tuning dos hiperparâmetros utilizando tune_grid()
, que precisa do novo workflow (rf_wflow_tune
), dos valores possíveis dos hiperparâmetros (rf_grid
) e do objeto indicando quais são os folds do cross-validation (cv_folds
).
<- tune_grid(rf_wflow_tune, grid = rf_grid, resamples = cv_folds,
rf_search param_info = rf_param)
rf_search
## # Tuning results
## # 4-fold cross-validation
## # A tibble: 4 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5274/1758]> Fold1 <tibble [18 x 6]> <tibble [0 x 1]>
## 2 <split [5274/1758]> Fold2 <tibble [18 x 6]> <tibble [0 x 1]>
## 3 <split [5274/1758]> Fold3 <tibble [18 x 6]> <tibble [0 x 1]>
## 4 <split [5274/1758]> Fold4 <tibble [18 x 6]> <tibble [0 x 1]>
Assim, alguns minutos depois, os 36 modelos foram estimados. Podemos calcular uma média de cada medida de performance para as 9 combinações diferentes de hiperparâmetros e listar os modelos com melhores performance segundo o ROC-UAC com a função show_best()
.
show_best(rf_search, n = 5, metric = 'roc_auc') %>% kable()
mtry | trees | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|---|
15 | 2000 | roc_auc | binary | 0.8311273 | 4 | 0.0070299 | Preprocessor1_Model8 |
15 | 1000 | roc_auc | binary | 0.8307947 | 4 | 0.0069385 | Preprocessor1_Model5 |
1 | 2000 | roc_auc | binary | 0.8294619 | 4 | 0.0066992 | Preprocessor1_Model7 |
1 | 1000 | roc_auc | binary | 0.8291777 | 4 | 0.0072042 | Preprocessor1_Model4 |
30 | 2000 | roc_auc | binary | 0.8247092 | 4 | 0.0076831 | Preprocessor1_Model9 |
Assim, o modelo com melhor ROC-AUC foi o Random Forest com mtry = 15
e trees = 2000
. Vamos utiliza-lo para realizar nossa previsão final.
Previsão Final
Agora que sabemos qual o melhor conjunto de hiperparâmetros, podemos ajustar o modelo em toda a base de treinamento. Usamos a função select_best
para retornar o melhor modelo, como definido acima.
<- select_best(rf_search, degree, metric = 'roc_auc')
rf_param_final %>% kable() rf_param_final
mtry | trees | .config |
---|---|---|
15 | 2000 | Preprocessor1_Model8 |
Usamos finalize_workflow()
para criar um workflow atualizado com os valores de hiperparâmetros do melhor modelo. Isso equivale a construir um novo workflow com mtry = 15
e trees = 2000
. Por fim, podemos ajustar o modelo utilizando toda a base de treinamento, os folds do cross-validation já realizaram seu trabalho.
<- rf_wflow_tune %>%
rf_wflow_final_fit finalize_workflow(rf_param_final) %>%
fit(training(train_test_split))
rf_wflow_final_fit
## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_dummy()
## * step_normalize()
##
## -- Model -----------------------------------------------------------------------
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~15L, x), num.trees = ~2000L, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 2000
## Sample size: 5626
## Number of independent variables: 30
## Mtry: 15
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.141572
E realizar a previsão do modelo, calculando todas as medidas de performance.
%>% predict(new_data = testing(train_test_split)) %>%
rf_wflow_final_fit bind_cols(testing(train_test_split) %>% select(Churn)) %>%
mutate_all(as.factor) %>%
conf_mat(Churn, .pred_class) %>%
summary() %>%
select(-.estimator) %>%
filter(.metric %in% c('precision', 'recall', 'f_meas',
'accuracy', 'spec', 'sens')) %>%
kable()
.metric | .estimate |
---|---|
accuracy | 0.7887624 |
sens | 0.8862333 |
spec | 0.5055556 |
precision | 0.8389140 |
recall | 0.8862333 |
f_meas | 0.8619247 |
É possível observar que o modelo de random forest com parâmetros ajustados não excedeu em muito o que já havia sido obtido no nosso primeiro modelo. Existe muito espaço para melhoramento, sobretudo em relação ao processo de feature engineering aplicado com o uso de recipes
.
Conclusão
O tidymodels
e seus pacotes tornam muito simples a tarefa de aplicar modelos de machine learning para bases de dados. Desde o processamento de bases de dados com recipes
, a produção de reamostragens com o rsample
, a estimação de modelos de machine learning com o parsnip
, a procurar hiperparâmetros com tune
e a criação workflows com workflow
.
Para além disto, o tidymodels
oferece soluções para compararação de diferentes modelos utilizando o pacote workflowsets
, tornando simples a criação e o ajustar de um conjunto grande de modelos, sendo necessário apenas uma lista de modelos, que compartilham a mesma receita e a mesma base de treinamento. O tidymodels
também permite a criação de assembly, seja por stacks com o uso do pacote stacks
ou por bagging, com o pacote baguette
, de modo a construir classificadores e previsores melhores a partir da união de diferentes algorítimos.
Em um texto futuro, podemos analisar estas ferramentas.