Comparing JAX, TensorFlow, and PyTorch: Building a Variational Autoencoder from Scratch
In the ever-evolving landscape of deep learning frameworks, JAX, TensorFlow, and PyTorch stand out as the most popular choices among researchers and practitioners. Each framework has its unique strengths and weaknesses, making it essential for developers to understand their differences. To facilitate this understanding, I decided to build a Variational Autoencoder (VAE) from scratch using all three frameworks simultaneously. This article will present the code for each component side by side, allowing us to explore the similarities and differences in their implementations.
Prologue
Before diving into the code, let’s clarify a few things:
-
Frameworks Used:
-
Focus on JAX and Flax:
- Given that many readers are already familiar with TensorFlow and PyTorch, this article will place a stronger emphasis on JAX and Flax. I will explain concepts that may be unfamiliar to many, effectively making this article a light tutorial on Flax.
- Assumed Knowledge:
- It is assumed that readers are familiar with the basic principles behind Variational Autoencoders. If not, I recommend checking out my previous article on latent variable models.
A Quick Recap on VAEs
A vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the input into a latent representation ( z ), while the decoder attempts to reconstruct the input based on that representation. In Variational Autoencoders, stochasticity is introduced, meaning the latent representation provides a probability distribution. This is achieved through the reparameterization trick.
The Encoder
For the encoder, a simple linear layer followed by a ReLU activation function suffices for our toy example. The output will be both the mean and standard deviation of the probability distribution.
JAX (Flax) Implementation
In Flax, the basic building block is the Module
abstraction. To implement our encoder, we need to:
- Initialize a class that inherits from
flax.linen.nn.Module
. - Define static arguments as dataclass attributes.
- Implement the forward pass inside the
__call__
method.
import jax
import jax.numpy as jnp
from flax import linen as nn
class Encoder(nn.Module):
latents: int
@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
return mean_x, logvar_x
TensorFlow (Keras) Implementation
In TensorFlow, we define the encoder as follows:
import tensorflow as tf
from tensorflow.keras import layers
class Encoder(layers.Layer):
def __init__(self, latent_dim=20, name='encoder', **kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
self.enc1 = layers.Dense(500, activation='relu')
self.mean_x = layers.Dense(latent_dim)
self.logvar_x = layers.Dense(latent_dim)
def call(self, inputs):
x = self.enc1(inputs)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
PyTorch Implementation
In PyTorch, the encoder is implemented as follows:
import torch
import torch.nn.functional as F
class Encoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Encoder, self).__init__()
self.enc1 = torch.nn.Linear(784, 500)
self.mean_x = torch.nn.Linear(500, latent_dim)
self.logvar_x = torch.nn.Linear(500, latent_dim)
def forward(self, inputs):
x = self.enc1(inputs)
x = F.relu(x)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
Observations
- The implementations across all three frameworks are quite similar, with minor syntactical differences.
- Flax’s
nn.linen
package provides most deep learning layers and operations, similar to TensorFlow and PyTorch.
The Decoder
The decoder will also consist of two linear layers that receive the latent representation ( z ) and output the reconstructed input.
JAX (Flax) Implementation
class Decoder(nn.Module):
@nn.compact
def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
return z
TensorFlow (Keras) Implementation
class Decoder(layers.Layer):
def __init__(self, name='decoder', **kwargs):
super(Decoder, self).__init__(name=name, **kwargs)
self.dec1 = layers.Dense(500, activation='relu')
self.out = layers.Dense(784)
def call(self, z):
z = self.dec1(z)
return self.out(z)
PyTorch Implementation
class Decoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Decoder, self).__init__()
self.dec1 = torch.nn.Linear(latent_dim, 500)
self.out = torch.nn.Linear(500, 784)
def forward(self, z):
z = self.dec1(z)
z = F.relu(z)
return self.out(z)
Observations
- The decoder implementations are similarly structured across all three frameworks.
- The use of activation functions and layer definitions remains consistent.
Variational Autoencoder
To combine the encoder and decoder, we will create a class called VAE
that represents the entire architecture. This class will also include the reparameterization trick.
JAX (Flax) Implementation
class VAE(nn.Module):
latents: int = 20
def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
z = reparameterize(z_rng, mean, logvar)
recon_x = self.decoder(z)
return recon_x, mean, logvar
def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std
TensorFlow (Keras) Implementation
class VAE(tf.keras.Model):
def __init__(self, latent_dim=20, name='vae', **kwargs):
super(VAE, self).__init__(name=name, **kwargs)
self.encoder = Encoder(latent_dim=latent_dim)
self.decoder = Decoder()
def call(self, inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return mean + eps * tf.exp(logvar * .5)
PyTorch Implementation
class VAE(torch.nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def forward(self, inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + (eps * std)
Observations
- The structure of the VAE class is quite similar across all three frameworks.
- The reparameterization trick is implemented in a nearly identical manner, showcasing the conceptual consistency across frameworks.
Loss and Training Step
The implementation of the training step and loss function begins to show more variation among the frameworks.
JAX (Flax) Implementation
In JAX, we leverage automatic vectorization and XLA compilation. Here’s how we define the loss functions and training step:
@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(params):
recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
_, grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
TensorFlow (Keras) Implementation
In TensorFlow, the training step is defined as follows:
def kl_divergence(mean, logvar):
return -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=1)
def binary_cross_entropy_with_logits(logits, labels):
logits = tf.math.log(logits)
return -tf.reduce_sum(labels * logits + (1 - labels) * tf.math.log(-tf.math.expm1(logits)), axis=1)
@tf.function
def train_step(model, x, optimizer):
with tf.GradientTape() as tape:
recon_x, mean, logvar = model(x)
bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))
kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))
loss = bce_loss + kld_loss
print(loss, kld_loss, bce_loss)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
PyTorch Implementation
In PyTorch, the training step is structured as follows:
def final_loss(reconstruction, train_x, mu, logvar):
BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train_step(train_x):
train_x = torch.from_numpy(train_x)
optimizer.zero_grad()
reconstruction, mu, logvar = model(train_x)
loss = final_loss(reconstruction, train_x, mu, logvar)
running_loss += loss.item()
loss.backward()
optimizer.step()
Observations
- JAX requires additional steps to leverage its capabilities, such as automatic differentiation and vectorization.
- TensorFlow and PyTorch have more straightforward implementations for the training step, but they lack the advanced features that JAX provides.
Training Loop
The training loop is where we execute the train_step
function iteratively.
JAX (Flax) Implementation
rng = random.PRNGKey(0)
rng, key = random.split(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
params = model().init(key, init_data, rng)['params']
optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)
optimizer = jax.device_put(optimizer)
rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE
for epoch in range(NUM_EPOCHS):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
optimizer = train_step(optimizer, batch, key)
TensorFlow (Keras) Implementation
vae = VAE(latent_dim=LATENTS)
optimizer = tf.keras.optimizers.Adam(1e-4)
for epoch in range(NUM_EPOCHS):
for train_x in train_ds:
train_step(vae, train_x, optimizer)
PyTorch Implementation
vae = VAE(LATENTS)
train(vae, train_ds)
Observations
- The training loop in JAX requires explicit initialization of the model and optimizer, which adds some complexity.
- TensorFlow and PyTorch provide simpler training loops, making them more accessible for beginners.
Load and Process Data
One aspect that is often overlooked is data loading and preprocessing. Flax does not include data manipulation packages, so we often borrow packages from other frameworks.
Data Loading Example
import tensorflow_datasets as tfds
tf.config.experimental.set_visible_devices([], 'GPU')
def prepare_image(x):
x = tf.cast(x['image'], tf.float32)
x = tf.reshape(x, (-1,))
return x
ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))
test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])
Final Observations
As we conclude this exploration of building a Variational Autoencoder using JAX, TensorFlow, and PyTorch, here are some final thoughts:
- Boilerplate Code: All three frameworks have minimized boilerplate code, with Flax requiring slightly more due to its advanced features.
- Module Definitions: The definition of modules, layers, and models is almost identical across the frameworks.
- Flexibility: Flax and JAX are designed to be flexible and expandable, making them suitable for research and experimentation.
- Data Handling: Flax currently lacks built-in data loading and processing capabilities, which may require developers to rely on other libraries.
- Layer Availability: While Flax may not have as extensive a library as TensorFlow and PyTorch, it is gradually catching up.
In summary, each framework has its unique advantages, and the choice largely depends on the specific needs of your project. Whether you prioritize flexibility, ease of use, or advanced features, there is a framework that can meet your requirements.
Deep Learning in Production Book 📖
Learn how to build, train, deploy, scale, and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples. Learn more.
Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.