Análise Discriminante com validação cruzada

Manuel Luís Castro Ribeiro https://fenix.tecnico.ulisboa.pt/homepage/ist90267 (CERENA/IST-UL)https://cerena.pt/
2021-11-05

Introdução

A validação cruzada para modelos de análise discriminante (AD) pode ser implementada em R de várias formas. Vamos ver dois exemplos com as funções MASS::lda() e funçao caret::train(). Por fim vamos aproveitar o código das aulas de laboratório para criar uma função (user-defined function) que devolve o resultado da validação k-fold com análise discriminante.

Dados

Vamos aplicar as funções a uma tabela de dados sinteticos (gerados artificialmente, por simulação), com duas variáveis quantitativas (explicativas) e uma variável resposta (categorica, com 3 categorias):

# dataset sintetico
# x1 x2 são variaveis explicativas numericas
# y é a variavel resposta categorica

# variaveis explicativas
x1 <- rnorm(100) # rnorm() gera observações com dist normal std
x2 <- rnorm(100) 

# variavel resposta
y <- sample(c("a","b","c"), 100, replace = TRUE) # sample() gera amostar aleatoria
y <- factor(y) 

# cria o dataframe
datxy <- data.frame(y, x1, x2)

Funções

MASS::lda()

Esta função tem a validação cruzada (loocv) implementada. Basta especificar o argumento CV=T.

# livraria para calcular lda()
library(MASS)

# validação cruzada (looc)
lda0 <- lda(y ~ x1 + x2, CV = TRUE, data = datxy)
cat("Precisão: ", mean(lda0$class == datxy$y))
Precisão:  0.39

Para a análise discriminante quadrática :

# validação cruzada (looc)
qda0 <- qda(y ~ x1 + x2, CV = TRUE, data = datxy)
cat("Precisão: ", mean(qda0$class == datxy$y))
Precisão:  0.37

caret::train()

A livraria caret é muito flexivel, e integra vários métodos de aprendizagem automática. Para implementar a análise discriminante, a sintaxe desta função requere que se forneçam os parametros de treino. Vamos especificar o método de validação cruzada k-fold, com 10 partições:

# livraria para usar validaçao cruzada k-fold
library(caret)

# temos de dar os parametros de treino
control <- trainControl(method = "cv", number = 10)

A análise discriminante é implementada com a função train(), especificando os argumento method. A métrica usada (precisão) é especificada no argumento metric e os parâmetros de treino no argumento trControl.

# ajuste do modelo
lda1 <- train(y ~ x1 + x2,  data = datxy, method = "lda",  metric = "Accuracy", trControl = control)

# precisão
cat("Precisão: ", lda1$results[,2])
Precisão:  0.3380808

Especificando o metodo repeatedcv na função trainControl() podemos obter repetições do processo de reamostragem. Vamos repetir a validação cruzada k-fold 20 vezes.

# parametros de treino
# para repetir validação cruzada k-fold 20 vezes
control20 <- trainControl(method = "repeatedcv", number = 10, repeats = 20)

# ajuste do modelo
lda2 <- train(y ~ x1 + x2,  data = datxy, method = "lda",  metric = "Accuracy", trControl = control20)

# precisão
cat("Precisão: ", lda2$results[,2])
Precisão:  0.3848359

cv_ad()

Aproveitando o código das aulas de laboratório (lab3_class.R e lab4_cross.R), vamos criar uma função que vamos chamar cv_ad() e que executa validação k-fold com análise discriminante, realizando os passos seguintes:

  1. distribui aleatoriamente as amostras por K partições,
  2. De 1 até K:
    1. ajusta modelo lda (ou qda) às observações das K-1 partições de treino,
    2. prevê a classe nas obs. da partição de validação, usando o modelo ajustado,
    3. calcula e guarda o valor da precisão.
  3. Calcula e devolve a precisão media.

Recordando a sintaxe para construção de uma função é f = function (argumentos) { corpo da função } em que o valor devolvido pela função é geralmente a última linha que a função avalia ou return().

cv_ad = function(dataset, matx, vety, ad = "lda", k = 10){
  
  # n obs
  n <- nrow(dataset)
  
  # n de partições
  K <- k
  
  x <- matx
  y <- vety
  
  if(k>n) stop("nr máximo de partições = nr observaçoes (loocv).")
  
  # gerar as partições
  
  foldk <- if(n %% K == 0){ 
    # se n é divisivel por K
    # criam-se K partições com n/K elementos cada
  rep(1:K, each = n/K)
  } else {
    # se n não é divisivel por K, criam-se K-1 partições com n/K elementos 
    # e a última partição com numero de elementos = n/K + resto
    c(rep(1:(K-1), each = n %/% K),
      rep(K, n %/% K + n %% K))
  }
  
  # gera vetor de números aleatorios, sem reposição
  rnd <- sample(nrow(dataset), nrow(dataset)) 
  
  # usa o vetor para atribuir ordem (aleatoria) das amostras
  dtfold <- dataset[rnd,]
  
  # cria dataset com amostras distribuidas 
  # aleatoriamente pelas K partições
  dtfold <- data.frame(dtfold, foldk)
  
  # inicializa vetor resultados
  cv_kfold <- c() 
  for (i in 1:K){
    tr <- dtfold[!foldk == i,]; 
    va <- dtfold[foldk == i,]
    
    if (ad == "lda"){
      ad_treino <- lda(y ~ x1 + x2, data = tr)
    } else if(ad == "qda"){
      ad_treino <- qda(y ~ x1 + x2, data = tr)
    } else {
      stop("argumento ad errado") 
    }
    
    # previsão nos dados validação i
    ad_valid <- predict(ad_treino, newdata = va)
    
    # precisão nos dados validação i
    cv_kfold[i] = mean(ad_valid$class == va$y)
  }
  
  # função devolve a média daprecisão
  return(precisao = mean(cv_kfold))
}

A função devolve resultado da validação cruzada com análise discriminante:

cv_ad(dataset = datxy, matx = c(x1,x2), vety = y, ad = "qda", k = 20 )
[1] 0.33
cv_ad(dataset = datxy, matx = c(x1,x2), vety = y, ad = "lda", k = 20 )
[1] 0.34