Sunday, December 22, 2024

Comparing JAX, TensorFlow, and PyTorch: Constructing a Variational Autoencoder (VAE)

Share

Comparing JAX, TensorFlow, and PyTorch: Building a Variational Autoencoder from Scratch

In the ever-evolving landscape of deep learning frameworks, JAX, TensorFlow, and PyTorch stand out as the most popular choices among researchers and practitioners. Each framework has its unique strengths and weaknesses, making it essential for developers to understand their differences. To facilitate this understanding, I decided to build a Variational Autoencoder (VAE) from scratch using all three frameworks simultaneously. This article will present the code for each component side by side, allowing us to explore the similarities and differences in their implementations.

Prologue

Before diving into the code, let’s clarify a few things:

  1. Frameworks Used:

    • For JAX, I will utilize Flax, a neural network library developed by Google that provides a range of ready-to-use deep learning modules.
    • For TensorFlow, I will rely on Keras abstractions.
    • For PyTorch, I will use the standard nn.Module.
  2. Focus on JAX and Flax:

    • Given that many readers are already familiar with TensorFlow and PyTorch, this article will place a stronger emphasis on JAX and Flax. I will explain concepts that may be unfamiliar to many, effectively making this article a light tutorial on Flax.
  3. Assumed Knowledge:
    • It is assumed that readers are familiar with the basic principles behind Variational Autoencoders. If not, I recommend checking out my previous article on latent variable models.

A Quick Recap on VAEs

A vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the input into a latent representation ( z ), while the decoder attempts to reconstruct the input based on that representation. In Variational Autoencoders, stochasticity is introduced, meaning the latent representation provides a probability distribution. This is achieved through the reparameterization trick.

VAE Architecture

The Encoder

For the encoder, a simple linear layer followed by a ReLU activation function suffices for our toy example. The output will be both the mean and standard deviation of the probability distribution.

JAX (Flax) Implementation

In Flax, the basic building block is the Module abstraction. To implement our encoder, we need to:

  1. Initialize a class that inherits from flax.linen.nn.Module.
  2. Define static arguments as dataclass attributes.
  3. Implement the forward pass inside the __call__ method.
import jax
import jax.numpy as jnp
from flax import linen as nn

class Encoder(nn.Module):
    latents: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(500, name='fc1')(x)
        x = nn.relu(x)
        mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
        logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
        return mean_x, logvar_x

TensorFlow (Keras) Implementation

In TensorFlow, we define the encoder as follows:

import tensorflow as tf
from tensorflow.keras import layers

class Encoder(layers.Layer):
    def __init__(self, latent_dim=20, name='encoder', **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.enc1 = layers.Dense(500, activation='relu')
        self.mean_x = layers.Dense(latent_dim)
        self.logvar_x = layers.Dense(latent_dim)

    def call(self, inputs):
        x = self.enc1(inputs)
        z_mean = self.mean_x(x)
        z_log_var = self.logvar_x(x)
        return z_mean, z_log_var

PyTorch Implementation

In PyTorch, the encoder is implemented as follows:

import torch
import torch.nn.functional as F

class Encoder(torch.nn.Module):
    def __init__(self, latent_dim=20):
        super(Encoder, self).__init__()
        self.enc1 = torch.nn.Linear(784, 500)
        self.mean_x = torch.nn.Linear(500, latent_dim)
        self.logvar_x = torch.nn.Linear(500, latent_dim)

    def forward(self, inputs):
        x = self.enc1(inputs)
        x = F.relu(x)
        z_mean = self.mean_x(x)
        z_log_var = self.logvar_x(x)
        return z_mean, z_log_var

Observations

  • The implementations across all three frameworks are quite similar, with minor syntactical differences.
  • Flax’s nn.linen package provides most deep learning layers and operations, similar to TensorFlow and PyTorch.

The Decoder

The decoder will also consist of two linear layers that receive the latent representation ( z ) and output the reconstructed input.

JAX (Flax) Implementation

class Decoder(nn.Module):
    @nn.compact
    def __call__(self, z):
        z = nn.Dense(500, name='fc1')(z)
        z = nn.relu(z)
        z = nn.Dense(784, name='fc2')(z)
        return z

TensorFlow (Keras) Implementation

class Decoder(layers.Layer):
    def __init__(self, name='decoder', **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dec1 = layers.Dense(500, activation='relu')
        self.out = layers.Dense(784)

    def call(self, z):
        z = self.dec1(z)
        return self.out(z)

PyTorch Implementation

class Decoder(torch.nn.Module):
    def __init__(self, latent_dim=20):
        super(Decoder, self).__init__()
        self.dec1 = torch.nn.Linear(latent_dim, 500)
        self.out = torch.nn.Linear(500, 784)

    def forward(self, z):
        z = self.dec1(z)
        z = F.relu(z)
        return self.out(z)

Observations

  • The decoder implementations are similarly structured across all three frameworks.
  • The use of activation functions and layer definitions remains consistent.

Variational Autoencoder

To combine the encoder and decoder, we will create a class called VAE that represents the entire architecture. This class will also include the reparameterization trick.

JAX (Flax) Implementation

class VAE(nn.Module):
    latents: int = 20

    def setup(self):
        self.encoder = Encoder(self.latents)
        self.decoder = Decoder()

    def __call__(self, x, z_rng):
        mean, logvar = self.encoder(x)
        z = reparameterize(z_rng, mean, logvar)
        recon_x = self.decoder(z)
        return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = random.normal(rng, logvar.shape)
    return mean + eps * std

TensorFlow (Keras) Implementation

class VAE(tf.keras.Model):
    def __init__(self, latent_dim=20, name='vae', **kwargs):
        super(VAE, self).__init__(name=name, **kwargs)
        self.encoder = Encoder(latent_dim=latent_dim)
        self.decoder = Decoder()

    def call(self, inputs):
        z_mean, z_log_var = self.encoder(inputs)
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed = self.decoder(z)
        return reconstructed, z_mean, z_log_var

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return mean + eps * tf.exp(logvar * .5)

PyTorch Implementation

class VAE(torch.nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, inputs):
        z_mean, z_log_var = self.encoder(inputs)
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed = self.decoder(z)
        return reconstructed, z_mean, z_log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + (eps * std)

Observations

  • The structure of the VAE class is quite similar across all three frameworks.
  • The reparameterization trick is implemented in a nearly identical manner, showcasing the conceptual consistency across frameworks.

Loss and Training Step

The implementation of the training step and loss function begins to show more variation among the frameworks.

JAX (Flax) Implementation

In JAX, we leverage automatic vectorization and XLA compilation. Here’s how we define the loss functions and training step:

@jax.vmap
def kl_divergence(mean, logvar):
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit
def train_step(optimizer, batch, z_rng):
    def loss_fn(params):
        recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        loss = bce_loss + kld_loss
        return loss, recon_x

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    _, grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer

TensorFlow (Keras) Implementation

In TensorFlow, the training step is defined as follows:

def kl_divergence(mean, logvar):
    return -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=1)

def binary_cross_entropy_with_logits(logits, labels):
    logits = tf.math.log(logits)
    return -tf.reduce_sum(labels * logits + (1 - labels) * tf.math.log(-tf.math.expm1(logits)), axis=1)

@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        recon_x, mean, logvar = model(x)
        bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))
        kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))
        loss = bce_loss + kld_loss
        print(loss, kld_loss, bce_loss)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

