Project- How to build low-latency deep-learning-based flask app

14 / 17

Creating the resnet_model_server.py

As discussed previosuly, we shall import the model, define the Server class and the RequestHandler class in the resnet_model_server.py.

Note:

  • vi is used to:

    • create and edit the file, if the file doesn't exist.
    • edit the file, if the file already exists.

    If we want to create and/or edit the file named myFile, we use vi myFile.

  • To exit from the file, we click esc key and type :wq and hit on enter key. In :wq, w means save, and q means quit.

  • A thread is used for multi-tasking. threading is a standard library in Python using which we can make use of Thread in subclassing it in our custom classes. Here, we will be inheriting the Thread in Server and RequestHandler classes. When we use Thread subclassing, we need to know about the following important methods of the Thread:

    • start(): It is the method used to start the thread. This method internally calls the run() method of the thread.
    • run(): This method represents the thread's functionality. So, we define the functionality of the thread by overriding this run() method. For example, since the task of the Server is to (1) listen to a port, (2) direct requests to RequestHandler (3) terminate if there is an interruption, we define all these functionalities in the run() method of Server class which inherits the threading.Thread class. Similarly, the functionality of handling the request to return predictions is defined in the run() method of the RequestHandler class.
  • We also have something called threading.Event(). An event manages a flag that can be set to true with the set() method. Here we are instantiating event by writing self._stop = threading.Event(). When we click Ctrl+Z on our keyboard, that event invokes self._stop.set() to set an internal flag to True. Thus, self._stop.isSet() returns True subsequently.

INSTRUCTIONS
  • This file resnet_model_server.py should be created inside /Flask-ZMQ-App-Folder/Model-Server-Folder. So make sure you are in the /Flask-ZMQ-App-Folder/Model-Server-Folder. You could check your present working directory using the command:

    pwd
    

    This command should output the path:

    /home/$USER/Flask-ZMQ-App-Folder/Model-Server-Folder
    

    If the path displayed is not the same as the above, switch to the Image-Classification-App using

    cd ~/Flask-ZMQ-App-Folder/Model-Server-Folder
    
  • Create the file named resnet_model_server.py using the vi command inside the Flask-ZMQ-App-Folder/Model-Server-Folder directory.

  • Press i key on your keyboard, to switch to insert mode in the file.

  • Copy-paste the following code:

    from io import BytesIO
    from PIL import Image
    import threading
    import zmq
    from base64 import b64decode
    import numpy as np
    
    import tensorflow as tf
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.applications.resnet50 import ResNet50 as myModel
    from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
    from tensorflow.python.keras.backend import set_session
    
    sess = tf.Session()
    set_session(sess)
    model = myModel(weights="imagenet")
    graph = tf.get_default_graph()
    
    class Server(threading.Thread):
        def __init__(self):
            self._stop = threading.Event()
            threading.Thread.__init__(self)
    
        def stop(self):
            self._stop.set()
    
        def stopped(self):
            return self._stop.isSet()
    
        def run(self):
            context = zmq.Context()
            frontend = context.socket(zmq.ROUTER)
            frontend.bind('tcp://*:5576')
    
            backend = context.socket(zmq.DEALER)
            backend.bind('inproc://backend_endpoint')
    
            poll = zmq.Poller()
            poll.register(frontend, zmq.POLLIN)
            poll.register(backend,  zmq.POLLIN)
    
            while not self.stopped():
                sockets = dict(poll.poll())
                if frontend in sockets:
                    if sockets[frontend] == zmq.POLLIN:
                        _id = frontend.recv()
                        json_msg = frontend.recv_json()
    
                        handler = RequestHandler(context, _id, json_msg)
                        handler.start()
    
                if backend in sockets:
                    if sockets[backend] == zmq.POLLIN:
                        _id = backend.recv()
                        msg = backend.recv()
                        frontend.send(_id, zmq.SNDMORE)
                        frontend.send(msg)
    
            frontend.close()
            backend.close()
            context.term()
    
    
    class RequestHandler(threading.Thread):
        def __init__(self, context, id, msg):
    
            """
            RequestHandler
            :param context: ZeroMQ context
            :param id: Requires the identity frame to include in the reply so that it will be properly routed
            :param msg: Message payload for the worker to process
            """
            threading.Thread.__init__(self)
            print("--------------------Entered requesthandler--------------------")
            self.context = context
            self.msg = msg
            self._id = id
    
    
        def process(self, obj):
    
            imgstr = obj['payload']
    
            img = Image.open(BytesIO(b64decode(imgstr)))
    
            if img.mode != "RGB":
                img = img.convert("RGB")
            # resize the input image and preprocess it
            img = img.resize((224,224))
            img = image.img_to_array(img)
            img = np.expand_dims(img, axis=0)
            img = preprocess_input(img)
    
            with graph.as_default():
                set_session(sess)
                predictions = model.predict(img)
    
            predictions = decode_predictions(predictions, top=3)[0]
            print("Predictions from class_model_server.py:",predictions)
    
            pred_strings = []
            for _,pred_class,pred_prob in predictions:
                pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
            preds = ", ".join(pred_strings)
    
            return_dict = {}
            return_dict["preds"] = preds
            return return_dict
    
        def run(self):
            # Worker will process the task and then send the reply back to the DEALER backend socket via inproc
            worker = self.context.socket(zmq.DEALER)
            worker.connect('inproc://backend_endpoint')
            print('Request handler started to process %s\n' % self.msg)
    
            # Simulate a long-running operation
            output = self.process(self.msg)
    
            worker.send(self._id, zmq.SNDMORE)
            worker.send_json(output)
            del self.msg
    
            print('Request handler quitting.\n')
            worker.close()
    
    def main():
        # Start the server that will handle incoming requests
        server = Server()
        server.start()
    
    if __name__ == '__main__':
        main()
    
  • Press esc key, then :wq, and then Enter.

