Monday, December 23, 2024

Create a Transformer from Scratch in JAX: A Guide to Writing and Training Your Own Models

Share

Building a Transformer Neural Network with JAX and Haiku

In the rapidly evolving landscape of deep learning, JAX has emerged as a powerful tool for researchers and developers alike. Its ability to facilitate high-performance numerical computing, combined with features like automatic differentiation and just-in-time compilation, makes it an attractive choice for building complex models. In this tutorial, we will delve into the process of developing a Transformer neural network using JAX and the Haiku library.

Why JAX and Haiku?

While JAX may not yet have the maturity of TensorFlow or PyTorch, it offers unique advantages that are particularly appealing for research applications. Its functional programming paradigm allows for cleaner and more modular code, while its performance optimizations enable efficient computation on accelerators like GPUs and TPUs. Haiku, developed by DeepMind, is a neural network library built on top of JAX that provides a simple and flexible way to define neural network architectures.

If you’re new to JAX, I recommend checking out my previous article for a solid understanding of its basics. You can also find the complete code for this tutorial in our GitHub repository.

Choosing the Right Framework

One common challenge for newcomers to JAX is selecting the right framework. DeepMind has released several libraries on top of JAX, each catering to different needs. Here are some of the most notable ones:

  • Haiku: A go-to framework for deep learning that provides composable abstractions and ready-to-use modules.
  • Optax: A gradient processing and optimization library with built-in optimizers.
  • Flax: A neural network library with a variety of modules and utilities.
  • Trax: An end-to-end library for deep learning focused on Transformers.
  • Jraph: A library for Graph Neural Networks.
  • JAXline: A library for distributed JAX training and evaluation.

For this tutorial, we will focus on Haiku, as it is widely used within Google and DeepMind and has an active community.

Understanding Transformers

Before we dive into the code, it’s essential to have a solid understanding of Transformers. If you’re unfamiliar with the architecture, I recommend reading our articles on attention mechanisms and Transformers.

Implementing the Self-Attention Block

Let’s start by implementing the self-attention block, a crucial component of the Transformer architecture. First, we need to import the necessary libraries:

import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np

Haiku provides a built-in MultiHeadAttention block that we can extend to create a masked self-attention block. This block will accept the query, key, value, and mask, returning the output as a JAX array.

class SelfAttention(hk.MultiHeadAttention):
    """Self attention with a causal mask applied."""
    def __call__(self, query: jnp.ndarray, key: Optional[jnp.ndarray] = None,
                 value: Optional[jnp.ndarray] = None, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        key = key if key is not None else query
        value = value if value is not None else query
        seq_len = query.shape[1]
        causal_mask = np.tril(np.ones((seq_len, seq_len)))
        mask = mask * causal_mask if mask is not None else causal_mask
        return super().__call__(query, key, value, mask)

This code snippet introduces a key principle of Haiku: all modules should subclass hk.Module, implementing the __init__ and __call__ methods, similar to how you would in PyTorch.

Building a Dense Block

Next, let’s create a simple two-layer Multilayer Perceptron (MLP) as another module in our Transformer.

class DenseBlock(hk.Module):
    """A 2-layer MLP"""
    def __init__(self, init_scale: float, widening_factor: int = 4, name: Optional[str] = None):
        super().__init__(name=name)
        self._init_scale = init_scale
        self._widening_factor = widening_factor

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        hiddens = x.shape[-1]
        initializer = hk.initializers.VarianceScaling(self._init_scale)
        x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
        x = jax.nn.gelu(x)
        return hk.Linear(hiddens, w_init=initializer)(x)

Implementing Layer Normalization

Layer normalization is another integral component of the Transformer architecture. We can easily implement it using Haiku’s built-in functionality.

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    """Apply a unique LayerNorm to x with default settings."""
    return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)

Constructing the Transformer Model

Now, let’s put everything together and construct a simple Transformer model. In the __init__ method, we define the basic parameters such as the number of layers, attention heads, and dropout rate. In the __call__ method, we compose a list of blocks using a for loop.

class Transformer(hk.Module):
    """A transformer stack."""
    def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name: Optional[str] = None):
        super().__init__(name=name)
        self._num_layers = num_layers
        self._num_heads = num_heads
        self._dropout_rate = dropout_rate

    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray:
        """Connects the transformer."""
        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = SelfAttention(num_heads=self._num_heads, key_size=64, w_init_scale=init_scale, name=f'h{i}_attn')(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense

        h = layer_norm(h, name='ln_f')
        return h

Adding the Embeddings Layer

To complete our Transformer model, we need to include an embeddings layer that converts input tokens into embeddings.

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask

Why Use Pure Functions?

One of the key features of JAX is its ability to transform functions into pure functions. A pure function is one that always produces the same output for the same input and has no side effects. This property is essential for JAX’s optimizations, including vectorization and parallelization.

To convert our model into a pure function, we will use hk.transform. This allows us to take advantage of JAX’s powerful features.

Implementing the Forward Pass

The forward pass of our model will involve computing the input embeddings, passing them through the Transformer blocks, and returning the output.

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int, num_layers: int, dropout_rate: float):
    """Create the model's forward pass."""
    def forward_fn(data: Mapping[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        input_embeddings, input_mask = embeddings(data, vocab_size)
        transformer = Transformer(num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask, is_training)
        return hk.Linear(vocab_size)(output_embeddings)
    return forward_fn

Defining the Loss Function

The loss function for our model will be the cross-entropy loss, taking the mask into account.

def lm_loss_fn(forward_fn, vocab_size: int, params, rng, data: Mapping[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray:
    """Compute the loss on data wrt params."""
    logits = forward_fn(params, rng, data, is_training)
    targets = jax.nn.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape
    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)
    return loss

Building the Training Loop

To train our model, we will use the Optax library for gradient processing. The training loop will involve initializing the model, updating the parameters, and logging the metrics.

class GradientUpdater:
    """A stateless abstraction around an init_fn/update_fn pair."""
    def __init__(self, net_init, loss_fn, optimizer: optax.GradientTransformation):
        self._net_init = net_init
        self._loss_fn = loss_fn
        self._opt = optimizer

    @functools.partial(jax.jit, static_argnums=0)
    def init(self, master_rng, data):
        """Initializes state of the updater."""
        out_rng, init_rng = jax.random.split(master_rng)
        params = self._net_init(init_rng, data)
        opt_state = self._opt.init(params)
        out = dict(step=np.array(0), rng=out_rng, opt_state=opt_state, params=params)
        return out

    @functools.partial(jax.jit, static_argnums=0)
    def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
        """Updates the state using some data and returns metrics."""
        rng, new_rng = jax.random.split(state['rng'])
        params = state['params']
        loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
        updates, opt_state = self._opt.update(g, state['opt_state'])
        params = optax.apply_updates(params, updates)
        new_state = {'step': state['step'] + 1, 'rng': new_rng, 'opt_state': opt_state, 'params': params}
        metrics = {'step': state['step'], 'loss': loss}
        return new_state, metrics

Finally, we can build the main training loop.

def main():
    train_dataset, vocab_size = load(batch_size, sequence_length)
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99))
    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

Conclusion

In this tutorial, we explored how to develop and train a Transformer neural network using JAX and Haiku. While the code may not be as straightforward as in PyTorch or TensorFlow, JAX’s unique features offer powerful capabilities for high-performance computing. I encourage you to experiment with JAX and discover its strengths and weaknesses for your projects.

For further exploration, consider checking out the official examples of the Haiku framework in the official repository.

If you have any questions or would like to share your experiences with JAX, feel free to join our Discord channel. Happy coding!

Read more

Related updates