Building a Transformer Neural Network with JAX and Haiku
In the rapidly evolving landscape of deep learning, JAX has emerged as a powerful tool for researchers and developers alike. Its ability to facilitate high-performance numerical computing, combined with features like automatic differentiation and just-in-time compilation, makes it an attractive choice for building complex models. In this tutorial, we will delve into the process of developing a Transformer neural network using JAX and the Haiku library.
Why JAX and Haiku?
While JAX may not yet have the maturity of TensorFlow or PyTorch, it offers unique advantages that are particularly appealing for research applications. Its functional programming paradigm allows for cleaner and more modular code, while its performance optimizations enable efficient computation on accelerators like GPUs and TPUs. Haiku, developed by DeepMind, is a neural network library built on top of JAX that provides a simple and flexible way to define neural network architectures.
If you’re new to JAX, I recommend checking out my previous article for a solid understanding of its basics. You can also find the complete code for this tutorial in our GitHub repository.
Choosing the Right Framework
One common challenge for newcomers to JAX is selecting the right framework. DeepMind has released several libraries on top of JAX, each catering to different needs. Here are some of the most notable ones:
- Haiku: A go-to framework for deep learning that provides composable abstractions and ready-to-use modules.
- Optax: A gradient processing and optimization library with built-in optimizers.
- Flax: A neural network library with a variety of modules and utilities.
- Trax: An end-to-end library for deep learning focused on Transformers.
- Jraph: A library for Graph Neural Networks.
- JAXline: A library for distributed JAX training and evaluation.
For this tutorial, we will focus on Haiku, as it is widely used within Google and DeepMind and has an active community.
Understanding Transformers
Before we dive into the code, it’s essential to have a solid understanding of Transformers. If you’re unfamiliar with the architecture, I recommend reading our articles on attention mechanisms and Transformers.
Implementing the Self-Attention Block
Let’s start by implementing the self-attention block, a crucial component of the Transformer architecture. First, we need to import the necessary libraries:
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
Haiku provides a built-in MultiHeadAttention
block that we can extend to create a masked self-attention block. This block will accept the query, key, value, and mask, returning the output as a JAX array.
class SelfAttention(hk.MultiHeadAttention):
"""Self attention with a causal mask applied."""
def __call__(self, query: jnp.ndarray, key: Optional[jnp.ndarray] = None,
value: Optional[jnp.ndarray] = None, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
This code snippet introduces a key principle of Haiku: all modules should subclass hk.Module
, implementing the __init__
and __call__
methods, similar to how you would in PyTorch.
Building a Dense Block
Next, let’s create a simple two-layer Multilayer Perceptron (MLP) as another module in our Transformer.
class DenseBlock(hk.Module):
"""A 2-layer MLP"""
def __init__(self, init_scale: float, widening_factor: int = 4, name: Optional[str] = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape[-1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
Implementing Layer Normalization
Layer normalization is another integral component of the Transformer architecture. We can easily implement it using Haiku’s built-in functionality.
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
"""Apply a unique LayerNorm to x with default settings."""
return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)
Constructing the Transformer Model
Now, let’s put everything together and construct a simple Transformer model. In the __init__
method, we define the basic parameters such as the number of layers, attention heads, and dropout rate. In the __call__
method, we compose a list of blocks using a for loop.
class Transformer(hk.Module):
"""A transformer stack."""
def __init__(self, num_heads: int, num_layers: int, dropout_rate: float, name: Optional[str] = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray:
"""Connects the transformer."""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(num_heads=self._num_heads, key_size=64, w_init_scale=init_scale, name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
Adding the Embeddings Layer
To complete our Transformer model, we need to include an embeddings layer that converts input tokens into embeddings.
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
Why Use Pure Functions?
One of the key features of JAX is its ability to transform functions into pure functions. A pure function is one that always produces the same output for the same input and has no side effects. This property is essential for JAX’s optimizations, including vectorization and parallelization.
To convert our model into a pure function, we will use hk.transform
. This allows us to take advantage of JAX’s powerful features.
Implementing the Forward Pass
The forward pass of our model will involve computing the input embeddings, passing them through the Transformer blocks, and returning the output.
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int, num_layers: int, dropout_rate: float):
"""Create the model's forward pass."""
def forward_fn(data: Mapping[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray:
"""Forward pass."""
input_embeddings, input_mask = embeddings(data, vocab_size)
transformer = Transformer(num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
Defining the Loss Function
The loss function for our model will be the cross-entropy loss, taking the mask into account.
def lm_loss_fn(forward_fn, vocab_size: int, params, rng, data: Mapping[str, jnp.ndarray], is_training: bool = True) -> jnp.ndarray:
"""Compute the loss on data wrt params."""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data['target'], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data['obs'], 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
Building the Training Loop
To train our model, we will use the Optax library for gradient processing. The training loop will involve initializing the model, updating the parameters, and logging the metrics.
class GradientUpdater:
"""A stateless abstraction around an init_fn/update_fn pair."""
def __init__(self, net_init, loss_fn, optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(step=np.array(0), rng=out_rng, opt_state=opt_state, params=params)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
"""Updates the state using some data and returns metrics."""
rng, new_rng = jax.random.split(state['rng'])
params = state['params']
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
new_state = {'step': state['step'] + 1, 'rng': new_rng, 'opt_state': opt_state, 'params': params}
metrics = {'step': state['step'], 'loss': loss}
return new_state, metrics
Finally, we can build the main training loop.
def main():
train_dataset, vocab_size = load(batch_size, sequence_length)
forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
Conclusion
In this tutorial, we explored how to develop and train a Transformer neural network using JAX and Haiku. While the code may not be as straightforward as in PyTorch or TensorFlow, JAX’s unique features offer powerful capabilities for high-performance computing. I encourage you to experiment with JAX and discover its strengths and weaknesses for your projects.
For further exploration, consider checking out the official examples of the Haiku framework in the official repository.
If you have any questions or would like to share your experiences with JAX, feel free to join our Discord channel. Happy coding!