Aplicação de Aprendizagem Supervisionada - Método de Classificação

Autor

Paulo Manoel da Silva Junior

Machine Learning - Aplicação de Aprendizagem Supervisionada - Problema de Classificação

Objetivo

Aplicar a aprendizagem supervisionada para um problema de classificação, e nesse exemplo incorporando os modelos baseados em árvores, bem como o SVM (Suport Vector Machines)

Utilizando tidymodels para treinar os modelos e junto com isso utilizando workflows.

Com a adição de uma rede neural perceptron com multi camadas.

Informações sobre o pacote tidymodels

O pacote tidymodels é um meta-pacote que consiste de algumas bibliotecas, tais como:

  • rsample: funções para particionamento e reamostragem eficiente de dados;

  • parsnip: interface unificada para um amplo conjunto de modelos que podem ser testados sem que o usuário se preocupe com diferenças de sintaxe;

  • recipes: pré-processamento e feature engineering;

  • workflows: junta pré-processamento, modelagem (treinamento) e pós-processamento; tune: otimização de hiperparâmetros;

  • yardstick: funções para avaliar a efetividade de modelos através de medidas de performance; broom: converte a informação contida em objetos comuns de R para o formato tidy;

  • dials: cria e gerencia hiperparâmetros de ajuste e grids de hiperparâmetros.

Outras bibliotecas serão também utilizadas no processo, como a finetune, que permite um processo de otimização de hiperparâmetros mais eficiente.

Informações sobre o banco de dados e sobre as variáveis

Disponibilidade de informação geral
  • O conjunto de dados trata-se de 11 características clínicas utilizadas para a previsão de possíveis eventos relacionados a doenças cardiovasculares.

O conjunto de dados pode ser encontrado em: conjunto de dados

Mais sobre o banco de dados

As doenças cardiovasculares (DCVs) são a causa número 1 de morte no mundo, levando cerca de 17,9 milhões de vidas a cada ano, o que representa 31% de todas as mortes em todo o mundo. Quatro em cada 5 mortes por DCV são devidas a ataques cardíacos e derrames, e um terço dessas mortes ocorre prematuramente em pessoas com menos de 70 anos de idade. A insuficiência cardíaca é um evento comum causado por DCVs e este conjunto de dados contém 11 recursos que podem ser usados para prever uma possível doença cardíaca.

Pessoas com doenças cardiovasculares ou com alto risco cardiovascular (devido à presença de um ou mais fatores de risco, como hipertensão, diabetes, hiperlipidemia ou doença já estabelecida) precisam de detecção e gerenciamento precoces, em que um modelo de aprendizado de máquina pode ser de grande ajuda.

Sobre as variáveis dependentes:

Idade: idade do paciente em anos

Sexo: sexo do paciente [M: Masculino, F: Feminino]

ChestPainType: tipo de dor no peito [TA: Angina Típica, ATA: Angina Atípica, NAP: Dor Não Anginosa, ASY: Assintomática]

RestingBP: pressão arterial em repouso [mm Hg]

Colesterol: colesterol sérico [mm/dl]

JejumBS: açúcar no sangue em jejum [1: se JejumBS > 120 mg/dl, 0: caso contrário]

ECG em repouso: resultados do eletrocardiograma em repouso [Normal: Normal, ST: com anormalidade da onda ST-T (inversões da onda T e/ou elevação ou depressão do ST > 0,05 mV), HVE: mostrando hipertrofia ventricular esquerda provável ou definitiva pelos critérios de Estes]

MaxHR: frequência cardíaca máxima alcançada [Valor numérico entre 60 e 202]

ExerciseAngina: angina induzida por exercício [S: Sim, N: Não]

Oldpeak: oldpeak = ST [Valor numérico medido na depressão]

ST_Slope: a inclinação do segmento ST do exercício de pico [Up: ascendente, Flat: plano, Down: descendente]

Sobre a variável resposta:

HeartDisease: classe de saída [1: doença cardíaca, 0: normal]

Carregamento dos Dados

  • Carregando o banco de dados
Código
setwd("\\Users\\paulo\\OneDrive\\Área de Trabalho\\ESTATÍSTICA\\UFPB\\8º PERÍODO\\ANÁLISE MULTIVARIADA II\\PROVA")
banco <- read.csv2("heart.csv", header = T, sep = ",")
  • Carregando as bibliotecas
Código
library(tidyverse) # Framework do tidyverse
library(tidymodels) # Framework de modelagem
library(skimr) # Estatística descritiva rápida
library(DataExplorer) # Exploração do conjunto de dados
library(corrplot) # Gráfico de correlação
library(GGally) # Gráficos adicionais com estrutura ggplot2
library(stringr) # Para lidar com strings
library(glmnet) # LASSO, Ridge e Rede Elástica
library(MASS) # Discriminante Linear (LDA) e Quadrático (RL)
library(recipes) # Pré-processamento dos dados
library(class) #knn
library(themis) # Balanceamento de dados
library(discrim) # lda, qda
library(kknn) # (Kernel) K-NN
library(finetune) # Otimização fina de hiperparâmetros
library(gt) # Para tabelas de maneira melhor visualmente 
library(dplyr) # Para tratamento dos dados 
library(plotly) # Para gráficos de melhor qualidade
library(stringr) # Para tratamentos com strings de maneira melhor
library(rpart) # Biblioteca para a engine de método de árvore de decisão
library(parsnip) 
library(ranger)
library(baguette)
library(kernlab)
library(xgboost) # Bibilioteca para o xgboost 
library(qgraph) # Biblioteca para um novo modo de visualizar a matriz de correlação 
library(lubridate)
library(nnet) # Engine para a rede neural - perceptron com multi camadas

Visualizando o banco

Código
glimpse(banco)
Rows: 918
Columns: 12
$ Age            <int> 40, 49, 37, 48, 54, 39, 45, 54, 37, 48, 37, 58, 39, 49,…
$ Sex            <chr> "M", "F", "M", "F", "M", "M", "F", "M", "M", "F", "F", …
$ ChestPainType  <chr> "ATA", "NAP", "ATA", "ASY", "NAP", "NAP", "ATA", "ATA",…
$ RestingBP      <int> 140, 160, 130, 138, 150, 120, 130, 110, 140, 120, 130, …
$ Cholesterol    <int> 289, 180, 283, 214, 195, 339, 237, 208, 207, 284, 211, …
$ FastingBS      <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ RestingECG     <chr> "Normal", "Normal", "ST", "Normal", "Normal", "Normal",…
$ MaxHR          <int> 172, 156, 98, 108, 122, 170, 170, 142, 130, 120, 142, 9…
$ ExerciseAngina <chr> "N", "N", "N", "Y", "N", "N", "N", "N", "Y", "N", "N", …
$ Oldpeak        <chr> "0", "1", "0", "1.5", "0", "0", "0", "0", "1.5", "0", "…
$ ST_Slope       <chr> "Up", "Flat", "Up", "Flat", "Up", "Up", "Up", "Up", "Fl…
$ HeartDisease   <int> 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1…
  • Como observamos acima é necessário ainda realizar algumas transformações, para transformar as variáveis categóricas com os seus respectivos fatores. Ficando dessa forma:
