決定木の練習

caretを使った決定木

# 必要なライブラリを読み込む
library(caret)
library(rpart)
library(tidyverse)
library(rpart.plot)

# テーマ設定
theme_set(theme_bw(base_size = 11))

# サンプルデータの作成

x <- rnorm(100, -1, 1)
y <- rnorm(100, -1, 1)
res <- rep("A", 100)

data_01 <- tibble(x,y,res)

x <- rnorm(100, 1, 1)
y <- rnorm(100, 1, 1)
res <- rep("B", 100)

data_02 <- tibble(x,y,res)

data <- full_join(data_01, data_02)

# コスト行列の作成
# 0を1と判断する際のコストを1を0と判断するコストの100倍に設定
cost_matrix <- matrix(c(0, 1, 1, 0), ncol = 2)
rownames(cost_matrix) <- colnames(cost_matrix) <- levels(data$res)

# コスト感受性の決定木モデルを作成
model <- train(res ~ x + y, data=data, method="rpart", 
               parms = list(loss=cost_matrix))

# 決定木の境界をプロットする関数
plot_decision_boundary <- function(model, data) {
  # 予測用のデータを作成
  grid_data <- expand.grid(x = seq(min(data$x), max(data$x), length.out = 100),
                           y = seq(min(data$y), max(data$y), length.out = 100))
  grid_data$res <- predict(model, grid_data)
  
  # グラフをプロット
  ggplot(data, aes(x = x, y = y)) +
    geom_tile(data = grid_data, aes(fill = res), alpha = 0.6) +
    geom_point(aes(color = res)) +
    scale_fill_manual(values = c("#FFDDDD", "#DDFFDD")) +
    scale_color_manual(values = c("red", "green")) +
    labs(title = "Decision Boundary of Decision Tree")
}

# グラフを表示
plot_decision_boundary(model, data)

# 決定木を表示
prp(model$finalModel, type=2, extra=104, main="Decision Tree")