Classificação de Churn com Tidymodels
Vamos tentar resolver um problema simples de customer churn utilizando uma série de pacotes no R que fazem parte de um universo de pequenos pacotes do tidymodels. O principal deles é o parsnip, que unifica uma série de algorítimos diferentes de machine learning, que vão de modelos simples como GLM, Regressão Linear e regression tree até modelos de machine learning mais complexos como Random Forest, XGBoost, lightgbm e outros. Este pacote unifica as sintaxes de várias implementações diferentes, e serve como um sucessor para o caret.
Outro pacote interessante é o recipes, que cria pequenas receitas de bolo, com instruções de pré-processamento dos dados, incluindo criação de dummies, normalização de variáveis, remoção de variáveis com variância zero e muitas outras. O pacote rsample também faz parte deste universo e tem o objetivo de realizar reamostragens dos dados, como o split entre training set e test set e a criação de cross-validation.
Já o pacote yardstick facilita a criação de medidas de performance dos modelos que serão estimados, produzindo as principais medidas de desempenho, seja para problemas de regressão (RMSE, R-Quadrado e outros), quanto problemas de classificação (matriz de confusão, precisão, acurácia e outros). Por fim, temos os pacotes tune e dials que facilitam a busca de hiperparâmetros ideais dos modelos e o pacote workflow, que é a cola que une todas estas diferentes funções.
O Banco de Dados
Para este projeto vou utilizar uma base de customer churn da Telco. Ela contém 7043 linhas, cada uma representando um cliente, e 20 colunas que representam potenciais preditores, oferecendo informações que podem nos ajudar a prever o comportamento dos clientes e a desenvolver programas focados em retenção de clientes.
library(tidyverse)
library(tidymodels)
library(skimr)
library(knitr)
library(doFuture)
df <- read_csv("https://raw.githubusercontent.com/DiegoUsaiUK/Classification_Churn_with_Parsnip/master/00_Data/WA_Fn-UseC_-Telco-Customer-Churn.csv") %>%
select(-customerID) %>%
drop_na()
df %>% head() %>% kable()| gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | OnlineBackup | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | Yes | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| Male | 0 | No | No | 34 | Yes | No | DSL | Yes | No | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.50 | No |
| Male | 0 | No | No | 2 | Yes | No | DSL | Yes | Yes | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | No | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | No | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
| Female | 0 | No | No | 8 | Yes | Yes | Fiber optic | No | No | Yes | No | Yes | Yes | Month-to-month | Yes | Electronic check | 99.65 | 820.50 | Yes |
df %>% skim()| Name | Piped data |
| Number of rows | 7032 |
| Number of columns | 20 |
| _______________________ | |
| Column type frequency: | |
| character | 16 |
| numeric | 4 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| gender | 0 | 1 | 4 | 6 | 0 | 2 | 0 |
| Partner | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| Dependents | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| PhoneService | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| MultipleLines | 0 | 1 | 2 | 16 | 0 | 3 | 0 |
| InternetService | 0 | 1 | 2 | 11 | 0 | 3 | 0 |
| OnlineSecurity | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| OnlineBackup | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| DeviceProtection | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| TechSupport | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| StreamingTV | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| StreamingMovies | 0 | 1 | 2 | 19 | 0 | 3 | 0 |
| Contract | 0 | 1 | 8 | 14 | 0 | 3 | 0 |
| PaperlessBilling | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
| PaymentMethod | 0 | 1 | 12 | 25 | 0 | 4 | 0 |
| Churn | 0 | 1 | 2 | 3 | 0 | 2 | 0 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| SeniorCitizen | 0 | 1 | 0.16 | 0.37 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▂ |
| tenure | 0 | 1 | 32.42 | 24.55 | 1.00 | 9.00 | 29.00 | 55.00 | 72.00 | ▇▃▃▃▅ |
| MonthlyCharges | 0 | 1 | 64.80 | 30.09 | 18.25 | 35.59 | 70.35 | 89.86 | 118.75 | ▇▅▆▇▅ |
| TotalCharges | 0 | 1 | 2283.30 | 2266.77 | 18.80 | 401.45 | 1397.47 | 3794.74 | 8684.80 | ▇▂▂▂▁ |
O banco de dados inclui 20 variáveis, sendo 4 numéricas e 16 categóricas. Churn é a nossa variável de interesse, e indica se o cliente abandonou ou não a empresa. Como a ideia é apresentar os procedimentos do tidymodels, não irei passar muito tempo analisando o banco de dados.
Definindo o Conjunto de Treinamento e Teste: rsample
O pacote rsample fornece uma série de funções que tornam simples o processo de gerar reamostragens aleatórias como os bancos de dados de treinamento e teste. Vamos usar a função initial_split() para gerar uma base de treinamento com 80% das informações do banco de dados original.
set.seed(123)
train_test_split <- initial_split(data = df, prop = 0.8)
train_tbl <- train_test_split %>% training()
test_tbl <- train_test_split %>% testing()
train_test_split## <Analysis/Assess/Total>
## <5626/1406/7032>
Do total de 7042 clientes, 5626 foram atribuídos ao conjunto de treinamento e 1406 para o conjunto de teste.
Pré-processando dados: recipes
O pacote recipes utiliza uma série de funções para lidar com o pré-processamento dos dados, como a imputação de valores missing, padronização de variáveis, criação de dummies e mais. O primeiro passo é informar a formula utilizada na receita, no nosso caso Churn como nossa variável dependente e as demais como preditores. Esta formula tem o objetivo de permitir o acesso a uma série de funções úteis e convenientes como all_predictors(), que permite aplicar transformações sobre todos os preditores, ou all_outcomes(), que permite transformações sobre todos os outcomes.
Muitas outras funções estão incluidas no pacote recipe, como ímputação por média step_meanimpute e por regressão linear step_impute_linear(), as transformações de variável como log (step_log) e quadrática (step_sqrt), a discretizaçã de variáveis numéricas (step_cut e outras), gerador de variáveis de data com step_date (que constroi variáveis de dia, mês, ano) e muitas outras. Uma lista de todas as transformações possíveis pode ser obtida aqui.
receita_simples <- recipe(Churn ~ ., data = train_tbl) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_normalize(all_numeric())Para exibir como o banco de dados ficou após as transformações, Chamamos prep(), que aplica a receita ao banco de dados e utilizar juice() para extrair o banco de dados transformado.
juice(prep(receita_simples)) %>% head() %>% kable()| SeniorCitizen | tenure | MonthlyCharges | TotalCharges | Churn | gender_Male | Partner_Yes | Dependents_Yes | PhoneService_Yes | MultipleLines_No.phone.service | MultipleLines_Yes | InternetService_Fiber.optic | InternetService_No | OnlineSecurity_No.internet.service | OnlineSecurity_Yes | OnlineBackup_No.internet.service | OnlineBackup_Yes | DeviceProtection_No.internet.service | DeviceProtection_Yes | TechSupport_No.internet.service | TechSupport_Yes | StreamingTV_No.internet.service | StreamingTV_Yes | StreamingMovies_No.internet.service | StreamingMovies_Yes | Contract_One.year | Contract_Two.year | PaperlessBilling_Yes | PaymentMethod_Credit.card..automatic. | PaymentMethod_Electronic.check | PaymentMethod_Mailed.check |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| -0.4409595 | 0.0613097 | -0.2635753 | -0.1776291 | No | 0.9956547 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | 1.9315911 | -0.5593716 | -1.2055272 | -0.5291892 | -0.7075153 | 1.8231223 |
| -0.4409595 | -1.2423051 | -0.3664940 | -0.9592242 | Yes | 0.9956547 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | 1.3771987 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | -0.7075153 | 1.8231223 |
| -0.4409595 | 0.5094273 | -0.7499488 | -0.1990189 | No | 0.9956547 | -0.9656622 | -0.6538199 | -3.1070222 | 3.1070222 | -0.8552659 | -0.8925964 | -0.5258861 | -0.5258861 | 1.5812934 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | 1.5730685 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | 1.9315911 | -0.5593716 | -1.2055272 | -0.5291892 | -0.7075153 | -0.5484121 |
| -0.4409595 | -1.2423051 | 0.1929186 | -0.9401379 | Yes | -1.0041858 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | -0.8552659 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | -0.7259826 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | -0.7881721 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | 1.4131458 | -0.5484121 |
| -0.4409595 | -0.9978774 | 1.1540457 | -0.6466695 | Yes | -1.0041858 | -0.9656622 | -0.6538199 | 0.3217944 | -0.3217944 | 1.1690193 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | -0.7259826 | -0.5258861 | 1.393601 | -0.5258861 | -0.6355872 | -0.5258861 | 1.2685330 | -0.5258861 | 1.2548101 | -0.5176159 | -0.5593716 | 0.8293651 | -0.5291892 | 1.4131458 | -0.5484121 |
| -0.4409595 | -0.4275459 | 0.8037904 | -0.1513471 | No | 0.9956547 | -0.9656622 | 1.5292013 | 0.3217944 | -0.3217944 | 1.1690193 | 1.1201280 | -0.5258861 | -0.5258861 | -0.6322813 | -0.5258861 | 1.3771987 | -0.5258861 | -0.717438 | -0.5258861 | -0.6355872 | -0.5258861 | 1.2685330 | -0.5258861 | -0.7967917 | -0.5176159 | -0.5593716 | 0.8293651 | 1.8893474 | -0.7075153 | -0.5484121 |
Ajustando modelos: parsnip
O pacote parsnip realiza o trabalho de unificar uma série de diferentes modelos estatísticos e de machine learning em um único local. O pacote é extremamente conveniente porque temos uma única forma de se comunicar com diferentes modelos que inicialmente possuiam sintaxes totalmente diferentes ou exigiam dados em diferentes formatos (as.matrix, ts, data.frame).
Para utilizar o parsnip, sempre começamos definindo o modelo com uma função específica do algorítimo. Assim, para estimar uma regressão linear utilizamos a função linear_reg() e para estimar um random forest utilizamos a função rand_forest().
Definido o algorítimo, devemos informar a implementação do algorítimo, que é definida pela função set_engine(). Assim, ao rodar um modelo de Random Forest podemos utilizar a implementação do pacote randomForest, com set_engine("randomForest"), a implementação do pacote ranger com set_engine("ranger") ou do spark com set_engine("spark"). Em cada um dos pacotes, os hiperparâmetros são nomeados de maneira diferente, como n.trees no ranger e trees no randomForest. O parsnip unifica essas sintaxes.
Para nossos dados, podemos estimar um modelo de Random Forest. Nenhuma razão específica nesta escolha, a não ser que ele é um modelo bastante popular para o problema de classificação que temos.
Assim, para ajustar o modelo de random forest aos dados, definimos o modelo com a função rand_forest() especificando os hiperparâmetros do modelo: mtry e trees. Em um primeiro momento, vamos definir estes hiperparâmetros de maneira fixa.
rf_model <-
rand_forest(mtry = 3, trees = 200) %>%
set_mode("classification") %>%
set_engine("ranger")
rf_model## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 3
## trees = 200
##
## Computational engine: ranger
Combinando tudo: workflows
Com o pacote workflows temos a oportunidade de unir todos os passos da construção de um modelo: podemos adicionar receitas, adicionar um modelo e por fim estimar o modelo definido sobre os dados tratados pela receita. Assim, iniciamos o workflow com workflow(), adicionamos o modelo definido acima com add_model e a receita que deve ser aplicada aos dados com add_recipe. O último passo é o fit, onde passamos a base de treinamento construida pelo pacote rsample para que o random forest seja estimado.
rf_wflow <- workflow() %>%
add_model(rf_model) %>%
add_recipe(receita_simples) %>%
fit(training(train_test_split))
rf_wflow## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_dummy()
## * step_normalize()
##
## -- Model -----------------------------------------------------------------------
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~3, x), num.trees = ~200, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 200
## Sample size: 5626
## Number of independent variables: 30
## Mtry: 3
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.1364033
Assim, é possível observar que o modelo foi ajustado para as 5626 observações do conjunto de treinamento. O número de variáveis independentes é igual a 30, como resultado da aplicação da receita sobre as variáveis do banco de dados original. Em Call observamos que nossos dados foram aplicados na função ranger::ranger(), como definido no set_engine().
Avaliando a Performance do Modelo: yardstick
Podemos calibrar o modelo no conjunto de teste e construir medidas de performance do modelo a partir da comparação dos valores previstos e dos valores verdadeiros. Para tanto utilizamos o pacote yardstick e sua função conf_mat() para a construção da matriz de confusão. Para problemas de regressão é possível utilizar a função metrics(), que retorna medidas como RMSE, MAE e R-Quadrado.
Portando, utilizamos a função predict para produzir o vetor de valores previstos e unimos com os dados observamos na base de teste. A função conf_mat precisa dos valores observados e previstos.
matriz_confusao <- rf_wflow %>%
predict(new_data = testing(train_test_split)) %>%
bind_cols(testing(train_test_split) %>% select(Churn)) %>%
mutate_all(as.factor) %>%
conf_mat(Churn, .pred_class)
matriz_confusao %>%
pluck(1) %>%
as_tibble() %>%
ggplot(aes(Prediction, Truth, alpha = n)) +
geom_tile(show.legend = FALSE) +
geom_text(aes(label = n), color = 'white', alpha = 1, size = 8)A matriz de confusão mostra que o modelo erroneamente previu como negativo 187 clientes que iriam sair da empresa e previu incorretamente que 82 clientes iriam sair da empresa quando na verdade eles permaneceram como clientes.
A partir da matriz de confusão podemos estimar algumas medidas de performance, como acurácia, precisão e recall.
matriz_confusao %>%
summary() %>%
select(-.estimator) %>%
filter(.metric %in% c('precision', 'recall', 'f_meas',
'accuracy', 'spec', 'sens')) %>%
kable()| .metric | .estimate |
|---|---|
| accuracy | 0.7880512 |
| sens | 0.9053537 |
| spec | 0.4472222 |
| precision | 0.8263525 |
| recall | 0.9053537 |
| f_meas | 0.8640511 |
A acurácia do modelo é a fração de previsões que o modelo fez corretamente. Contudo, acurácia não é uma medida muito confiável como ela pode produzir resultados incorretos se a base de dados for muito desbalanceada. No caso, nosso modelo random Forest produziu uma acurácia de 78%. Contudo, 73% das observações no banco de dados são negativas.
Já a precisão mostra o quanto o modelo é sensível a falsos positivos (FP), como prever que um cliente está abandonando a empresa quando na verdade ele vai permanecer. O recall procura mostrar o quão sensível o modelo é a falsos negativos (FN), como prever que o cliente irá permanecer quando na verdade ele está saindo. Estas duas medidas são mais relevantes no contexto de uma empresa, dado que a firma está interessada em prever de maneira precisa quais clientes realmente estão em risco de abandonar a empresa. A previsão correta garante que a empresa pode conduzir estratégias de retenção com estes clientes, ao mesmo tempo que minimiza a possibilidade de aplicar esforços de retenção sob clientes falsos positivos.
Outra medida popular é o F1 Score, que é uma média harmônica da precisão e do recall. Um F1 score obtém seu melhor valor em 1 quando se tem um recall e precisão perfeitos. O nosso modelo obteve um F1 de 0,86. Por fim, temos a medida de UAC-ROC, que vamos utilizar na próxima seção para escolher o melhor modelo.
Tuning e Cross-validation: tune e dials
Em vez de utilizar um valor padrão para os hiperparâmetros do modelo, podemos ajustar diferentes valores utilizando a função tune() do pacote tune. Assim, deixamos que o cross-validation decida o melhor valor dos hiperparâmetros do nosso modelo.
Assim, vamos utilizar a cross-validation para encontrar o melhor valor de hiperparâmetro. Podemos utilizar otimização bayesiana para problemas complexos de tuning com o uso da função tune_bayes(). Contudo, para problemas mais simples como este, um grid search é suficiente. Assim, vamos construir uma combinação de valores dos hiperparâmetros, de modo que em cada fold do cross-validation serão estimados múltiplos modelos com diferentes hiperparâmetros.
Antes disto, precisamos construir nossa base de cross-validation. O pacote rsample permite construir um objetivo de cross-validation com n folds. No nosso exemplo, vamos criar 4 folds diferentes. Em cada fold temos uma base de análise (equivalente ao treinamento) e de assessment (equivalente ao teste). Por fim, note que em todos os folds a proporção de churns é bem parecida.
cv_folds <- vfold_cv(df, v = 4)
map_dbl(cv_folds$splits, ~ mean(as.data.frame(.x)$Churn == "Yes"))## [1] 0.2679181 0.2654532 0.2603337 0.2694350
Agora precisamos definir novamente nosso modelo, mas agora sem valores fixos para mtry e trees. Ao invés disto, vamos sinalizar com o uso de tune() que diferentes valores dos hiperparâmetros devem ser utilizados. Para tanto, podemos atualizar o nosso modelo anterior substituindo mtry = 3 e trees = 200 por tune().
rf_model_tune <-
rand_forest(mtry = tune(), trees = tune()) %>%
set_mode("classification") %>%
set_engine("ranger")
rf_model_tune## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = tune()
##
## Computational engine: ranger
Vamos criar um workflow diferente, agora com um modelo que aceita hiperparâmetros não fixos
rf_wflow_tune <- workflow() %>%
add_model(rf_model_tune) %>%
add_recipe(receita_simples)E atualizar nosso modelo com os valores máximos e mínimos dos hiperparâmetros. utilizando as funções parameters e update do pacote dials. Para mtry, a documentação do randomForest informa que este parâmetro assume valores entre 1 e o número de colunas. Assim, vamos definir seus valores mínimos e maximos como 1 e 30. trees é um hiperparâmetro quantitativo sem relação com a base de dados, seus valores máximos e mínimos já são definidos por padrão como 1 e 2000.
rf_param <-
rf_wflow_tune %>%
parameters() %>%
update(mtry = mtry(range = c(1L, 30L)),
trees = trees())
rf_param$object## [[1]]
## # Randomly Selected Predictors (quantitative)
## Range: [1, 30]
##
## [[2]]
## # Trees (quantitative)
## Range: [1, 2000]
Como definimos o valor mínimo e máximo para mtry como 1L e 30L e para trees como 1L e 2000L, a função grid_regular() constroi uma combinação de valores possíveis para os dois hiperparâmetros. Estes valores poderiam ser definidos manualmente com o uso de uma função como expand.grid, mas grid_regular torna esta tarefa mais conveniente. Assim, passamos o argumento levels = 3 que indica que serão utilizados 3 valores igualmente esparçados dos dois hiperparâmetros. Contudo, se a intenção for produzir valores aleatórios do grid, podemos utilizar random_grid().
rf_grid <- grid_regular(rf_param, levels = 3)
rf_grid## # A tibble: 9 x 2
## mtry trees
## <int> <int>
## 1 1 1
## 2 15 1
## 3 30 1
## 4 1 1000
## 5 15 1000
## 6 30 1000
## 7 1 2000
## 8 15 2000
## 9 30 2000
A realização de cross-validation e do tuning de hiperparâmetros aumenta enormente o número de modelos que serão ajustados. Aqui temos 4 folds e 9 valores possíveis de hiperparâmetros, totalizando 36 modelos. Para aumentar a velocidade com que o processo de estimação ocorre, podemos estimar os modelos em paralelo com o uso do pacote doFuture. Para tanto, será necessário a instalações de alguns pacotes adicionais como Rmpi e snow, que dependem da instalação do Microsoft MPI que pode ser encontrado aqui. Caso ocorra dificuldades na instalação dos pacotes acima, é possível ignorar os códigos abaixo. Neste caso, os modelos podem ser ajustados serialmente, com o único custo sendo um tempo maior de processamento.
all_cores <- parallel::detectCores(logical = FALSE) - 1
registerDoFuture()
cl <- makeClusterPSOCK(all_cores)
plan(future::cluster, workers = cl)Finalmente podemos iniciar o trabalho de tuning dos hiperparâmetros utilizando tune_grid(), que precisa do novo workflow (rf_wflow_tune), dos valores possíveis dos hiperparâmetros (rf_grid) e do objeto indicando quais são os folds do cross-validation (cv_folds).
rf_search <- tune_grid(rf_wflow_tune, grid = rf_grid, resamples = cv_folds,
param_info = rf_param)
rf_search## # Tuning results
## # 4-fold cross-validation
## # A tibble: 4 x 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5274/1758]> Fold1 <tibble [18 x 6]> <tibble [0 x 1]>
## 2 <split [5274/1758]> Fold2 <tibble [18 x 6]> <tibble [0 x 1]>
## 3 <split [5274/1758]> Fold3 <tibble [18 x 6]> <tibble [0 x 1]>
## 4 <split [5274/1758]> Fold4 <tibble [18 x 6]> <tibble [0 x 1]>
Assim, alguns minutos depois, os 36 modelos foram estimados. Podemos calcular uma média de cada medida de performance para as 9 combinações diferentes de hiperparâmetros e listar os modelos com melhores performance segundo o ROC-UAC com a função show_best().
show_best(rf_search, n = 5, metric = 'roc_auc') %>% kable()| mtry | trees | .metric | .estimator | mean | n | std_err | .config |
|---|---|---|---|---|---|---|---|
| 15 | 2000 | roc_auc | binary | 0.8311273 | 4 | 0.0070299 | Preprocessor1_Model8 |
| 15 | 1000 | roc_auc | binary | 0.8307947 | 4 | 0.0069385 | Preprocessor1_Model5 |
| 1 | 2000 | roc_auc | binary | 0.8294619 | 4 | 0.0066992 | Preprocessor1_Model7 |
| 1 | 1000 | roc_auc | binary | 0.8291777 | 4 | 0.0072042 | Preprocessor1_Model4 |
| 30 | 2000 | roc_auc | binary | 0.8247092 | 4 | 0.0076831 | Preprocessor1_Model9 |
Assim, o modelo com melhor ROC-AUC foi o Random Forest com mtry = 15e trees = 2000. Vamos utiliza-lo para realizar nossa previsão final.
Previsão Final
Agora que sabemos qual o melhor conjunto de hiperparâmetros, podemos ajustar o modelo em toda a base de treinamento. Usamos a função select_best para retornar o melhor modelo, como definido acima.
rf_param_final <- select_best(rf_search, degree, metric = 'roc_auc')
rf_param_final %>% kable()| mtry | trees | .config |
|---|---|---|
| 15 | 2000 | Preprocessor1_Model8 |
Usamos finalize_workflow() para criar um workflow atualizado com os valores de hiperparâmetros do melhor modelo. Isso equivale a construir um novo workflow com mtry = 15e trees = 2000. Por fim, podemos ajustar o modelo utilizando toda a base de treinamento, os folds do cross-validation já realizaram seu trabalho.
rf_wflow_final_fit <- rf_wflow_tune %>%
finalize_workflow(rf_param_final) %>%
fit(training(train_test_split))
rf_wflow_final_fit## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_dummy()
## * step_normalize()
##
## -- Model -----------------------------------------------------------------------
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~15L, x), num.trees = ~2000L, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 2000
## Sample size: 5626
## Number of independent variables: 30
## Mtry: 15
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.141572
E realizar a previsão do modelo, calculando todas as medidas de performance.
rf_wflow_final_fit %>% predict(new_data = testing(train_test_split)) %>%
bind_cols(testing(train_test_split) %>% select(Churn)) %>%
mutate_all(as.factor) %>%
conf_mat(Churn, .pred_class) %>%
summary() %>%
select(-.estimator) %>%
filter(.metric %in% c('precision', 'recall', 'f_meas',
'accuracy', 'spec', 'sens')) %>%
kable()| .metric | .estimate |
|---|---|
| accuracy | 0.7887624 |
| sens | 0.8862333 |
| spec | 0.5055556 |
| precision | 0.8389140 |
| recall | 0.8862333 |
| f_meas | 0.8619247 |
É possível observar que o modelo de random forest com parâmetros ajustados não excedeu em muito o que já havia sido obtido no nosso primeiro modelo. Existe muito espaço para melhoramento, sobretudo em relação ao processo de feature engineering aplicado com o uso de recipes.
Conclusão
O tidymodels e seus pacotes tornam muito simples a tarefa de aplicar modelos de machine learning para bases de dados. Desde o processamento de bases de dados com recipes, a produção de reamostragens com o rsample, a estimação de modelos de machine learning com o parsnip, a procurar hiperparâmetros com tune e a criação workflows com workflow.
Para além disto, o tidymodels oferece soluções para compararação de diferentes modelos utilizando o pacote workflowsets, tornando simples a criação e o ajustar de um conjunto grande de modelos, sendo necessário apenas uma lista de modelos, que compartilham a mesma receita e a mesma base de treinamento. O tidymodels também permite a criação de assembly, seja por stacks com o uso do pacote stacks ou por bagging, com o pacote baguette, de modo a construir classificadores e previsores melhores a partir da união de diferentes algorítimos.
Em um texto futuro, podemos analisar estas ferramentas.