api_test.py 2.16 KB
# -*- coding: utf-8 -*-
# @Author        : Lyu Kui
# @Email         : 9428.al@gmail.com
# @Create Date   : 2022-05-06 22:02:01
# @Last Modified : 2022-08-03 14:59:51
# @Description   : 


import os
import time
import random
import requests
import numpy as np
from threading import Thread


class API_test:
    def __init__(self, file_dir, test_time, num_request):
        
        self.file_paths = []
        for fn in os.listdir(file_dir):
            file_path = os.path.join(file_dir, fn)
            self.file_paths.append(file_path)

        self.time_start = time.time()
        self.test_time = test_time * 60                  # 单位:秒
        threads = []
        for i in range(num_request):
            t = Thread(target=self.update, args=())
            threads.append(t)
        for t in threads:
            print(f'[INFO] {t} is running')
            t.start()
        self.results = list()
        self.index = 0

    def update(self):
        while True:
            file_path = random.choice(self.file_paths)

            # 二进制方式打开图片文件
            data = open(file_path, 'rb')

            t0 = time.time()
            response = requests.post(url=r'http://localhost:9001/gen_ocr_with_rotation', files={'file': data})

            # 失败请求统计
            if response.status_code != 200:
                print(response)

            t1 = time.time()
            self.results.append((t1-t0))

            time_cost = (time.time() - self.time_start)
            time_remaining = self.test_time - time_cost

            self.index += 1

            if time_remaining > 0:
                print(f'\r[INFO] 剩余时间 {time_remaining} 秒, 平均响应时间 {np.mean(self.results)} 秒, TPS {len(self.results)/time_cost}, 吞吐量 {self.index}', end='   ', flush=True)
            else:
                break


if __name__ == '__main__':

    imageDir = './demos/img_ocr'           # 测试数据路径
    testTime = 10                                                       # 加压时间, 单位:分钟
    numRequest = 10                                                     # 并发数,单位:个

    API_test(imageDir, testTime, numRequest)