PyTorch Implementation

In PyTorch, the training step is structured as follows:

def final_loss(reconstruction, train_x, mu, logvar):
    BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_step(train_x):
    train_x = torch.from_numpy(train_x)
    optimizer.zero_grad()
    reconstruction, mu, logvar = model(train_x)
    loss = final_loss(reconstruction, train_x, mu, logvar)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()

Observations

  • JAX requires additional steps to leverage its capabilities, such as automatic differentiation and vectorization.
  • TensorFlow and PyTorch have more straightforward implementations for the training step, but they lack the advanced features that JAX provides.

Training Loop

The training loop is where we execute the train_step function iteratively.

JAX (Flax) Implementation

rng = random.PRNGKey(0)
rng, key = random.split(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)

params = model().init(key, init_data, rng)['params']
optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)
optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE

for epoch in range(NUM_EPOCHS):
    for _ in range(steps_per_epoch):
        batch = next(train_ds)
        rng, key = random.split(rng)
        optimizer = train_step(optimizer, batch, key)

TensorFlow (Keras) Implementation

vae = VAE(latent_dim=LATENTS)
optimizer = tf.keras.optimizers.Adam(1e-4)

for epoch in range(NUM_EPOCHS):
    for train_x in train_ds:
        train_step(vae, train_x, optimizer)

PyTorch Implementation

vae = VAE(LATENTS)
train(vae, train_ds)

Observations

  • The training loop in JAX requires explicit initialization of the model and optimizer, which adds some complexity.
  • TensorFlow and PyTorch provide simpler training loops, making them more accessible for beginners.

Load and Process Data

One aspect that is often overlooked is data loading and preprocessing. Flax does not include data manipulation packages, so we often borrow packages from other frameworks.

Data Loading Example

import tensorflow_datasets as tfds

tf.config.experimental.set_visible_devices([], 'GPU')

def prepare_image(x):
    x = tf.cast(x['image'], tf.float32)
    x = tf.reshape(x, (-1,))
    return x

ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])

Final Observations

As we conclude this exploration of building a Variational Autoencoder using JAX, TensorFlow, and PyTorch, here are some final thoughts:

  1. Boilerplate Code: All three frameworks have minimized boilerplate code, with Flax requiring slightly more due to its advanced features.
  2. Module Definitions: The definition of modules, layers, and models is almost identical across the frameworks.
  3. Flexibility: Flax and JAX are designed to be flexible and expandable, making them suitable for research and experimentation.
  4. Data Handling: Flax currently lacks built-in data loading and processing capabilities, which may require developers to rely on other libraries.
  5. Layer Availability: While Flax may not have as extensive a library as TensorFlow and PyTorch, it is gradually catching up.

In summary, each framework has its unique advantages, and the choice largely depends on the specific needs of your project. Whether you prioritize flexibility, ease of use, or advanced features, there is a framework that can meet your requirements.


Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale, and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples. Learn more.

Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.

Read more

Related updates