Monday, December 23, 2024

Deploying a Deep Learning Model as a Web Application with Flask and TensorFlow

Share

Bringing Deep Learning Models to Life: From Research to Real-World Applications

Developing a state-of-the-art deep learning model is an impressive feat, but its true value lies in its application to real-world problems. While research is undeniably fascinating, the ultimate goal is often to leverage these advancements to create solutions that can be utilized by everyday users. In the realm of deep learning, many models are deployed as web or mobile applications, making them accessible and functional in practical scenarios.

In this article series, we will embark on a journey to take our image segmentation model, expose it via an API using Flask, and deploy it in a production environment. If you’re new to this series, let’s recap: we started with a simple U-Net model in a Colab notebook that performs image segmentation, and we transformed it into a full-scale, highly-optimized project. Now, we will serve it to real users at scale. For more details, check out the previous article or our GitHub repository.

Our end goal is to create a fully functional service that can be accessed by clients/users to perform segmentation in real-time. If you’re not familiar with building client-server applications, don’t worry! We’ll break down the fundamental concepts to guide you through the process.

Understanding Key Concepts

Before diving into the implementation, let’s clarify some essential terms:

  • Web Service: A self-contained piece of software that is available over the internet and uses standard communication protocols like HTTP.
  • Server: A computer program or device that provides services to another program (the client).
  • Client-Server Model: A programming paradigm where one program (the client) requests a service or resource from another program (the server).
  • API (Application Programming Interface): A set of definitions and functions that allows applications to access data and interact with external software components, operating systems, or microservices.

With these definitions in mind, let’s outline our next steps. We need to create an “inferrer” class that interacts with our TensorFlow model to compute the segmentation map. Then, we will build a web application to expose this functionality via an API, and finally, we will set up a web service that allows clients to communicate with it and send their images for prediction.

Inferring a Segmentation Mask of a Custom Image

We previously trained our model using a custom training loop and saved the training variables using TensorFlow’s built-in saving functionality:

save_path = os.path.join(self.model_save_path, "unet")
tf.saved_model.save(self.model, save_path)

Our next steps are straightforward: a) load the saved model, b) feed it with the user’s image, and c) infer the segmentation mask. A good approach is to build an inferrer class that loads the model upon creation (to avoid loading it multiple times) and includes an inference method that returns the model’s result.

The Inferrer Class

Here’s a basic structure for our UnetInferrer class:

class UnetInferrer:
    def __init__(self):
        self.saved_path = 'model_path_location'
        self.model = tf.saved_model.load(self.saved_path)
        self.predict = self.model.signatures["serving_default"]

    def preprocess(self, image):
        image = tf.image.resize(image, (self.image_size, self.image_size))
        return tf.cast(image, tf.float32) / 255.0

    def infer(self, image=None):
        tensor_image = tf.convert_to_tensor(image, dtype=tf.float32)
        tensor_image = self.preprocess(tensor_image)
        shape = tensor_image.shape
        tensor_image = tf.reshape(tensor_image, [1, shape[0], shape[1], shape[2]])
        return self.predict(tensor_image)['conv2d_transpose_4']

Preprocessing the Image

It’s important to note that user images may not be in the desired format, so we need to preprocess them before passing them to the model. The preprocessing method resizes the image and normalizes it:

def preprocess(self, image):
    image = tf.image.resize(image, (self.image_size, self.image_size))
    return tf.cast(image, tf.float32) / 255.0

Testing the Inferrer

To ensure our inferrer works correctly, we can implement a simple unit test:

from PIL import Image
import numpy as np
from executor.unet_inferrer import UnetInferrer

class MyTestCase(unittest.TestCase):
    def test_infer(self):
        image = np.asarray(Image.open('resources/yorkshire_terrier.jpg')).astype(np.float32)
        inferrer = UnetInferrer()
        inferrer.infer(image)

Creating a Web Application Using Flask

Now that we have our inferrer ready, it’s time to create a web application using Flask. But first, what is Flask? Flask is a lightweight web application framework that allows us to build applications with minimal boilerplate code and provides essential functionalities out of the box.

Setting Up Flask

To create our Flask application, we start by importing Flask and creating an instance:

from flask import Flask, request

app = Flask(__name__)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8080)

Next, we will define an endpoint for our inference function. For example, we can set up a route at 0.0.0.0:8080/infer using the POST method.

Defining the Endpoint

We can create our endpoint as follows:

@app.route('/infer', methods=["POST"])
def infer():
    data = request.json
    image = data['image']
    return u_net.infer(image)

In this code, we retrieve the image from the request body and pass it to our inferrer class for processing.

Error Handling

Flask also provides a simple way to handle exceptions:

from flask import jsonify

@app.errorhandler(Exception)
def handle_exception(e):
    return jsonify(stackTrace=traceback.format_exc())

This will return a traceback of any errors that occur during execution, making debugging easier.

Creating a Client

To test our server, we can create a simple client using Python’s requests library. This client will send a request to our endpoint and display the response.

import requests
from PIL import Image
import numpy as np

ENDPOINT_URL = "http://0.0.0.0:8080/infer"

def infer():
    image = np.asarray(Image.open('resources/yorkshire_terrier.jpg')).astype(np.float32)
    data = {'image': image.tolist()}
    response = requests.post(ENDPOINT_URL, json=data)
    response.raise_for_status()
    print(response)

if __name__ == "__main__":
    infer()

Displaying the Results

To visualize the predicted segmentation mask, we can use Matplotlib:

import matplotlib.pyplot as plt
import tensorflow as tf

def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

Conclusion

In this article, we built an inferrer for our deep learning model, exposed it through a web server using Flask, and created a client to send requests and receive predictions. While we’ve made significant progress, our current setup runs locally and is not optimized for production environments.

In the next article, we will explore how to utilize uWSGI to create a high-performance, production-ready server and how to use a load balancer like Nginx to distribute traffic effectively among multiple processes. If that sounds interesting, I hope to see you in the next installment!

Auf Wiedersehen…

References

Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.

Read more

Related updates