Sunday, December 22, 2024

Creating a Custom Production-Ready Deep Learning Training Loop in TensorFlow from the Ground Up

Share

Building a Custom Trainer for Machine Learning Models: A Step-by-Step Guide

Training is undoubtedly the cornerstone of developing a machine learning application. It’s during this phase that you begin to gauge the effectiveness of your model, fine-tune hyperparameters, and make architectural adjustments. Most machine learning engineers dedicate a significant amount of time to training, experimenting with various models, tuning architectures, and identifying the best metrics and loss functions for their specific problems.

In this article, we continue our "Deep Learning in Production" series by constructing a model trainer for a segmentation example we’ve been working on. Instead of merely outlining basic topics and software engineering principles, we will walk through the entire development lifecycle step by step. This approach allows us to apply the best practices discussed in previous articles while building high-performance, maintainable software in real-time.

So, prepare for a deep dive into code as we embark on this journey!

Building a Training Loop in TensorFlow

Before we dive into coding, let’s recap what we have so far. Our Colab notebook currently contains boilerplate Keras code, including model compilation and fitting:

self.model.compile(optimizer=self.config.train.optimizer.type,
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   metrics=self.config.train.metrics)

LOG.info('Training started')
model_history = self.model.fit(self.train_dataset, epochs=self.epoches,
                                steps_per_epoch=self.steps_per_epoch,
                                validation_steps=self.validation_steps,
                                validation_data=self.test_dataset)

return model_history.history['loss'], model_history.history['val_loss']

Understanding the Code

In this snippet, we compile the model using the Adam optimizer, Sparse Categorical Crossentropy as our loss function, and Sparse Categorical Accuracy as our main metric. It’s important to note that we are using Python 3.7 and TensorFlow 2.0. For a complete setup guide, refer back to the first article in our series.

To simplify our training function, we can replace the above code with the following:

optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy()
trainer = Trainer(self.model, self.train_dataset, loss, optimizer, metrics, self.epoches)
trainer.train()

Here, we define the optimizer, loss, and metrics, and pass them along with the model and dataset to a new class called Trainer. This encapsulation allows us to call the train method and initiate the training process.

Separation of Concerns

A good practice in software engineering is to keep classes unaware of other components of the application. Each class should have a single responsibility, which is crucial for maintainability and scalability. This principle is known as separation of concerns.

Let’s define our Trainer class in a separate file. I typically create a folder named executors to house all basic ML functionalities, including training, validation, and prediction. Each trainer will depend on six components: model, input data, loss function, optimizer, metric, and the number of epochs.

class Trainer:
    def __init__(self, model, input, loss_fn, optimizer, metric, epoches):
        self.model = model
        self.input = input
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.metric = metric
        self.epoches = epoches

Implementing the Training Step

Inside the Trainer class, we need a train function for overall training functionality and a train_step function for executing a single training step. Custom training loops are often preferred over high-level APIs like Keras because they allow for fine-tuning of every detail and provide full control over the training process.

In the train_step function, we perform the actual training for a single batch:

def train_step(self, batch):
    trainable_variables = self.model.trainable_variables
    inputs, labels = batch
    with tf.GradientTape() as tape:
        predictions = self.model(inputs)
        step_loss = self.loss_fn(labels, predictions)
    grads = tape.gradient(step_loss, trainable_variables)
    self.optimizer.apply_gradients(zip(grads, trainable_variables))
    self.metric.update_state(labels, predictions)
    return step_loss, predictions

The Training Loop

Next, we implement the train method, which will iterate over the number of epochs and train each batch:

def train(self):
    for epoch in range(self.epoches):
        LOG.info(f'Start epoch {epoch}')
        for step, training_batch in enumerate(self.input):
            step_loss, predictions = self.train_step(training_batch)
            LOG.info("Loss at step %d: %.2f" % (step, step_loss))
        train_acc = self.metric.result()
        LOG.info(f'Saved checkpoint: {save_path}')

This method includes logging, which is vital for tracking the training process and diagnosing issues. The input is a TensorFlow dataset (tf.data), allowing us to iterate over it like a normal array or list.

Implementing Checkpoints

Deep learning models often require extensive training time, making it essential to save the model’s state periodically. TensorFlow provides built-in functionality for managing checkpoints:

self.checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, './tf_ckpts')

To save the current state, we simply call:

save_path = self.checkpoint_manager.save()
LOG.info(f'Saved checkpoint: {save_path}')

Saving the Trained Model

Once training is complete, we want to store the trained model for future use:

self.model_save_path = 'saved_models/'
save_path = os.path.join(self.model_save_path, "unet/1/")
tf.saved_model.save(self.model, save_path)

Loading a saved model is equally straightforward:

model = tf.saved_model.load(save_path)

Visualizing Training with TensorBoard

For those who prefer visual data, TensorBoard is an excellent tool for plotting metrics captured during training. It allows you to create beautiful graphs and visualize the computational graph of your architecture.

To use TensorBoard, we can create a summary writer:

self.train_log_dir = 'logs/gradient_tape/'
self.train_summary_writer = tf.summary.create_file_writer(self.train_log_dir)

At the end of each epoch, we can save the current metrics:

def _write_summary(self, loss, epoch):
    with self.train_summary_writer.as_default():
        tf.summary.scalar('loss', loss, step=epoch)
        tf.summary.scalar('accuracy', self.metric.result(), step=epoch)

Conclusion

In this article, we explored how to build a custom trainer from scratch, adhering to best practices for maintainability and extensibility. We also delved into TensorFlow techniques that simplify our code. The process outlined here reflects the thought process I would follow in a real-world project, and the code is nearly identical to what would run in production.

As we continue this series, we will focus on optimizing training, including distributing it across GPUs and multiple machines. We will also present a full example of running a training job in the cloud.

I hope you found this article informative and engaging. If you have any feedback or would like to see more content like this, please let us know. Until next time, enjoy your journey into AI!

Read more

Related updates