Text Classification of the IMDB Dataset using BERT

Author

Dr. Hervé Teguim

Setup the Python engine to use in R Markdown. Reticulate allows to run chunks of Python code, print Python output, access Python objects, and so on

library(reticulate)
use_condaenv("tensorflow")

Python Component

Datasets library

!pip install transformers[sentencepiece] datasets

Load the IMDB Dataset

from datasets import load_dataset
imdb = load_dataset("imdb")

Overview of the dataset

imdb
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

R Component

Data Wrangling

library(tidyverse)
# Take 2,000 observations from the original traning dataset
set.seed(123)
n_train <- sample(25000,2000)
data_train <- as.tibble(py$imdb["train"][n_train]) %>%
  mutate(row_name = n_train)
dim(data_train)
[1] 2000    3
data_train
# A tibble: 2,000 × 3
   text                                                            label row_n…¹
   <chr>                                                           <int>   <int>
 1 "These guys are anything but the Usual Suspects! They are a to…     1   18847
 2 "Weak, fast and multicolor,this is the Valvoline's movie in fa…     1   18895
 3 "From the first moment, this \"thing\" is just an awful sequen…     0    2986
 4 "Veteran sleazeball Bruno Mattei is at it again with this erot…     0    1842
 5 "First, IFC runs Town and Country, and now this. The differenc…     0    3371
 6 "Yeah, unfortunately I came across the DVD of this and found t…     0   11638
 7 "I decided to watch this one because it's been nominated for O…     0    4761
 8 "Run away from this movie. Even by B-movie standards this movi…     0    6746
 9 "Hollywood North is a satirical look at the time in Canadian f…     1   16128
10 "I was dreading taking my nephews to this movie, as I didn't t…     0    2757
# … with 1,990 more rows, and abbreviated variable name ¹​row_name
# For the test set take 20% of 2,000 (400)observation from the orginial test set
set.seed(12345)
n_test <- sample(25000,400)
data_test <- as.tibble(py$imdb["test"][n_test]) %>%
  mutate(row_name = n_test)
dim(data_test)
[1] 400   3
data_test
# A tibble: 400 × 3
   text                                                            label row_n…¹
   <chr>                                                           <int>   <int>
 1 "eXistenZ is an exploration of reality and virtual reality, wh…     1   14478
 2 "I'm not from USA I'm from central Europe and i think the show…     1   24627
 3 "This movie was thought to be low budget but it turned out to …     1   17104
 4 "This woman is a terrible comedian. She can't crack a joke. Sh…     0   10904
 5 "I am surprised by the relatively low rating this film has. It…     1   21306
 6 "The 1935 version of \"Enchanted April\" manages to be simulta…     0     605
 7 "I've lost count of the times I have seen this movie, but I lo…     1   14923
 8 "There is something about Doug McLure's appearance in a movie …     0    2264
 9 "Well, I'll begin with this: I love horror-movies, not even th…     0    9986
10 "This probably ranks in my Top-5 list of the funniest movies I…     1   15446
# … with 390 more rows, and abbreviated variable name ¹​row_name

Distribution of the word count

theme_set(theme_light())
data_train %>%
  ggplot(aes(str_count(text, '\\w+'))) + 
  geom_histogram(alpha = 0.8, bins = 200) +
  labs(x = "Number of words",
       y = "Number of film reviews")

Check the balance of the dataset

data_train %>%
  count(label)
# A tibble: 2 × 2
  label     n
  <int> <int>
1     0  1000
2     1  1000

Number of words for each label

data_train %>%
  mutate(label = as_factor(label)) %>%
  ggplot(aes(x = label, y = str_count(text, '\\w+'))) + 
  geom_boxplot() + labs(y = "Number of words")

Splitting the traing dataset in modeling data and validation data

library(tidymodels)
train_split <- data_train %>%
  initial_split(prop = 0.8)
data_train <- training(train_split)
data_validation <- testing(train_split)
dim(data_train)
[1] 1600    3
dim(data_validation)
[1] 400   3
dim(data_test)
[1] 400   3

Back to Python Component

Update the dataset with the generated training, validation and test data from R

imdb.pop('unsupervised')
imdb['validation'] = imdb['train'].select(r.data_validation.row_name)
imdb['train'] = imdb['train'].select(r.data_train.row_name)
imdb['test'] = imdb['test'].select(r.data_test.row_name)

Tokenization

from transformers import AutoTokenizer

checkpoint = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

def tokenize_function(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)

imdb_encoded = imdb.map(tokenize_function, batched=True, batch_size=None)

# Show the sentence, the different tokens and the corresponding numerical ids
#print(imdb_encoded['train'][0])

Definition of the model for training

from transformers import AutoModelForSequenceClassification

num_labels = 2
model = (AutoModelForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels))

Function to calculate the accuracy

from sklearn.metrics import accuracy_score

def get_accuracy(preds):
  predictions = preds.predictions.argmax(axis=-1)
  labels = preds.label_ids
  accuracy = accuracy_score(preds.label_ids, preds.predictions.argmax(axis=-1))
  return {'accuracy': accuracy}

Parameters for the model

from transformers import TrainingArguments

batch_size = 16
logging_steps = len(imdb_encoded["train"]) // batch_size
model_name = "distilbert-base-cased-finetuned-imdb"
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=2,
                                  learning_rate=2e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  disable_tqdm=False,
                                  logging_steps=logging_steps,
                                  log_level="error",
                                  optim='adamw_torch'
                                  )

Model training

from transformers import Trainer

trainer = Trainer(model=model, 
                  args=training_args, 
                  compute_metrics=get_accuracy,
                  train_dataset=imdb_encoded["train"],
                  eval_dataset=imdb_encoded["validation"],
                  tokenizer=tokenizer)
trainer.train()

## Save the model
trainer.save_model()

## Training and validation accuracy
trainer.evaluate()

## Get the predictions
preds = trainer.predict(imdb_encoded['test'])
pred_class = preds.predictions.argmax(axis=-1)
label = preds.label_ids

Back to R

Computing different model metrics

prediction <- tibble(pred_class = as_factor(py$pred_class),
                        label = as_factor(py$label))

# Accuracy
metrics(prediction, label, pred_class)
# A tibble: 2 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.878
2 kap      binary         0.755
# Confusion matrix
prediction %>% conf_mat(label, pred_class) %>%
  autoplot(type = "heatmap")

Back to Python

Playing with the language model

from transformers import pipeline
classifier = pipeline('text-classification', model=model_name)
classifier('This is not my idea of fun')
[{'label': 'LABEL_0', 'score': 0.5372037291526794}]
classifier('This was beyond incredible')
[{'label': 'LABEL_1', 'score': 0.7653437852859497}]