Código
banco$Sex <- factor(banco$Sex, levels = c("M","F"), labels = c("Masculino", "Feminino"))
banco$ChestPainType <- factor(banco$ChestPainType, levels = c("TA", "ATA", "NAP", "ASY"), labels = c("Angina Típica", "Angina Atípica", "Dor Não Anginosa", "Assintomática"))
banco$FastingBS <- factor(banco$FastingBS, levels = c(0,1),labels = c("C.C", "JejumBS > 120 mg/dl"))
banco$RestingECG <- factor(banco$RestingECG, levels = c("Normal", "ST", "LVH"), labels = c("Normal", "anormalidade da onda", "hipertrofia ventricular"))
banco$ExerciseAngina <- factor(banco$ExerciseAngina, levels = c("N", "Y"), labels = c("Não", "Sim"))
banco$ST_Slope <- factor(banco$ST_Slope, levels = c("Up", "Flat", "Down"), labels = c("Ascendente", "Plano", "Descendente"))
banco$HeartDisease <- factor(banco$HeartDisease, levels = c(0,1),labels = c("Normal", "Doença cardiaca"))
banco$Oldpeak <- as.numeric(banco$Oldpeak)
Código
glimpse(banco)
Rows: 918
Columns: 12
$ Age            <int> 40, 49, 37, 48, 54, 39, 45, 54, 37, 48, 37, 58, 39, 49,…
$ Sex            <fct> Masculino, Feminino, Masculino, Feminino, Masculino, Ma…
$ ChestPainType  <fct> Angina Atípica, Dor Não Anginosa, Angina Atípica, Assin…
$ RestingBP      <int> 140, 160, 130, 138, 150, 120, 130, 110, 140, 120, 130, …
$ Cholesterol    <int> 289, 180, 283, 214, 195, 339, 237, 208, 207, 284, 211, …
$ FastingBS      <fct> C.C, C.C, C.C, C.C, C.C, C.C, C.C, C.C, C.C, C.C, C.C, …
$ RestingECG     <fct> Normal, Normal, anormalidade da onda, Normal, Normal, N…
$ MaxHR          <int> 172, 156, 98, 108, 122, 170, 170, 142, 130, 120, 142, 9…
$ ExerciseAngina <fct> Não, Não, Não, Sim, Não, Não, Não, Não, Sim, Não, Não, …
$ Oldpeak        <dbl> 0.0, 1.0, 0.0, 1.5, 0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 0.0, …
$ ST_Slope       <fct> Ascendente, Plano, Ascendente, Plano, Ascendente, Ascen…
$ HeartDisease   <fct> Normal, Doença cardiaca, Normal, Doença cardiaca, Norma…

Comentário: Com o gráfico abaixo, podemos analisar se existe informação ausente, como NA e em quais variáveis encontra-se a observação ausente, se assim existir:

Código
visdat::vis_miss(banco)

De maneira analítica, temos a seguinte quantidade de informações ausentes no banco de dados:

Código
is.na(banco) %>% 
  colSums() 
           Age            Sex  ChestPainType      RestingBP    Cholesterol 
             0              0              0              0              0 
     FastingBS     RestingECG          MaxHR ExerciseAngina        Oldpeak 
             0              0              0              0              0 
      ST_Slope   HeartDisease 
             0              0 

E visualmente, temos um banco de dados dessa maneira:

Código
visdat::vis_dat(banco)

Análise Exploratória dos Dados

  • Uma análise Exploratória de maneira mais geral, utilizando a função skim do pacote skimr
Código
skim(banco)
Data summary
Name banco
Number of rows 918
Number of columns 12
_______________________
Column type frequency:
factor 7
numeric 5
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Sex 0 1 FALSE 2 Mas: 725, Fem: 193
ChestPainType 0 1 FALSE 4 Ass: 496, Dor: 203, Ang: 173, Ang: 46
FastingBS 0 1 FALSE 2 C.C: 704, Jej: 214
RestingECG 0 1 FALSE 3 Nor: 552, hip: 188, ano: 178
ExerciseAngina 0 1 FALSE 2 Não: 547, Sim: 371
ST_Slope 0 1 FALSE 3 Pla: 460, Asc: 395, Des: 63
HeartDisease 0 1 FALSE 2 Doe: 508, Nor: 410

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Age 0 1 53.51 9.43 28.0 47.00 54.0 60.0 77.0 ▁▅▇▆▁
RestingBP 0 1 132.40 18.51 0.0 120.00 130.0 140.0 200.0 ▁▁▃▇▁
Cholesterol 0 1 198.80 109.38 0.0 173.25 223.0 267.0 603.0 ▃▇▇▁▁
MaxHR 0 1 136.81 25.46 60.0 120.00 138.0 156.0 202.0 ▁▃▇▆▂
Oldpeak 0 1 0.89 1.07 -2.6 0.00 0.6 1.5 6.2 ▁▇▆▁▁
  • Agora, podemos analisar de maneira mais precisa dentro das classes de acordo com algumas medidas (posição e dispersão) de interesse, bem como a visualização gráfica do boxplot dessas variáveis de acordo com o grupo.

Variáveis númericas

Medidas de Posição e Dispersão

As medidas de posição e dispersão que serão utilizadas serão:

  • Média

  • Mediana

  • 1º Quartil

  • 3º Quartil

  • Mínimo

  • Máximo

  • Desvio Padrão

  • Coeficiente de Variação

Código
desc_idade <- banco %>% 
  group_by(HeartDisease) %>% 
  summarise(media = mean(Age, na.rm = T),
            mediana = median(Age, na.rm = T), 
            quartil_1 = quantile(Age, 0.25, na.rm = T), 
            quartil_3 = quantile(Age, 0.75, na.rm = T), 
            minimo = min(Age, na.rm = T), 
            maximo = max(Age, na.rm = T), 
            desvio = sd(Age, na.rm = T), 
            coeficiente = round(sd(Age, na.rm = T)/mean(Age, na.rm = T)*100,2))


colnames(desc_idade) <- c("Grupo","Média", "Mediana", "1º Quartil","3º Quartil", "Mínimo", "Máximo", "Desvio Padrão", "Coeficiente de Variação")

desc_idade %>% 
  gt() %>% 
  fmt_percent(
    columns = `Coeficiente de Variação`, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da Pressão Arterial em Repouso</b>")) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  data_color(
    columns = Média,
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "gray"), 
      domain = c(min(desc_idade$Média), max(desc_idade$Média)), 
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = vars(everything())) %>% 
  fmt_number(
    columns = c(Média, `Desvio Padrão`), 
    decimals = 3
  )
Estatística Descritiva da Pressão Arterial em Repouso
Grupo Média Mediana 1º Quartil 3º Quartil Mínimo Máximo Desvio Padrão Coeficiente de Variação
Normal 50.551 51 43 57 28 76 9.445 18,68%
Doença cardiaca 55.900 57 51 62 31 77 8.727 15,61%
Fonte: Elaboração Própria
Código
plot_ly(banco, x = banco$Age, color = banco$HeartDisease, type = "box") %>% 
  layout(title = "Boxplot da Idade de acordo com a presença ou ausência de doença cardíaca")
Comentário

No boxplot acima, pacientes que apresentaram doenças cardíacas possuem uma variabilidade menor de idade quando comparado com pacientes que não apresentaram. Além disso nota-se a presença de 4 outliers referentes a pacientes que apresentaram doenças cardiovasculares antes dos 35 anos.

Também podemos observar 75% dos pacientes que não apresentaram doenças cardiovasculares são mais novos que a idade mediana dos pacientes que apresentaram doenças cardiovasculares. Isto é, pode-se levantar a hipótese de que a idade talvez tenha uma contribuição significativa para o desenvolvimento de doenças cardíacas.

Código
desc_MAXHR <- banco %>% 
  group_by(HeartDisease) %>% 
  summarise(media = mean(MaxHR, na.rm = T),
            mediana = median(MaxHR, na.rm = T), 
            quartil_1 = quantile(MaxHR, 0.25, na.rm = T), 
            quartil_3 = quantile(MaxHR, 0.75, na.rm = T), 
            minimo = min(MaxHR, na.rm = T), 
            maximo = max(MaxHR, na.rm = T), 
            desvio = sd(MaxHR, na.rm = T), 
            coeficiente = round(sd(MaxHR, na.rm = T)/mean(MaxHR, na.rm = T)*100,2)) 

colnames(desc_MAXHR) <- c("Grupo","Média", "Mediana", "1º Quartil","3º Quartil", "Mínimo", "Máximo", "Desvio Padrão", "Coeficiente de Variação")

