Learn how to deploy your deep learning model as a REST API for your customers.
In this short article, you will learn how to deploy a deep learning model as a REST API with Flask RESTful.
Ok, you've worked hard tuning your model for best performance. Congratulations! It is now time to get it out of your notebooks and to bring it to the world.
If you've never done that, you might have no clue how to proceed, especially if your background is in data science rather than in software development.
REST APIs is a very common way to provide any kind of service. You can use them in the backend of a website, or even provide direct access to your APIs to your customers. Maybe you have already used REST APIs yourself to interact with services such as Google Vision or Amazon Textract.
Today you will learn how to :
This might seem ambitious, but it will actually take us less than 50 lines of code.
But beware ! I'm leaving out all the operational stuff such as security, containerized deployment, web server, etc. These points will be addressed in a future post. In the meanwhile, don't use this code as is in production.
Here is the Github repo with the code for this tutorial.
My cat capuchon, proudly modelling for the model
First, let's install the tools we need :
As usual, we will use Anaconda. First install it, and then create an environment with the necessary tools:
conda create -n dlflask python=3.7 tensorflow flask pillow
We used python 3.7 because, at the moment, more recent versions of python seem to lead to conflicts between the dependencies of the flask and tensorflow packages.
Now activate the environment:
conda activate dlflask
Finally, we install flask RESTful with pip, as it is not available in conda :
pip install flask-restful
The first thing we need is a deep learning model to integrate in our REST API.
We don't want to waste any time on this today, so we are simply going to use a pre-trained model from Keras.
I went for ResNet50, which is a high-performance classification model trained on ImageNet, a dataset with 1000 categories and 15 million images at the time of writing.
Create a python module called predict_resnet50.py
with this code :
import tensorflow.keras.applications.resnet50 as resnet50
from tensorflow.keras.preprocessing import image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # so that it runs on a mac
def predict(fname):
"""returns top 5 categories for an image.
:param fname : path to the file
"""
# ResNet50 is trained on color images with 224x224 pixels
input_shape = (224, 224, 3)
# load and resize image ----------------------
img = image.load_img(fname, target_size=input_shape[:2])
x = image.img_to_array(img)
# preprocess image ---------------------------
# make a batch
import numpy as np
x = np.expand_dims(x, axis=0)
print(x.shape)
# apply the preprocessing function of resnet50
img_array = resnet50.preprocess_input(x)
model = resnet50.ResNet50(weights='imagenet',
input_shape=input_shape)
preds = model.predict(x)
return resnet50.decode_predictions(preds)
if __name__ == '__main__':
import pprint
import sys
file_name = sys.argv[1]
results = predict(file_name)
pprint.pprint(results)
Make sure to read the comments to understand what the script is doing.
Before going further, you should check that the script works (I'm using the image of my cat, but you can use any image you want:)
python predict_resnet50.py capuchon.jpg
You should get something like:
[[('n02123159', 'tiger_cat', 0.58581424),
('n02124075', 'Egyptian_cat', 0.21068987),
('n02123045', 'tabby', 0.14554422),
('n03938244', 'pillow', 0.008319859),
('n02127052', 'lynx', 0.006789663)]]
The predictions look pretty good! this is indeed a tiger cat, and the next two categories are also tiger cats. Then comes "pillow" albeit with a much smaller probability. This is not surprising: the cat is on pillow.
We see that our script is working, so let's get started with flask.
In my opinion, there are two notable web frameworks for python :
One could build a REST API with Flask rather easily, but it's even faster with its Flask RESTful extension.
Let's start with a simple "Hello World" example.
Create a python module called rest_api_hello.py
with this code:
from flask import Flask
from flask_restful import Resource, Api
app = Flask(__name__)
app.logger.setLevel('INFO')
api = Api(app)
class Hello(Resource):
def get(self):
return {'hello': 'world'}
api.add_resource(Hello, '/hello')
if __name__ == '__main__':
app.run(debug=True)
Then start this app on the flask debug server:
python rest_api_hello.py
Now, you can send a request to the server with curl:
curl localhost:5000/hello
Which gives:
{
"hello": "world"
}
Alternatively, you can point your browser to http://localhost:5000/hello, and you will get the same thing.
See ? that's quite easy.
If you want to understand in more details what's going on, just take one hour to follow the tutorials for Flask and Flask RESTful.
But for now, let's plug the model into our API.
Let's create another flask app to classify images.
To do this, create a python module called rest_api_predict.py
with this code:
from flask import Flask
from flask_restful import Resource, Api, reqparse
from werkzeug.datastructures import FileStorage
from predict_resnet50 import predict
import tempfile
app = Flask(__name__)
app.logger.setLevel('INFO')
api = Api(app)
parser = reqparse.RequestParser()
parser.add_argument('file',
type=FileStorage,
location='files',
required=True,
help='provide a file')
class Image(Resource):
def post(self):
args = parser.parse_args()
the_file = args['file']
# save a temporary copy of the file
ofile, ofname = tempfile.mkstemp()
the_file.save(ofname)
# predict
results = predict(ofname)[0]
# formatting the results as a JSON-serializable structure:
output = {'top_categories': []}
for _, categ, score in results:
output['top_categories'].append((categ, float(score)))
return output
api.add_resource(Image, '/image')
if __name__ == '__main__':
app.run(debug=True)
The main differences are that :
/image
endpoint, able to receive images.Start the app server:
python rest_api_predict.py
And send a request with an image (note the @):
curl localhost:5000/image -F file=@capuchon.jpg
This should give:
{
"top_categories": [
[
"tiger_cat",
0.5858142375946045
],
[
"Egyptian_cat",
0.21068987250328064
],
[
"tabby",
0.14554421603679657
],
[
"pillow",
0.008319859392940998
],
[
"lynx",
0.006789662875235081
]
]
}
In this post, you have learned how to :
Next time, we will see how to serve the web app with a proper web server, protected behind a reverse proxy, in https, and with user authentication. Until you know how to do this, don't use this code in production.
Please let me know what you think in the comments! I’ll try and answer all questions.
And if you liked this article, you can subscribe to my mailing list to be notified of new posts (no more than one mail per week I promise.)
You can join my mailing list for new posts and exclusive content: