In this short blogpost, we’ll try and demistify the concept of self-attention as used in the paper Attention is all you need. Hopefully we’ll build up enough intuition to understand self-attention, multi-head attention and in the process, unravel what makes a transformer beat. Also, the aim is to keep this as simple and light as possible, so that we can build on top of this in future posts. How would this be a Transformers blogpost without the Transformer architecture?
Transformer Architecture
1 Self-Attention
Attention is a mechanism that allows a model to focus on certain parts of a sequence by giving more weight to elements that are relevant to each other. Self-attention applies this mechanism to the same input sequence to compute a weighted representation of each element in the sequence. This is different from say cross-attention which applies the attention mechanism to two different sequences such as source sequence (input sequence in one language) and target sequence (output sequence in another language).
So how do we implement self-attention? The most common way to implement a self attention layer is the scaled dot product attention implemented in the paper Attention is all you need as shown in the figure below.
Scaled Dot Product Attention
Not all that intimidating, right? Now, let’s implement this in code.
Let’s say we have a sequence T, d where T is the length of the sequence and d is the dimension of each element in the sequence (we can think of this as T tokens, each of dimension d). We’ll call this sequence X:
```{python}import torchimport torch.nn as nntorch.manual_seed(89)# Create a sequence x of length 3, each element is a vector of size 5B, T, d =1, 3, 5# Batch size, sequence length, embedding dimensionx = torch.randn((B, T, d))print(x)print(x.shape)```
```{r}#| eval: falselibrary(torch)torch_manual_seed(89)# Create a sequence x of length 3, each element is a vector of size 5B <-1# Batch size T <-3# Sequence lengthd <-5# Embedding dimensionx <-torch_randn(c(B, T, d))print(x)print(x$shape)```
1.1 MatMul(Q, K)
Now that we have this, the first step in scaled dot product attention is to project each of the token embeddings into three different vectors: Query, Key, Value, each containing head_size learnable parameters. These vectors can be interpreted in various ways, but for now, let’s think of them as Andrej Karpathy does in his GPT lecture:
Query: This is what a token is looking for.
Key: This is what a token has to offer.
Value: This is what a token will communicated if there is a match between the Query and Key.
```{r}#| eval: falsehead_size <-16# Number of learnable parameters in each attention head# Create projections of queries, keys and valuesquery <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)key <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)value <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)# Project token embeddings into queries, keys and valuesQ <-query(x) # Shape: (B, T, 16)K <-key(x) # Shape: (B, T, 16)V <-value(x) # Shape: (B, T, 16)print(Q$shape)print(K$shape)print(V$shape)```
Now that each token has emitted a query and a key, the next step is to compute the similarity between each query and key. This is done by computing the dot product between all the queries and all the keys. Queries and keys that are similar will have a large dot product, while those that do not share much similarity will have a small dot product.
```{r}#| eval: false# Compute dot product of queries and keys:attention_scores <-torch_matmul(Q, K$transpose(-1, -2)) # (B, T, 16) x (B, 16, T) = (B, T, T)print(attention_scores$shape)print(attention_scores)```
This results in a matrix of size T, T which makes sense since we had T tokens to begin with. This matrix is called the attention score where each row represents the similarity between a query of a token i and the key of all the other tokens.
1.2 Scale
Dot products can in general produce arbitrarily large numbers, which can destabilize the training process. To address this, the attention scores are scaled by \frac{1}{\sqrt{d_k}} where d_k is the dimension of the query and key vectors.
```{python}# Multiply each element by 1/sqrt(d)dim_k = K.shape[-1]print(f"dim_k: {dim_k}")attention_scores = attention_scores * dim_k**-0.5attention_scores```
```{r}#| eval: false# Multiply each element by 1/sqrt(d)dim_k <- purrr::pluck(K$shape, -1)paste("dim_k:", dim_k)attention_scores <- attention_scores * dim_k^-0.5attention_scores```
1.3 SoftMax
To better interpret the attention score, a softmax function is applied to each row. The softmax function makes the scores interpretable as probabilities which sum up to 1, meaning that all elements receive some level of attention, with varying degrees of importance. Let’s put this in code:
```{python}import torch.nn.functional as F# Apply softmax to get attention probabilitiesweights = F.softmax(attention_scores, dim =-1)# Print weights and sum of weights along each rowprint(f"weights: {weights}")print(f"\nsum of weights along each row: {torch.sum(weights, dim =-1)}")```
weights: tensor([[[0.3668, 0.3001, 0.3331],
[0.2955, 0.4472, 0.2573],
[0.3729, 0.2933, 0.3338]]], grad_fn=<SoftmaxBackward0>)
sum of weights along each row: tensor([[1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)
```{r}#| eval: false# Apply softmax to get attention probabilitiesweights <-nnf_softmax(attention_scores, dim =-1)# Print weights and sum of weights along each rowprint(weights)print(torch_sum(weights, dim =-1))```
Much better! This shows the attention that each token pays to all the other tokens.
Note
If we had not scaled the dot product (attention scores), the softmax could have leaned towards a one hot vector, i.e. one token would have received all the attention, while the rest would have received none. This is not what we want, the ideal situation is that all tokens receive some level of attention, with varying degrees of importance. This can be illustrated by the following example:
```{r}#| eval: false# Sharpen attention scores and apply softmax againnnf_softmax(attention_scores *12, dim =-1)```
As illustrated by the example above, not scaling the attention scores could cause the softmax to saturate. This is why we scale the attention scores by \frac{1}{\sqrt{d_k}}.
1.4 MatMul(SoftMax, V)
Finally, let’s update the token embeddings. Once the attention weights have been computed, the last step is to compute the weighted sum of the values. This is achieved by multiplying the attenton weights with the value projection. The matrix multiplication of the attention weights and the value projection allows the token embeddings of each token to be updated with information from other tokens, i.e
x_i = \sum_{j = 1}^{T} w_{ij}V_j
where w_{ij} is the attention weight between token i and token j and V_j is the value projection of token j.
Our attention outputs represent our 3 initial tokens, but now whose embeddings have been updated with information from other tokens at different proportions.
That’s it! We have gone through the entire process of implementing a self attention layer. This was basically two matrix multiplications and a softmax. At its core, self attention is just a fancy form of averaging the embeddings of all the elements in a sequence.
1.5 Mask (Optional)
Let’s quickly talk about masking which is an optional step that we skipped. Masking is used in the Transformer’s decoder to prevent tokens from attending to future tokens. This is particularly useful in language generation where we want the model to predict each token based only on the tokens before it. Masking is achieved by setting the attention scores of future tokens to -\infty before applying the softmax. This ensures that the softmax will assign a probability of 0 to future tokens. Let’s put this in code:
Now, if we apply a softmax to this, we’ll see that the attention weights of future tokens are 0 meaning that the model will not attend to future tokens:
```{r}#| eval: false# Apply softmax to get attention probabilitiesmasked_weights <-nnf_softmax(masked_attention_scores, dim =-1)masked_weights```
That’s it! This is what masking is all about. We can go ahead and multiply the attention weights with the value projection to get the attention outputs.
1.6 Creating a Self Attention Layer
Let’s now implement all this steps as a single attention head layer:
```{python}# Create a sequence x of length 3, each element is a vector of size 32seq_len, embed_dim =3, 32x = torch.randn((1, seq_len, embed_dim))# Create an attention headhead_size =8# x will be projected to 8 dimensionsattn_head = AttentionHead(embed_dim = embed_dim, head_size =8)# Apply attention head to xattn_outputs = attn_head(x)attn_outputs.shape```
torch.Size([1, 3, 8])
```{r}#| eval: false# Create a sequence x of length 3, each element is a vector of size 32seq_len <-3embed_dim <-32x <-torch_randn(c(1, seq_len, embed_dim))# Create an attention headhead_size <-8# x will be projected to 8 dimensionsattn_head <-attention_head(embed_dim = embed_dim, head_size =8)# Apply attention head to xattn_outputs <-attn_head(x)attn_outputs$shape```
Like before, the sequence embeddings have been transformed into contexualized embeddings which contain information from other tokens. This is the essence of self attention. Now, let’s take this a notch higher: multi-head attention.
2 Multi-Head Attention
Instead of performing a single attention operation, the Transformer performs multiple attention operations in parallel as illustrated in the figure below.
Multi-Head Attention
But why do we need more than one attention head? The reason is that multiple attention heads allow the model to to focus on different aspects of the sequence at once. For example, one attention head can focus on the subject of a sentence, while another focuses on the verb. This is made possible by the fact that each attention head will be instantiated with different linear layer parameters, hence each attention head will learn to focus on different parts of the sequence.
As illustrated in the figure above, multi-headed attention is simply the concatenation of multiple attention heads.
Note
Although the head_dim does not have to be smaller than the number of embedding dimensions of the tokens embed_dim, in practice it is chosen to be a multiple of embed_dim so that the computation across each head is constant. For instance, if embed_dim is 32, and we want 4 attention heads, then the output of each attention head head_dim will be 8. The outputs of the attention heads are then concatenated to form a single output vector of size 32.
Now, let’s implement a full multi-head attention layer:
```{python}# Implement multi-head attention layerclass MultiHeadAttention(nn.Module):def__init__(self, embed_dim, num_heads):super().__init__()# Calculate the projection dimension for each headself.head_size = embed_dim // num_headsassertself.head_size * num_heads == embed_dim, "embed_dim must be divisible by num_heads"# Create num_heads different attention headsself.heads = nn.ModuleList([AttentionHead(embed_dim = embed_dim, head_size = head_size) for _ inrange(num_heads)])# Create a linear to project the concatenation of all heads into a vector of embed_dimself.output_linear = nn.Linear(in_features = embed_dim, out_features = embed_dim)def forward(self, x):# Pass x through each attention head multihead_attn = [attn_head(x) for attn_head inself.heads]# Concatenate the outputs of all attention heads multihead_attn = torch.cat(multihead_attn, dim =-1) # Shape: (B, T, embed_dim)# Project the concatenated outputs of all attention heads multihead_attn =self.output_linear(multihead_attn) # Shape: (B, T, embed_dim)return multihead_attn```
```{r}#| eval: false# Implement multi-head attention layermulti_head_attention <-nn_module(initialize =function(embed_dim, num_heads){# Calculate the projection dimension for each head self$head_size <- embed_dim %/% num_headsstopifnot("embed_dim must be divisible by num_heads"= self$head_size * num_heads == embed_dim)# Create num_heads different attention heads self$heads <-nn_module_list(purrr::map(1:num_heads, ~attention_head(embed_dim = embed_dim, head_size = self$head_size)))# Create a linear to project the concatenation of all heads into a vector of embed_dim self$output_linear <-nn_linear(in_features = embed_dim, out_features = embed_dim) },forward =function(x){# Pass x through each attention head multihead_attn <- purrr::map(1:length(self$heads), ~self$heads[[.x]](x))# Concatenate the outputs of all attention heads multihead_attn <-torch_cat(multihead_attn, dim =-1) # Shape: (B, T, embed_dim)# Project the concatenated outputs of all attention heads multihead_attn <- self$output_linear(multihead_attn) # Shape: (B, T, embed_dim) multihead_attn })```
Now let’s confirm that the output of the multi-head attention layer produces the same shape as the input:
```{python}# Create a sequence x of length 3, each element is a vector of size 32seq_len, embed_dim =3, 32x = torch.randn((1, seq_len, embed_dim)) # Create a multi-head attention layernum_heads =4msa = MultiHeadAttention(embed_dim = embed_dim, num_heads = num_heads)# Apply multi-head attention to xmsa_outputs = msa(x)print(f"input shape: {x.shape}")print(f"msa_outputs: {msa_outputs.shape}")```
```{r}#| eval: false# Create a sequence x of length 3, each element is a vector of size 32seq_len <-3embed_dim <-32x <-torch_randn(c(1, seq_len, embed_dim))# Create a multi-head attention layernum_heads <-4msa <-multi_head_attention(embed_dim = embed_dim, num_heads = num_heads)# Apply multi-head attention to xmsa_outputs <-msa(x)print(x$shape)print(msa_outputs$shape)```
There you have it! We have broken down what attention is all about and implemented it in code. In a nutshell, self-attention allows us to turn an input sequence of token embeddings into a sequence of contextualized embeddings where each is a weighted sum of all the token embeddings in the sequence. This is the essence of self-attention.
In the next blogpost, we can go ahead and implement the rest of the Transformer architecture.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 2017-December.
Tunstall, L., Von Werra, L., & Wolf, T. (2022). Natural Language Processing with Transformers (Revised Edition). In O’Reilly Media (Vol. 19, Issue 1).
---title: Let's talk attention in Transformerstoc: truetoc-location: leftformat: html: code-tools: true number-sections: true html-math-method: katex code-fold: false code-link: true embed-resources: true smooth-scroll: trueexecute: eval: true warning: false message: false echo: fencedengine: knitr---::: callout-noteAll images in this post are from the paper [Attention is all you need](https://arxiv.org/abs/1706.03762).:::In this short blogpost, we'll try and demistify the concept of `self-attention` as used in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762). Hopefully we'll build up enough intuition to understand self-attention, multi-head attention and in the process, unravel what makes a transformer beat. Also, the aim is to keep this as simple and light as possible, so that we can build on top of this in future posts. How would this be a Transformers blogpost without the Transformer architecture?{fig.align="center" width="566"}## Self-AttentionAttention is a mechanism that allows a model to focus on certain parts of a sequence by giving more **weight** to elements that are relevant to each other. Self-attention applies this mechanism to the same input sequence to compute a weighted representation of each element in the sequence. This is different from say *cross-attention* which applies the attention mechanism to two different sequences such as source sequence (input sequence in one language) and target sequence (output sequence in another language).So how do we implement self-attention? The most common way to implement a self attention layer is the **scaled dot product attention** implemented in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762) as shown in the figure below.{fig.align="center"}Not all that intimidating, right? Now, let's implement this in code.Let's say we have a sequence $T, d$ where $T$ is the length of the sequence and $d$ is the dimension of each element in the sequence (we can think of this as $T$ tokens, each of dimension $d$). We'll call this sequence $X$:::: panel-tabset## Python```{python}import torchimport torch.nn as nntorch.manual_seed(89)# Create a sequence x of length 3, each element is a vector of size 5B, T, d =1, 3, 5# Batch size, sequence length, embedding dimensionx = torch.randn((B, T, d))print(x)print(x.shape)```## R```{r}#| eval: falselibrary(torch)torch_manual_seed(89)# Create a sequence x of length 3, each element is a vector of size 5B <-1# Batch size T <-3# Sequence lengthd <-5# Embedding dimensionx <-torch_randn(c(B, T, d))print(x)print(x$shape)```:::### MatMul(Q, K)Now that we have this, the first step in **scaled dot product attention** is to project each of the token embeddings into three different vectors: $Query, Key, Value$, each containing `head_size` learnable parameters. These vectors can be interpreted in various ways, but for now, let's think of them as Andrej Karpathy does in his [GPT lecture](https://youtu.be/kCc8FmEb1nY?si=v_q-7mZxLr2BvV3f&t=3716):- *Query*: This is what a token is looking for.- *Key*: This is what a token has to offer.- *Value*: This is what a token will communicated if there is a match between the Query and Key.Let's put this in code:::: panel-tabset## Python```{python}head_size =16# Number of learnable parameters in each attention head# Create projections of queries, keys and valuesquery = nn.Linear(in_features = d, out_features = head_size, bias =False)key = nn.Linear(in_features = d, out_features = head_size, bias =False)value = nn.Linear(in_features = d, out_features = head_size, bias =False)# Project token embeddings into queries, keys and valuesQ = query(x) # Shape: (B, T, 16)K = key(x) # Shape: (B, T, 16)V = value(x) # Shape: (B, T, 16)Q.shape, K.shape, V.shape```## R```{r}#| eval: falsehead_size <-16# Number of learnable parameters in each attention head# Create projections of queries, keys and valuesquery <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)key <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)value <-nn_linear(in_features = d, out_features = head_size, bias =FALSE)# Project token embeddings into queries, keys and valuesQ <-query(x) # Shape: (B, T, 16)K <-key(x) # Shape: (B, T, 16)V <-value(x) # Shape: (B, T, 16)print(Q$shape)print(K$shape)print(V$shape)```:::Now that each token has emitted a query and a key, the next step is to compute the similarity between each query and key. This is done by computing the dot product between all the queries and all the keys. Queries and keys that are similar will have a large dot product, while those that do not share much similarity will have a small dot product.::: panel-tabset## Python```{python}# Compute dot product of queries and keys:attention_scores = torch.matmul(Q, K.transpose(-1, -2)) # (B, T, 16) x (B, 16, T) = (B, T, T)print(f"weights.shape: {attention_scores.shape}, \n\nattention_scores: {attention_scores}")```## R```{r}#| eval: false# Compute dot product of queries and keys:attention_scores <-torch_matmul(Q, K$transpose(-1, -2)) # (B, T, 16) x (B, 16, T) = (B, T, T)print(attention_scores$shape)print(attention_scores)```:::This results in a matrix of size $T, T$ which makes sense since we had $T$ tokens to begin with. This matrix is called the **attention score** where each row represents the similarity between a query of a token **i** and the key of all the other tokens.### ScaleDot products can in general produce arbitrarily large numbers, which can destabilize the training process. To address this, the attention scores are scaled by $\frac{1}{\sqrt{d_k}}$ where $d_k$ is the dimension of the query and key vectors.::: panel-tabset## Python```{python}# Multiply each element by 1/sqrt(d)dim_k = K.shape[-1]print(f"dim_k: {dim_k}")attention_scores = attention_scores * dim_k**-0.5attention_scores```## R```{r}#| eval: false# Multiply each element by 1/sqrt(d)dim_k <- purrr::pluck(K$shape, -1)paste("dim_k:", dim_k)attention_scores <- attention_scores * dim_k^-0.5attention_scores```:::### SoftMaxTo better interpret the attention score, a softmax function is applied to each row. The softmax function makes the scores interpretable as probabilities which sum up to 1, meaning that all elements receive some level of attention, with varying degrees of importance. Let's put this in code:::: panel-tabset## Python```{python}import torch.nn.functional as F# Apply softmax to get attention probabilitiesweights = F.softmax(attention_scores, dim =-1)# Print weights and sum of weights along each rowprint(f"weights: {weights}")print(f"\nsum of weights along each row: {torch.sum(weights, dim =-1)}")```## R```{r}#| eval: false # Apply softmax to get attention probabilitiesweights <-nnf_softmax(attention_scores, dim =-1)# Print weights and sum of weights along each rowprint(weights)print(torch_sum(weights, dim =-1))```:::Much better! This shows the attention that each token pays to all the other tokens.::: callout-noteIf we had not scaled the dot product (attention scores), the softmax could have leaned towards a one hot vector, i.e. one token would have received all the attention, while the rest would have received none. This is not what we want, the ideal situation is that all tokens receive some level of attention, with varying degrees of importance. This can be illustrated by the following example::::::: panel-tabset## Python```{python}# Sharpen attention scores and apply softmax againtorch.set_printoptions(sci_mode=False)F.softmax(attention_scores *12, dim =-1)```## R```{r}#| eval: false# Sharpen attention scores and apply softmax againnnf_softmax(attention_scores *12, dim =-1)```:::As illustrated by the example above, not scaling the attention scores could cause the softmax to saturate. This is why we scale the attention scores by $\frac{1}{\sqrt{d_k}}$.### MatMul(SoftMax, V)Finally, let's update the token embeddings. Once the attention weights have been computed, the last step is to compute the weighted sum of the values. This is achieved by multiplying the attenton weights with the **value projection**. The matrix multiplication of the attention weights and the value projection allows the token embeddings of each token to be updated with information from other tokens, i.e $$x_i = \sum_{j = 1}^{T} w_{ij}V_j$$where $w_{ij}$ is the attention weight between token $i$ and token $j$ and $V_j$ is the value projection of token $j$.::: panel-tabset## Python```{python}# Update token embeddings: Multiply attention weights with valuesattn_outputs = torch.matmul(weights, V) # (B, T, T) x (B, T, 16) = (B, T, 16)V.shape, attn_outputs.shape```## R```{r}#| eval: false# Update token embeddings: Multiply attention weights with valuesattn_outputs <-torch_matmul(weights, V) # (B, T, T) x (B, T, 16) = (B, T, 16)print(V$shape)print(attn_outputs$shape)```:::Our attention outputs represent our 3 initial tokens, but now whose embeddings have been updated with information from other tokens at different proportions.That's it! We have gone through the entire process of implementing a self attention layer. This was basically two matrix multiplications and a softmax. At its core, self attention is just a fancy form of averaging the embeddings of all the elements in a sequence.### Mask (Optional)Let's quickly talk about masking which is an optional step that we skipped. Masking is used in the Transformer's decoder to prevent tokens from attending to future tokens. This is particularly useful in language generation where we want the model to predict each token based only on the tokens before it. Masking is achieved by setting the attention scores of future tokens to $-\infty$ before applying the softmax. This ensures that the softmax will assign a probability of 0 to future tokens. Let's put this in code:::: panel-tabset## Python```{python}# Create a mask using lower triangular matrix of shape (T, T)mask = torch.tril(torch.ones((T, T)))mask```## R```{r}#| eval: false# Create a mask using lower triangular matrix of shape (T, T)mask <-torch_tril(torch_ones(c(T, T)))mask```:::Now let's prevent tokens from peeking at future tokens by replacing the attention scores of future tokens with $-\infty$:::: panel-tabset## Python```{python}# Mask future tokensmasked_attention_scores = attention_scores.masked_fill(mask ==0, float("-inf"))masked_attention_scores```## R```{r}#| eval: false# Mask future tokensmasked_attention_scores <- attention_scores$masked_fill(mask ==0, -Inf)masked_attention_scores```:::Now, if we apply a softmax to this, we'll see that the attention weights of future tokens are *0* meaning that the model will not attend to future tokens:::: panel-tabset## Python```{python}masked_weights = F.softmax(masked_attention_scores, dim =-1)masked_weights```## R```{r}#| eval: false# Apply softmax to get attention probabilitiesmasked_weights <-nnf_softmax(masked_attention_scores, dim =-1)masked_weights```:::That's it! This is what masking is all about. We can go ahead and multiply the attention weights with the value projection to get the attention outputs.### Creating a Self Attention LayerLet's now implement all this steps as a single attention head layer:::: panel-tabset## Python```{python}class AttentionHead(nn.Module):def__init__(self, embed_dim, head_size):super().__init__()# Create projections of queries, keys and valuesself.query = nn.Linear(in_features = embed_dim, out_features = head_size, bias =False)self.key = nn.Linear(in_features = embed_dim, out_features = head_size, bias =False)self.value = nn.Linear(in_features = embed_dim, out_features = head_size, bias =False)def forward(self, x):# Project token embeddings into queries, keys and values Q =self.query(x) # Shape: (B, T, head_size) K =self.key(x) # Shape: (B, T, head_size) V =self.value(x) # Shape: (B, T, head_size)# Compute dot product of queries and keys: attention_scores = torch.matmul(Q, K.transpose(-1, -2)) # (B, T, head_size) x (B, head_size, T) = (B, T, T)# Scale the attention scores attention_scores = attention_scores * head_size**-0.5# Apply softmax to get attention probabilities weights = F.softmax(attention_scores, dim =-1)# Multiply attention weights with values to get contextualized embeddings attn_outputs = torch.matmul(weights, V) # (B, T, T) x (B, T, head_size) = (B, T, head_size)return attn_outputs```## R```{r}#| eval: false# Implement an attention headattention_head <-nn_module(initialize =function(embed_dim, head_size){# Create projections of queries, keys and values self$query <-nn_linear(in_features = embed_dim, out_features = head_size, bias =FALSE) self$key <-nn_linear(in_features = embed_dim, out_features = head_size, bias =FALSE) self$value <-nn_linear(in_features = embed_dim, out_features = head_size, bias =FALSE) },forward =function(x){# Project token embeddings into queries, keys and values Q <- self$query(x) # Shape: (B, T, head_size) K <- self$key(x) # Shape: (B, T, head_size) V <- self$value(x) # Shape: (B, T, head_size)# Compute dot product of queries and keys: attention_scores <-torch_matmul(Q, K$transpose(-1, -2)) # (B, T, head_size) x (B, head_size, T) = (B, T, T)# Scale the attention scores attention_scores <- attention_scores * head_size^-0.5# Apply softmax to get attention probabilities weights <-nnf_softmax(attention_scores, dim =-1)# Multiply attention weights with values to get contextualized embeddings attn_outputs <-torch_matmul(weights, V) # (B, T, T) x (B, T, head_size) = (B, T, head_size) attn_outputs })```:::Now let's apply this attention head to a sequence of embeddings:::: panel-tabset## Python```{python}# Create a sequence x of length 3, each element is a vector of size 32seq_len, embed_dim =3, 32x = torch.randn((1, seq_len, embed_dim))# Create an attention headhead_size =8# x will be projected to 8 dimensionsattn_head = AttentionHead(embed_dim = embed_dim, head_size =8)# Apply attention head to xattn_outputs = attn_head(x)attn_outputs.shape```## R```{r}#| eval: false# Create a sequence x of length 3, each element is a vector of size 32seq_len <-3embed_dim <-32x <-torch_randn(c(1, seq_len, embed_dim))# Create an attention headhead_size <-8# x will be projected to 8 dimensionsattn_head <-attention_head(embed_dim = embed_dim, head_size =8)# Apply attention head to xattn_outputs <-attn_head(x)attn_outputs$shape```:::Like before, the sequence embeddings have been transformed into contexualized embeddings which contain information from other tokens. This is the essence of self attention. Now, let's take this a notch higher: *multi-head attention*.## Multi-Head AttentionInstead of performing a single attention operation, the Transformer performs multiple attention operations in parallel as illustrated in the figure below.{fig.align="center"}But why do we need more than one attention head? The reason is that multiple attention heads allow the model to to focus on different aspects of the sequence at once. For example, one attention head can focus on the subject of a sentence, while another focuses on the verb. This is made possible by the fact that each attention head will be instantiated with different linear layer parameters, hence each attention head will learn to focus on different parts of the sequence.As illustrated in the figure above, multi-headed attention is simply the concatenation of multiple attention heads.::: callout-noteAlthough the *head_dim* does not have to be smaller than the number of embedding dimensions of the tokens *embed_dim*, in practice it is chosen to be a multiple of *embed_dim* so that the computation across each head is constant. For instance, if *embed_dim* is 32, and we want 4 attention heads, then the output of each attention head *head_dim* will be 8. The outputs of the attention heads are then concatenated to form a single output vector of size 32.:::Now, let's implement a full multi-head attention layer:::: panel-tabset## Python```{python}# Implement multi-head attention layerclass MultiHeadAttention(nn.Module):def__init__(self, embed_dim, num_heads):super().__init__()# Calculate the projection dimension for each headself.head_size = embed_dim // num_headsassertself.head_size * num_heads == embed_dim, "embed_dim must be divisible by num_heads"# Create num_heads different attention headsself.heads = nn.ModuleList([AttentionHead(embed_dim = embed_dim, head_size = head_size) for _ inrange(num_heads)])# Create a linear to project the concatenation of all heads into a vector of embed_dimself.output_linear = nn.Linear(in_features = embed_dim, out_features = embed_dim)def forward(self, x):# Pass x through each attention head multihead_attn = [attn_head(x) for attn_head inself.heads]# Concatenate the outputs of all attention heads multihead_attn = torch.cat(multihead_attn, dim =-1) # Shape: (B, T, embed_dim)# Project the concatenated outputs of all attention heads multihead_attn =self.output_linear(multihead_attn) # Shape: (B, T, embed_dim)return multihead_attn```## R```{r}#| eval: false# Implement multi-head attention layermulti_head_attention <-nn_module(initialize =function(embed_dim, num_heads){# Calculate the projection dimension for each head self$head_size <- embed_dim %/% num_headsstopifnot("embed_dim must be divisible by num_heads"= self$head_size * num_heads == embed_dim)# Create num_heads different attention heads self$heads <-nn_module_list(purrr::map(1:num_heads, ~attention_head(embed_dim = embed_dim, head_size = self$head_size)))# Create a linear to project the concatenation of all heads into a vector of embed_dim self$output_linear <-nn_linear(in_features = embed_dim, out_features = embed_dim) },forward =function(x){# Pass x through each attention head multihead_attn <- purrr::map(1:length(self$heads), ~self$heads[[.x]](x))# Concatenate the outputs of all attention heads multihead_attn <-torch_cat(multihead_attn, dim =-1) # Shape: (B, T, embed_dim)# Project the concatenated outputs of all attention heads multihead_attn <- self$output_linear(multihead_attn) # Shape: (B, T, embed_dim) multihead_attn })```::: Now let's confirm that the output of the multi-head attention layer produces the same shape as the input:::: panel-tabset## Python```{python}# Create a sequence x of length 3, each element is a vector of size 32seq_len, embed_dim =3, 32x = torch.randn((1, seq_len, embed_dim)) # Create a multi-head attention layernum_heads =4msa = MultiHeadAttention(embed_dim = embed_dim, num_heads = num_heads)# Apply multi-head attention to xmsa_outputs = msa(x)print(f"input shape: {x.shape}")print(f"msa_outputs: {msa_outputs.shape}")```## R```{r}#| eval: false# Create a sequence x of length 3, each element is a vector of size 32seq_len <-3embed_dim <-32x <-torch_randn(c(1, seq_len, embed_dim))# Create a multi-head attention layernum_heads <-4msa <-multi_head_attention(embed_dim = embed_dim, num_heads = num_heads)# Apply multi-head attention to xmsa_outputs <-msa(x)print(x$shape)print(msa_outputs$shape)```:::There you have it! We have broken down what attention is all about and implemented it in code. In a nutshell, self-attention allows us to turn an input sequence of token embeddings into a sequence of contextualized embeddings where each is a weighted sum of all the token embeddings in the sequence. This is the essence of self-attention.In the next blogpost, we can go ahead and implement the rest of the Transformer architecture.In the meantime, Happy learning.```{r}#| eval: false#| include: false#| echo: false#quarto render attention.ipynb rpubsUpload("Let's talk attention in Transformers", "attention.html", NULL)```## References- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 2017-December.- Tunstall, L., Von Werra, L., & Wolf, T. (2022). Natural Language Processing with Transformers (Revised Edition). In O’Reilly Media (Vol. 19, Issue 1).- Andrej Karparthy, [Let's build GPT: from scratch, in code, spelled out.](https://youtu.be/kCc8FmEb1nY?si=EvGH3ZRkGIlN1t0_)