desc_MAXHR %>% 
  gt() %>% 
  fmt_percent(
    columns = `Coeficiente de Variação`, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da Pressão Arterial em Repouso</b>")) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  data_color(
    columns = Média,
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "green"), 
      domain = c(min(desc_MAXHR$Média), max(desc_MAXHR$Média)), 
      reverse = TRUE
    )
  ) %>% 
  fmt_number(columns = c(Média, `Desvio Padrão`), 
             decimals = 3) %>% 
  cols_align(align = "center", columns = everything())
Estatística Descritiva da Pressão Arterial em Repouso
Grupo Média Mediana 1º Quartil 3º Quartil Mínimo Máximo Desvio Padrão Coeficiente de Variação
Normal 148.151 150 134 165.00 69 202 23.288 15,72%
Doença cardiaca 127.656 126 112 144.25 60 195 23.387 18,32%
Fonte: Elaboração Própria
Código
plot_ly(banco, x = banco$MaxHR, color = banco$HeartDisease, type = "box") %>% 
  layout(title = "Boxplot da Frequência Máxima Cardíaca de acordo com a presença ou ausência de doença cardíaca")
Comentário

Como o coração de pessoas com doenças cardíacas não estão funcionando de maneira adequada, quando se é necessária uma carga maior de trabalho o coração desse indivíduo não consegue trabalhar de forma tão eficiente quanto o coração de uma pessoa saudável. Por conta disso, 75% dos indíviduos doentes possuem a frequência máxima cardíaca abaixo da mediana da frequência máxima cardíaca do grupo de pessoas saudáveis.

Código
desc_restingBP <- banco %>% 
  group_by(HeartDisease) %>% 
  summarise(media = mean(RestingBP, na.rm = T),
            mediana = median(RestingBP, na.rm = T), 
            quartil_1 = quantile(RestingBP, 0.25, na.rm = T), 
            quartil_3 = quantile(RestingBP, 0.75, na.rm = T), 
            minimo = min(RestingBP, na.rm = T), 
            maximo = max(RestingBP, na.rm = T), 
            desvio = sd(RestingBP, na.rm = T), 
            coeficiente = round(sd(RestingBP, na.rm = T)/mean(RestingBP, na.rm = T)*100,2)) 

colnames(desc_restingBP) <- c("Grupo","Média", "Mediana", "1º Quartil","3º Quartil", "Mínimo", "Máximo", "Desvio Padrão", "Coeficiente de Variação")

desc_restingBP %>% 
  gt() %>% 
  fmt_percent(
    columns = `Coeficiente de Variação`, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da Pressão Arterial em Repouso</b>")) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  data_color(
    columns = Média,
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "green"), 
      domain = c(min(desc_restingBP$Média), max(desc_restingBP$Média)), 
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = everything()) %>% 
  fmt_number(columns = c(Média, `Desvio Padrão`), 
             decimals = 4)
Estatística Descritiva da Pressão Arterial em Repouso
Grupo Média Mediana 1º Quartil 3º Quartil Mínimo Máximo Desvio Padrão Coeficiente de Variação
Normal 130.1805 130 120 140 80 190 16.4996 12,67%
Doença cardiaca 134.1850 132 120 145 0 200 19.8287 14,78%
Fonte: Elaboração Própria
Código
plot_ly(banco, x = banco$RestingBP, color = banco$HeartDisease, type = "box") %>% 
  layout(title = "Boxplot da Pressão Arterial em Repouso de acordo com a presença ou ausência de doença cardíaca")
Comentário

Apesar de ambos os grupos apresentarem dados semelhantes, o grupo que possui doenças cardiovasculares é ligeiramente maior que o grupo de indivíduos saudáveis. Afinal, o coração doente por não apresentar os batimentos tão eficientes quanto um coração saudável, as artérias tendem a compensar esses batimentos cardíacos aumentando sua pressão.

Código
desc_Choles <- banco %>% 
  group_by(HeartDisease) %>% 
  summarise(media = mean(Cholesterol, na.rm = T),
            mediana = median(Cholesterol, na.rm = T), 
            quartil_1 = quantile(Cholesterol, 0.25, na.rm = T), 
            quartil_3 = quantile(Cholesterol, 0.75, na.rm = T), 
            minimo = min(Cholesterol, na.rm = T), 
            maximo = max(Cholesterol, na.rm = T), 
            desvio = sd(Cholesterol, na.rm = T), 
            coeficiente = round(sd(Cholesterol, na.rm = T)/mean(Cholesterol, na.rm = T)*100,2)) 

colnames(desc_Choles) <- c("Grupo","Média", "Mediana", "1º Quartil","3º Quartil", "Mínimo", "Máximo", "Desvio Padrão", "Coeficiente de Variação")

desc_Choles %>% 
  gt() %>% 
  fmt_percent(
    columns = `Coeficiente de Variação`, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da Pressão Arterial em Repouso</b>")) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  data_color(
    columns = Média,
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "Reds 2"), 
      domain = c(min(desc_Choles$Média), max(desc_Choles$Média)), 
      reverse = TRUE
    )
  ) %>% 
  cols_align("center", columns = everything()) %>% 
  fmt_number(columns = c(Média, `Desvio Padrão`), 
             decimals = 2)
Estatística Descritiva da Pressão Arterial em Repouso
Grupo Média Mediana 1º Quartil 3º Quartil Mínimo Máximo Desvio Padrão Coeficiente de Variação
Normal 227.12 227 197.25 266.75 0 564 74.63 32,86%
Doença cardiaca 175.94 217 0.00 267.00 0 603 126.39 71,84%
Fonte: Elaboração Própria
Código
plot_ly(banco, x = banco$Cholesterol, color = banco$HeartDisease, type = "box") %>% 
  layout(title = "Boxplot do Colesterol Sérico de acordo \ncom a presença ou ausência de doença cardíaca")
Comentário

Devem existir inúmeros fatores que podem implicar uma maior variabilidade do colesterol no grupo de pessoas que possuem doenças cardíacas, uma delas pode estar relacionada com o fato do uso de medicações para diminuir e tentar controlar esse nível do colesterol.

Após a visualização gráfica do boxplot, observamos uma diferença significativa na variação dos dados do colesterol de acordo com a presença ou ausência de doença cardíaca.

Código
desc_old_peak <- banco %>% 
  group_by(HeartDisease) %>% 
  summarise(media = mean(Oldpeak, na.rm = T),
            mediana = median(Oldpeak, na.rm = T), 
            quartil_1 = quantile(Oldpeak, 0.25, na.rm = T), 
            quartil_3 = quantile(Oldpeak, 0.75, na.rm = T), 
            minimo = min(Oldpeak, na.rm = T), 
            maximo = max(Oldpeak, na.rm = T), 
            desvio = sd(Oldpeak, na.rm = T), 
            coeficiente = round(sd(Oldpeak, na.rm = T)/mean(Oldpeak, na.rm = T)*100,2)) 


colnames(desc_old_peak) <- c("Grupo","Média", "Mediana", "1º Quartil","3º Quartil", "Mínimo", "Máximo", "Desvio Padrão", "Coeficiente de Variação")

desc_old_peak %>% 
  gt() %>% 
  fmt_percent(
    columns = `Coeficiente de Variação`, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da Pressão Arterial em Repouso</b>")) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  data_color(
    columns = Média,
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "Reds 2"), 
      domain = c(min(desc_old_peak$Média), max(desc_old_peak$Média)), 
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = everything()) %>% 
  fmt_number(columns = c(Média, `Desvio Padrão`), 
             decimals = 4)
Estatística Descritiva da Pressão Arterial em Repouso
Grupo Média Mediana 1º Quartil 3º Quartil Mínimo Máximo Desvio Padrão Coeficiente de Variação
Normal 0.4080 0.0 0 0.6 -1.1 4.2 0.6997 171,48%
Doença cardiaca 1.2742 1.2 0 2.0 -2.6 6.2 1.1519 90,40%
Fonte: Elaboração Própria
Código
plot_ly(banco, x = banco$Oldpeak, color = banco$HeartDisease, type = "box") %>% 
  layout(title = "Boxplot do banco do valor númerico medido \nna depressão de acordo com a presença ou ausência de doença cardíaca")
Comentário

