server2.py 2.09 KB
import grpc
import cv2
import numpy as np
import tensorflow as tf
from tensorflow_serving.apis import prediction_service_pb2_grpc, predict_pb2

from sanic import Sanic
from sanic.response import json

from classification import classifier

app = Sanic("async_test")

# TODO 从配置文件读取
tf_serving_settings = {
    'servers': {
        'server_1': {
            'host': '192.168.10.191',
            'port': '8500',
            'options': [
                ('grpc.max_send_message_length', 1000 * 1024 * 1024),
                ('grpc.max_receive_message_length', 1000 * 1024 * 1024),
            ],
        }
    },
}
app.config.update(tf_serving_settings)


@app.post("/sync_classification")
async def sync_handler(request):
    image = request.files.get("image")
    img_array = np.frombuffer(image.body, np.uint8)
    image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    input_images = classifier.preprocess_input(image)

    # print(type(image))

    # See prediction_service.proto for gRPC request/response details.
    request = predict_pb2.PredictRequest()
    request.model_spec.name = classifier.model_name
    request.model_spec.signature_name = classifier.signature_name
    stub = getattr(app.ctx, classifier.server_name)

    request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images))
    result = stub.Predict(request, 100.0)  # 100 secs timeout
    outputs = tf.make_ndarray(result.outputs['output'])

    res = classifier.reprocess_output(outputs)
    return json(res)


# @app.get("/async")
# async def async_handler(request):
#     await asyncio.sleep(2)
#     return json({'code': 1})


@app.listener("before_server_start")
async def set_grpc_channel(app, loop):
    for server_name, server_settings in app.config['servers'].items():
        channel = grpc.insecure_channel(
            '{0}:{1}'.format(server_settings['host'], server_settings['port']),
            options=server_settings.get('options'))
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        setattr(app.ctx, server_name, stub)


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=6699, workers=5)