server2.py 3.51 KB
import os
import cv2
import grpc
import numpy as np
import tensorflow as tf
from tensorflow_serving.apis import prediction_service_pb2_grpc, predict_pb2

from sanic import Sanic
from sanic.response import json

from classification import classifier
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

app = Sanic("async_test")

# TODO 从配置文件读取
tf_serving_settings = {
    'servers': {
        'server_1': {
            'host': 'localhost',
            'port': '8500',
            'options': [
                ('grpc.max_send_message_length', 1000 * 1024 * 1024),
                ('grpc.max_receive_message_length', 1000 * 1024 * 1024),
            ],
        }
    },
}
app.config.update(tf_serving_settings)


# 同步写法02
# @app.post("/sync_classification")
# async def sync_handler(request):
#     image = request.files.get("image")
#     img_array = np.frombuffer(image.body, np.uint8)
#     image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
#     input_images = classifier.preprocess_input(image)
# 
#     # See prediction_service.proto for gRPC request/response details.
#     request = predict_pb2.PredictRequest()
#     request.model_spec.name = classifier.model_name
#     request.model_spec.signature_name = classifier.signature_name
#     stub = getattr(app, classifier.server_name)
# 
#     res_list = []
#     for _ in range(5):
#         request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images))
#         result = stub.Predict(request, timeout=100.0)  # 100 secs timeout
#         outputs = tf.make_ndarray(result.outputs['output'])
# 
#         res = classifier.reprocess_output(outputs)
#         res_list.append(res)
#     return json(res_list)
# 
# @app.listener("before_server_start")
# async def set_grpc_channel(app, loop):
#     for server_name, server_settings in app.config['servers'].items():
#         channel = grpc.insecure_channel(
#             '{0}:{1}'.format(server_settings['host'], server_settings['port']),
#             options=server_settings.get('options'))
#         stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
#         setattr(app, server_name, stub)

# 异步写法02
@app.post("/async_classification")
async def async_handler(request):
    image = request.files.get("image")
    img_array = np.frombuffer(image.body, np.uint8)
    image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    input_images = classifier.preprocess_input(image)

    # See prediction_service.proto for gRPC request/response details.
    request = predict_pb2.PredictRequest()
    request.model_spec.name = classifier.model_name
    request.model_spec.signature_name = classifier.signature_name
    stub = getattr(app, classifier.server_name)

    res_list = []
    for _ in range(5):
        request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images))
        result = await stub.Predict(request, timeout=100.0)  # 100 secs timeout
        outputs = tf.make_ndarray(result.outputs['output'])

        res = classifier.reprocess_output(outputs)
        res_list.append(res)
    return json(res_list)


@app.listener("before_server_start")
async def set_grpc_channel(app, loop):
    for server_name, server_settings in app.config['servers'].items():
        channel = grpc.aio.insecure_channel(
            '{0}:{1}'.format(server_settings['host'], server_settings['port']),
            options=server_settings.get('options'))
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        setattr(app, server_name, stub)


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=6699, workers=10)