Sunday, December 22, 2024

Understanding Positional Embeddings in Self-Attention: A PyTorch Implementation

Share

Understanding Positional Embeddings in Transformers: A Comprehensive Guide

If you’ve delved into transformer papers, you’ve likely encountered the concept of Positional Embeddings (PE). While they may seem straightforward at first glance, implementing them can quickly become perplexing. This article aims to demystify positional embeddings, particularly in the context of computer vision, and provide a clear understanding of their implementation.

The Importance of Positional Embeddings

When working with transformer models, especially in computer vision, grasping the concept of positional embeddings is crucial. Unlike natural language processing (NLP) tasks, where the order of words is inherently sequential, images possess a highly structured nature. This structure necessitates a robust representation of positional information within the multi-head self-attention (MHSA) block.

Positional Encodings vs. Positional Embeddings

In the original transformer architecture, positional encodings are added before the first MHSA block. However, it’s essential to distinguish between positional encodings and positional embeddings. While positional encodings are fixed and sinusoidal, positional embeddings are trainable vectors that map each position in a sequence to a corresponding vector of a specified dimension.

For instance, in PyTorch, the initialization of positional embeddings can be represented as follows:

pos_emb1D = torch.nn.Parameter(torch.randn(max_seq_tokens, dim))
input_to_transformer_mhsa = input_embedding + pos_emb1D[:current_seq_tokens, :]
out = transformer(input_to_transformer_mhsa)

This snippet illustrates how positional embeddings are integrated into the input representation, allowing the model to learn positional information during training.

Visualizing Positional Embeddings

To better understand what positional embeddings learn, we can refer to a study by Wang et al. (2020), which visualizes the position-wise similarity of various positional embeddings across different NLP models. The brighter areas in the visualizations indicate higher similarity, showcasing how larger models like GPT-2 can process more tokens effectively.

The Emergence of Positional Embeddings in MHSA

One of the primary challenges with traditional positional encodings is that they only provide positional information at the input stage. The MHSA mechanism itself is permutation equivariant, meaning it lacks inherent positional awareness. This limitation is particularly problematic for structured data like images.

To address this, positional embeddings can be incorporated directly into the MHSA block. By doing so, we can inject positional information into the self-attention mechanism, enhancing its ability to capture the spatial relationships present in images.

Understanding Self-Attention with Positional Information

In self-attention, the input sequence is modeled as a fully connected directed graph. Each attention weight can be viewed as an arrow connecting query and key elements. By incorporating positional embeddings, we can modify the attention weights to account for the positions of the tokens, allowing the model to learn the spatial relationships more effectively.

The modified attention weight can be expressed as:

[
\epsilon_{ij} = \frac{x_i W^Q (x_j W^K)^T + xi W^Q (p{ij}^K)^T}{\sqrt{d}}
]

Here, the term (xi W^Q (p{ij}^K)^T) represents the positional information, allowing the model to consider the distance of the query element to specific sequence positions.

Absolute vs. Relative Positional Embeddings

Positional embeddings can be categorized into two types: absolute and relative.

Absolute Positional Embeddings

Absolute positional embeddings assign a unique trainable vector to each position in the sequence. This approach modifies the representation based on the position of each token. The implementation is straightforward, as shown below:

import torch
from torch import nn, einsum

class AbsPosEmb1DAISummer(nn.Module):
    def __init__(self, tokens, dim_head):
        super().__init__()
        scale = dim_head ** -0.5
        self.abs_pos_emb = nn.Parameter(torch.randn(tokens, dim_head) * scale)

    def forward(self, q):
        return einsum('b h i d, j d -> b h i j', q, self.abs_pos_emb)

Relative Positional Embeddings

Relative positional embeddings, on the other hand, represent the distances between tokens rather than their absolute positions. This method allows the model to gain translation equivariance, similar to convolutional layers. The challenge with relative embeddings lies in managing the shape of the matrices involved.

To convert relative distances into a usable format, we can use the following function:

def relative_to_absolute(q):
    b, h, l, _, device, dtype = *q.shape, q.device, q.dtype
    col_pad = torch.zeros((b, h, l, 1), device=device, dtype=dtype)
    x = torch.cat((q, col_pad), dim=3)
    flat_x = rearrange(x, 'b h l c -> b h (l c)')
    flat_pad = torch.zeros((b, h, l - 1), device=device, dtype=dtype)
    flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
    final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
    final_x = final_x[:, :, :l, (l - 1):]
    return final_x

Implementation of Relative Positional Embeddings

The implementation of relative positional embeddings follows a similar structure to absolute embeddings, with adjustments to account for the relative distances:

class RelPosEmb1DAISummer(nn.Module):
    def __init__(self, tokens, dim_head, heads=None):
        super().__init__()
        scale = dim_head ** -0.5
        self.shared_heads = heads if heads is not None else True
        if self.shared_heads:
            self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens - 1, dim_head) * scale)
        else:
            self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens - 1, dim_head) * scale)

    def forward(self, q):
        return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)

Two-Dimensional Relative Positional Embeddings

In the context of computer vision, particularly when dealing with images, it’s beneficial to extend the concept of relative positional embeddings to two dimensions. This approach allows each pixel to receive two independent distances: one for the row and one for the column.

The implementation of 2D relative positional embeddings can be structured as follows:

class RelPosEmb2DAISummer(nn.Module):
    def __init__(self, feat_map_size, dim_head, heads=None):
        super().__init__()
        self.h, self.w = feat_map_size
        self.total_tokens = self.h * self.w
        self.shared_heads = heads if heads is not None else True
        self.emb_w = RelPosEmb1D(self.h, dim_head, heads)
        self.emb_h = RelPosEmb1D(self.w, dim_head, heads)

    def expand_emb(self, r, dim_size):
        r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
        expand_index = [-1, -1, -1, dim_size, -1, -1]
        r = r.expand(expand_index)
        return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')

    def forward(self, q):
        assert self.total_tokens == q.shape[2], f'Tokens {q.shape[2]} of q must be equal to the product of the feat map size {self.total_tokens}'
        r_h = self.emb_w(rearrange(q, 'b h (x y) d -> b (h x) y d', x=self.h, y=self.w))
        r_w = self.emb_h(rearrange(q, 'b h (x y) d -> b (h y) x d', x=self.h, y=self.w))
        q_r = self.expand_emb(r_h, self.h) + self.expand_emb(r_w, self.h)
        return q_r

Conclusion

This article has provided a comprehensive overview of positional embeddings in transformers, highlighting their significance in both NLP and computer vision tasks. Understanding the differences between absolute and relative positional embeddings, as well as their implementations, is crucial for effectively utilizing transformer architectures.

Acknowledgments

I would like to express my gratitude to Phil Wang for his inspiring implementations and to Amirhossein Kazemnejad for his insightful articles on positional encoding. Their contributions have significantly aided my understanding of this complex topic.

References

  1. Wang, Y. A., & Chen, Y. N. (2020). What Do Position Embeddings Learn? An Empirical Study of Pre-Trained Language Model Positional Encoding. arXiv preprint arXiv:2010.04903.
  2. Vaswani, A., et al. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  3. Shaw, P., Uszkoreit, J., & Vaswani, A. (2018). Self-attention with relative position representations. arXiv preprint arXiv:1803.02155.
  4. Ramachandran, P., et al. (2019). Stand-alone self-attention in vision models. arXiv preprint arXiv:1906.05909.

By understanding and implementing positional embeddings effectively, you can enhance the performance of transformer models across various applications. Happy coding!

Read more

Related updates