triton_pipeline.py 2.53 KB
import base64
import json
from bank_ocr_inference import *


def enlarge_position(box):
    x1, y1, x2, y2 = box
    w, h = abs(x2 - x1), abs(y2 - y1)
    y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
    x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
    return [x1, y1, x2, y2]


def path_base64(file_path):
    f = open(file_path, 'rb')
    file64 = base64.b64encode(f.read())  # image 64 bytes 类型
    file64 = file64.decode('utf-8')
    return file64


def bgr_base64(image):
    _, img64 = cv2.imencode('.jpg', image)
    img64 = base64.b64encode(img64)
    return img64.decode('utf-8')


def base64_bgr(img64):
    str_img64 = base64.b64decode(img64)
    image = np.frombuffer(str_img64, np.uint8)
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    return image


def tamper_detect_(image):
    img64 = bgr_base64(image)
    resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64}))
    results = resp.json()
    return results


if __name__ == '__main__':
    image = cv2.imread(
        '/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg')
    st = time.time()
    ocr_results = bill_ocr(image)
    et1 = time.time()
    info_results = extract_bank_info(ocr_results)
    et2 = time.time()
    tamper_results = []
    if len(info_results) != 0:
        for info_result in info_results:
            box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
            x1, y1, x2, y2 = enlarge_position(box)
            # x1, y1, x2, y2 = box
            info_image = image[y1:y2, x1:x2, :]
            results = tamper_detect_(info_image)
            print(results)
            if len(results['results']) != 0:
                for res in results['results']:
                    cx = int(res[0])
                    cy = int(res[1])
                    width = int(res[2])
                    height = int(res[3])
                    left = cx - width // 2
                    top = cy - height // 2
                    absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
                    # absolute_position = [x1+left, y1+top, x2, y2]
                    tamper_results.append(absolute_position)
    et3 = time.time()
    print(tamper_results)

    print(f'all time:{et3 - st}  ocr time:{et1 - st}  extract info time:{et2 - et1}  yolo time:{et3 - et2}')
    for i in tamper_results:
        cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
    cv2.imshow('info', image)
    cv2.waitKey(0)