Login using Social Account
     Continue with GoogleLogin using your credentials
As discussed previosuly, we shall import the model, define the Server
class and the RequestHandler
class in the resnet_model_server.py
.
Note:
vi
is used to:
If we want to create and/or edit the file named myFile, we use vi myFile
.
To exit from the file, we click esc
key and type :wq
and hit on enter
key. In :wq
, w
means save, and q
means quit.
A thread is used for multi-tasking. threading
is a standard library in Python using which we can make use of Thread
in subclassing it in our custom classes. Here, we will be inheriting the Thread
in Server
and RequestHandler
classes. When we use Thread
subclassing, we need to know about the following important methods of the Thread
:
start()
: It is the method used to start the thread. This method internally calls the run()
method of the thread. run()
: This method represents the thread's functionality. So, we define the functionality of the thread by overriding this run()
method. For example, since the task of the Server
is to (1) listen to a port, (2) direct requests to RequestHandler
(3) terminate if there is an interruption, we define all these functionalities in the run()
method of Server
class which inherits the threading.Thread
class. Similarly, the functionality of handling the request to return predictions is defined in the run()
method of the RequestHandler
class.We also have something called threading.Event()
.
An event manages a flag that can be set to true with the set() method. Here we are instantiating event by writing self._stop = threading.Event()
. When we click Ctrl+Z
on our keyboard, that event invokes self._stop.set()
to set an internal flag to True. Thus, self._stop.isSet()
returns True
subsequently.
This file resnet_model_server.py
should be created inside /Flask-ZMQ-App-Folder/Model-Server-Folder
. So make sure you are in the /Flask-ZMQ-App-Folder/Model-Server-Folder
. You could check your present working directory using the command:
pwd
This command should output the path:
/home/$USER/Flask-ZMQ-App-Folder/Model-Server-Folder
If the path displayed is not the same as the above, switch to the Image-Classification-App
using
cd ~/Flask-ZMQ-App-Folder/Model-Server-Folder
Create the file named resnet_model_server.py
using the vi
command inside the Flask-ZMQ-App-Folder/Model-Server-Folder
directory.
Press i
key on your keyboard, to switch to insert mode in the file.
Copy-paste the following code:
from io import BytesIO
from PIL import Image
import threading
import zmq
from base64 import b64decode
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50 as myModel
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.python.keras.backend import set_session
sess = tf.Session()
set_session(sess)
model = myModel(weights="imagenet")
graph = tf.get_default_graph()
class Server(threading.Thread):
def __init__(self):
self._stop = threading.Event()
threading.Thread.__init__(self)
def stop(self):
self._stop.set()
def stopped(self):
return self._stop.isSet()
def run(self):
context = zmq.Context()
frontend = context.socket(zmq.ROUTER)
frontend.bind('tcp://*:5576')
backend = context.socket(zmq.DEALER)
backend.bind('inproc://backend_endpoint')
poll = zmq.Poller()
poll.register(frontend, zmq.POLLIN)
poll.register(backend, zmq.POLLIN)
while not self.stopped():
sockets = dict(poll.poll())
if frontend in sockets:
if sockets[frontend] == zmq.POLLIN:
_id = frontend.recv()
json_msg = frontend.recv_json()
handler = RequestHandler(context, _id, json_msg)
handler.start()
if backend in sockets:
if sockets[backend] == zmq.POLLIN:
_id = backend.recv()
msg = backend.recv()
frontend.send(_id, zmq.SNDMORE)
frontend.send(msg)
frontend.close()
backend.close()
context.term()
class RequestHandler(threading.Thread):
def __init__(self, context, id, msg):
"""
RequestHandler
:param context: ZeroMQ context
:param id: Requires the identity frame to include in the reply so that it will be properly routed
:param msg: Message payload for the worker to process
"""
threading.Thread.__init__(self)
print("--------------------Entered requesthandler--------------------")
self.context = context
self.msg = msg
self._id = id
def process(self, obj):
imgstr = obj['payload']
img = Image.open(BytesIO(b64decode(imgstr)))
if img.mode != "RGB":
img = img.convert("RGB")
# resize the input image and preprocess it
img = img.resize((224,224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
with graph.as_default():
set_session(sess)
predictions = model.predict(img)
predictions = decode_predictions(predictions, top=3)[0]
print("Predictions from class_model_server.py:",predictions)
pred_strings = []
for _,pred_class,pred_prob in predictions:
pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
preds = ", ".join(pred_strings)
return_dict = {}
return_dict["preds"] = preds
return return_dict
def run(self):
# Worker will process the task and then send the reply back to the DEALER backend socket via inproc
worker = self.context.socket(zmq.DEALER)
worker.connect('inproc://backend_endpoint')
print('Request handler started to process %s\n' % self.msg)
# Simulate a long-running operation
output = self.process(self.msg)
worker.send(self._id, zmq.SNDMORE)
worker.send_json(output)
del self.msg
print('Request handler quitting.\n')
worker.close()
def main():
# Start the server that will handle incoming requests
server = Server()
server.start()
if __name__ == '__main__':
main()
Press esc
key, then :wq
, and then Enter
.
Now let us understand the code:
Here we are importing the necessary packages.
from io import BytesIO
from PIL import Image
import threading
import zmq
from base64 import b64decode
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50 as myModel
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.python.keras.backend import set_session
Then we are setting a TensorFlow session and instantiating the resnet model myModel
that we have imported above.
sess = tf.Session()
set_session(sess)
model = myModel(weights="imagenet")
graph = tf.get_default_graph()
In the above code snippet, by writing
sess = tf.Session()
set_session(sess)
we are creating and setting a TensorFlow session. We are doing so because we will be using thread concept in the Server
(which we will be defining shortly), and we introduce asynchronicity by using threads, since threads are designed to support multi-tasking. So the Server
thread keeps listening to a port and directs the requests to the RequestHandler
, and the RequestHandler
handles the requests and returns the results.We set a session to make sure each thread has the same model graph since we always want to use the same pre-trained model. And whenever we want to use the graph in any thread, we need to write
with graph.as_default():
set_session(sess)
which we will be writing in the RequestHandler
, while predicting the class of the image.
Nextly, by writing model = myModel(weights="imagenet")
we are instantiating the pre-trained resnet50
model (which we previously imported as myModel
) with the imagenet
weights.
And by using graph = tf.get_default_graph()
we are getting the default graph of the program.
The main()
method will be called and the Server
will be started by using the following code:
def main():
# Start the server that will handle incoming requests
server = Server()
server.start()
if __name__ == '__main__':
main()
Now its time for us to understand the Server
class:
class Server(threading.Thread):
def __init__(self):
self._stop = threading.Event()
threading.Thread.__init__(self)
def stop(self):
self._stop.set()
def stopped(self):
return self._stop.isSet()
def run(self):
context = zmq.Context()
frontend = context.socket(zmq.ROUTER)
frontend.bind('tcp://*:5576')
backend = context.socket(zmq.DEALER)
backend.bind('inproc://backend_endpoint')
poll = zmq.Poller()
poll.register(frontend, zmq.POLLIN)
poll.register(backend, zmq.POLLIN)
while not self.stopped():
sockets = dict(poll.poll())
if frontend in sockets:
if sockets[frontend] == zmq.POLLIN:
_id = frontend.recv()
json_msg = frontend.recv_json()
handler = RequestHandler(context, _id, json_msg)
handler.start()
if backend in sockets:
if sockets[backend] == zmq.POLLIN:
_id = backend.recv()
msg = backend.recv()
frontend.send(_id, zmq.SNDMORE)
frontend.send(msg)
frontend.close()
backend.close()
context.term()
We are inheriting the threading.Thread
into the Server
class, so that we could use and modify the functionality of the threads in our Server
as per our requirement.
The run()
method of threading
represents the thread’s activity. So we define the activity to be handled by our Server
here.
As discussed earlier, the task of Server
is to listen the port 5576
and direct the data in the request to the RequestHandler
class. This is done as follows:
(1) context = zmq.Context()
: we are creating the zmq context.
(2) frontend = context.socket(zmq.ROUTER)
: we are creating the zmq socket of type zmq.ROUTER
.
(3) frontend.bind('tcp://*:5576')
: Any server binds to a port and communicates through it. Similarly, we are creating our own server Server
which communicates through port 5576
. ZMQ has frontend and backend. The frontend is exposed out for us to make our server bind to it with a port number 5576
. The backend is something that is in-process, meaning the ZMQ maintains parallel computing and data distribution to parallel processes with the help of the backend(here the backend shall be zmq.DEALER). If the frontend is zmq.ROUTER
, then the backend should be zmq.DEALER
. So here, we are binding the frontend to port 5576
and the zmq.ROUTER gets the requests with the encoded image data.
(4) Similarly, using
backend = context.socket(zmq.DEALER)
backend.bind('inproc://backend_endpoint')
we are creating the ZMQ backend of type zmq.DEALER
(since the frontend is zmq.ROUTER).
(5) Next,
poll = zmq.Poller()
poll.register(frontend, zmq.POLLIN)
poll.register(backend, zmq.POLLIN)
We create a poller using poll = zmq.Poller()
, and we register the frontend
and the backed
with the poll
. We are polling to check if there are any awaiting messages(requests) to be transmitted. zmq.POLLIN
means the data to be read.
(6) Next, we enter a loop while not self.stopped():
that continuously runs till the server is stopped. Upon entering the loop,
if sockets[frontend] == zmq.POLLIN:
. _id = frontend.recv()
to receive the message frame. Also, we get the JSON object of the request data using json_msg = frontend.recv_json()
. The JSON object contains the encoded image with the key payload
. So remember, by the term 'payload'
, we refer to the encoded image received by the ROUTER from the client.RequestHandler
using handler = RequestHandler(context, _id, json_msg)
and start it using handler.start()
. Observe that we are sending the context, _id, json_msg
as input parameters to initialize the class RequestHandler
. This RequestHandler
sees that the predictions are sent to the client in JSON form.Now that we created an instance of RequestHandler
thread, let us understand what is happening with the same.
class RequestHandler(threading.Thread):
def __init__(self, context, id, msg):
"""
RequestHandler
:param context: ZeroMQ context
:param id: Requires the identity frame to include in the reply so that it will be properly routed
:param msg: Message payload for the worker to process
"""
threading.Thread.__init__(self)
print("--------------------Entered requesthandler--------------------")
self.context = context
self.msg = msg
self._id = id
def process(self, obj):
imgstr = obj['payload']
img = Image.open(BytesIO(b64decode(imgstr)))
if img.mode != "RGB":
img = img.convert("RGB")
# resize the input image and preprocess it
img = img.resize((224,224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
with graph.as_default():
set_session(sess)
predictions = model.predict(img)
predictions = decode_predictions(predictions, top=3)[0]
print("Predictions from class_model_server.py:",predictions)
pred_strings = []
for _,pred_class,pred_prob in predictions:
pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
preds = ", ".join(pred_strings)
return_dict = {}
return_dict["preds"] = preds
return return_dict
def run(self):
# Worker will process the task and then send the reply back to the DEALER backend socket via inproc
worker = self.context.socket(zmq.DEALER)
worker.connect('inproc://backend_endpoint')
print('Request handler started to process %s\n' % self.msg)
# Simulate a long-running operation
output = self.process(self.msg)
worker.send(self._id, zmq.SNDMORE)
worker.send_json(output)
del self.msg
print('Request handler quitting.\n')
worker.close()
(1) Similarly, we initialize the thread, and store the context, msg and _id
as follows:
threading.Thread.__init__(self)
print("--------------------Entered requesthandler--------------------")
self.context = context
self.msg = msg
self._id = id
(2) Upon starting the RequestHandler
from the Server
using handle.start()
, the run()
method of the RequestHandler
will be internally invoked.
(3) If you remember, we created a context
with frontend
and backend
. The backend socket is of type zmq.DEALER and is bound it to the address inproc://backend_endpoint
. Now we create a worker socket of type zmq.DEALER, and connect it with the backend socket through its endpoint 'inproc://backend_endpoint'
. The following code enables us to do so.
worker = self.context.socket(zmq.DEALER)
worker.connect('inproc://backend_endpoint')
(4) We then call the process()
method and store the return value of the method in output
using output = self.process(self.msg)
. Inside the process()
method, we are doing the following:
We receive the 'payload'
, that is the encoded image using the code imgstr = obj['payload']
.
We then convert back the encoded image from the base64
type into the usual image format using img = Image.open(BytesIO(b64decode(imgstr)))
. Now img
is the usual image form.
Then, the following enables the conversion of img
from any other format to the standard 'RGB'
format.
if img.mode != "RGB":
img = img.convert("RGB")
Using the following code, we pre-process the image img
to make it compatible to be fed to the resnet50
model we are going to use.
img = img.resize((224,224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
Do you recall that we have already discussed this in the previous project?
Now we are going to set the default graph and session as follows, and get the predictions of the img
by feeding it to the model
as follows:
with graph.as_default():
set_session(sess)
predictions = model.predict(img)
If you remember, we have already discussed that we need to do as above whenever we want to use the same graph while using the threads. Since RequestHandler
is a subclass of thread, we are doing this.
Then, we are decoding the predictions and formatting the results in a nicer way such that we make string of the top 3 predictions along with their probabilities by the network, as already discussed in the previous project.
pred_strings = []
for _,pred_class,pred_prob in predictions:
pred_strings.append(str(pred_class).strip()+" : "+str(round(pred_prob,5)).strip())
preds = ", ".join(pred_strings)
Then, we create a dictionaryreturn_dict
with the key 'preds'
and the value preds
(the results obtained above):
return_dict = {}
return_dict["preds"] = preds
return return_dict
Now this return result will be stored in outputs
of the run()
method of RequestHandler
, since we are doing output = self.process(self.msg)
.
(5) Going back to run()
method of the RequestHandler
,
worker.send(self._id, zmq.SNDMORE)
worker.send_json(output)
del self.msg
print('Request handler quitting.\n')
worker.close()
The client id along with a signal called zmq.SNDMORE
is sent by the worker to the backend socket. The zmq.SNDMORE
signal cautions the backend that there are still some messages to be transmitted.
Next, the output
is sent in the JSON form to the backend from the worer.
Then the self.msg
is deleted since it is no more needed.
Then, we close the worker.
(6) Going back to the Server
:
if backend in sockets:
if sockets[backend] == zmq.POLLIN:
_id = backend.recv()
msg = backend.recv()
frontend.send(_id, zmq.SNDMORE)
frontend.send(msg)
The backend keeps waiting to read any incoming data. Once there is some data for it to read(which is checked by if sockets[backend] == zmq.POLLIN:
),
It receives the client id and message(the resultant predictions which are sent by the worker to the backend).
This message(predictions) is sent to the client by the frontend
.
Note- If you face Unable to open file
error while loading the model, refer to Input/Output Error(Error no. 5).
Taking you to the next exercise in seconds...
Want to create exercises like this yourself? Click here.
Note - Having trouble with the assessment engine? Follow the steps listed here
Loading comments...