機械学習でデータの分類をやろうと思ったとき、正例と負例のデータ数が不均衡(imbalanced)であったとき、問題が生じます。
この記事では、iris データを用いて不均衡データを作成し、その分類にどのような問題が生じるのかを確認します。
さらに、不均衡データの分類における問題を解決する手法として、SMOTE アルゴリズム(Synthetic Minority Over-sampling Technique)を適用してみます。
まず、iris データの versicolor と virginica を取り出し、Sepal.Length と Petal.Width のみを変数としたデータを作成します。
library(dplyr)
data <- iris %>%
filter(Species %in% c("versicolor", "virginica")) %>%
select(Sepal.Length, Petal.Width, Species) %>%
mutate(Species=as.factor(as.character(Species)))
head(data, 3)
## Sepal.Length Petal.Width Species
## 1 7.0 1.4 versicolor
## 2 6.4 1.5 versicolor
## 3 6.9 1.5 versicolor
このデータを、分類精度を検証するために、訓練データとテストデータに分けます。
訓練データのみを用いて分類器を作成し、テストデータをうまく分類できるかどうかを見るためです。
library(caret)
train_index <- createDataPartition(data$Species, p=0.7, list=FALSE)
train_data <- data[train_index,]
test_data <- data[-train_index,]
myplot <- function(data) {
plot(Sepal.Length ~ Petal.Width, data=data, col=Species, pch=19,
xlim=c(1, 2.5), ylim=c(4.5, 8))
}
myplot(train_data)
このデータは不均衡データではありません。
このデータに対してランダムフォレストを用いて分類器を作成してみましょう。
compute <- function(data) {
model <- train(data %>% select(-Species), data$Species,
method="rf", preProcess=c("center", "scale"),
trControl=trainControl(method="oob"), ntree=2000)
all_pred <- extractPrediction(list(model),
testX=test_data %>% select(-Species),
testY=test_data$Species)
pred <- all_pred %>% filter(dataType == "Test")
cm <- confusionMatrix(pred$pred, pred$obs)
cm$table
}
compute(train_data)
## Reference
## Prediction versicolor virginica
## versicolor 12 1
## virginica 3 14
不均衡データでない場合、分類結果はこのようになりました。
では、このデータを不均衡データにしてみましょう。
35 個ある virginica のデータを 6 個にしてみます。
data_imbalanced <- train_data %>%
slice(c(1:35, sample(36:70, size = 6, replace = FALSE)))
myplot(data_imbalanced)
この不均衡データに対して分類器を作成してみましょう。
compute(data_imbalanced)
## Reference
## Prediction versicolor virginica
## versicolor 15 7
## virginica 0 8
分類結果はこのようになりました。
この結果をよく見ると、virginica を versicolor と誤分類することが多いようです。
これは、訓練データにおいて virginica の割合が少ないため、分類器はとりあえず versicolor に分類したほうが正解しやすくなる、という状況から生じた誤分類です。
このような誤分類が生じることが、不均衡データの分類における問題点です。
このような問題を解決するために考案された手法の一つに SMOTE があります。
SMOTE は不均衡データに対して、少ない方のデータを人工的に生成し、多い方のデータを削除することによって、均衡データに近づけるという手法です。
R で SMOTE を行うには、DMwR パッケージの SMOTE() 関数を使います。
library(DMwR)
data_smote <- SMOTE(Species ~ ., data = data_imbalanced)
myplot(data_smote)
SMOTE を適用することによって、データが均衡状態に近づきました。
このデータに対して分類器を作成してみましょう。
compute(data_smote)
## Reference
## Prediction versicolor virginica
## versicolor 15 1
## virginica 0 14
データの不均衡状態が緩和されたため、先ほどのような誤分類が生じなくなりました。
不均衡データと SMOTE 適用後のデータの各 Species のデータ数を確認してみましょう。
data_imbalanced %>% count(Species)
## Species n
## 1 versicolor 35
## 2 virginica 6
data_smote %>% count(Species)
## Species n
## 1 versicolor 24
## 2 virginica 18
SMOTE によって、virginica のデータ数が増え、versicolor のデータが減ったことが確認できました。
今回は、SMOTE アルゴリズムの適用によって、不均衡データにおける分類の問題がうまく解決されましたが、データによっては SMOTE では解決されない場合があります。
SMOTE アルゴリズムの詳細や、不均衡データに対する SMOTE 以外の手法についての情報は、下記参考サイトをご参照下さい。
以上です。