Let’s talk attention in Transformers

Note

All images in this post are from the paper Attention is all you need.

In this short blogpost, we’ll try and demistify the concept of self-attention as used in the paper Attention is all you need. Hopefully we’ll build up enough intuition to understand self-attention, multi-head attention and in the process, unravel what makes a transformer beat. Also, the aim is to keep this as simple and light as possible, so that we can build on top of this in future posts. How would this be a Transformers blogpost without the Transformer architecture?

Transformer Architecture

1 Self-Attention

Attention is a mechanism that allows a model to focus on certain parts of a sequence by giving more weight to elements that are relevant to each other. Self-attention applies this mechanism to the same input sequence to compute a weighted representation of each element in the sequence. This is different from say cross-attention which applies the attention mechanism to two different sequences such as source sequence (input sequence in one language) and target sequence (output sequence in another language).

So how do we implement self-attention? The most common way to implement a self attention layer is the scaled dot product attention implemented in the paper Attention is all you need as shown in the figure below.

Scaled Dot Product Attention

Not all that intimidating, right? Now, let’s implement this in code.

Let’s say we have a sequence T, d where T is the length of the sequence and d is the dimension of each element in the sequence (we can think of this as T tokens, each of dimension d). We’ll call this sequence X:

```{python}
import torch
import torch.nn as nn
torch.manual_seed(89)
# Create a sequence x of length 3, each element is a vector of size 5
B, T, d = 1, 3, 5 # Batch size, sequence length, embedding dimension
x = torch.randn((B, T, d))

print(x)
print(x.shape)
```
<torch._C.Generator object at 0x000001DE7B528190>
tensor([[[-0.4138,  0.4192, -0.9875, -0.6275, -1.5819],
         [ 0.5018, -1.2535, -0.1187, -2.3139, -0.5850],
         [ 0.4416,  0.6665, -0.2620, -1.0138, -1.0609]]])
torch.Size([1, 3, 5])
```{r}
#| eval: false
library(torch)
torch_manual_seed(89)
# Create a sequence x of length 3, each element is a vector of size 5
B <- 1 # Batch size 
T <- 3 # Sequence length
d <- 5 # Embedding dimension
x <- torch_randn(c(B, T, d))

print(x)
print(x$shape)
```

1.1 MatMul(Q, K)

Now that we have this, the first step in scaled dot product attention is to project each of the token embeddings into three different vectors: Query, Key, Value, each containing head_size learnable parameters. These vectors can be interpreted in various ways, but for now, let’s think of them as Andrej Karpathy does in his GPT lecture:

  • Query: This is what a token is looking for.
  • Key: This is what a token has to offer.
  • Value: This is what a token will communicated if there is a match between the Query and Key.

Let’s put this in code:

```{python}
head_size = 16 # Number of learnable parameters in each attention head

# Create projections of queries, keys and values
query = nn.Linear(in_features = d, out_features = head_size, bias = False)
key = nn.Linear(in_features = d, out_features = head_size, bias = False)
value = nn.Linear(in_features = d, out_features = head_size, bias = False)

# Project token embeddings into queries, keys and values
Q = query(x) # Shape: (B, T, 16)
K = key(x) # Shape: (B, T, 16)
V = value(x) # Shape: (B, T, 16)

Q.shape, K.shape, V.shape
```
(torch.Size([1, 3, 16]), torch.Size([1, 3, 16]), torch.Size([1, 3, 16]))
```{r}
#| eval: false

head_size <- 16 # Number of learnable parameters in each attention head

# Create projections of queries, keys and values
query <- nn_linear(in_features = d, out_features = head_size, bias = FALSE)
key  <- nn_linear(in_features = d, out_features = head_size, bias = FALSE)
value <- nn_linear(in_features = d, out_features = head_size, bias = FALSE)

# Project token embeddings into queries, keys and values
Q <- query(x) # Shape: (B, T, 16)
K <- key(x) # Shape: (B, T, 16)
V <- value(x) # Shape: (B, T, 16)

print(Q$shape)
print(K$shape)
print(V$shape)
```

