library(tidyverse)
library(tidytext)
library(tidymodels)
library(vip)
library(textstem)
tidymodels_prefer()Многоклассовая классификация
Импортируем библиотеки
Загрузка данных
overview <- read_tsv("overview.tsv")
files <- list.files("british_fiction", full.names = TRUE)
data <- overview |>
mutate(
file = files,
text = map_chr(file, read_file),
doc_id = row_number(),
author = factor(authorID)
) |>
select(doc_id, author, text)Предобработка
data <- data |>
mutate(text_clean = str_to_lower(text),
text_clean = str_replace_all(text_clean, "[^a-z\\s]", " "),
text_clean = str_squish(text_clean))
# 2. Признаки
tokens <- data |>
unnest_tokens(word, text_clean) |>
anti_join(stop_words, by = "word") |>
mutate(word = lemmatize_words(word)) |>
filter(str_length(word) > 1)
# Стилистические признаки
style <- tokens |>
group_by(doc_id, author) |>
summarise(
n_words = n(),
ttr = n_distinct(word) / n(),
avg_word_len = mean(str_length(word)),
sd_word_len = sd(str_length(word)),
.groups = "drop"
) |>
mutate(across(c(ttr, avg_word_len, sd_word_len), ~replace_na(., 0)))
# Признаки предложений
sentence_features <- data |>
mutate(
n_sentences = str_count(text, "[.!?]+"),
avg_sent_len = str_count(text_clean, "\\w+") / n_sentences
) |>
mutate(avg_sent_len = ifelse(is.na(avg_sent_len), 0, avg_sent_len)) |>
select(doc_id, n_sentences, avg_sent_len)
# Частотность служебных слов
function_words <- c("the", "be", "to", "of", "and", "a", "in", "that",
"it", "for", "on", "with", "as", "at", "by", "from")
fw_freq <- data |>
unnest_tokens(word, text_clean) |>
filter(word %in% function_words) |>
count(doc_id, word) |>
group_by(doc_id) |>
mutate(prop = n / sum(n)) |>
select(-n) |>
pivot_wider(names_from = word, values_from = prop, values_fill = 0)
# TF-IDF
word_features <- tokens |>
count(doc_id, word) |>
bind_tf_idf(word, doc_id, n) |>
group_by(doc_id) |>
slice_max(tf_idf, n = 20) |>
ungroup() |>
pivot_wider(names_from = word, values_from = tf_idf, values_fill = 0)
# Объединяем
model_data <- style |>
left_join(sentence_features, by = "doc_id") |>
left_join(fw_freq, by = "doc_id") |>
left_join(word_features, by = "doc_id") |>
mutate(across(where(is.numeric), ~replace_na(., 0))) |>
mutate(author = factor(author)) |>
select(-doc_id, where(~n_distinct(.) > 1))EDA
p1 <- model_data |>
count(author) |>
ggplot(aes(x = reorder(author, n), y = n, fill = author)) +
geom_col() +
coord_flip() +
labs(title = "Number of Texts per Author", x = "Author", y = "Count") +
theme_minimal() +
theme(legend.position = "none")
print(p1)p2 <- model_data |>
select(author, n_words, ttr, avg_word_len, avg_sent_len) |>
pivot_longer(-author, names_to = "feature", values_to = "value") |>
ggplot(aes(x = author, y = value, fill = author)) +
geom_boxplot() +
facet_wrap(~feature, scales = "free_y") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
labs(title = "Linguistic Features by Author")
print(p2)Моделирование
set.seed(123)
split <- initial_split(model_data, prop = 0.8, strata = author)
train <- training(split)
test <- testing(split)
cv_folds <- vfold_cv(train, v = 5, strata = author)
metrics <- metric_set(accuracy, f_meas, precision, recall)
# Recipe
rec <- recipe(author ~ ., data = train) |>
step_zv(all_predictors()) |>
step_normalize(all_numeric_predictors())
# Модели
model_lasso <- multinom_reg(penalty = tune(), mixture = 1) |>
set_engine("glmnet") |>
set_mode("classification")
model_rf <- rand_forest(mtry = tune(), trees = 100, min_n = tune()) |>
set_engine("ranger", importance = "impurity") |>
set_mode("classification")
model_ridge <- multinom_reg(penalty = tune(), mixture = 0) |>
set_engine("glmnet") |>
set_mode("classification")
# Workflows
wf_lasso <- workflow() |> add_recipe(rec) |> add_model(model_lasso)
wf_rf <- workflow() |> add_recipe(rec) |> add_model(model_rf)
wf_ridge <- workflow() |> add_recipe(rec) |> add_model(model_ridge)
# Настройка
set.seed(123)
lasso_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 4))
rf_grid <- expand.grid(mtry = c(3, 5, 8), min_n = c(2, 5))
ridge_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 4))
ctrl <- control_grid(verbose = FALSE)
lasso_res <- tune_grid(wf_lasso, cv_folds, grid = lasso_grid, metrics = metrics, control = ctrl)
rf_res <- tune_grid(wf_rf, cv_folds, grid = rf_grid, metrics = metrics, control = ctrl)
ridge_res <- tune_grid(wf_ridge, cv_folds, grid = ridge_grid, metrics = metrics, control = ctrl)
# Лучшие параметры
best_lasso <- select_best(lasso_res, metric = "accuracy")
best_rf <- select_best(rf_res, metric = "accuracy")
best_ridge <- select_best(ridge_res, metric = "accuracy")
# Финальные модели
final_lasso <- finalize_workflow(wf_lasso, best_lasso) |> fit(train)
final_rf <- finalize_workflow(wf_rf, best_rf) |> fit(train)
final_ridge <- finalize_workflow(wf_ridge, best_ridge) |> fit(train)
# Предсказания
preds_lasso <- predict(final_lasso, test) |> bind_cols(test)
preds_rf <- predict(final_rf, test) |> bind_cols(test)
preds_ridge <- predict(final_ridge, test) |> bind_cols(test)
# Предсказания с вероятностями
preds_lasso_prob <- predict(final_lasso, test, type = "prob") |> bind_cols(test)
preds_rf_prob <- predict(final_rf, test, type = "prob") |> bind_cols(test)
preds_ridge_prob <- predict(final_ridge, test, type = "prob") |> bind_cols(test)Сравнение моделей
compare_models <- bind_rows(
metrics(preds_lasso, truth = author, estimate = .pred_class) |> mutate(model = "Lasso"),
metrics(preds_rf, truth = author, estimate = .pred_class) |> mutate(model = "RF"),
metrics(preds_ridge, truth = author, estimate = .pred_class) |> mutate(model = "Ridge")
)
print(compare_models)# A tibble: 12 × 4
.metric .estimator .estimate model
<chr> <chr> <dbl> <chr>
1 accuracy multiclass 0.963 Lasso
2 f_meas macro 0.986 Lasso
3 precision macro 0.975 Lasso
4 recall macro 0.909 Lasso
5 accuracy multiclass 1 RF
6 f_meas macro 1 RF
7 precision macro 1 RF
8 recall macro 1 RF
9 accuracy multiclass 0.872 Ridge
10 f_meas macro 0.870 Ridge
11 precision macro 0.911 Ridge
12 recall macro 0.799 Ridge
p3 <- compare_models |>
ggplot(aes(x = model, y = .estimate, fill = model)) +
geom_col() +
facet_wrap(~.metric, scales = "free_y") +
labs(title = "Model Comparison", y = "Score", x = "") +
theme_minimal() +
theme(legend.position = "none")
print(p3)best_model_name <- compare_models |>
filter(.metric == "accuracy") |>
slice_max(.estimate) |>
pull(model)
best_probs <- switch(best_model_name,
"Lasso" = preds_lasso_prob,
"RF" = preds_rf_prob,
"Ridge" = preds_ridge_prob)
# Confusion Matrix
best_preds <- switch(best_model_name,
"Lasso" = preds_lasso,
"RF" = preds_rf,
"Ridge" = preds_ridge)
p5 <- conf_mat(best_preds, truth = author, estimate = .pred_class) |>
autoplot(type = "heatmap") +
labs(title = paste("Confusion Matrix -", best_model_name)) +
theme_minimal() +
theme(
axis.text.x = element_text(angle = 45, hjust = 1),
axis.text.y = element_text(angle = 0)
)
print(p5)Важность признаков
if(best_model_name == "RF") {
p6 <- final_rf |>
extract_fit_parsnip() |>
vip(num_features = 10) +
labs(title = "Top 10 Important Features - Random Forest") +
theme_minimal()
print(p6)
} else {
coef_data <- get(paste0("final_", tolower(best_model_name))) |>
extract_fit_parsnip() |>
tidy() |>
filter(term != "(Intercept)") |>
group_by(term) |>
summarise(mean_coef = mean(abs(estimate))) |>
slice_max(mean_coef, n = 10)
print(coef_data)
p6 <- coef_data |>
ggplot(aes(x = reorder(term, mean_coef), y = mean_coef)) +
geom_col(fill = "steelblue") +
coord_flip() +
labs(title = paste("Top 10 Features -", best_model_name),
x = "Feature", y = "Mean Absolute Coefficient") +
theme_minimal()
print(p6)
}Анализ результатов
В результате сравнения моделей лучше всего себя показал случайный лес. Он дал самую высокую точность и практически без ошибок определял авторов текстов.
Это говорит о том, что в корпусе действительно есть заметные различия в стиле письма разных авторов. Модель смогла их хорошо “поймать” и использовать для классификации.
Также видно, что более простые модели (например, логистическая регрессия с Lasso и Ridge) тоже показывают хорошие результаты, но немного уступают случайному лесу.
В целом можно сказать, что лингвистические признаки, которые мы использовали (длина слов, структура предложений, частотность служебных слов и ключевые слова), хорошо подходят для задачи определения авторства.