Now let us understand the code:

  • Here we are importing the necessary packages.

    from io import BytesIO
    from PIL import Image
    import threading
    import zmq
    from base64 import b64decode
    import numpy as np
    
    import tensorflow as tf
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.applications.resnet50 import ResNet50 as myModel
    from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
    from tensorflow.python.keras.backend import set_session
    
  • Then we are setting a TensorFlow session and instantiating the resnet model myModel that we have imported above.

    sess = tf.Session()
    set_session(sess)
    model = myModel(weights="imagenet")
    graph = tf.get_default_graph()
    

    In the above code snippet, by writing

    sess = tf.Session()
    set_session(sess)
    

    we are creating and setting a TensorFlow session. We are doing so because we will be using thread concept in the Server(which we will be defining shortly), and we introduce asynchronicity by using threads, since threads are designed to support multi-tasking. So the Server thread keeps listening to a port and directs the requests to the RequestHandler, and the RequestHandler handles the requests and returns the results.We set a session to make sure each thread has the same model graph since we always want to use the same pre-trained model. And whenever we want to use the graph in any thread, we need to write

    with graph.as_default():
        set_session(sess)
    

    which we will be writing in the RequestHandler, while predicting the class of the image.

    Nextly, by writing model = myModel(weights="imagenet") we are instantiating the pre-trained resnet50 model (which we previously imported as myModel) with the imagenet weights.

    And by using graph = tf.get_default_graph() we are getting the default graph of the program.

  • The main() method will be called and the Server will be started by using the following code:

    def main():
        # Start the server that will handle incoming requests
        server = Server()
        server.start()
    
    if __name__ == '__main__':
        main()
    
  • Now its time for us to understand the Server class:

    class Server(threading.Thread):
        def __init__(self):
            self._stop = threading.Event()
            threading.Thread.__init__(self)
    
        def stop(self):
            self._stop.set()
    
        def stopped(self):
            return self._stop.isSet()
    
        def run(self):
            context = zmq.Context()
            frontend = context.socket(zmq.ROUTER)
            frontend.bind('tcp://*:5576')
    
            backend = context.socket(zmq.DEALER)
            backend.bind('inproc://backend_endpoint')
    
            poll = zmq.Poller()
            poll.register(frontend, zmq.POLLIN)
            poll.register(backend,  zmq.POLLIN)
    
            while not self.stopped():
                sockets = dict(poll.poll())
                if frontend in sockets:
                    if sockets[frontend] == zmq.POLLIN:
                        _id = frontend.recv()
                        json_msg = frontend.recv_json()
    
                        handler = RequestHandler(context, _id, json_msg)
    
                        handler.start()
    
                if backend in sockets:
                    if sockets[backend] == zmq.POLLIN:
                        _id = backend.recv()
                        msg = backend.recv()
                        frontend.send(_id, zmq.SNDMORE)
                        frontend.send(msg)
    
            frontend.close()
            backend.close()
            context.term()
    

    We are inheriting the threading.Thread into the Server class, so that we could use and modify the functionality of the threads in our Server as per our requirement.

    The run() method of threading represents the thread’s activity. So we define the activity to be handled by our Server here.

    As discussed earlier, the task of Server is to listen the port 5576 and direct the data in the request to the RequestHandler class. This is done as follows:

    (1) context = zmq.Context(): we are creating the zmq context.

    (2) frontend = context.socket(zmq.ROUTER): we are creating the zmq socket of type zmq.ROUTER.

    (3) frontend.bind('tcp://*:5576'): Any server binds to a port and communicates through it. Similarly, we are creating our own server Server which communicates through port 5576. ZMQ has frontend and backend. The frontend is exposed out for us to make our server bind to it with a port number 5576. The backend is something that is in-process, meaning the ZMQ maintains parallel computing and data distribution to parallel processes with the help of the backend(here the backend shall be zmq.DEALER). If the frontend is zmq.ROUTER, then the backend should be zmq.DEALER. So here, we are binding the frontend to port 5576 and the zmq.ROUTER gets the requests with the encoded image data.

    (4) Similarly, using

    backend = context.socket(zmq.DEALER)
    backend.bind('inproc://backend_endpoint')
    

    we are creating the ZMQ backend of type zmq.DEALER(since the frontend is zmq.ROUTER).

    (5) Next,

    poll = zmq.Poller()
    poll.register(frontend, zmq.POLLIN)
    poll.register(backend,  zmq.POLLIN)
    

    We create a poller using poll = zmq.Poller(), and we register the frontend and the backed with the poll. We are polling to check if there are any awaiting messages(requests) to be transmitted. zmq.POLLIN means the data to be read.

    (6) Next, we enter a loop while not self.stopped(): that continuously runs till the server is stopped. Upon entering the loop,

    • we get the sockets(frontend and backend), and check if the frontend has any data to be read by using if sockets[frontend] == zmq.POLLIN:.
    • If the frontend has any data to read, we receive that data using _id = frontend.recv() to receive the message frame. Also, we get the JSON object of the request data using json_msg = frontend.recv_json(). The JSON object contains the encoded image with the key payload. So remember, by the term 'payload', we refer to the encoded image received by the ROUTER from the client.
    • Then, we shall instantiate the RequestHandler using handler = RequestHandler(context, _id, json_msg) and start it using handler.start(). Observe that we are sending the context, _id, json_msg as input parameters to initialize the class RequestHandler. This RequestHandler sees that the predictions are sent to the client in JSON form.
  • Now that we created an instance of RequestHandler thread, let us understand what is happening with the same.

    class RequestHandler(threading.Thread):
        def __init__(self, context, id, msg):
    
            """
            RequestHandler
            :param context: ZeroMQ context
            :param id: Requires the identity frame to include in the reply so that it will be properly routed
            :param msg: Message payload for the worker to process
            """
            threading.Thread.__init__(self)
            print("--------------------Entered requesthandler--------------------")
            self.context = context
            self.msg = msg
            self._id = id
    
    
        def process(self, obj):
    
            imgstr = obj['payload']        
            img = Image.open(BytesIO(b64decode(imgstr)))
    
            if img.mode != "RGB":
                img = img.convert("RGB")
            # resize the input image and preprocess it
            img = img.resize((224,224))
            img = image.img_to_array(img)
            img = np.expand_dims(img, axis=0)
            img = preprocess_input(img)
    
            with graph.as_default():
                set_session(sess)
                predictions = model.predict(img)
    
            predictions = decode_predictions(predictions, top=3)[0]
            print("Predictions from class_model_server.py:",predictions)
    
            pred_strings = []
            for _,pred_class,pred_prob in predictions:
                pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
            preds = ", ".join(pred_strings)
    
            return_dict = {}
            return_dict["preds"] = preds
            return return_dict
    
        def run(self):
            # Worker will process the task and then send the reply back to the DEALER backend socket via inproc
            worker = self.context.socket(zmq.DEALER)
            worker.connect('inproc://backend_endpoint')
            print('Request handler started to process %s\n' % self.msg)
    
            # Simulate a long-running operation
            output = self.process(self.msg)
    
            worker.send(self._id, zmq.SNDMORE)
            worker.send_json(output)
            del self.msg
    
            print('Request handler quitting.\n')
            worker.close()
    

    (1) Similarly, we initialize the thread, and store the context, msg and _id as follows:

            threading.Thread.__init__(self)
            print("--------------------Entered requesthandler--------------------")
            self.context = context
            self.msg = msg
            self._id = id
    

    (2) Upon starting the RequestHandler from the Server using handle.start(), the run() method of the RequestHandler will be internally invoked.

    (3) If you remember, we created a context with frontend and backend. The backend socket is of type zmq.DEALER and is bound it to the address inproc://backend_endpoint. Now we create a worker socket of type zmq.DEALER, and connect it with the backend socket through its endpoint 'inproc://backend_endpoint'. The following code enables us to do so.

            worker = self.context.socket(zmq.DEALER)
            worker.connect('inproc://backend_endpoint')
    

    (4) We then call the process() method and store the return value of the method in output using output = self.process(self.msg). Inside the process() method, we are doing the following:

    • We receive the 'payload', that is the encoded image using the code imgstr = obj['payload'].

    • We then convert back the encoded image from the base64 type into the usual image format using img = Image.open(BytesIO(b64decode(imgstr))). Now img is the usual image form.

    • Then, the following enables the conversion of img from any other format to the standard 'RGB' format.

      if img.mode != "RGB":
          img = img.convert("RGB")
      
    • Using the following code, we pre-process the image img to make it compatible to be fed to the resnet50 model we are going to use.

      img = img.resize((224,224))
      img = image.img_to_array(img)
      img = np.expand_dims(img, axis=0)
      img = preprocess_input(img)
      

      Do you recall that we have already discussed this in the previous project?

    • Now we are going to set the default graph and session as follows, and get the predictions of the img by feeding it to the model as follows:

      with graph.as_default():
          set_session(sess)
          predictions = model.predict(img)
      

      If you remember, we have already discussed that we need to do as above whenever we want to use the same graph while using the threads. Since RequestHandler is a subclass of thread, we are doing this.

    • Then, we are decoding the predictions and formatting the results in a nicer way such that we make string of the top 3 predictions along with their probabilities by the network, as already discussed in the previous project.

      pred_strings = []
      for _,pred_class,pred_prob in predictions:
          pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
      preds = ", ".join(pred_strings)
      
    • Then, we create a dictionaryreturn_dict with the key 'preds' and the value preds(the results obtained above):

      return_dict = {}
      return_dict["preds"] = preds
      return return_dict
      
    • Now this return result will be stored in outputs of the run() method of RequestHandler, since we are doing output = self.process(self.msg).

    (5) Going back to run() method of the RequestHandler,

    worker.send(self._id, zmq.SNDMORE)
    worker.send_json(output)
    del self.msg
    
    print('Request handler quitting.\n')
    worker.close()
    
    • The client id along with a signal called zmq.SNDMORE is sent by the worker to the backend socket. The zmq.SNDMORE signal cautions the backend that there are still some messages to be transmitted.

    • Next, the output is sent in the JSON form to the backend from the worer.

    • Then the self.msg is deleted since it is no more needed.

    • Then, we close the worker.

    (6) Going back to the Server:

        if backend in sockets:
            if sockets[backend] == zmq.POLLIN:
                _id = backend.recv()
                msg = backend.recv()
                frontend.send(_id, zmq.SNDMORE)
                frontend.send(msg)
    

    The backend keeps waiting to read any incoming data. Once there is some data for it to read(which is checked by if sockets[backend] == zmq.POLLIN:),

    • It receives the client id and message(the resultant predictions which are sent by the worker to the backend).

    • This message(predictions) is sent to the client by the frontend.

Note- If you face Unable to open file error while loading the model, refer to Input/Output Error(Error no. 5).

Get Hint See Answer


Note - Having trouble with the assessment engine? Follow the steps listed here

Loading comments...