Now that each token has emitted a query and a key, the next step is to compute the similarity between each query and key. This is done by computing the dot product between all the queries and all the keys. Queries and keys that are similar will have a large dot product, while those that do not share much similarity will have a small dot product.

```{python}
# Compute dot product of queries and keys:
attention_scores = torch.matmul(Q, K.transpose(-1, -2)) # (B, T, 16) x (B, 16, T) = (B, T, T)
print(f"weights.shape: {attention_scores.shape}, \n\nattention_scores: {attention_scores}")
```
weights.shape: torch.Size([1, 3, 3]), 

attention_scores: tensor([[[ 0.7696, -0.0334,  0.3836],
         [-0.9035,  0.7536, -1.4572],
         [ 0.2226, -0.7387, -0.2212]]], grad_fn=<UnsafeViewBackward0>)
```{r}
#| eval: false
# Compute dot product of queries and keys:
attention_scores <- torch_matmul(Q, K$transpose(-1, -2)) # (B, T, 16) x (B, 16, T) = (B, T, T)
print(attention_scores$shape)
print(attention_scores)
```

This results in a matrix of size T, T which makes sense since we had T tokens to begin with. This matrix is called the attention score where each row represents the similarity between a query of a token i and the key of all the other tokens.

1.2 Scale

Dot products can in general produce arbitrarily large numbers, which can destabilize the training process. To address this, the attention scores are scaled by \frac{1}{\sqrt{d_k}} where d_k is the dimension of the query and key vectors.

```{python}
# Multiply each element by 1/sqrt(d)
dim_k = K.shape[-1]
print(f"dim_k: {dim_k}")
attention_scores = attention_scores * dim_k**-0.5
attention_scores
```
dim_k: 16
tensor([[[ 0.1924, -0.0083,  0.0959],
         [-0.2259,  0.1884, -0.3643],
         [ 0.0557, -0.1847, -0.0553]]], grad_fn=<MulBackward0>)
```{r}
#| eval: false
# Multiply each element by 1/sqrt(d)
dim_k <- purrr::pluck(K$shape, -1)
paste("dim_k:", dim_k)
attention_scores <- attention_scores * dim_k^-0.5
attention_scores
```

1.3 SoftMax

To better interpret the attention score, a softmax function is applied to each row. The softmax function makes the scores interpretable as probabilities which sum up to 1, meaning that all elements receive some level of attention, with varying degrees of importance. Let’s put this in code:

```{python}
import torch.nn.functional as F
# Apply softmax to get attention probabilities
weights = F.softmax(attention_scores, dim = -1)

# Print weights and sum of weights along each row
print(f"weights: {weights}")
print(f"\nsum of weights along each row: {torch.sum(weights, dim = -1)}")
```
weights: tensor([[[0.3668, 0.3001, 0.3331],
         [0.2955, 0.4472, 0.2573],
         [0.3729, 0.2933, 0.3338]]], grad_fn=<SoftmaxBackward0>)

