library(reticulate)
use_condaenv("tensorflow")Text Classification of the IMDB Dataset using BERT
Setup the Python engine to use in R Markdown. Reticulate allows to run chunks of Python code, print Python output, access Python objects, and so on
Python Component
Datasets library
!pip install transformers[sentencepiece] datasets
Load the IMDB Dataset
from datasets import load_dataset
imdb = load_dataset("imdb")Overview of the dataset
imdbDatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
unsupervised: Dataset({
features: ['text', 'label'],
num_rows: 50000
})
})
R Component
Data Wrangling
library(tidyverse)
# Take 2,000 observations from the original traning dataset
set.seed(123)
n_train <- sample(25000,2000)
data_train <- as.tibble(py$imdb["train"][n_train]) %>%
mutate(row_name = n_train)
dim(data_train)[1] 2000 3
data_train# A tibble: 2,000 × 3
text label row_n…¹
<chr> <int> <int>
1 "These guys are anything but the Usual Suspects! They are a to… 1 18847
2 "Weak, fast and multicolor,this is the Valvoline's movie in fa… 1 18895
3 "From the first moment, this \"thing\" is just an awful sequen… 0 2986
4 "Veteran sleazeball Bruno Mattei is at it again with this erot… 0 1842
5 "First, IFC runs Town and Country, and now this. The differenc… 0 3371
6 "Yeah, unfortunately I came across the DVD of this and found t… 0 11638
7 "I decided to watch this one because it's been nominated for O… 0 4761
8 "Run away from this movie. Even by B-movie standards this movi… 0 6746
9 "Hollywood North is a satirical look at the time in Canadian f… 1 16128
10 "I was dreading taking my nephews to this movie, as I didn't t… 0 2757
# … with 1,990 more rows, and abbreviated variable name ¹row_name
# For the test set take 20% of 2,000 (400)observation from the orginial test set
set.seed(12345)
n_test <- sample(25000,400)
data_test <- as.tibble(py$imdb["test"][n_test]) %>%
mutate(row_name = n_test)
dim(data_test)[1] 400 3
data_test# A tibble: 400 × 3
text label row_n…¹
<chr> <int> <int>
1 "eXistenZ is an exploration of reality and virtual reality, wh… 1 14478
2 "I'm not from USA I'm from central Europe and i think the show… 1 24627
3 "This movie was thought to be low budget but it turned out to … 1 17104
4 "This woman is a terrible comedian. She can't crack a joke. Sh… 0 10904
5 "I am surprised by the relatively low rating this film has. It… 1 21306
6 "The 1935 version of \"Enchanted April\" manages to be simulta… 0 605
7 "I've lost count of the times I have seen this movie, but I lo… 1 14923
8 "There is something about Doug McLure's appearance in a movie … 0 2264
9 "Well, I'll begin with this: I love horror-movies, not even th… 0 9986
10 "This probably ranks in my Top-5 list of the funniest movies I… 1 15446
# … with 390 more rows, and abbreviated variable name ¹row_name
Distribution of the word count
theme_set(theme_light())
data_train %>%
ggplot(aes(str_count(text, '\\w+'))) +
geom_histogram(alpha = 0.8, bins = 200) +
labs(x = "Number of words",
y = "Number of film reviews")Check the balance of the dataset
data_train %>%
count(label)# A tibble: 2 × 2
label n
<int> <int>
1 0 1000
2 1 1000
Number of words for each label
data_train %>%
mutate(label = as_factor(label)) %>%
ggplot(aes(x = label, y = str_count(text, '\\w+'))) +
geom_boxplot() + labs(y = "Number of words")Splitting the traing dataset in modeling data and validation data
library(tidymodels)
train_split <- data_train %>%
initial_split(prop = 0.8)
data_train <- training(train_split)
data_validation <- testing(train_split)
dim(data_train)[1] 1600 3
dim(data_validation)[1] 400 3
dim(data_test)[1] 400 3
Back to Python Component
Update the dataset with the generated training, validation and test data from R
imdb.pop('unsupervised')imdb['validation'] = imdb['train'].select(r.data_validation.row_name)
imdb['train'] = imdb['train'].select(r.data_train.row_name)
imdb['test'] = imdb['test'].select(r.data_test.row_name)Tokenization
from transformers import AutoTokenizer
checkpoint = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
def tokenize_function(batch):
return tokenizer(batch["text"], padding=True, truncation=True)
imdb_encoded = imdb.map(tokenize_function, batched=True, batch_size=None)
# Show the sentence, the different tokens and the corresponding numerical ids
#print(imdb_encoded['train'][0])Definition of the model for training
from transformers import AutoModelForSequenceClassification
num_labels = 2
model = (AutoModelForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels))Function to calculate the accuracy
from sklearn.metrics import accuracy_score
def get_accuracy(preds):
predictions = preds.predictions.argmax(axis=-1)
labels = preds.label_ids
accuracy = accuracy_score(preds.label_ids, preds.predictions.argmax(axis=-1))
return {'accuracy': accuracy}Parameters for the model
from transformers import TrainingArguments
batch_size = 16
logging_steps = len(imdb_encoded["train"]) // batch_size
model_name = "distilbert-base-cased-finetuned-imdb"
training_args = TrainingArguments(output_dir=model_name,
num_train_epochs=2,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
evaluation_strategy="epoch",
disable_tqdm=False,
logging_steps=logging_steps,
log_level="error",
optim='adamw_torch'
)Model training
from transformers import Trainer
trainer = Trainer(model=model,
args=training_args,
compute_metrics=get_accuracy,
train_dataset=imdb_encoded["train"],
eval_dataset=imdb_encoded["validation"],
tokenizer=tokenizer)
trainer.train()
## Save the modeltrainer.save_model()
## Training and validation accuracy
trainer.evaluate()
## Get the predictionspreds = trainer.predict(imdb_encoded['test'])pred_class = preds.predictions.argmax(axis=-1)
label = preds.label_idsBack to R
Computing different model metrics
prediction <- tibble(pred_class = as_factor(py$pred_class),
label = as_factor(py$label))
# Accuracy
metrics(prediction, label, pred_class)# A tibble: 2 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.878
2 kap binary 0.755
# Confusion matrix
prediction %>% conf_mat(label, pred_class) %>%
autoplot(type = "heatmap")Back to Python
Playing with the language model
from transformers import pipeline
classifier = pipeline('text-classification', model=model_name)
classifier('This is not my idea of fun')[{'label': 'LABEL_0', 'score': 0.5372037291526794}]
classifier('This was beyond incredible')[{'label': 'LABEL_1', 'score': 0.7653437852859497}]