Podemos observar que no grupo normal, temos a presença de muitos outliers, já no grupo de doença cardíaca, entre o primeiro e terceiro quartil podemos enxergar uma variabilidade maior do que comparando esses mesmos quartis.

Agora, vamos visualizar a matriz de correlação das variáveis preditoras

Código
rho <- banco %>% 
  dplyr::select(where(is.numeric)) %>% 
  cor()

corrplot::corrplot(rho, method = "square", type = "full")

Observando a correlação de outra maneira mais gráfica, através de um gráfico que vai mostrar as ligações e força da correlação se existir

No gráfico abaixo, correlações positivas podem ser vistas com a ligação azul e correlações negativas com a vermelha.

Código
rho <- round(rho,2)
qgraph(rho, shape = "diamond", 
       posCol = "darkblue", 
       negCol = "darkred", 
       layout = "groups", 
       size = 20, 
       vTrans = 100, 
       borders = F,
       edge.labels = rho, 
       color = "skyblue", 
       edge.label.position = 0.5, 
       edge.label.color = "black", 
       edge.label.margin = 0.01)

Variáveis Categóricas

  • Como variável dependente temos se o paciente tem doença cardíaca ou não, sendo assim a quantidade de pessoas que possuem em nosso banco é essa:
Código
desc_heart <- banco %>% 
  dplyr::group_by(HeartDisease) %>% 
  dplyr::summarise(quantidade = n(), 
            proporção = round(n()/dim(banco)[1]*100,2))
  
colnames(desc_heart) <- c("Doença Cardíaca", "Frequência", "Percentual")

