티스토리 뷰

python

keras model predict 를 API로 제공해보자

주먹불끈 2018. 6. 19. 15:37


링크

 

- 원문 (기본기능): https://blog.keras.io/building-a-simple-keras-deep-learning-rest-api.html

- 더볼것 (확장성 고려): https://www.pyimagesearch.com/2018/01/29/scalable-keras-deep-learning-rest-api/

 

배울것

 

- 케라스모델을 메모리에 올리고 예측에 이용하기

- Flask 프레임워크로 우리의 API 위한 endpoint 만들기

- 예측을 하고, JSON 으로 만들어서 결과를 client 에게 보내게 구현

- 우리가 만든 Keras REST API 호출해보기

 

환경설정

 

- windows 7 에서 설치함

 

tensorflow 설치

 

- python version 문제등이 있어서 anaconda 통해 tensorflow 설치함.

- GPU 버전 설치는 번거로워서 CPU 버전 사용

- 링크: https://www.tensorflow.org/install/install_windows

1) 링크대로 anaconda 설치하고, conda 명령이 안먹힐 경우 PATH 추가: C:\Users\nova010\Anaconda3\Scripts;

2) python 3.5 버전으로 conda 가상환경 생성하고, 가상환경 진입

conda create -n tensorflow pip python=3.5

activate tensorflow

3) tensorflow 설치

pip install --ignore-installed --upgrade tensorflow

 

keras 기타 패키지 설치

 

pip install keras

pip install flask gevent requests pillow

 

소스코드

 

- 전체코드: https://github.com/jrosebr1/simple-keras-rest-api

- 3개의 함수가 있다

 

1) load_model: train 시킨 모델을 로드해서 inference 준비

2) prepare_image: 이미지가 입력되면 그걸 model 입력로 넣을 있게 처리

3) predict: API endpoint. request 받아 response 돌려줄거다.

 


 

from keras.applications import ResNet50
from keras_preprocessing.image import img_to_array
from keras_applications import imagenet_utils
from PIL import Image
import numpy as np
import flask
import io
import tensorflow as tf
 
app = flask.Flask(__name__)
model = None
 
def load_model():
    """
    load pretrained model - ResNet50
    """
    global model
    model = ResNet50(weights="imagenet");
    global graph
    graph = tf.get_default_graph()
 
def prepare_image(image, target):
    # image mode should be "RGB"
    if image.mode != "RGB":
        image = image.convert("RGB");
 
    # resize for model 
    image = image.resize(target)
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)
    image = imagenet_utils.preprocess_input(image)
 
    # return it
    return image
 
@app.route("/predict", methods=["POST"])
def predict():
    data = {"success": False};
 
    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))
 
            image = prepare_image(image, target=(224,224))
 
            with graph.as_default():
 
                preds = model.predict(image)
                results = imagenet_utils.decode_predictions(preds)
                data["predictions"] = []
                
                for (imagenetID, label, prob) in results[0]:
                    r = {"label": label, "probability": float(prob)}
                    data["predictions"].append(r)
 
                data["success"] = True
    
    return flask.jsonify(data)
 
if __name__ == "__main__":
    print(("* Loading Keras model and Flask staring server..."
        "peases wait until server has fully started"))
    load_model()

    app.run()

테스트

 

원문과 다르게 POSTMAN 이용하여 테스트함

1) python run_keras_server.py 서버 실행

2) POSTMAN 으로 image 보내기

 


 

 

Trouble shooting

 

Error:

 

raise ValueError("Tensor %s is not an element of this graph." % obj)

ValueError: Tensor Tensor("fc1000/Softmax:0", shape=(?, 1000), dtype=float32) is not an element of this graph.

Solution:

 

- 링크: https://github.com/jrosebr1/simple-keras-rest-api/issues/1

- graph 정의하고, model.predict 사용시 graph.as_default() 내에서 실행되도록

def load_model():

global model

model = ResNet50(weights="imagenet")

global graph

graph = tf.get_default_graph()

 

with graph.as_default():

preds = model.predict(image)

#... etc

반응형
반응형
잡학툰 뱃지
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
글 보관함