Monday, December 23, 2024

Mastering Einsum for Deep Learning: Building a Transformer with Multi-Head Self-Attention from the Ground Up

Share

Understanding Einsum Operations: A Key to Mastering Self-Attention in Transformers

In the rapidly evolving field of machine learning, particularly in deep learning and natural language processing, the ability to efficiently manipulate tensors is paramount. One of the most powerful tools at your disposal is the einsum operation. If you are a machine learning researcher or engineer, familiarizing yourself with einsum operations is not just beneficial; it’s essential.

My Journey with Einsum

I remember my initial struggles with understanding Git repositories that utilized einsum operations. Despite my comfort with tensor operations, einsum felt like an enigma. Determined to conquer this challenge, I delved into the einsum notation, particularly in the context of transformers and self-attention mechanisms in computer vision. This exploration turned out to be a rewarding journey, enhancing my understanding far beyond my expectations.

For those interested in the theoretical underpinnings of attention mechanisms and transformers, I recommend checking out my articles on attention and transformers. However, if you’re ready to dive into the practical aspects, let’s get started!

You can find the code for this tutorial on GitHub. If you find it helpful, please show your support with a star!

Why Einsum?

1. Elegant and Clean Code

Einsum notation allows for writing tensor operations in a concise and readable manner. Many AI industry specialists and researchers consistently use it for this reason. For instance, consider the task of merging two dimensions of a 4D tensor. The traditional method might look cumbersome:

x = x.permute(0, 3, 1, 2)
N, W, C, H = x.shape
x = x.contiguous().view(N * W, C, -1)

This can be simplified using einsum:

x = einops.rearrange(x, 'b c h w -> (b w) c h')

This not only improves readability but also reduces the likelihood of errors.

2. Efficient Batched Implementations

If you are working with custom layers involving multi-dimensional tensors, einsum is invaluable. It streamlines operations that would otherwise require cumbersome reshaping and permuting of tensors.

3. Simplified Code Translation

Einsum makes it easier to translate code between different frameworks like PyTorch, TensorFlow, and NumPy. This flexibility is crucial for researchers and engineers who work across multiple platforms.

The Basics of Einsum and Einops Notation

Einsum

The einsum operation is based on the Einstein summation convention, which provides a compact way to express summation over indices. The structure of an einsum command consists of two parts:

  • Equation String: This indicates the indices of the tensors involved in the operation. Lowercase letters represent dimensions, and matching letters indicate dimensions to be summed over.

  • Operands: The tensors involved in the operation are specified here.

For example, to perform batch matrix multiplication with two tensors a and c, you can use:

y1 = torch.einsum('b i k, b j k -> b i j', a, c)

Einops

While einsum is powerful, einops provides an additional layer of abstraction for tensor manipulation. The einops.rearrange function allows for intuitive reshaping of tensors. For instance, to decompose a tensor into multiple parts, you can use:

qkv = torch.rand(2, 128, 3 * 512)
q, k, v = tuple(rearrange(qkv, 'b t (d n) -> n b t d', n=3))

This flexibility in decomposition is one of the reasons I favor einops for single tensor operations.

Implementing Scaled Dot-Product Self-Attention

Now that we have a grasp of einsum and einops, let’s implement the scaled dot-product self-attention mechanism. Here’s a breakdown of the steps involved:

Step 1: Create Linear Projections

Given an input tensor ( X \in \mathbb{R}^{\text{batch} \times \text{tokens} \times \text{dim}} ), we create linear projections for queries, keys, and values:

Q = XW_Q
K = XW_K
V = XW_V

Step 2: Calculate Scaled Dot Product

Next, we compute the dot product of the queries and keys, apply a scaling factor, and then compute the softmax:

scaled_dot_prod = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale_factor
attention = torch.softmax(scaled_dot_prod, dim=-1)

Step 3: Multiply Scores with Values

Finally, we multiply the attention scores with the values:

output = torch.einsum('b i j, b j d -> b i d', attention, v)

Implementation of Scaled Dot-Product Self-Attention

Here’s the complete implementation:

import numpy as np
import torch
from einops import rearrange
from torch import nn

class SelfAttentionAISummer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.to_qvk = nn.Linear(dim, dim * 3, bias=False)
        self.scale_factor = dim ** -0.5

    def forward(self, x, mask=None):
        assert x.dim() == 3, '3D tensor must be provided'
        qkv = self.to_qvk(x)
        q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d', k=3))
        scaled_dot_prod = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale_factor
        if mask is not None:
            assert mask.shape == scaled_dot_prod.shape[1:]
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
        attention = torch.softmax(scaled_dot_prod, dim=-1)
        return torch.einsum('b i j, b j d -> b i d', attention, v

Multi-Head Self-Attention

To enhance the model’s capacity, we can introduce multiple heads in our self-attention mechanism. This allows the model to learn different representations of the input data.

Implementation of Multi-Head Self-Attention

Here’s how we can implement multi-head self-attention:

class MultiHeadSelfAttentionAISummer(nn.Module):
    def __init__(self, dim, heads=8, dim_head=None):
        super().__init__()
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        _dim = self.dim_head * heads
        self.heads = heads
        self.to_qvk = nn.Linear(dim, _dim * 3, bias=False)
        self.W_0 = nn.Linear(_dim, dim, bias=False)
        self.scale_factor = self.dim_head ** -0.5

    def forward(self, x, mask=None):
        assert x.dim() == 3
        qkv = self.to_qvk(x)
        q, k, v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.heads))
        scaled_dot_prod = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale_factor
        if mask is not None:
            assert mask.shape == scaled_dot_prod.shape[2:]
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
        attention = torch.softmax(scaled_dot_prod, dim=-1)
        out = torch.einsum('b h i j, b h j d -> b h i d', attention, v)
        out = rearrange(out, "b h t d -> b t (h d)")
        return self.W_0(out)

Building the Transformer Encoder

With our multi-head self-attention layer in place, constructing a transformer block becomes straightforward. Here’s a simple implementation:

class TransformerBlockAISummer(nn.Module):
    def __init__(self, dim, heads=8, dim_head=None, dim_linear_block=1024, dropout=0.1):
        super().__init__()
        self.mhsa = MultiHeadSelfAttention(dim=dim, heads=heads, dim_head=dim_head)
        self.drop = nn.Dropout(dropout)
        self.norm_1 = nn.LayerNorm(dim)
        self.norm_2 = nn.LayerNorm(dim)
        self.linear = nn.Sequential(
            nn.Linear(dim, dim_linear_block),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_linear_block, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        y = self.norm_1(self.drop(self.mhsa(x, mask)) + x)
        return self.norm_2(self.linear(y) + y)

Finally, we can stack multiple transformer blocks to create a complete transformer encoder.

Conclusion

Mastering einsum operations and their application in self-attention mechanisms has significantly enhanced my understanding of transformers. This journey has been both challenging and rewarding, and I encourage you to explore these concepts further.

For those looking to implement advanced self-attention blocks for computer vision, I invite you to check out our GitHub repository and share your experiences. Don’t forget to star the repository if you find it useful!

If you feel your PyTorch fundamentals need a boost, consider exploring Deep Learning with PyTorch for comprehensive insights.

Acknowledgments

I would like to extend my gratitude to Alex Rogozhnikov for the fantastic einops library, which has been instrumental in my learning process.

Happy coding!

Read more

Related updates