Многоклассовая классификация

Author

Лесниченко Михаил

Published

May 31, 2026

Импортируем библиотеки

library(tidyverse)
library(tidytext)
library(tidymodels)
library(vip)
library(textstem)

tidymodels_prefer()

Загрузка данных

overview <- read_tsv("overview.tsv")
files <- list.files("british_fiction", full.names = TRUE)

data <- overview |>
  mutate(
    file = files,
    text = map_chr(file, read_file),
    doc_id = row_number(),
    author = factor(authorID)
  ) |>
  select(doc_id, author, text)

Предобработка

data <- data |>
  mutate(text_clean = str_to_lower(text),
         text_clean = str_replace_all(text_clean, "[^a-z\\s]", " "),
         text_clean = str_squish(text_clean))


# 2. Признаки

tokens <- data |>
  unnest_tokens(word, text_clean) |>
  anti_join(stop_words, by = "word") |>
  mutate(word = lemmatize_words(word)) |>
  filter(str_length(word) > 1)

# Стилистические признаки
style <- tokens |>
  group_by(doc_id, author) |>
  summarise(
    n_words = n(),
    ttr = n_distinct(word) / n(),
    avg_word_len = mean(str_length(word)),
    sd_word_len = sd(str_length(word)),
    .groups = "drop"
  ) |>
  mutate(across(c(ttr, avg_word_len, sd_word_len), ~replace_na(., 0)))

# Признаки предложений
sentence_features <- data |>
  mutate(
    n_sentences = str_count(text, "[.!?]+"),
    avg_sent_len = str_count(text_clean, "\\w+") / n_sentences
  ) |>
  mutate(avg_sent_len = ifelse(is.na(avg_sent_len), 0, avg_sent_len)) |>
  select(doc_id, n_sentences, avg_sent_len)

# Частотность служебных слов
function_words <- c("the", "be", "to", "of", "and", "a", "in", "that", 
                    "it", "for", "on", "with", "as", "at", "by", "from")

fw_freq <- data |>
  unnest_tokens(word, text_clean) |>
  filter(word %in% function_words) |>
  count(doc_id, word) |>
  group_by(doc_id) |>
  mutate(prop = n / sum(n)) |>
  select(-n) |>
  pivot_wider(names_from = word, values_from = prop, values_fill = 0)

# TF-IDF
word_features <- tokens |>
  count(doc_id, word) |>
  bind_tf_idf(word, doc_id, n) |>
  group_by(doc_id) |>
  slice_max(tf_idf, n = 20) |>
  ungroup() |>
  pivot_wider(names_from = word, values_from = tf_idf, values_fill = 0)

# Объединяем
model_data <- style |>
  left_join(sentence_features, by = "doc_id") |>
  left_join(fw_freq, by = "doc_id") |>
  left_join(word_features, by = "doc_id") |>
  mutate(across(where(is.numeric), ~replace_na(., 0))) |>
  mutate(author = factor(author)) |>
  select(-doc_id, where(~n_distinct(.) > 1))

EDA

p1 <- model_data |>
  count(author) |>
  ggplot(aes(x = reorder(author, n), y = n, fill = author)) +
  geom_col() +
  coord_flip() +
  labs(title = "Number of Texts per Author", x = "Author", y = "Count") +
  theme_minimal() +
  theme(legend.position = "none")
print(p1)

p2 <- model_data |>
  select(author, n_words, ttr, avg_word_len, avg_sent_len) |>
  pivot_longer(-author, names_to = "feature", values_to = "value") |>
  ggplot(aes(x = author, y = value, fill = author)) +
  geom_boxplot() +
  facet_wrap(~feature, scales = "free_y") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  labs(title = "Linguistic Features by Author")
print(p2)

Моделирование

set.seed(123)
split <- initial_split(model_data, prop = 0.8, strata = author)
train <- training(split)
test <- testing(split)

cv_folds <- vfold_cv(train, v = 5, strata = author)
metrics <- metric_set(accuracy, f_meas, precision, recall)

# Recipe
rec <- recipe(author ~ ., data = train) |>
  step_zv(all_predictors()) |>
  step_normalize(all_numeric_predictors())

# Модели
model_lasso <- multinom_reg(penalty = tune(), mixture = 1) |>
  set_engine("glmnet") |>
  set_mode("classification")

model_rf <- rand_forest(mtry = tune(), trees = 100, min_n = tune()) |>
  set_engine("ranger", importance = "impurity") |>
  set_mode("classification")

model_ridge <- multinom_reg(penalty = tune(), mixture = 0) |>
  set_engine("glmnet") |>
  set_mode("classification")

# Workflows
wf_lasso <- workflow() |> add_recipe(rec) |> add_model(model_lasso)
wf_rf <- workflow() |> add_recipe(rec) |> add_model(model_rf)
wf_ridge <- workflow() |> add_recipe(rec) |> add_model(model_ridge)

# Настройка
set.seed(123)
lasso_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 4))
rf_grid <- expand.grid(mtry = c(3, 5, 8), min_n = c(2, 5))
ridge_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 4))

ctrl <- control_grid(verbose = FALSE)

