Project- How to build low-latency deep-learning-based flask app

15 / 17

Testing with dummy test_client.py

Now that we are ready with the server-side code, let us test its working using a dummy test_client by:

(1) establishing a connection between the server and the client

(2) sending a test image from client to server

(3) getting the results from the server.

We shall check if the performance improved, now that we introduced zmq.

If this is successful, let us proceed to make the necessary changes to the app.py in the flask server to make it faster.

INSTRUCTIONS
  • Make sure to be in /Flask-ZMQ-App-Folder/Model-Server-Folder/:

    cd ~/Flask-ZMQ-App-Folder/Model-Server-Folder/
    
  • Create a new file called test_client.py using

    vi test_client.py
    
  • Then, press i key and copy-paste the following code:

    import base64
    import uuid
    import zmq
    
    def test_zmq_embdserver(image_file_name):
        _rid = "{}".format(str(uuid.uuid4()))
    
        global img_str
    
        with open(image_file_name, "rb") as image_file:
            img_str = base64.b64encode(image_file.read())
    
        context = zmq.Context()
        socket = context.socket(zmq.DEALER)
        socket.setsockopt(zmq.IDENTITY, _rid)
        socket.connect('tcp://localhost:5576')
        print('Client %s started\n' % _rid)
        poll = zmq.Poller()
        poll.register(socket, zmq.POLLIN)
    
        obj = socket.send_json({"payload": img_str, "_rid": _rid})
    
    
        received_reply = False
        while not received_reply:
            sockets = dict(poll.poll(1000))
            if socket in sockets:
                if sockets[socket] == zmq.POLLIN:
                    msg = socket.recv_json()
                    preds = msg['preds']
                    print(preds)
                    del msg
                    received_reply = True
    
        socket.close()
        context.term()
    
    if __name__ == "__main__":
        name = '/cxldata/projects/image-class/dog.png'
        test_zmq_embdserver(name)
    
  • Press esc, then :wq and hit Enter.

  • Run the resnet_model_server.py

    python resnet_model_server.py
    
  • Now in another console, run the program using the below command to get the time of execution:

    cd ~/Flask-ZMQ-App-Folder/Model-Server-Folder/
    
    source model-env/bin/activate
    
    time python test_client.py
    

    Run time python test_client.py multiple times and note the execution time for each run.

    We could observe that the time of execution has drastically dropped to around 0.3 seconds. Previously, it was around 10-12 seconds. Thus, this method improved the performance by at least 30 times.

Now let us understand the code:

(a) We import the following modules:

import base64
import uuid
import zmq

We import base64 to encode the image. uuid is used to generate a unique, which will be sent to server along with the encoded image from the client. It is through this unique id that the server keeps track of client. Next, we import zmq to establish the connection between the client socket and the server socket.

(b) We provide a path for the image we want to test. Here we are sending dog.png which is in the /cxldata/projects/image-class folder. Then, we call the function test_zmq_embdserver where we define the further process of connecting to the server, sending the data, and getting the results.

if __name__ == "__main__":
    name = '/cxldata/projects/image-class/dog.png'
    test_zmq_embdserver(name)

(c) We define the test_zmq_embdserver function:

    def test_zmq_embdserver(image_file_name):
        _rid = "{}".format(str(uuid.uuid4()))

        global img_str

        with open(image_file_name, "rb") as image_file:
            img_str = base64.b64encode(image_file.read())

        context = zmq.Context()
        socket = context.socket(zmq.DEALER)
        socket.setsockopt(zmq.IDENTITY, _rid)
        socket.connect('tcp://localhost:5576')
        print('Client %s started\n' % _rid)
        poll = zmq.Poller()
        poll.register(socket, zmq.POLLIN)

        obj = socket.send_json({"payload": img_str, "_rid": _rid})


        received_reply = False
        while not received_reply:
            sockets = dict(poll.poll(1000))
            if socket in sockets:
                if sockets[socket] == zmq.POLLIN:
                    msg = socket.recv_json()
                    preds = msg['preds']
                    print(preds)
                    del msg
                    received_reply = True

        socket.close()
        context.term()
  • We are generating the unique id using _rid = "{}".format(str(uuid.uuid4())).

  • Then, we set the variable img_str as global, and get the encoded form of the image as follows:

        global img_str
        with open(image_file_name, "rb") as image_file:
            img_str = base64.b64encode(image_file.read())
    

    We set img_str as global because the img_str is generated within the context of opening the file. But since we want to use it even outside that, we set it as global.

    We open the image file in "rb" which is the read-binary mode, because the base64.b64encode expects it to be in binary form.

  • Next we try to establish the connection between the client and server as follows:

        context = zmq.Context()
        socket = context.socket(zmq.DEALER)
        socket.setsockopt(zmq.IDENTITY, _rid)
        socket.connect('tcp://localhost:5576')
    

    We create a zmq context, and create a socket of type zmq.DEALER.

    zmq.DEALER has a property called zmq.IDENTITY to which we assign the unique id using socket.setsockopt(zmq.IDENTITY, _rid). As discussed, it is using this id that the ROUTER tracks the client upon receiving a request from the client.

  • Next, we connect this socket to the port 5576, the port to which the server listens to. This is done using socket.connect('tcp://localhost:5576').

  • Then we create a poll object:

        poll = zmq.Poller()
        poll.register(socket, zmq.POLLIN)
    

    We register the socket with the poll to check for any incoming messages.

  • Then, we send the payload and id in the JSON format from the socket to the server using obj = socket.send_json({"payload": img_str, "_rid": _rid}).

  • We initially set received_reply = False and loop to keep polling for any incoming messages. We expect the incoming messages to be the response from the server to our request.

         received_reply = False
        while not received_reply:
            sockets = dict(poll.poll(1000))
            if socket in sockets:
    
  • Once there is some incoming data to be read, we receive the msg JSON object and store the predictions in a variable called preds.

        if sockets[socket] == zmq.POLLIN:
            msg = socket.recv_json()
            preds = msg['preds']
    
  • Then, we delete the msg and turn the received_reply = True. Once the received_reply = True, the loop breaks, and the socket and context will be terminated.


No hints are availble for this assesment

Answer is not availble for this assesment


Note - Having trouble with the assessment engine? Follow the steps listed here

Loading comments...