desc_heart %>% 
  gt() %>% 
  fmt_percent(
    columns = Percentual, 
    decimals = 2, 
    scale_values = FALSE, 
    sep_mark = ".", 
    dec_mark = ","
  ) %>% 
  tab_header(title = html("<b> Estatística Descritiva da variável de interesse </b>")) %>% 
  tab_source_note(html("<b>Fonte: </b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything())
Estatística Descritiva da variável de interesse
Doença Cardíaca Frequência Percentual
Normal 410 44,66%
Doença cardiaca 508 55,34%
Fonte: Elaboração Própria
Código
banco %>% 
  count(HeartDisease, RestingECG) %>% 
  group_by(HeartDisease) %>% 
  mutate(percent = n / sum(n) *100,
         percent = round(percent, 2)) %>% 
  gt::gt() %>% 
    gt::tab_header(
    title = html("<b> Situação dos pacientes quanto a presença de doença cardíaca</b>"),
    subtitle = "Com relação ao eletrocardiograma em repouso"
  ) %>% 
  gt::cols_label(
    RestingECG = "ECG em Repouso",
    n = "Frequência",
    percent = "Percentual"
  ) %>% 
  gt::fmt_number(
    columns = vars(n),
    suffixing = T, 
    decimals = 0
  ) %>% 
  fmt_percent(columns = percent, 
              decimals = 2, 
              scale_values = F, 
              sep_mark = ".", 
              dec_mark = ",") %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything())
Situação dos pacientes quanto a presença de doença cardíaca
Com relação ao eletrocardiograma em repouso
ECG em Repouso Frequência Percentual
Normal
Normal 267 65,12%
anormalidade da onda 61 14,88%
hipertrofia ventricular 82 20,00%
Doença cardiaca
Normal 285 56,10%
anormalidade da onda 117 23,03%
hipertrofia ventricular 106 20,87%
Fonte: Elaboração Própria
Código
banco %>% 
  count(HeartDisease, ST_Slope) %>% 
  group_by(HeartDisease) %>% 
  mutate(percent = n / sum(n) *100,
         percent = round(percent, 2)) %>% 
  gt::gt() %>% 
    gt::tab_header(
    title = html("<b> Situação dos pacientes quanto a presença de doença cardíaca</b>"),
    subtitle = "Com relação a inclinação do Segmento"
  ) %>% 
  gt::cols_label(
    ST_Slope = "Inclinação do Segmento",
    n = "Frequência",
    percent = "Percentual"
  ) %>% 
  gt::fmt_number(
    columns = vars(n),
    suffixing = TRUE, 
    decimals = 0
  ) %>% 
  fmt_percent(
    columns = percent, 
    decimals = 2, 
    dec_mark = ",", 
    sep_mark = ".", 
    scale_values = F
  ) %>% 
  tab_source_note(html("<b> Fonte: </b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything())
Situação dos pacientes quanto a presença de doença cardíaca
Com relação a inclinação do Segmento
Inclinação do Segmento Frequência Percentual
Normal
Ascendente 317 77,32%
Plano 79 19,27%
Descendente 14 3,41%
Doença cardiaca
Ascendente 78 15,35%
Plano 381 75,00%
Descendente 49 9,65%
Fonte: Elaboração Própria
Código
banco %>% 
  count(HeartDisease, Sex) %>% 
  group_by(HeartDisease) %>% 
  mutate(percent = n / sum(n) *100,
         percent = round(percent, 2)) %>% 
  gt::gt() %>% 
    gt::tab_header(
    title = html("<b> Situação dos pacientes quanto a presença de doença cardíaca</b>"),
    subtitle = "Com relação ao Sexo"
  ) %>% 
  gt::cols_label(
    Sex = "Sexo",
    n = "Frequência",
    percent = "Percentual"
  ) %>% 
  gt::fmt_number(
    columns = vars(n),
    suffixing = TRUE,
    decimals = 0
  ) %>% 
  fmt_percent(
    columns = percent, 
    decimals = 2, 
    scale_values = F, 
    dec_mark = ",", 
    sep_mark = "."
  ) %>% 
  tab_source_note(html("<b> Fonte:</b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything())
Situação dos pacientes quanto a presença de doença cardíaca
Com relação ao Sexo
Sexo Frequência Percentual
Normal
Masculino 267 65,12%
Feminino 143 34,88%
Doença cardiaca
Masculino 458 90,16%
Feminino 50 9,84%
Fonte: Elaboração Própria
Código
banco %>% 
  count(HeartDisease, ChestPainType) %>% 
  group_by(HeartDisease) %>% 
  mutate(percent = n / sum(n) *100,
         percent = round(percent, 2)) %>% 
  gt::gt() %>% 
    gt::tab_header(
    title = html("<b> Situação dos pacientes quanto a presença de doença cardíaca </b>"),
    subtitle = "Com relação a dor no peito"
  ) %>% 
  gt::cols_label(
    ChestPainType = "Tipo de dor no peito",
    n = "Frequência",
    percent = "Percentual"
  ) %>% 
  gt::fmt_number(
    columns = vars(n),
    suffixing = TRUE,
    decimals = 0
  ) %>% 
  fmt_percent(
    columns = percent, 
    decimals = 2, 
    scale_values = F, 
    dec_mark = ",", 
    sep_mark = "."
  ) %>% 
  tab_source_note(html("<b> Fonte:</b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything())
Situação dos pacientes quanto a presença de doença cardíaca
Com relação a dor no peito
Tipo de dor no peito Frequência Percentual
Normal
Angina Típica 26 6,34%
Angina Atípica 149 36,34%
Dor Não Anginosa 131 31,95%
Assintomática 104 25,37%
Doença cardiaca
Angina Típica 20 3,94%
Angina Atípica 24 4,72%
Dor Não Anginosa 72 14,17%
Assintomática 392 77,17%
Fonte: Elaboração Própria
Código
banco %>% 
  count(HeartDisease, FastingBS) %>% 
  group_by(HeartDisease) %>% 
  mutate(percent = n / sum(n) *100,
         percent = round(percent, 2)) %>% 
  gt::gt() %>% 
    gt::tab_header(
    title = html("<b> Situação dos pacientes quanto a presença de doença cardíaca</b>"),
    subtitle = "Com relação ao Açúcar no Sengue em Jejum"
  ) %>% 
  gt::cols_label(
    FastingBS = "Açucar no Sangue",
    n = "Frequência",
    percent = "Percentual"
  ) %>% 
  gt::fmt_number(
    columns = vars(n),
    suffixing = TRUE,
    decimals = 0
  ) %>% 
  fmt_percent(
    columns = percent, 
    decimals = 2, 
    scale_values = F, 
    dec_mark = ",", 
    sep_mark = "."
  ) %>% 
  tab_source_note(html("<b> Fonte:</b> Elaboração Própria")) %>% 
  cols_align(align = "center", columns = everything()) %>% 
  tab_style(
    style = list(
      cell_text(weight =  "bolder")
    ),
    locations = cells_body(columns = percent)
  )
Situação dos pacientes quanto a presença de doença cardíaca
Com relação ao Açúcar no Sengue em Jejum
Açucar no Sangue Frequência Percentual
Normal
C.C 366 89,27%
JejumBS > 120 mg/dl 44 10,73%
Doença cardiaca
C.C 338 66,54%
JejumBS > 120 mg/dl 170 33,46%
Fonte: Elaboração Própria

Ajuste dos Modelos

Particionamento do conjunto de dados

  • Primeiro é necessário particionar o conjunto de dados em treinamento e teste, lembrando que os dados de treinamento ainda serão submetidos a validação cruzada para buscarmos os melhores valores de hiperparâmetros para alguns modelos.

Para esse banco de dados a proporção que será utilizada no particionamento é de 80%.

Código
set.seed(2024)
split_inicial <- initial_split(banco, prop = 0.80, strata = HeartDisease)
banco_treino <- training(split_inicial)
banco_teste <- testing(split_inicial)

Pré-processamento

Para o pré-processamento é a parte de aplicarmos algumas receitas para melhorar os dados, com isso utilizamos a biblioteca recipes, e aplicamos as receitas de acordo com as necessidades que possuímos.

Código
banco_receita <- recipe(HeartDisease ~ ., data = banco_treino) %>%
  # step_impute_knn(all_predictors(), neighbors = 5) %>% # imputa valores ausentes por K-NN
  # step_impute_bag(all_predictors()) %>% # imputa valores ausentes por bagged trees
  # step_impute_mean(all_numeric_predictors()) %>% # imputa valores ausentes pela média
  # step_impute_median(all_numeric_predictors()) %>% # imputa valores ausentes pela mediana
  # step_impute_mode(all_predictors()) %>% # imputa valores ausentes pela mediana
  # step_naomit(everything(), skip = TRUE) %>% # remove linhas que contém NA ou NaN
  # step_interact(terms = ~ all_numeric_predictors():all_numeric_predictors()) %>%
  # step_log( # Transformação log: y = log(x)
  #   ) %>%
  step_YeoJohnson( # Transformação Yeo-Johnson
    all_numeric_predictors()
    ) %>%
  step_poly(all_numeric_predictors(), degree = 2) %>%
  step_normalize( # normaliza variáveis numéricas para terem média 0 e variância 1
    all_numeric_predictors()
    ) %>%
  # step_range( # normaliza variáveis numéricas para pertencerem ao intervalo [0,1]
  #   all_numeric_predictors()
  #   ) %>%
   step_dummy(all_nominal_predictors()) %>% # converte variáveis qualitativas em variáveis dummy
  # step_smote(HeartDisease, over_ratio = 1) %>% # Balanceamento de classes usando SMOTE
  # step_upsample(diagnosis) %>% # Balanceamento de classes usando upsample
  # step_downsample(diagnosis) %>% # Balanceamento de classes usando downsample
  step_nzv(all_numeric_predictors()) %>% # remove variáveis que têm variância próxima de zero
  step_corr( # remove preditores que tenham alta correlação com algum outro preditor
    all_numeric_predictors(),
    threshold = 0.8,
    method = "spearman"
    )
Receita Utilizada

Na aplicação desse banco de dados as receitas utilizadas foram:

  • Transformação de Yeo - Johnson nas variáveis numéricas;

  • Polinômios de Grau dois, ou seja, elevando as variáveis numéricas para termos talvez uma classificação melhor;

  • Normalização das variáveis númericas;

  • Transformação das variáveis categóricas em variáveis Dummy;

  • Remoção de variáveis preditoras altamente correlacionadas, estabelecido como ponto de corte uma correlação através do método de spearman, não sensível apenas a correlação linear, foi de 0.80.

Depois de aplicar as receitas precisamos continuar para extrair o banco aplicando as receitas.

Código
# Dessa vez, não usaremos os "dados preparados" explicitamente, mas criarei esse objeto para verificarmos o efeito do smote e da imputação dos valores ausentes
set.seed(2024)
banco_preparado <- 
  banco_receita %>% # usa a receita
  prep() %>% # aplica a receita no conjunto de treinamento
  juice() # extrai apenas o dataframe preprocessado

Quantidade de Observações no banco de Treinamento, de acordo com a variável destino, ou feature target, que é a Heart Disease

Código
tabela <- banco_preparado %>% 
  group_by(HeartDisease) %>% 
  summarise(quantidade = n(), 
            percentual = n()/dim(banco_preparado)[1])
  
colnames(tabela) <- c("Heart Disease","Quantidade", "Proporção")

tabela %>%
  gt::gt() %>% 
  tab_header(
    title = gt::html("<b> Quantidade de Observações de acordo com as classes da variável dependente 
                     </b>"), 
    subtitle = glue::glue("No banco de Treinamento")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração Própria")
  ) %>% 
  fmt_number(columns = Quantidade, 
             suffixing = TRUE, 
             decimals = 0) %>% 
  fmt_percent(columns = Proporção, 
              decimals = 2, 
              scale_values = T, 
              dec_mark = ",", 
              sep_mark = ".") %>% 
  cols_align(align = "center", columns = everything()) %>% 
  tab_style(
    style = list(
      cell_text(weight =  "bolder")
    ),
    locations = cells_body(columns = everything())
  )
Quantidade de Observações de acordo com as classes da variável dependente
No banco de Treinamento
Heart Disease Quantidade Proporção
Normal 328 44,69%
Doença cardiaca 406 55,31%
Fonte: Elaboração Própria

Conjunto de Validação

Utilizamos o método \(k\)-fold cross-validation para construir um conjunto de validação com \(k\) folds. Consideramos um procedimento com \(k = 10\) folds. Os dados são particionados em 10 partes utilizando amostragem estratificada e, em cada iteração, os modelos são ajustados em um conjunto de treinamento com composto por 9 dessas partes e avaliado em um conjunto de teste composto por 1 dessas partes. Esse procedimento foi utilizado para avaliar o modelo e obter os valores ótimos dos hiperparâmetros dos modelos.

Código
cv_folds <- vfold_cv(banco_treino, 
                     v = 10, 
                     strata = HeartDisease)

Otimização dos Hiperparâmetros e definição dos modelos

Os hiperparâmetros dos modelos foram otimizados no processo de validação cruzada. A busca pelos valores ótimos dos hiperparâmetros se deu através de um processo de busca em uma grade aleatória de valores definida através de um esquema de hipercubo latino, visando preencher adequadamente o espaço de valores dos hiperparâmetros.

Código
knn_spec <- nearest_neighbor(neighbors = tune()) %>% # K-NN
  set_mode("classification") %>%
  set_engine("kknn")

nbayes_spec <- naive_Bayes() %>% # Naive Bayes
  set_engine("naivebayes") %>%
  set_mode("classification")

lda_spec <- discrim_linear() %>% # Linear discriminant analysis
  set_engine("MASS") %>%
  set_mode("classification")

qda_spec <- discrim_quad() %>% # Quadratic discriminant analysis
  set_engine("MASS") %>%
  set_mode("classification")

reg_log_spec <- logistic_reg(penalty = tune(), mixture = tune()) %>% # RL
  set_engine(engine = "glmnet", standardize = FALSE) %>%
  set_mode("classification")

dec_tree <- decision_tree(cost_complexity = tune(), 
                          min_n = tune(), 
                          tree_depth = tune()) %>% # Decision tree
  set_engine(engine = "rpart") %>% 
  set_mode("classification")

bagg_tree <- bag_tree(cost_complexity = tune(), 
                      min_n = tune(), 
                      tree_depth = tune()) %>% # Bagged tree
  set_engine(engine = "rpart") %>% 
  set_mode("classification")

random_forest <- rand_forest(mtry = tune(), 
                             min_n = tune(), 
                             trees = tune()) %>% # Random Forest
  set_engine(engine = "ranger", importance = "impurity") %>% 
  set_mode("classification")

xgboost <- boost_tree(tree_depth = tune(), 
                      learn_rate = tune(), 
                      loss_reduction = tune(), 
                      min_n = tune(), 
                      sample_size = tune(), 
                      trees = tune(), 
                      mtry = tune()) %>% # Boosted trees
  set_engine(engine = "xgboost") %>% 
  set_mode("classification")

lsvm <- svm_linear(cost = tune(),           # SVM Linear 
                   margin = tune()) %>% 
  set_engine(engine = "kernlab") %>% 
  set_mode("classification")


rsvm_spec <- svm_rbf(cost = tune(),   # RBF/Gaussian Kernel SVM
                     rbf_sigma = tune(), 
                     margin = tune()) %>% 
  set_engine(engine = "kernlab") %>% 
  set_mode("classification")

psvm_spec <- svm_poly(cost = tune(),  # Polynomial Kernel SVM
                      degree = tune(), 
                      scale_factor = tune(), 
                      margin = tune()) %>% 
  set_engine(engine = "kernlab") %>% 
  set_mode("classification")

rede <- mlp(hidden_units = tune(),     # MLP - Perceptron com multi camadas  
            epochs = tune(), 
            activation = "softmax",
            penalty = tune(),
            mode = "classification", 
            engine = "nnet")
  • Prepando um workflow, ou seja, prepando um fluxo de tarefas para ele treinar os modelos ao mesmo tempo.
Código
wf = workflow_set(
  preproc = list(banco_receita),
  models = list(
    KNN = knn_spec,
    Nayve_Bayes = nbayes_spec,
    LDA = lda_spec,
    QDA = qda_spec,
    Reg_log = reg_log_spec, 
    decision = dec_tree, 
    bag_treeee = bagg_tree,
    random_fore = random_forest, 
    xgbost = xgboost,
    svm_lin = lsvm, 
    svm_rf = rsvm_spec, 
    psvm = psvm_spec, 
    rede_neural = rede
  )
) %>%
  mutate(wflow_id = gsub("(recipe_)", "", wflow_id))

Treinando os modelos

Código
tempo_inicial <- Sys.time()
grid_ctrl = control_grid(
  save_pred = TRUE,
  parallel_over = "resamples",
  save_workflow = TRUE
)


grid_results = wf %>%
  workflow_map(
    seed = 2024,
    resamples = cv_folds,
    grid = 10,
    control = grid_ctrl)


tempo_final <- Sys.time()

tempo <- as.numeric(difftime(tempo_final, tempo_inicial, units = "secs"))

O tempo total de treinamento do modelo foi 200.965714 horas, 3 minutos e 20.965714 segundos.

Verificando os resultados dos modelos

  • Verificando o resultado de maneira geral de todos os modelos, de acordo com duas métricas, sendo a Curva Roc e a Acurácia.
Código
autoplot(grid_results)

Resultados preliminares
  • Podemos observar que nos dados de treinamento os melhores modelos que tiveram os melhores ajustes foram os modelos de Random Forest, com base na curva ROC e na Acurácia. Vamos continuar observando qual o melhor modelo e vendo os valores de algumas métricas.
  • Visualizando as métricas do melhor modelo de cada método o resultado dos hiperparâmetros.
Código
autoplot(grid_results, select_best = TRUE, metric = "roc_auc")

  • Observando os valores de maneira mais analítica
Código
results <- workflowsets::rank_results(grid_results,
                          select_best = TRUE,
                          rank_metric = "roc_auc") %>%
  filter(.metric == "roc_auc") %>%
  dplyr::select(wflow_id, mean, std_err, model, rank)

colnames(results) <- c("Método", "Média", "Desvio Padrão", "Modelo", "Ranking")

results$Método[which(results$Método=="Reg_log")] <- "Regressão Logística"
results$Método[which(results$Método=="LDA")] <- "Discriminante Linear"
results$Método[which(results$Método=="QDA")] <- "Discriminante Quadrática"
results$Método[which(results$Método=="KNN")] <- "Knn - mais próximos"
results$Método[which(results$Método=="Nayve_Bayes")] <- "Nayve Bayes"
results$Método[which(results$Método=="random_fore")] <- "Floresta Aleatória"
results$Método[which(results$Método=="bag_treeee")] <- "Bagged Trees"
results$Método[which(results$Método=="decision")] <- "Decision Tree"
results$Método[which(results$Método=="xgbost")] <- "XG Boost"
results$Método[which(results$Método=="svm_lin")] <- "SVM Linear"
results$Método[which(results$Método=="svm_rf")] <- "SVM RBF"
results$Método[which(results$Método=="psvm")] <- "SVM Polinomial"
results$Método[which(results$Método=="rede_neural")] <- "Rede Neural"


results$Média <- round(results$Média, 4)
results$`Desvio Padrão` <- round(results$`Desvio Padrão`, 4)

results %>% gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado do Treinamento dos Modelos</b>"), 
    subtitle = glue::glue("De acordo com a métrica da curva Roc")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  gt::data_color(
    columns = Média, 
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "Green"), 
      domain = c(min(results$Média), max(results$Média)),
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = everything())
Resultado do Treinamento dos Modelos
De acordo com a métrica da curva Roc
Método Média Desvio Padrão Modelo Ranking
SVM Polinomial 0.9248 0.0088 svm_poly 1
SVM Linear 0.9243 0.0090 svm_linear 2
Regressão Logística 0.9238 0.0091 logistic_reg 3
Floresta Aleatória 0.9226 0.0096 rand_forest 4
Discriminante Linear 0.9222 0.0095 discrim_linear 5
Bagged Trees 0.9195 0.0085 bag_tree 6
SVM RBF 0.9194 0.0079 svm_rbf 7
XG Boost 0.9187 0.0091 boost_tree 8
Rede Neural 0.9063 0.0136 mlp 9
Knn - mais próximos 0.8971 0.0108 nearest_neighbor 10
Nayve Bayes 0.8933 0.0091 naive_Bayes 11
Discriminante Quadrática 0.8890 0.0142 discrim_quad 12
Decision Tree 0.8766 0.0132 decision_tree 13
Fonte: Elaboração própria
Código
autoplot(grid_results, select_best = TRUE, metric = "accuracy")

  • Observando de maneira analítica, temos a seguinte informação:
Código
results_acc <- workflowsets::rank_results(grid_results,
                          select_best = TRUE,
                          rank_metric = "accuracy") %>%
  filter(.metric == "accuracy") %>%
  dplyr::select(wflow_id, mean, std_err, model, rank)

colnames(results_acc) <- c("Método", "Média", "Desvio Padrão", "Modelo", "Ranking")

results_acc$Método[which(results_acc$Método=="Reg_log")] <- "Regressão Logística"
results_acc$Método[which(results_acc$Método=="LDA")] <- "Discriminante Linear"
results_acc$Método[which(results_acc$Método=="QDA")] <- "Discriminante Quadrática"
results_acc$Método[which(results_acc$Método=="KNN")] <- "Knn - mais próximos"
results_acc$Método[which(results_acc$Método=="Nayve_Bayes")] <- "Nayve Bayes"
results_acc$Método[which(results_acc$Método=="random_fore")] <- "Floresta Aleatória"
results_acc$Método[which(results_acc$Método=="bag_treeee")] <- "Bagged Trees"
results_acc$Método[which(results_acc$Método=="decision")] <- "Decision Tree"
results_acc$Método[which(results_acc$Método=="xgbost")] <- "XG Boost"
results_acc$Método[which(results_acc$Método=="svm_lin")] <- "SVM Linear"
results_acc$Método[which(results_acc$Método=="svm_rf")] <- "SVM RBF"
results_acc$Método[which(results_acc$Método=="psvm")] <- "SVM Polinomial"
results_acc$Método[which(results_acc$Método=="rede_neural")] <- "Rede Neural"

results_acc$Média <- round(results_acc$Média, 4)
results_acc$`Desvio Padrão` <- round(results_acc$`Desvio Padrão`, 4)

results_acc %>% gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado do Treinamento dos Modelos</b>"), 
    subtitle = glue::glue("De acordo com a métrica da Acurácia")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  gt::data_color(
    columns = Média, 
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 5, palette = "Green"), 
      domain = c(min(results_acc$Média), max(results_acc$Média)),
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = everything())
Resultado do Treinamento dos Modelos
De acordo com a métrica da Acurácia
Método Média Desvio Padrão Modelo Ranking
SVM Linear 0.8705 0.0096 svm_linear 1
SVM Polinomial 0.8692 0.0087 svm_poly 2
Floresta Aleatória 0.8652 0.0091 rand_forest 3
Regressão Logística 0.8651 0.0100 logistic_reg 4
Discriminante Linear 0.8637 0.0069 discrim_linear 5
XG Boost 0.8624 0.0116 boost_tree 6
SVM RBF 0.8597 0.0108 svm_rbf 7
Rede Neural 0.8502 0.0131 mlp 8
Bagged Trees 0.8501 0.0095 bag_tree 9
Knn - mais próximos 0.8501 0.0150 nearest_neighbor 10
Decision Tree 0.8337 0.0111 decision_tree 11
Discriminante Quadrática 0.8218 0.0179 discrim_quad 12
Nayve Bayes 0.8175 0.0133 naive_Bayes 13
Fonte: Elaboração própria

Agora é necessário selecionar o melhor modelos dos que foram treinados (um de cada modelo), com base na estimação dos hiperparâmetros para aplicar os dados de teste.

Código
best_set_linear = grid_results %>% 
  extract_workflow_set_result("Reg_log") %>% 
  select_best(metric = "accuracy")
best_set_knn = grid_results %>% 
  extract_workflow_set_result("KNN") %>% 
  select_best(metric = "accuracy")
best_set_nbayes = grid_results %>%
  extract_workflow_set_result("Nayve_Bayes") %>% 
  select_best(metric = "accuracy")
best_set_lda = grid_results %>% 
  extract_workflow_set_result("LDA") %>% 
  select_best(metric = "accuracy")
best_set_qda = grid_results %>% 
  extract_workflow_set_result("QDA") %>% 
  select_best(metric = "accuracy")
best_set_rand_fore = grid_results %>% 
  extract_workflow_set_result("random_fore") %>% 
  select_best(metric = "accuracy")
best_set_decision = grid_results %>% 
  extract_workflow_set_result("decision") %>% 
  select_best(metric = "accuracy")
best_set_bag = grid_results %>% 
  extract_workflow_set_result("bag_treeee") %>% 
  select_best(metric = "accuracy")
best_set_xbost = grid_results %>% 
  extract_workflow_set_result("xgbost") %>% 
  select_best(metric = "accuracy")
best_set_svmlin = grid_results %>% 
  extract_workflow_set_result("svm_lin") %>% 
  select_best(metric = "accuracy")
best_set_svmrf = grid_results %>% 
  extract_workflow_set_result("svm_rf") %>% 
  select_best(metric = "accuracy")
best_set_psvm = grid_results %>% 
  extract_workflow_set_result("psvm") %>% 
  select_best(metric = "accuracy")
best_set_mlp = grid_results %>% 
  extract_workflow_set_result("rede_neural") %>% 
  select_best(metric = "accuracy")

Os hiperparâmetros do melhor modelo de cada método foi

O modelo KNN que foi o melhor teve um número de 14 vizinhos.

Código
best_set_linear %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de Regressão Logística")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  gt::cols_label(
    penalty = "Penalização", 
    mixture = "Mistura", 
    .config = "Configuração"
  ) %>% 
  fmt_number(columns = c(penalty, mixture), 
             decimals = 4) %>% 
  cols_align(align = "center", columns = everything())
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de Regressão Logística
Penalização Mistura Configuração
0.0128 0.4759 Preprocessor1_Model05
Fonte: Elaboração própria
Código
best_set_decision %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de Árvore de Decisão")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  fmt_number(columns = c(cost_complexity), 
             decimals = 5) %>% 
  cols_align(align = "center", columns = everything())
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de Árvore de Decisão
cost_complexity tree_depth min_n .config
0.00009 3 40 Preprocessor1_Model07
Fonte: Elaboração própria
Código
best_set_bag %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de Bageed Tree")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  fmt_number(columns = cost_complexity, 
             decimals = 7) %>% 
  cols_align(align = "center", columns = everything())
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de Bageed Tree
cost_complexity tree_depth min_n .config
0.0019777 7 19 Preprocessor1_Model10
Fonte: Elaboração própria
Código
best_set_rand_fore %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de Floresta Aleatória")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything())
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de Floresta Aleatória
mtry trees min_n .config
3 1573 16 Preprocessor1_Model03
Fonte: Elaboração própria
Código
best_set_xbost %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo XGBOOST")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything()) %>% 
  fmt_number(columns = c(learn_rate, loss_reduction, sample_size), 
             decimals = 4)
Resultado dos hiperparâmetros do melhor ajuste
Do modelo XGBOOST
mtry trees min_n tree_depth learn_rate loss_reduction sample_size .config
11 1818 21 6 0.0022 0.6912 0.8391 Preprocessor1_Model06
Fonte: Elaboração própria
Código
best_set_svmlin %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de SVM Linear")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything()) %>% 
  fmt_number(columns = c(cost, margin), 
             decimals = 4)
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de SVM Linear
cost margin .config
0.0232 0.1173 Preprocessor1_Model03
Fonte: Elaboração própria
Código
best_set_svmrf %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de SVM RBF")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything()) %>% 
  fmt_number(columns = c(cost, margin, rbf_sigma), 
             decimals = 4)
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de SVM RBF
cost rbf_sigma margin .config
1.5309 0.1228 0.1225 Preprocessor1_Model01
Fonte: Elaboração própria
Código
best_set_psvm %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste</b>"), 
    subtitle = glue::glue("Do modelo de SVM Polinomial")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything()) %>% 
  fmt_number(columns = c(cost, margin, scale_factor), 
             decimals = 4) %>% 
  fmt_number(columns = degree, 
             decimals = 0)
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de SVM Polinomial
cost degree scale_factor margin .config
16.3060 2 0.0005 0.1852 Preprocessor1_Model09
Fonte: Elaboração própria
Código
best_set_mlp %>% 
  gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos hiperparâmetros do melhor ajuste </b>"), 
    subtitle = glue::glue("Do modelo de Rede Neural Perceptron com Multi Camadas")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>% 
  cols_align(align = "center", 
             columns = everything()) %>% 
  fmt_number(columns = penalty, 
             decimals = 4) %>% 
  fmt_number(columns = c(hidden_units, epochs), 
             decimals = 0)
Resultado dos hiperparâmetros do melhor ajuste
Do modelo de Rede Neural Perceptron com Multi Camadas
hidden_units penalty epochs .config
1 0.0001 78 Preprocessor1_Model08
Fonte: Elaboração própria

Avaliando os modelos no conjunto de teste

Esses conjuntos de hiperparâmetros ótimos foram utilizados para reajustar os modelos no conjunto de treinamento completo para, em seguida, obter predições das classes da variável alvo no conjunto de teste. Foram calculadas as seguintes medidas no conjunto de teste: acurácia, área sob a curva ROC, F-measure, precision, recall, especificidade e Kappa.

  • Para isso criamos uma função que vai facilitar todo o trabalho
Código
resultado_teste <- function(rc_rslts, fit_obj, par_set, split_obj) {
  res <- rc_rslts %>%
    extract_workflow(fit_obj) %>%
    finalize_workflow(par_set) %>%
    last_fit(split = split_obj,
             metrics = metric_set(
              accuracy,roc_auc,
              f_meas,precision,
              recall,spec,kap))
  res
}
Código
resultado_teste_reg_log <- resultado_teste(grid_results, "Reg_log", best_set_linear, split_inicial)
resultado_teste_knn <- resultado_teste(grid_results, "KNN", best_set_knn, split_inicial)
resultado_teste_lda <- resultado_teste(grid_results, "LDA", best_set_lda, split_inicial)
resultado_teste_qda <- resultado_teste(grid_results, "QDA", best_set_qda, split_inicial)
resultado_teste_naive <- resultado_teste(grid_results, "Nayve_Bayes", best_set_nbayes, split_inicial)
resultado_teste_decision <- resultado_teste(grid_results, "decision", best_set_decision, split_inicial)
resultado_teste_bag <- resultado_teste(grid_results, "bag_treeee", best_set_bag, split_inicial)
resultado_teste_rand_fore <- resultado_teste(grid_results, "random_fore", best_set_rand_fore, split_inicial)
resultado_teste_xgbost <- resultado_teste(grid_results, "xgbost", best_set_xbost, split_inicial)
resultado_teste_svmlin <- resultado_teste(grid_results, "svm_lin", best_set_svmlin, split_inicial)
resultado_teste_svmrf <- resultado_teste(grid_results, "svm_rf", best_set_svmrf, split_inicial)
resultado_teste_psvm <- resultado_teste(grid_results, "psvm", best_set_psvm, split_inicial)
resultado_teste_mlp <- resultado_teste(grid_results, "rede_neural", best_set_mlp, split_inicial)
  • Agora, vamos coletar as métricas e observar os resultados:
Código
metrics_table <- rbind(collect_metrics(resultado_teste_reg_log)$.estimate, 
                       collect_metrics(resultado_teste_knn)$.estimate, 
                       collect_metrics(resultado_teste_lda)$.estimate, 
                       collect_metrics(resultado_teste_qda)$.estimate, 
                       collect_metrics(resultado_teste_naive)$.estimate, 
                       collect_metrics(resultado_teste_decision)$.estimate, 
                       collect_metrics(resultado_teste_bag)$.estimate,
                       collect_metrics(resultado_teste_rand_fore)$.estimate, 
                       collect_metrics(resultado_teste_xgbost)$.estimate, 
                       collect_metrics(resultado_teste_svmlin)$.estimate, 
                       collect_metrics(resultado_teste_svmrf)$.estimate, 
                       collect_metrics(resultado_teste_psvm)$.estimate, 
                       collect_metrics(resultado_teste_mlp)$.estimate)

Ajustando a tabela de métricas

Código
metrics_table <- round(metrics_table, 4)

row_names <- c("Regressão Logística", "KNN", "Discriminante Linear", "Discriminante Quadrático", "Naive Bayes", "Árvore de Decisão", "Bageed Tree", "Floresta Aleatória", "XG Boost", "SVM Linear", "SVM RBF", "SVM Polinomial", "Rede Neural - Multi Camada")

metrics_table <- cbind(row_names, metrics_table)

metrics_table <- metrics_table %>% 
  as.tibble()

O resultado da aplicação dos dados no conjunto de treinamento, as métricas para análise do melhor método, o qual deve ser aplicado nesses dados segue abaixo.

Código
colnames(metrics_table) <- c("Método", "Acurácia", "Curva Roc", "f_means", "Precisão", "Recall", "Específicidade", "Kappa")

metrics_table <- metrics_table %>% 
  mutate(Acurácia = as.numeric(Acurácia), 
         `Curva Roc` = as.numeric(`Curva Roc`), 
         f_means = as.numeric(f_means), 
         Precisão = as.numeric(Precisão), 
         Recall = as.numeric(Recall), 
         Específicidade = as.numeric(Específicidade), 
         Kappa = as.numeric(Kappa)) %>% 
  arrange(desc(Acurácia), desc(`Curva Roc`), desc(f_means), desc(Kappa)) 
  
metrics_table %>%  
  gt::gt() %>% 
  gt::tab_header(
    title = gt::html("<b> Resultado dos modelos nos dados de teste</b>"), 
    subtitle = glue::glue("De acordo com algumas métricas")) %>% 
  gt::tab_source_note(
    gt::html("<b> Fonte:</b> Elaboração própria")
  ) %>%  
  gt::data_color(
    columns = Acurácia, 
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 10, palette = "Green"), 
      domain = c(min(metrics_table$Acurácia), max(metrics_table$Acurácia)),
      reverse = TRUE
    )
  ) %>% 
  gt::data_color(
    columns = `Curva Roc`, 
    colors = scales::col_numeric(
      palette = colorspace::sequential_hcl(n = 10, palette = "Green"), 
      domain = c(min(metrics_table$`Curva Roc`), max(metrics_table$`Curva Roc`)),
      reverse = TRUE
    )
  ) %>% 
  cols_align(align = "center", columns = everything()) %>% 
  tab_style(
    style = list(
      cell_text(weight = "bold")
    ), 
    locations = cells_body(columns = c(Acurácia, `Curva Roc`, f_means, Kappa))
  )