lasso_res <- tune_grid(wf_lasso, cv_folds, grid = lasso_grid, metrics = metrics, control = ctrl)
rf_res <- tune_grid(wf_rf, cv_folds, grid = rf_grid, metrics = metrics, control = ctrl)
ridge_res <- tune_grid(wf_ridge, cv_folds, grid = ridge_grid, metrics = metrics, control = ctrl)

# Лучшие параметры
best_lasso <- select_best(lasso_res, metric = "accuracy")
best_rf <- select_best(rf_res, metric = "accuracy")
best_ridge <- select_best(ridge_res, metric = "accuracy")

# Финальные модели
final_lasso <- finalize_workflow(wf_lasso, best_lasso) |> fit(train)
final_rf <- finalize_workflow(wf_rf, best_rf) |> fit(train)
final_ridge <- finalize_workflow(wf_ridge, best_ridge) |> fit(train)

# Предсказания
preds_lasso <- predict(final_lasso, test) |> bind_cols(test)
preds_rf <- predict(final_rf, test) |> bind_cols(test)
preds_ridge <- predict(final_ridge, test) |> bind_cols(test)

# Предсказания с вероятностями
preds_lasso_prob <- predict(final_lasso, test, type = "prob") |> bind_cols(test)
preds_rf_prob <- predict(final_rf, test, type = "prob") |> bind_cols(test)
preds_ridge_prob <- predict(final_ridge, test, type = "prob") |> bind_cols(test)

Сравнение моделей

compare_models <- bind_rows(
  metrics(preds_lasso, truth = author, estimate = .pred_class) |> mutate(model = "Lasso"),
  metrics(preds_rf, truth = author, estimate = .pred_class) |> mutate(model = "RF"),
  metrics(preds_ridge, truth = author, estimate = .pred_class) |> mutate(model = "Ridge")
)

print(compare_models)
# A tibble: 12 × 4
   .metric   .estimator .estimate model
   <chr>     <chr>          <dbl> <chr>
 1 accuracy  multiclass     0.963 Lasso
 2 f_meas    macro          0.986 Lasso
 3 precision macro          0.975 Lasso
 4 recall    macro          0.909 Lasso
 5 accuracy  multiclass     1     RF   
 6 f_meas    macro          1     RF   
 7 precision macro          1     RF   
 8 recall    macro          1     RF   
 9 accuracy  multiclass     0.872 Ridge
10 f_meas    macro          0.870 Ridge
11 precision macro          0.911 Ridge
12 recall    macro          0.799 Ridge
p3 <- compare_models |>
  ggplot(aes(x = model, y = .estimate, fill = model)) +
  geom_col() +
  facet_wrap(~.metric, scales = "free_y") +
  labs(title = "Model Comparison", y = "Score", x = "") +
  theme_minimal() +
  theme(legend.position = "none")
print(p3)

best_model_name <- compare_models |>
  filter(.metric == "accuracy") |>
  slice_max(.estimate) |>
  pull(model)

best_probs <- switch(best_model_name,
                     "Lasso" = preds_lasso_prob,
                     "RF" = preds_rf_prob,
                     "Ridge" = preds_ridge_prob)

# Confusion Matrix

best_preds <- switch(best_model_name,
                     "Lasso" = preds_lasso,
                     "RF" = preds_rf,
                     "Ridge" = preds_ridge)

p5 <- conf_mat(best_preds, truth = author, estimate = .pred_class) |>
  autoplot(type = "heatmap") +
  labs(title = paste("Confusion Matrix -", best_model_name)) +
  theme_minimal() +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1),
    axis.text.y = element_text(angle = 0)
  )

print(p5)

Важность признаков

if(best_model_name == "RF") {
  p6 <- final_rf |>
    extract_fit_parsnip() |>
    vip(num_features = 10) +
    labs(title = "Top 10 Important Features - Random Forest") +
    theme_minimal()
  print(p6)
} else {
  coef_data <- get(paste0("final_", tolower(best_model_name))) |>
    extract_fit_parsnip() |>
    tidy() |>
    filter(term != "(Intercept)") |>
    group_by(term) |>
    summarise(mean_coef = mean(abs(estimate))) |>
    slice_max(mean_coef, n = 10)
  
  print(coef_data)
  
  p6 <- coef_data |>
    ggplot(aes(x = reorder(term, mean_coef), y = mean_coef)) +
    geom_col(fill = "steelblue") +
    coord_flip() +
    labs(title = paste("Top 10 Features -", best_model_name),
         x = "Feature", y = "Mean Absolute Coefficient") +
    theme_minimal()
  print(p6)
}

Анализ результатов

В результате сравнения моделей лучше всего себя показал случайный лес. Он дал самую высокую точность и практически без ошибок определял авторов текстов.

Это говорит о том, что в корпусе действительно есть заметные различия в стиле письма разных авторов. Модель смогла их хорошо “поймать” и использовать для классификации.

Также видно, что более простые модели (например, логистическая регрессия с Lasso и Ridge) тоже показывают хорошие результаты, но немного уступают случайному лесу.

В целом можно сказать, что лингвистические признаки, которые мы использовали (длина слов, структура предложений, частотность служебных слов и ключевые слова), хорошо подходят для задачи определения авторства.