sum of weights along each row: tensor([[1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)
```{r}
#| eval: false
# Apply softmax to get attention probabilities
weights  <- nnf_softmax(attention_scores, dim = -1)

# Print weights and sum of weights along each row
print(weights)
print(torch_sum(weights, dim = -1))
```

Much better! This shows the attention that each token pays to all the other tokens.

Note

If we had not scaled the dot product (attention scores), the softmax could have leaned towards a one hot vector, i.e. one token would have received all the attention, while the rest would have received none. This is not what we want, the ideal situation is that all tokens receive some level of attention, with varying degrees of importance. This can be illustrated by the following example:

```{python}
# Sharpen attention scores and apply softmax again
torch.set_printoptions(sci_mode=False)
F.softmax(attention_scores * 12, dim = -1)
```
tensor([[[0.7122, 0.0640, 0.2237],
         [0.0069, 0.9918, 0.0013],
         [0.7576, 0.0424, 0.2001]]], grad_fn=<SoftmaxBackward0>)
```{r}
#| eval: false
# Sharpen attention scores and apply softmax again
nnf_softmax(attention_scores * 12, dim = -1)
```

As illustrated by the example above, not scaling the attention scores could cause the softmax to saturate. This is why we scale the attention scores by \frac{1}{\sqrt{d_k}}.

1.4 MatMul(SoftMax, V)

Finally, let’s update the token embeddings. Once the attention weights have been computed, the last step is to compute the weighted sum of the values. This is achieved by multiplying the attenton weights with the value projection. The matrix multiplication of the attention weights and the value projection allows the token embeddings of each token to be updated with information from other tokens, i.e x_i = \sum_{j = 1}^{T} w_{ij}V_j where w_{ij} is the attention weight between token i and token j and V_j is the value projection of token j.

```{python}
# Update token embeddings: Multiply attention weights with values
attn_outputs = torch.matmul(weights, V) # (B, T, T) x (B, T, 16) = (B, T, 16)

V.shape, attn_outputs.shape
```
(torch.Size([1, 3, 16]), torch.Size([1, 3, 16]))
```{r}
#| eval: false
# Update token embeddings: Multiply attention weights with values
attn_outputs <- torch_matmul(weights, V) # (B, T, T) x (B, T, 16) = (B, T, 16)

print(V$shape)
print(attn_outputs$shape)
```

Our attention outputs represent our 3 initial tokens, but now whose embeddings have been updated with information from other tokens at different proportions.

That’s it! We have gone through the entire process of implementing a self attention layer. This was basically two matrix multiplications and a softmax. At its core, self attention is just a fancy form of averaging the embeddings of all the elements in a sequence.

1.5 Mask (Optional)

Let’s quickly talk about masking which is an optional step that we skipped. Masking is used in the Transformer’s decoder to prevent tokens from attending to future tokens. This is particularly useful in language generation where we want the model to predict each token based only on the tokens before it. Masking is achieved by setting the attention scores of future tokens to -\infty before applying the softmax. This ensures that the softmax will assign a probability of 0 to future tokens. Let’s put this in code:

```{python}
# Create a mask using lower triangular matrix of shape (T, T)
mask = torch.tril(torch.ones((T, T)))
mask
```
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
```{r}
#| eval: false
# Create a mask using lower triangular matrix of shape (T, T)
mask <- torch_tril(torch_ones(c(T, T)))
mask
```

Now let’s prevent tokens from peeking at future tokens by replacing the attention scores of future tokens with -\infty:

```{python}
# Mask future tokens
masked_attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))
masked_attention_scores
```
tensor([[[ 0.1924,    -inf,    -inf],
         [-0.2259,  0.1884,    -inf],
         [ 0.0557, -0.1847, -0.0553]]], grad_fn=<MaskedFillBackward0>)
```{r}
#| eval: false
# Mask future tokens
masked_attention_scores <- attention_scores$masked_fill(mask == 0, -Inf)
masked_attention_scores
```

Now, if we apply a softmax to this, we’ll see that the attention weights of future tokens are 0 meaning that the model will not attend to future tokens:

```{python}
masked_weights = F.softmax(masked_attention_scores, dim = -1)
masked_weights
```
tensor([[[1.0000, 0.0000, 0.0000],
         [0.3979, 0.6021, 0.0000],
         [0.3729, 0.2933, 0.3338]]], grad_fn=<SoftmaxBackward0>)
```{r}
#| eval: false
# Apply softmax to get attention probabilities
masked_weights <- nnf_softmax(masked_attention_scores, dim = -1)
masked_weights
```

That’s it! This is what masking is all about. We can go ahead and multiply the attention weights with the value projection to get the attention outputs.

1.6 Creating a Self Attention Layer

Let’s now implement all this steps as a single attention head layer:

```{python}
class AttentionHead(nn.Module):

    def __init__(self, embed_dim, head_size):
        super().__init__()
        # Create projections of queries, keys and values
        self.query = nn.Linear(in_features = embed_dim, out_features = head_size, bias = False)
        self.key = nn.Linear(in_features = embed_dim, out_features = head_size, bias = False)
        self.value = nn.Linear(in_features = embed_dim, out_features = head_size, bias = False)

    def forward(self, x):
        # Project token embeddings into queries, keys and values
        Q = self.query(x) # Shape: (B, T, head_size)
        K = self.key(x) # Shape: (B, T, head_size)
        V = self.value(x) # Shape: (B, T, head_size)

        # Compute dot product of queries and keys:
        attention_scores = torch.matmul(Q, K.transpose(-1, -2)) # (B, T, head_size) x (B, head_size, T) = (B, T, T)

        # Scale the attention scores
        attention_scores = attention_scores * head_size**-0.5

        # Apply softmax to get attention probabilities
        weights = F.softmax(attention_scores, dim = -1)

        # Multiply attention weights with values to get contextualized embeddings
        attn_outputs = torch.matmul(weights, V) # (B, T, T) x (B, T, head_size) = (B, T, head_size)

        return attn_outputs
```
```{r}
#| eval: false
# Implement an attention head
attention_head <- nn_module(
    initialize = function(embed_dim, head_size){
        # Create projections of queries, keys and values
        self$query <- nn_linear(in_features = embed_dim, out_features = head_size, bias = FALSE)
        self$key <- nn_linear(in_features = embed_dim, out_features = head_size, bias = FALSE)
        self$value <- nn_linear(in_features = embed_dim, out_features = head_size, bias = FALSE)
    },

    forward = function(x){
        # Project token embeddings into queries, keys and values
        Q <- self$query(x) # Shape: (B, T, head_size)
        K <- self$key(x) # Shape: (B, T, head_size)
        V <- self$value(x) # Shape: (B, T, head_size)

        # Compute dot product of queries and keys:
        attention_scores <- torch_matmul(Q, K$transpose(-1, -2)) # (B, T, head_size) x (B, head_size, T) = (B, T, T)

        # Scale the attention scores
        attention_scores <- attention_scores * head_size^-0.5

        # Apply softmax to get attention probabilities
        weights <- nnf_softmax(attention_scores, dim = -1)

        # Multiply attention weights with values to get contextualized embeddings
        attn_outputs <- torch_matmul(weights, V) # (B, T, T) x (B, T, head_size) = (B, T, head_size)

        attn_outputs
    }
)
```

Now let’s apply this attention head to a sequence of embeddings:

```{python}
# Create a sequence x of length 3, each element is a vector of size 32
seq_len, embed_dim = 3, 32
x = torch.randn((1, seq_len, embed_dim))

# Create an attention head
head_size = 8 # x will be projected to 8 dimensions
attn_head = AttentionHead(embed_dim = embed_dim, head_size = 8)

# Apply attention head to x
attn_outputs = attn_head(x)
attn_outputs.shape
```
torch.Size([1, 3, 8])
```{r}
#| eval: false
# Create a sequence x of length 3, each element is a vector of size 32
seq_len <- 3
embed_dim <- 32
x <- torch_randn(c(1, seq_len, embed_dim))

# Create an attention head
head_size <- 8 # x will be projected to 8 dimensions
attn_head <- attention_head(embed_dim = embed_dim, head_size = 8)

# Apply attention head to x
attn_outputs <- attn_head(x)
attn_outputs$shape
```

Like before, the sequence embeddings have been transformed into contexualized embeddings which contain information from other tokens. This is the essence of self attention. Now, let’s take this a notch higher: multi-head attention.

2 Multi-Head Attention

Instead of performing a single attention operation, the Transformer performs multiple attention operations in parallel as illustrated in the figure below.

Multi-Head Attention

But why do we need more than one attention head? The reason is that multiple attention heads allow the model to to focus on different aspects of the sequence at once. For example, one attention head can focus on the subject of a sentence, while another focuses on the verb. This is made possible by the fact that each attention head will be instantiated with different linear layer parameters, hence each attention head will learn to focus on different parts of the sequence.

As illustrated in the figure above, multi-headed attention is simply the concatenation of multiple attention heads.

Note

Although the head_dim does not have to be smaller than the number of embedding dimensions of the tokens embed_dim, in practice it is chosen to be a multiple of embed_dim so that the computation across each head is constant. For instance, if embed_dim is 32, and we want 4 attention heads, then the output of each attention head head_dim will be 8. The outputs of the attention heads are then concatenated to form a single output vector of size 32.

Now, let’s implement a full multi-head attention layer:

```{python}
# Implement multi-head attention layer
class MultiHeadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # Calculate the projection dimension for each head
        self.head_size = embed_dim // num_heads
        assert self.head_size * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        # Create num_heads different attention heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim = embed_dim, head_size = head_size) for _ in range(num_heads)])
        # Create a linear to project the concatenation of all heads into a vector of embed_dim
        self.output_linear = nn.Linear(in_features = embed_dim, out_features = embed_dim)

    
    def forward(self, x):
        # Pass x through each attention head
        multihead_attn = [attn_head(x) for attn_head in self.heads]
        # Concatenate the outputs of all attention heads
        multihead_attn = torch.cat(multihead_attn, dim = -1) # Shape: (B, T, embed_dim)
        # Project the concatenated outputs of all attention heads
        multihead_attn = self.output_linear(multihead_attn) # Shape: (B, T, embed_dim)
        return multihead_attn
```
```{r}
#| eval: false
# Implement multi-head attention layer
multi_head_attention <- nn_module(
    initialize = function(embed_dim, num_heads){
        # Calculate the projection dimension for each head
        self$head_size <- embed_dim %/% num_heads
        stopifnot("embed_dim must be divisible by num_heads" = self$head_size * num_heads == embed_dim)
        # Create num_heads different attention heads
        self$heads <- nn_module_list(purrr::map(1:num_heads, ~attention_head(embed_dim = embed_dim, head_size = self$head_size)))
        # Create a linear to project the concatenation of all heads into a vector of embed_dim
        self$output_linear <- nn_linear(in_features = embed_dim, out_features = embed_dim)

    },
    forward = function(x){
        # Pass x through each attention head
        multihead_attn  <- purrr::map(1:length(self$heads), ~self$heads[[.x]](x))
        # Concatenate the outputs of all attention heads
        multihead_attn <- torch_cat(multihead_attn, dim = -1) # Shape: (B, T, embed_dim)
        # Project the concatenated outputs of all attention heads
        multihead_attn <- self$output_linear(multihead_attn) # Shape: (B, T, embed_dim)
        multihead_attn
    }
)
```

Now let’s confirm that the output of the multi-head attention layer produces the same shape as the input:

```{python}
# Create a sequence x of length 3, each element is a vector of size 32
seq_len, embed_dim = 3, 32
x = torch.randn((1, seq_len, embed_dim))    

# Create a multi-head attention layer
num_heads = 4
msa = MultiHeadAttention(embed_dim = embed_dim, num_heads = num_heads)

# Apply multi-head attention to x
msa_outputs = msa(x)
print(f"input shape: {x.shape}")
print(f"msa_outputs: {msa_outputs.shape}")
```
input shape: torch.Size([1, 3, 32])
msa_outputs: torch.Size([1, 3, 32])
```{r}
#| eval: false
# Create a sequence x of length 3, each element is a vector of size 32
seq_len <- 3
embed_dim <- 32
x <- torch_randn(c(1, seq_len, embed_dim))

# Create a multi-head attention layer
num_heads <- 4
msa <- multi_head_attention(embed_dim = embed_dim, num_heads = num_heads)

# Apply multi-head attention to x
msa_outputs <- msa(x)
print(x$shape)
print(msa_outputs$shape)
```

There you have it! We have broken down what attention is all about and implemented it in code. In a nutshell, self-attention allows us to turn an input sequence of token embeddings into a sequence of contextualized embeddings where each is a weighted sum of all the token embeddings in the sequence. This is the essence of self-attention.

In the next blogpost, we can go ahead and implement the rest of the Transformer architecture.

In the meantime, Happy learning!

Eric.

3 References

  • Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 2017-December.
  • Tunstall, L., Von Werra, L., & Wolf, T. (2022). Natural Language Processing with Transformers (Revised Edition). In O’Reilly Media (Vol. 19, Issue 1).
  • Andrej Karparthy, Let’s build GPT: from scratch, in code, spelled out.