Resultado dos modelos nos dados de teste
De acordo com algumas métricas
Método Acurácia Curva Roc f_means Precisão Recall Específicidade Kappa
Regressão Logística 0.8967 0.8834 0.8889 0.8780 0.9118 0.7908 0.9452
Rede Neural - Multi Camada 0.8859 0.8712 0.8765 0.8659 0.9020 0.7687 0.9448
XG Boost 0.8859 0.8696 0.8861 0.8537 0.9118 0.7682 0.9448
SVM Polinomial 0.8859 0.8696 0.8861 0.8537 0.9118 0.7682 0.9439
Discriminante Linear 0.8859 0.8696 0.8861 0.8537 0.9118 0.7682 0.9438
SVM RBF 0.8859 0.8679 0.8961 0.8415 0.9216 0.7676 0.9290
Floresta Aleatória 0.8804 0.8608 0.8947 0.8293 0.9216 0.7563 0.9433
KNN 0.8750 0.8623 0.8471 0.8780 0.8725 0.7479 0.9347
Bageed Tree 0.8750 0.8571 0.8734 0.8415 0.9020 0.7461 0.9382
SVM Linear 0.8750 0.8553 0.8831 0.8293 0.9118 0.7455 0.9419
Árvore de Decisão 0.8533 0.8323 0.8481 0.8171 0.8824 0.7019 0.8616
Discriminante Quadrático 0.8424 0.8362 0.7789 0.9024 0.7941 0.6859 0.9314
Naive Bayes 0.8370 0.8235 0.7955 0.8537 0.8235 0.6724 0.8856
Fonte: Elaboração própria
Código
test_results_rf = grid_results %>% 
   extract_workflow("Reg_log") %>% 
   finalize_workflow(best_set_linear) %>% 
   last_fit(split = split_inicial,
            metrics = metric_set(
              recall, precision, f_meas,
              accuracy, kap,
              roc_auc, sens, spec)
            )
Informações do melhor modelo

Visualizando a importância das variáveis para o problema de classificação com base no melhor modelo que foi selecionando, tendo esse modelo via método Regressão Logística, e com os seguintes parâmetros:

  • O valor da penalização do modelo escolhido de Regressão Logística foi de 0.0128;

  • O valor da mistura do modelo escolhido de Regressão Logística foi de 0.4759.

Código
test_results_rf %>% 
  pluck(".workflow",1) %>% 
  extract_fit_parsnip() %>% 
  vip::vip()

Comentário: Acima podemos analisar o grau de importância das variáveis.

Conclusão

Na tarefa passada, sem aplicar alguns desses métodos e com menos variáveis, tivemos uma acurácia menor do que 80%, agora com a aplicação desses métodos e com a utilização de mais variáveis, podemos observar que através da acurácia o melhor modelo foi Regressão Logística e a Acurácia foi 0.8967, e o pior método pela acurácia foi o método Naive Bayes.